2015-10-24 16:03:07 +00:00
|
|
|
#include "netlearner.h"
|
|
|
|
#include "../../Net.h"
|
2015-10-29 12:06:30 +00:00
|
|
|
#include "mnistloader.h"
|
2015-10-24 16:03:07 +00:00
|
|
|
|
2015-10-26 08:19:48 +00:00
|
|
|
#include <QElapsedTimer>
|
2015-10-27 17:20:50 +00:00
|
|
|
#include <QImage>
|
2015-10-26 08:19:48 +00:00
|
|
|
|
2015-10-24 16:03:07 +00:00
|
|
|
void NetLearner::run()
|
|
|
|
{
|
|
|
|
try
|
|
|
|
{
|
2015-10-26 08:19:48 +00:00
|
|
|
QElapsedTimer timer;
|
|
|
|
|
2015-10-27 14:33:54 +00:00
|
|
|
emit logMessage("Loading training data...");
|
|
|
|
|
2015-10-29 12:06:30 +00:00
|
|
|
MnistLoader mnistLoader;
|
2015-10-29 15:00:58 +00:00
|
|
|
mnistLoader.load("../NeuroUI/MNIST Database/train-images.idx3-ubyte",
|
|
|
|
"../NeuroUI/MNIST Database/train-labels.idx1-ubyte");
|
2015-10-27 14:33:54 +00:00
|
|
|
|
|
|
|
emit logMessage("done");
|
|
|
|
|
2015-10-31 13:58:49 +00:00
|
|
|
Net digitClassifier({28*28, 256, 1});
|
2015-10-24 16:03:07 +00:00
|
|
|
|
2015-10-26 08:19:48 +00:00
|
|
|
timer.start();
|
|
|
|
|
2015-10-31 13:58:49 +00:00
|
|
|
size_t numIterations = 100000;
|
2015-11-01 11:36:23 +00:00
|
|
|
for (size_t iteration = 0; iteration < numIterations && cancel == false; ++iteration)
|
2015-10-24 16:03:07 +00:00
|
|
|
{
|
2015-10-31 13:58:49 +00:00
|
|
|
auto trainingSample = mnistLoader.getRandomSample();
|
|
|
|
|
|
|
|
QImage trainingImage(trainingSample.data, 28, 28, QImage::Format_Grayscale8);
|
|
|
|
emit sampleImageLoaded(trainingImage);
|
|
|
|
|
2015-10-24 16:03:07 +00:00
|
|
|
std::vector<double> targetValues =
|
|
|
|
{
|
2015-10-31 13:58:49 +00:00
|
|
|
trainingSample.label / 10.0
|
2015-10-24 16:03:07 +00:00
|
|
|
};
|
|
|
|
|
2015-10-31 13:58:49 +00:00
|
|
|
std::vector<double> trainingData;
|
|
|
|
trainingData.reserve(28*28);
|
|
|
|
for (const uint8_t &val : trainingSample.data)
|
|
|
|
{
|
|
|
|
trainingData.push_back(val / 255.0);
|
|
|
|
}
|
|
|
|
|
|
|
|
digitClassifier.feedForward(trainingData);
|
2015-10-24 16:03:07 +00:00
|
|
|
|
2015-10-27 14:33:54 +00:00
|
|
|
std::vector<double> outputValues = digitClassifier.getOutput();
|
2015-10-24 16:03:07 +00:00
|
|
|
|
|
|
|
double error = outputValues[0] - targetValues[0];
|
|
|
|
|
2015-10-27 14:33:54 +00:00
|
|
|
QString logString;
|
2015-10-24 16:03:07 +00:00
|
|
|
|
2015-10-27 14:33:54 +00:00
|
|
|
logString.append("Error: ");
|
|
|
|
logString.append(QString::number(std::abs(error)));
|
2015-10-24 16:03:07 +00:00
|
|
|
|
2015-10-27 14:33:54 +00:00
|
|
|
emit logMessage(logString);
|
|
|
|
emit currentNetError(error);
|
|
|
|
emit progress((double)iteration / (double)numIterations);
|
2015-10-24 16:03:07 +00:00
|
|
|
|
2015-10-27 14:33:54 +00:00
|
|
|
digitClassifier.backProp(targetValues);
|
2015-10-24 16:03:07 +00:00
|
|
|
}
|
2015-10-25 16:40:22 +00:00
|
|
|
|
2015-10-26 08:19:48 +00:00
|
|
|
QString timerLogString;
|
|
|
|
timerLogString.append("Elapsed time: ");
|
|
|
|
timerLogString.append(QString::number(timer.elapsed() / 1000.0));
|
|
|
|
timerLogString.append(" seconds");
|
|
|
|
|
|
|
|
emit logMessage(timerLogString);
|
|
|
|
|
2015-10-27 14:33:54 +00:00
|
|
|
digitClassifier.save("DigitClassifier.nnet");
|
2015-10-24 16:03:07 +00:00
|
|
|
}
|
|
|
|
catch (std::exception &ex)
|
|
|
|
{
|
|
|
|
QString logString("Error: ");
|
|
|
|
logString.append(ex.what());
|
|
|
|
emit logMessage(logString);
|
|
|
|
}
|
2015-11-01 11:36:23 +00:00
|
|
|
|
|
|
|
cancel = false;
|
|
|
|
}
|
|
|
|
|
|
|
|
void NetLearner::cancelLearning()
|
|
|
|
{
|
|
|
|
cancel = true;
|
2015-10-24 16:03:07 +00:00
|
|
|
}
|