From f9be5ca71791522619addd9b9e5040ecd2a219f2 Mon Sep 17 00:00:00 2001 From: Michael Mandl Date: Tue, 27 Oct 2015 15:33:54 +0100 Subject: [PATCH] Load and train handwritten digits --- gui/NeuroUI/netlearner.cpp | 82 ++++++++++++++---------------- gui/NeuroUI/trainingdataloader.cpp | 58 +++++++++++++++++---- gui/NeuroUI/trainingdataloader.h | 4 +- 3 files changed, 88 insertions(+), 56 deletions(-) diff --git a/gui/NeuroUI/netlearner.cpp b/gui/NeuroUI/netlearner.cpp index 8005ef3..8baffa5 100644 --- a/gui/NeuroUI/netlearner.cpp +++ b/gui/NeuroUI/netlearner.cpp @@ -10,70 +10,64 @@ void NetLearner::run() { QElapsedTimer timer; + emit logMessage("Loading training data..."); + emit progress(0.0); + TrainingDataLoader dataLoader; dataLoader.addSamples("../NeuroUI/training data/mnist_train0.jpg", 0); + emit progress(0.1); + dataLoader.addSamples("../NeuroUI/training data/mnist_train1.jpg", 1); + emit progress(0.2); + dataLoader.addSamples("../NeuroUI/training data/mnist_train2.jpg", 2); + emit progress(0.3); + dataLoader.addSamples("../NeuroUI/training data/mnist_train3.jpg", 3); + emit progress(0.4); + dataLoader.addSamples("../NeuroUI/training data/mnist_train4.jpg", 4); + emit progress(0.5); + dataLoader.addSamples("../NeuroUI/training data/mnist_train5.jpg", 5); + emit progress(0.6); + dataLoader.addSamples("../NeuroUI/training data/mnist_train6.jpg", 6); + emit progress(0.7); + dataLoader.addSamples("../NeuroUI/training data/mnist_train7.jpg", 7); + emit progress(0.8); + dataLoader.addSamples("../NeuroUI/training data/mnist_train8.jpg", 8); + emit progress(0.9); + dataLoader.addSamples("../NeuroUI/training data/mnist_train9.jpg", 9); + emit progress(1.0); - Net myNet; - try - { - myNet.load("mynet.nnet"); - } - catch (...) - { - myNet.initialize({2, 3, 1}); - } + emit logMessage("done"); + emit progress(0.0); - size_t batchSize = 5000; - size_t batchIndex = 0; - double batchMaxError = 0.0; - double batchMeanError = 0.0; + Net digitClassifier({32*32, 16*16, 32, 1}); timer.start(); - size_t numIterations = 2000000; + size_t numIterations = 10000; for (size_t iteration = 0; iteration < numIterations; ++iteration) { - std::vector inputValues = - { - std::rand() / (double)RAND_MAX, - std::rand() / (double)RAND_MAX - }; + const TrainingDataLoader::Sample &trainingSample = dataLoader.getRandomSample(); std::vector targetValues = { - (inputValues[0] + inputValues[1]) / 2.0 + trainingSample.first / 10.0 }; - myNet.feedForward(inputValues); + digitClassifier.feedForward(trainingSample.second); - std::vector outputValues = myNet.getOutput(); + std::vector outputValues = digitClassifier.getOutput(); double error = outputValues[0] - targetValues[0]; - batchMeanError += error; - batchMaxError = std::max(batchMaxError, error); + QString logString; - if (batchIndex++ == batchSize) - { - QString logString; + logString.append("Error: "); + logString.append(QString::number(std::abs(error))); - logString.append("Batch error ("); - logString.append(QString::number(batchSize)); - logString.append(" iterations, max/mean): "); - logString.append(QString::number(std::abs(batchMaxError))); - logString.append(" / "); - logString.append(QString::number(std::abs(batchMeanError / batchSize))); + emit logMessage(logString); + emit currentNetError(error); + emit progress((double)iteration / (double)numIterations); - emit logMessage(logString); - emit currentNetError(batchMaxError); - emit progress((double)iteration / (double)numIterations); - - batchIndex = 0; - batchMaxError = 0.0; - batchMeanError = 0.0; - } - - myNet.backProp(targetValues); + digitClassifier.backProp(targetValues); } QString timerLogString; @@ -83,7 +77,7 @@ void NetLearner::run() emit logMessage(timerLogString); - myNet.save("mynet.nnet"); + digitClassifier.save("DigitClassifier.nnet"); } catch (std::exception &ex) { diff --git a/gui/NeuroUI/trainingdataloader.cpp b/gui/NeuroUI/trainingdataloader.cpp index 8c99d41..cfde6d3 100644 --- a/gui/NeuroUI/trainingdataloader.cpp +++ b/gui/NeuroUI/trainingdataloader.cpp @@ -1,5 +1,7 @@ #include "trainingdataloader.h" +#include + #include #include @@ -11,19 +13,53 @@ TrainingDataLoader::TrainingDataLoader() void TrainingDataLoader::addSamples(const QString &sourceFile, TrainingDataLoader::SampleId sampleId) { QImage sourceImage; - sourceImage.load(sourceFile); - - Sample sample; - sample.first = sampleId; - - for (unsigned int y = 0; y < 8; ++y) + if (sourceImage.load(sourceFile) == false) { - for (unsigned int x = 0; x < 8; ++x) - { - sample.second[x + y * 8] = qGray(sourceImage.pixel(x, y)) / 255.0; - } + std::ostringstream errorString; + errorString << "error loading " << sourceFile.toStdString(); + + throw std::runtime_error(errorString.str()); } - m_samples.push_back(sample); + QSize scanWindow(32, 32); + QPoint scanPosition(0, 0); + + while (scanPosition.y() + scanWindow.height() < sourceImage.height()) + { + scanPosition.setX(0); + + while (scanPosition.x() + scanWindow.width() < sourceImage.width()) + { + Sample sample; + sample.first = sampleId; + + for (int y = 0; y < scanWindow.height(); ++y) + { + for (int x = 0; x < scanWindow.width(); ++x) + { + QRgb color = sourceImage.pixel(scanPosition.x() + x, scanPosition.y() + y); + sample.second[x + y * scanWindow.height()] = qGray(color) / 255.0; + } + } + + m_samples.push_back(sample); + + scanPosition.rx() += scanWindow.width(); + } + + scanPosition.ry() += scanWindow.height(); + } +} + +const TrainingDataLoader::Sample &TrainingDataLoader::getRandomSample() const +{ + size_t sampleIndex = (std::rand() * m_samples.size()) / RAND_MAX; + + auto it = m_samples.cbegin(); + for (size_t index = 0; index < sampleIndex; ++index) + { + it++; + } + return *it; } diff --git a/gui/NeuroUI/trainingdataloader.h b/gui/NeuroUI/trainingdataloader.h index 912b426..6472255 100644 --- a/gui/NeuroUI/trainingdataloader.h +++ b/gui/NeuroUI/trainingdataloader.h @@ -10,7 +10,7 @@ class TrainingDataLoader { public: - using SampleData = double[64]; + using SampleData = double[32*32]; using SampleId = unsigned int; using Sample = std::pair; @@ -21,6 +21,8 @@ public: TrainingDataLoader(); void addSamples(const QString &sourceFile, SampleId sampleId); + + const Sample &getRandomSample() const; }; #endif // TRAININGDATALOADER_H