Load and train handwritten digits
This commit is contained in:
parent
650b4be9fc
commit
f9be5ca717
3 changed files with 88 additions and 56 deletions
|
@ -10,70 +10,64 @@ void NetLearner::run()
|
|||
{
|
||||
QElapsedTimer timer;
|
||||
|
||||
emit logMessage("Loading training data...");
|
||||
emit progress(0.0);
|
||||
|
||||
TrainingDataLoader dataLoader;
|
||||
dataLoader.addSamples("../NeuroUI/training data/mnist_train0.jpg", 0);
|
||||
emit progress(0.1);
|
||||
dataLoader.addSamples("../NeuroUI/training data/mnist_train1.jpg", 1);
|
||||
emit progress(0.2);
|
||||
dataLoader.addSamples("../NeuroUI/training data/mnist_train2.jpg", 2);
|
||||
emit progress(0.3);
|
||||
dataLoader.addSamples("../NeuroUI/training data/mnist_train3.jpg", 3);
|
||||
emit progress(0.4);
|
||||
dataLoader.addSamples("../NeuroUI/training data/mnist_train4.jpg", 4);
|
||||
emit progress(0.5);
|
||||
dataLoader.addSamples("../NeuroUI/training data/mnist_train5.jpg", 5);
|
||||
emit progress(0.6);
|
||||
dataLoader.addSamples("../NeuroUI/training data/mnist_train6.jpg", 6);
|
||||
emit progress(0.7);
|
||||
dataLoader.addSamples("../NeuroUI/training data/mnist_train7.jpg", 7);
|
||||
emit progress(0.8);
|
||||
dataLoader.addSamples("../NeuroUI/training data/mnist_train8.jpg", 8);
|
||||
emit progress(0.9);
|
||||
dataLoader.addSamples("../NeuroUI/training data/mnist_train9.jpg", 9);
|
||||
emit progress(1.0);
|
||||
|
||||
Net myNet;
|
||||
try
|
||||
{
|
||||
myNet.load("mynet.nnet");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
myNet.initialize({2, 3, 1});
|
||||
}
|
||||
emit logMessage("done");
|
||||
emit progress(0.0);
|
||||
|
||||
size_t batchSize = 5000;
|
||||
size_t batchIndex = 0;
|
||||
double batchMaxError = 0.0;
|
||||
double batchMeanError = 0.0;
|
||||
Net digitClassifier({32*32, 16*16, 32, 1});
|
||||
|
||||
timer.start();
|
||||
|
||||
size_t numIterations = 2000000;
|
||||
size_t numIterations = 10000;
|
||||
for (size_t iteration = 0; iteration < numIterations; ++iteration)
|
||||
{
|
||||
std::vector<double> inputValues =
|
||||
{
|
||||
std::rand() / (double)RAND_MAX,
|
||||
std::rand() / (double)RAND_MAX
|
||||
};
|
||||
const TrainingDataLoader::Sample &trainingSample = dataLoader.getRandomSample();
|
||||
|
||||
std::vector<double> targetValues =
|
||||
{
|
||||
(inputValues[0] + inputValues[1]) / 2.0
|
||||
trainingSample.first / 10.0
|
||||
};
|
||||
|
||||
myNet.feedForward(inputValues);
|
||||
digitClassifier.feedForward(trainingSample.second);
|
||||
|
||||
std::vector<double> outputValues = myNet.getOutput();
|
||||
std::vector<double> outputValues = digitClassifier.getOutput();
|
||||
|
||||
double error = outputValues[0] - targetValues[0];
|
||||
|
||||
batchMeanError += error;
|
||||
batchMaxError = std::max<double>(batchMaxError, error);
|
||||
QString logString;
|
||||
|
||||
if (batchIndex++ == batchSize)
|
||||
{
|
||||
QString logString;
|
||||
logString.append("Error: ");
|
||||
logString.append(QString::number(std::abs(error)));
|
||||
|
||||
logString.append("Batch error (");
|
||||
logString.append(QString::number(batchSize));
|
||||
logString.append(" iterations, max/mean): ");
|
||||
logString.append(QString::number(std::abs(batchMaxError)));
|
||||
logString.append(" / ");
|
||||
logString.append(QString::number(std::abs(batchMeanError / batchSize)));
|
||||
emit logMessage(logString);
|
||||
emit currentNetError(error);
|
||||
emit progress((double)iteration / (double)numIterations);
|
||||
|
||||
emit logMessage(logString);
|
||||
emit currentNetError(batchMaxError);
|
||||
emit progress((double)iteration / (double)numIterations);
|
||||
|
||||
batchIndex = 0;
|
||||
batchMaxError = 0.0;
|
||||
batchMeanError = 0.0;
|
||||
}
|
||||
|
||||
myNet.backProp(targetValues);
|
||||
digitClassifier.backProp(targetValues);
|
||||
}
|
||||
|
||||
QString timerLogString;
|
||||
|
@ -83,7 +77,7 @@ void NetLearner::run()
|
|||
|
||||
emit logMessage(timerLogString);
|
||||
|
||||
myNet.save("mynet.nnet");
|
||||
digitClassifier.save("DigitClassifier.nnet");
|
||||
}
|
||||
catch (std::exception &ex)
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue