From cd1101dfe252e1c4701c0804f4c00c4f1ed6be6c Mon Sep 17 00:00:00 2001 From: Michael Mandl Date: Sat, 31 Oct 2015 14:58:49 +0100 Subject: [PATCH] Learning from and displaying of digit samples --- gui/NeuroUI/mnistloader.cpp | 19 +++++++++---------- gui/NeuroUI/mnistloader.h | 8 ++++---- gui/NeuroUI/netlearner.cpp | 24 ++++++++++++++++-------- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/gui/NeuroUI/mnistloader.cpp b/gui/NeuroUI/mnistloader.cpp index 4eac8bb..67c4c1f 100644 --- a/gui/NeuroUI/mnistloader.cpp +++ b/gui/NeuroUI/mnistloader.cpp @@ -1,16 +1,6 @@ #include "mnistloader.h" #include -#include -#include -#include - -#include - -MnistLoader::MnistLoader() -{ - -} void MnistLoader::load(const std::string &databaseFileName, const std::string &labelsFileName) { @@ -18,6 +8,13 @@ void MnistLoader::load(const std::string &databaseFileName, const std::string &l loadLabels(labelsFileName); } +const MnistLoader::MnistSample &MnistLoader::getRandomSample() const +{ + size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX; + + return *(samples[sampleIndex].get()); +} + void MnistLoader::loadDatabase(const std::string &fileName) { std::ifstream databaseFile; @@ -43,6 +40,8 @@ void MnistLoader::loadDatabase(const std::string &fileName) throw std::runtime_error("unexpected sample size loading MNIST database"); } + samples.reserve(samples.size() + sampleCount); + for (int32_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex) { std::unique_ptr sample = std::make_unique(); diff --git a/gui/NeuroUI/mnistloader.h b/gui/NeuroUI/mnistloader.h index 482931c..d2a1e2b 100644 --- a/gui/NeuroUI/mnistloader.h +++ b/gui/NeuroUI/mnistloader.h @@ -2,7 +2,7 @@ #define MNISTLOADER_H #include -#include +#include #include #include @@ -26,13 +26,13 @@ public: using MnistSample = Sample; private: - std::list> samples; + std::vector> samples; public: - MnistLoader(); - void load(const std::string &databaseFileName, const std::string &labelsFileName); + const MnistSample &getRandomSample() const; + private: void loadDatabase(const std::string &fileName); void loadLabels(const std::string &fileName); diff --git a/gui/NeuroUI/netlearner.cpp b/gui/NeuroUI/netlearner.cpp index 0b96baf..b0956e1 100644 --- a/gui/NeuroUI/netlearner.cpp +++ b/gui/NeuroUI/netlearner.cpp @@ -12,30 +12,38 @@ void NetLearner::run() QElapsedTimer timer; emit logMessage("Loading training data..."); - emit progress(0.0); MnistLoader mnistLoader; mnistLoader.load("../NeuroUI/MNIST Database/train-images.idx3-ubyte", "../NeuroUI/MNIST Database/train-labels.idx1-ubyte"); emit logMessage("done"); - emit progress(0.0); - return; - - Net digitClassifier({32*32, 16*16, 32, 1}); + Net digitClassifier({28*28, 256, 1}); timer.start(); - size_t numIterations = 10000; + size_t numIterations = 100000; for (size_t iteration = 0; iteration < numIterations; ++iteration) { + auto trainingSample = mnistLoader.getRandomSample(); + + QImage trainingImage(trainingSample.data, 28, 28, QImage::Format_Grayscale8); + emit sampleImageLoaded(trainingImage); + std::vector targetValues = { - //trainingSample.first / 10.0 + trainingSample.label / 10.0 }; - //digitClassifier.feedForward(trainingSample.second); + std::vector trainingData; + trainingData.reserve(28*28); + for (const uint8_t &val : trainingSample.data) + { + trainingData.push_back(val / 255.0); + } + + digitClassifier.feedForward(trainingData); std::vector outputValues = digitClassifier.getOutput();