Load and train handwritten digits

This commit is contained in:
Michael Mandl 2015-10-27 15:33:54 +01:00
parent 650b4be9fc
commit f9be5ca717
3 changed files with 88 additions and 56 deletions

View file

@ -10,70 +10,64 @@ void NetLearner::run()
{ {
QElapsedTimer timer; QElapsedTimer timer;
emit logMessage("Loading training data...");
emit progress(0.0);
TrainingDataLoader dataLoader; TrainingDataLoader dataLoader;
dataLoader.addSamples("../NeuroUI/training data/mnist_train0.jpg", 0); dataLoader.addSamples("../NeuroUI/training data/mnist_train0.jpg", 0);
emit progress(0.1);
dataLoader.addSamples("../NeuroUI/training data/mnist_train1.jpg", 1);
emit progress(0.2);
dataLoader.addSamples("../NeuroUI/training data/mnist_train2.jpg", 2);
emit progress(0.3);
dataLoader.addSamples("../NeuroUI/training data/mnist_train3.jpg", 3);
emit progress(0.4);
dataLoader.addSamples("../NeuroUI/training data/mnist_train4.jpg", 4);
emit progress(0.5);
dataLoader.addSamples("../NeuroUI/training data/mnist_train5.jpg", 5);
emit progress(0.6);
dataLoader.addSamples("../NeuroUI/training data/mnist_train6.jpg", 6);
emit progress(0.7);
dataLoader.addSamples("../NeuroUI/training data/mnist_train7.jpg", 7);
emit progress(0.8);
dataLoader.addSamples("../NeuroUI/training data/mnist_train8.jpg", 8);
emit progress(0.9);
dataLoader.addSamples("../NeuroUI/training data/mnist_train9.jpg", 9);
emit progress(1.0);
Net myNet; emit logMessage("done");
try emit progress(0.0);
{
myNet.load("mynet.nnet");
}
catch (...)
{
myNet.initialize({2, 3, 1});
}
size_t batchSize = 5000; Net digitClassifier({32*32, 16*16, 32, 1});
size_t batchIndex = 0;
double batchMaxError = 0.0;
double batchMeanError = 0.0;
timer.start(); timer.start();
size_t numIterations = 2000000; size_t numIterations = 10000;
for (size_t iteration = 0; iteration < numIterations; ++iteration) for (size_t iteration = 0; iteration < numIterations; ++iteration)
{ {
std::vector<double> inputValues = const TrainingDataLoader::Sample &trainingSample = dataLoader.getRandomSample();
{
std::rand() / (double)RAND_MAX,
std::rand() / (double)RAND_MAX
};
std::vector<double> targetValues = std::vector<double> targetValues =
{ {
(inputValues[0] + inputValues[1]) / 2.0 trainingSample.first / 10.0
}; };
myNet.feedForward(inputValues); digitClassifier.feedForward(trainingSample.second);
std::vector<double> outputValues = myNet.getOutput(); std::vector<double> outputValues = digitClassifier.getOutput();
double error = outputValues[0] - targetValues[0]; double error = outputValues[0] - targetValues[0];
batchMeanError += error;
batchMaxError = std::max<double>(batchMaxError, error);
if (batchIndex++ == batchSize)
{
QString logString; QString logString;
logString.append("Batch error ("); logString.append("Error: ");
logString.append(QString::number(batchSize)); logString.append(QString::number(std::abs(error)));
logString.append(" iterations, max/mean): ");
logString.append(QString::number(std::abs(batchMaxError)));
logString.append(" / ");
logString.append(QString::number(std::abs(batchMeanError / batchSize)));
emit logMessage(logString); emit logMessage(logString);
emit currentNetError(batchMaxError); emit currentNetError(error);
emit progress((double)iteration / (double)numIterations); emit progress((double)iteration / (double)numIterations);
batchIndex = 0; digitClassifier.backProp(targetValues);
batchMaxError = 0.0;
batchMeanError = 0.0;
}
myNet.backProp(targetValues);
} }
QString timerLogString; QString timerLogString;
@ -83,7 +77,7 @@ void NetLearner::run()
emit logMessage(timerLogString); emit logMessage(timerLogString);
myNet.save("mynet.nnet"); digitClassifier.save("DigitClassifier.nnet");
} }
catch (std::exception &ex) catch (std::exception &ex)
{ {

View file

@ -1,5 +1,7 @@
#include "trainingdataloader.h" #include "trainingdataloader.h"
#include <sstream>
#include <QImage> #include <QImage>
#include <QColor> #include <QColor>
@ -11,19 +13,53 @@ TrainingDataLoader::TrainingDataLoader()
void TrainingDataLoader::addSamples(const QString &sourceFile, TrainingDataLoader::SampleId sampleId) void TrainingDataLoader::addSamples(const QString &sourceFile, TrainingDataLoader::SampleId sampleId)
{ {
QImage sourceImage; QImage sourceImage;
sourceImage.load(sourceFile); if (sourceImage.load(sourceFile) == false)
{
std::ostringstream errorString;
errorString << "error loading " << sourceFile.toStdString();
throw std::runtime_error(errorString.str());
}
QSize scanWindow(32, 32);
QPoint scanPosition(0, 0);
while (scanPosition.y() + scanWindow.height() < sourceImage.height())
{
scanPosition.setX(0);
while (scanPosition.x() + scanWindow.width() < sourceImage.width())
{
Sample sample; Sample sample;
sample.first = sampleId; sample.first = sampleId;
for (unsigned int y = 0; y < 8; ++y) for (int y = 0; y < scanWindow.height(); ++y)
{ {
for (unsigned int x = 0; x < 8; ++x) for (int x = 0; x < scanWindow.width(); ++x)
{ {
sample.second[x + y * 8] = qGray(sourceImage.pixel(x, y)) / 255.0; QRgb color = sourceImage.pixel(scanPosition.x() + x, scanPosition.y() + y);
sample.second[x + y * scanWindow.height()] = qGray(color) / 255.0;
} }
} }
m_samples.push_back(sample); m_samples.push_back(sample);
scanPosition.rx() += scanWindow.width();
}
scanPosition.ry() += scanWindow.height();
}
}
const TrainingDataLoader::Sample &TrainingDataLoader::getRandomSample() const
{
size_t sampleIndex = (std::rand() * m_samples.size()) / RAND_MAX;
auto it = m_samples.cbegin();
for (size_t index = 0; index < sampleIndex; ++index)
{
it++;
}
return *it;
} }

View file

@ -10,7 +10,7 @@
class TrainingDataLoader class TrainingDataLoader
{ {
public: public:
using SampleData = double[64]; using SampleData = double[32*32];
using SampleId = unsigned int; using SampleId = unsigned int;
using Sample = std::pair<SampleId, SampleData>; using Sample = std::pair<SampleId, SampleData>;
@ -21,6 +21,8 @@ public:
TrainingDataLoader(); TrainingDataLoader();
void addSamples(const QString &sourceFile, SampleId sampleId); void addSamples(const QString &sourceFile, SampleId sampleId);
const Sample &getRandomSample() const;
}; };
#endif // TRAININGDATALOADER_H #endif // TRAININGDATALOADER_H