Neuro/gui/NeuroUI/netlearner.cpp

82 lines
2.2 KiB
C++
Raw Normal View History

#include "netlearner.h"
#include "../../Net.h"
#include "mnistloader.h"
2015-10-26 08:19:48 +00:00
#include <QElapsedTimer>
#include <QImage>
2015-10-26 08:19:48 +00:00
void NetLearner::run()
{
try
{
2015-10-26 08:19:48 +00:00
QElapsedTimer timer;
2015-10-27 14:33:54 +00:00
emit logMessage("Loading training data...");
MnistLoader mnistLoader;
2015-10-29 15:00:58 +00:00
mnistLoader.load("../NeuroUI/MNIST Database/train-images.idx3-ubyte",
"../NeuroUI/MNIST Database/train-labels.idx1-ubyte");
2015-10-27 14:33:54 +00:00
emit logMessage("done");
Net digitClassifier({28*28, 256, 1});
2015-10-26 08:19:48 +00:00
timer.start();
size_t numIterations = 100000;
for (size_t iteration = 0; iteration < numIterations && cancel == false; ++iteration)
{
2015-11-15 15:09:25 +00:00
auto trainingSample = mnistLoader.getRandomSample();
2015-11-15 15:09:25 +00:00
// emit logMessage(QString("training sample ") + QString::number(trainingSample.label));
emit sampleImageLoaded(trainingSample.toQImage());
std::vector<double> targetValues =
{
trainingSample.label / 10.0
};
std::vector<double> trainingData;
trainingData.reserve(28*28);
for (const uint8_t &val : trainingSample.data)
{
trainingData.push_back(val / 255.0);
}
digitClassifier.feedForward(trainingData);
2015-10-27 14:33:54 +00:00
std::vector<double> outputValues = digitClassifier.getOutput();
double error = outputValues[0] - targetValues[0];
2015-11-15 15:09:25 +00:00
emit logMessage(QString("Error: ") + QString::number(std::abs(error)));
2015-10-27 14:33:54 +00:00
emit currentNetError(error);
emit progress((double)iteration / (double)numIterations);
2015-10-27 14:33:54 +00:00
digitClassifier.backProp(targetValues);
}
2015-10-26 08:19:48 +00:00
QString timerLogString;
timerLogString.append("Elapsed time: ");
timerLogString.append(QString::number(timer.elapsed() / 1000.0));
timerLogString.append(" seconds");
emit logMessage(timerLogString);
2015-10-27 14:33:54 +00:00
digitClassifier.save("DigitClassifier.nnet");
}
catch (std::exception &ex)
{
QString logString("Error: ");
logString.append(ex.what());
emit logMessage(logString);
}
cancel = false;
}
void NetLearner::cancelLearning()
{
cancel = true;
}