Neuro/gui/NeuroUI/netlearner.cpp

89 lines
2.8 KiB
C++
Raw Normal View History

#include "netlearner.h"
#include "../../Net.h"
#include "trainingdataloader.h"
2015-10-26 08:19:48 +00:00
#include <QElapsedTimer>
void NetLearner::run()
{
try
{
2015-10-26 08:19:48 +00:00
QElapsedTimer timer;
2015-10-27 14:33:54 +00:00
emit logMessage("Loading training data...");
emit progress(0.0);
TrainingDataLoader dataLoader;
dataLoader.addSamples("../NeuroUI/training data/mnist_train0.jpg", 0);
2015-10-27 14:33:54 +00:00
emit progress(0.1);
dataLoader.addSamples("../NeuroUI/training data/mnist_train1.jpg", 1);
emit progress(0.2);
dataLoader.addSamples("../NeuroUI/training data/mnist_train2.jpg", 2);
emit progress(0.3);
dataLoader.addSamples("../NeuroUI/training data/mnist_train3.jpg", 3);
emit progress(0.4);
dataLoader.addSamples("../NeuroUI/training data/mnist_train4.jpg", 4);
emit progress(0.5);
dataLoader.addSamples("../NeuroUI/training data/mnist_train5.jpg", 5);
emit progress(0.6);
dataLoader.addSamples("../NeuroUI/training data/mnist_train6.jpg", 6);
emit progress(0.7);
dataLoader.addSamples("../NeuroUI/training data/mnist_train7.jpg", 7);
emit progress(0.8);
dataLoader.addSamples("../NeuroUI/training data/mnist_train8.jpg", 8);
emit progress(0.9);
dataLoader.addSamples("../NeuroUI/training data/mnist_train9.jpg", 9);
emit progress(1.0);
emit logMessage("done");
emit progress(0.0);
Net digitClassifier({32*32, 16*16, 32, 1});
2015-10-26 08:19:48 +00:00
timer.start();
2015-10-27 14:33:54 +00:00
size_t numIterations = 10000;
for (size_t iteration = 0; iteration < numIterations; ++iteration)
{
2015-10-27 14:33:54 +00:00
const TrainingDataLoader::Sample &trainingSample = dataLoader.getRandomSample();
std::vector<double> targetValues =
{
2015-10-27 14:33:54 +00:00
trainingSample.first / 10.0
};
2015-10-27 14:33:54 +00:00
digitClassifier.feedForward(trainingSample.second);
2015-10-27 14:33:54 +00:00
std::vector<double> outputValues = digitClassifier.getOutput();
double error = outputValues[0] - targetValues[0];
2015-10-27 14:33:54 +00:00
QString logString;
2015-10-27 14:33:54 +00:00
logString.append("Error: ");
logString.append(QString::number(std::abs(error)));
2015-10-27 14:33:54 +00:00
emit logMessage(logString);
emit currentNetError(error);
emit progress((double)iteration / (double)numIterations);
2015-10-27 14:33:54 +00:00
digitClassifier.backProp(targetValues);
}
2015-10-26 08:19:48 +00:00
QString timerLogString;
timerLogString.append("Elapsed time: ");
timerLogString.append(QString::number(timer.elapsed() / 1000.0));
timerLogString.append(" seconds");
emit logMessage(timerLogString);
2015-10-27 14:33:54 +00:00
digitClassifier.save("DigitClassifier.nnet");
}
catch (std::exception &ex)
{
QString logString("Error: ");
logString.append(ex.what());
emit logMessage(logString);
}
}