Load and train handwritten digits
This commit is contained in:
parent
650b4be9fc
commit
f9be5ca717
3 changed files with 88 additions and 56 deletions
|
@ -10,70 +10,64 @@ void NetLearner::run()
|
||||||
{
|
{
|
||||||
QElapsedTimer timer;
|
QElapsedTimer timer;
|
||||||
|
|
||||||
|
emit logMessage("Loading training data...");
|
||||||
|
emit progress(0.0);
|
||||||
|
|
||||||
TrainingDataLoader dataLoader;
|
TrainingDataLoader dataLoader;
|
||||||
dataLoader.addSamples("../NeuroUI/training data/mnist_train0.jpg", 0);
|
dataLoader.addSamples("../NeuroUI/training data/mnist_train0.jpg", 0);
|
||||||
|
emit progress(0.1);
|
||||||
|
dataLoader.addSamples("../NeuroUI/training data/mnist_train1.jpg", 1);
|
||||||
|
emit progress(0.2);
|
||||||
|
dataLoader.addSamples("../NeuroUI/training data/mnist_train2.jpg", 2);
|
||||||
|
emit progress(0.3);
|
||||||
|
dataLoader.addSamples("../NeuroUI/training data/mnist_train3.jpg", 3);
|
||||||
|
emit progress(0.4);
|
||||||
|
dataLoader.addSamples("../NeuroUI/training data/mnist_train4.jpg", 4);
|
||||||
|
emit progress(0.5);
|
||||||
|
dataLoader.addSamples("../NeuroUI/training data/mnist_train5.jpg", 5);
|
||||||
|
emit progress(0.6);
|
||||||
|
dataLoader.addSamples("../NeuroUI/training data/mnist_train6.jpg", 6);
|
||||||
|
emit progress(0.7);
|
||||||
|
dataLoader.addSamples("../NeuroUI/training data/mnist_train7.jpg", 7);
|
||||||
|
emit progress(0.8);
|
||||||
|
dataLoader.addSamples("../NeuroUI/training data/mnist_train8.jpg", 8);
|
||||||
|
emit progress(0.9);
|
||||||
|
dataLoader.addSamples("../NeuroUI/training data/mnist_train9.jpg", 9);
|
||||||
|
emit progress(1.0);
|
||||||
|
|
||||||
Net myNet;
|
emit logMessage("done");
|
||||||
try
|
emit progress(0.0);
|
||||||
{
|
|
||||||
myNet.load("mynet.nnet");
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
myNet.initialize({2, 3, 1});
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t batchSize = 5000;
|
Net digitClassifier({32*32, 16*16, 32, 1});
|
||||||
size_t batchIndex = 0;
|
|
||||||
double batchMaxError = 0.0;
|
|
||||||
double batchMeanError = 0.0;
|
|
||||||
|
|
||||||
timer.start();
|
timer.start();
|
||||||
|
|
||||||
size_t numIterations = 2000000;
|
size_t numIterations = 10000;
|
||||||
for (size_t iteration = 0; iteration < numIterations; ++iteration)
|
for (size_t iteration = 0; iteration < numIterations; ++iteration)
|
||||||
{
|
{
|
||||||
std::vector<double> inputValues =
|
const TrainingDataLoader::Sample &trainingSample = dataLoader.getRandomSample();
|
||||||
{
|
|
||||||
std::rand() / (double)RAND_MAX,
|
|
||||||
std::rand() / (double)RAND_MAX
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<double> targetValues =
|
std::vector<double> targetValues =
|
||||||
{
|
{
|
||||||
(inputValues[0] + inputValues[1]) / 2.0
|
trainingSample.first / 10.0
|
||||||
};
|
};
|
||||||
|
|
||||||
myNet.feedForward(inputValues);
|
digitClassifier.feedForward(trainingSample.second);
|
||||||
|
|
||||||
std::vector<double> outputValues = myNet.getOutput();
|
std::vector<double> outputValues = digitClassifier.getOutput();
|
||||||
|
|
||||||
double error = outputValues[0] - targetValues[0];
|
double error = outputValues[0] - targetValues[0];
|
||||||
|
|
||||||
batchMeanError += error;
|
|
||||||
batchMaxError = std::max<double>(batchMaxError, error);
|
|
||||||
|
|
||||||
if (batchIndex++ == batchSize)
|
|
||||||
{
|
|
||||||
QString logString;
|
QString logString;
|
||||||
|
|
||||||
logString.append("Batch error (");
|
logString.append("Error: ");
|
||||||
logString.append(QString::number(batchSize));
|
logString.append(QString::number(std::abs(error)));
|
||||||
logString.append(" iterations, max/mean): ");
|
|
||||||
logString.append(QString::number(std::abs(batchMaxError)));
|
|
||||||
logString.append(" / ");
|
|
||||||
logString.append(QString::number(std::abs(batchMeanError / batchSize)));
|
|
||||||
|
|
||||||
emit logMessage(logString);
|
emit logMessage(logString);
|
||||||
emit currentNetError(batchMaxError);
|
emit currentNetError(error);
|
||||||
emit progress((double)iteration / (double)numIterations);
|
emit progress((double)iteration / (double)numIterations);
|
||||||
|
|
||||||
batchIndex = 0;
|
digitClassifier.backProp(targetValues);
|
||||||
batchMaxError = 0.0;
|
|
||||||
batchMeanError = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
myNet.backProp(targetValues);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
QString timerLogString;
|
QString timerLogString;
|
||||||
|
@ -83,7 +77,7 @@ void NetLearner::run()
|
||||||
|
|
||||||
emit logMessage(timerLogString);
|
emit logMessage(timerLogString);
|
||||||
|
|
||||||
myNet.save("mynet.nnet");
|
digitClassifier.save("DigitClassifier.nnet");
|
||||||
}
|
}
|
||||||
catch (std::exception &ex)
|
catch (std::exception &ex)
|
||||||
{
|
{
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
#include "trainingdataloader.h"
|
#include "trainingdataloader.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
#include <QImage>
|
#include <QImage>
|
||||||
#include <QColor>
|
#include <QColor>
|
||||||
|
|
||||||
|
@ -11,19 +13,53 @@ TrainingDataLoader::TrainingDataLoader()
|
||||||
void TrainingDataLoader::addSamples(const QString &sourceFile, TrainingDataLoader::SampleId sampleId)
|
void TrainingDataLoader::addSamples(const QString &sourceFile, TrainingDataLoader::SampleId sampleId)
|
||||||
{
|
{
|
||||||
QImage sourceImage;
|
QImage sourceImage;
|
||||||
sourceImage.load(sourceFile);
|
if (sourceImage.load(sourceFile) == false)
|
||||||
|
{
|
||||||
|
std::ostringstream errorString;
|
||||||
|
errorString << "error loading " << sourceFile.toStdString();
|
||||||
|
|
||||||
|
throw std::runtime_error(errorString.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
QSize scanWindow(32, 32);
|
||||||
|
QPoint scanPosition(0, 0);
|
||||||
|
|
||||||
|
while (scanPosition.y() + scanWindow.height() < sourceImage.height())
|
||||||
|
{
|
||||||
|
scanPosition.setX(0);
|
||||||
|
|
||||||
|
while (scanPosition.x() + scanWindow.width() < sourceImage.width())
|
||||||
|
{
|
||||||
Sample sample;
|
Sample sample;
|
||||||
sample.first = sampleId;
|
sample.first = sampleId;
|
||||||
|
|
||||||
for (unsigned int y = 0; y < 8; ++y)
|
for (int y = 0; y < scanWindow.height(); ++y)
|
||||||
{
|
{
|
||||||
for (unsigned int x = 0; x < 8; ++x)
|
for (int x = 0; x < scanWindow.width(); ++x)
|
||||||
{
|
{
|
||||||
sample.second[x + y * 8] = qGray(sourceImage.pixel(x, y)) / 255.0;
|
QRgb color = sourceImage.pixel(scanPosition.x() + x, scanPosition.y() + y);
|
||||||
|
sample.second[x + y * scanWindow.height()] = qGray(color) / 255.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m_samples.push_back(sample);
|
m_samples.push_back(sample);
|
||||||
|
|
||||||
|
scanPosition.rx() += scanWindow.width();
|
||||||
|
}
|
||||||
|
|
||||||
|
scanPosition.ry() += scanWindow.height();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const TrainingDataLoader::Sample &TrainingDataLoader::getRandomSample() const
|
||||||
|
{
|
||||||
|
size_t sampleIndex = (std::rand() * m_samples.size()) / RAND_MAX;
|
||||||
|
|
||||||
|
auto it = m_samples.cbegin();
|
||||||
|
for (size_t index = 0; index < sampleIndex; ++index)
|
||||||
|
{
|
||||||
|
it++;
|
||||||
|
}
|
||||||
|
return *it;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
class TrainingDataLoader
|
class TrainingDataLoader
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
using SampleData = double[64];
|
using SampleData = double[32*32];
|
||||||
using SampleId = unsigned int;
|
using SampleId = unsigned int;
|
||||||
using Sample = std::pair<SampleId, SampleData>;
|
using Sample = std::pair<SampleId, SampleData>;
|
||||||
|
|
||||||
|
@ -21,6 +21,8 @@ public:
|
||||||
TrainingDataLoader();
|
TrainingDataLoader();
|
||||||
|
|
||||||
void addSamples(const QString &sourceFile, SampleId sampleId);
|
void addSamples(const QString &sourceFile, SampleId sampleId);
|
||||||
|
|
||||||
|
const Sample &getRandomSample() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // TRAININGDATALOADER_H
|
#endif // TRAININGDATALOADER_H
|
||||||
|
|
Loading…
Reference in a new issue