Compare commits

...

9 Commits
digits ... main

9 changed files with 61 additions and 10 deletions

View File

@ -41,7 +41,7 @@ double Layer::getWeightedSum(size_t outputNeuron) const
sum += neuron.getWeightedOutputValue(outputNeuron);
}
return sum;
return sum / size();
}
void Layer::connectTo(const Layer & nextLayer)

View File

@ -28,6 +28,11 @@ void ErrorPlotter::clear()
void ErrorPlotter::addErrorValue(double errorValue)
{
if (m_errorValues.size() == m_bufferSize)
{
m_errorValues.pop_front();
}
m_errorValues.push_back(errorValue);
m_maxErrorValue = std::max<double>(m_maxErrorValue, errorValue);

View File

@ -11,6 +11,8 @@ private:
std::list<double> m_errorValues;
double m_maxErrorValue;
size_t m_bufferSize = 10000;
public:
explicit ErrorPlotter(QWidget *parent = 0);

View File

@ -8,6 +8,21 @@ void MnistLoader::load(const std::string &databaseFileName, const std::string &l
loadLabels(labelsFileName);
}
size_t MnistLoader::getSamleCount() const
{
return samples.size();
}
const MnistLoader::MnistSample &MnistLoader::getSample(size_t index) const
{
if (index >= samples.size())
{
throw std::runtime_error("MNIST sample index out of range");
}
return *(samples[index].get());
}
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
{
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;

View File

@ -6,6 +6,8 @@
#include <memory>
#include <inttypes.h>
#include <QImage>
class MnistLoader
{
private:
@ -21,6 +23,11 @@ public:
public:
uint8_t label;
uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT];
QImage toQImage() const
{
return QImage(data, SAMPLE_WIDTH, SAMPLE_HEIGHT, QImage::Format_Grayscale8);
}
};
using MnistSample = Sample<SampleWidth, SampleHeight>;
@ -31,6 +38,8 @@ private:
public:
void load(const std::string &databaseFileName, const std::string &labelsFileName);
size_t getSamleCount() const;
const MnistSample &getSample(size_t index) const;
const MnistSample &getRandomSample() const;
private:

View File

@ -24,12 +24,12 @@ void NetLearner::run()
timer.start();
size_t numIterations = 100000;
for (size_t iteration = 0; iteration < numIterations; ++iteration)
for (size_t iteration = 0; iteration < numIterations && cancel == false; ++iteration)
{
auto trainingSample = mnistLoader.getRandomSample();
QImage trainingImage(trainingSample.data, 28, 28, QImage::Format_Grayscale8);
emit sampleImageLoaded(trainingImage);
// emit logMessage(QString("training sample ") + QString::number(trainingSample.label));
emit sampleImageLoaded(trainingSample.toQImage());
std::vector<double> targetValues =
{
@ -49,12 +49,7 @@ void NetLearner::run()
double error = outputValues[0] - targetValues[0];
QString logString;
logString.append("Error: ");
logString.append(QString::number(std::abs(error)));
emit logMessage(logString);
emit logMessage(QString("Error: ") + QString::number(std::abs(error)));
emit currentNetError(error);
emit progress((double)iteration / (double)numIterations);
@ -76,4 +71,11 @@ void NetLearner::run()
logString.append(ex.what());
emit logMessage(logString);
}
cancel = false;
}
void NetLearner::cancelLearning()
{
cancel = true;
}

View File

@ -7,6 +7,9 @@ class NetLearner : public QThread
{
Q_OBJECT
private:
bool cancel = false;
private:
void run() Q_DECL_OVERRIDE;
@ -15,6 +18,9 @@ signals:
void progress(double progress);
void currentNetError(double error);
void sampleImageLoaded(const QImage &image);
public slots:
void cancelLearning();
};
#endif // NETLEARNER_H

View File

@ -12,6 +12,12 @@ NeuroUI::NeuroUI(QWidget *parent) :
NeuroUI::~NeuroUI()
{
if (m_netLearner != nullptr)
{
m_netLearner->cancelLearning();
m_netLearner->wait();
}
delete ui;
}
@ -40,6 +46,11 @@ void NeuroUI::on_runButton_clicked()
void NeuroUI::logMessage(const QString &logMessage)
{
if (ui->logView->count() == static_cast<int>(m_logSize))
{
delete ui->logView->item(0);
}
ui->logView->addItem(logMessage);
ui->logView->scrollToBottom();
}

View File

@ -17,6 +17,7 @@ class NeuroUI : public QMainWindow
private:
std::unique_ptr<NetLearner> m_netLearner;
size_t m_logSize = 128;
public:
explicit NeuroUI(QWidget *parent = 0);