Learning from and displaying of digit samples

This commit is contained in:
Michael Mandl 2015-10-31 14:58:49 +01:00
parent d98ec63fbd
commit cd1101dfe2
3 changed files with 29 additions and 22 deletions

View file

@ -1,16 +1,6 @@
#include "mnistloader.h" #include "mnistloader.h"
#include <fstream> #include <fstream>
#include <functional>
#include <memory>
#include <list>
#include <intrin.h>
MnistLoader::MnistLoader()
{
}
void MnistLoader::load(const std::string &databaseFileName, const std::string &labelsFileName) void MnistLoader::load(const std::string &databaseFileName, const std::string &labelsFileName)
{ {
@ -18,6 +8,13 @@ void MnistLoader::load(const std::string &databaseFileName, const std::string &l
loadLabels(labelsFileName); loadLabels(labelsFileName);
} }
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
{
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;
return *(samples[sampleIndex].get());
}
void MnistLoader::loadDatabase(const std::string &fileName) void MnistLoader::loadDatabase(const std::string &fileName)
{ {
std::ifstream databaseFile; std::ifstream databaseFile;
@ -43,6 +40,8 @@ void MnistLoader::loadDatabase(const std::string &fileName)
throw std::runtime_error("unexpected sample size loading MNIST database"); throw std::runtime_error("unexpected sample size loading MNIST database");
} }
samples.reserve(samples.size() + sampleCount);
for (int32_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex) for (int32_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex)
{ {
std::unique_ptr<MnistSample> sample = std::make_unique<MnistSample>(); std::unique_ptr<MnistSample> sample = std::make_unique<MnistSample>();

View file

@ -2,7 +2,7 @@
#define MNISTLOADER_H #define MNISTLOADER_H
#include <string> #include <string>
#include <list> #include <vector>
#include <memory> #include <memory>
#include <inttypes.h> #include <inttypes.h>
@ -26,13 +26,13 @@ public:
using MnistSample = Sample<SampleWidth, SampleHeight>; using MnistSample = Sample<SampleWidth, SampleHeight>;
private: private:
std::list<std::unique_ptr<MnistSample>> samples; std::vector<std::unique_ptr<MnistSample>> samples;
public: public:
MnistLoader();
void load(const std::string &databaseFileName, const std::string &labelsFileName); void load(const std::string &databaseFileName, const std::string &labelsFileName);
const MnistSample &getRandomSample() const;
private: private:
void loadDatabase(const std::string &fileName); void loadDatabase(const std::string &fileName);
void loadLabels(const std::string &fileName); void loadLabels(const std::string &fileName);

View file

@ -12,30 +12,38 @@ void NetLearner::run()
QElapsedTimer timer; QElapsedTimer timer;
emit logMessage("Loading training data..."); emit logMessage("Loading training data...");
emit progress(0.0);
MnistLoader mnistLoader; MnistLoader mnistLoader;
mnistLoader.load("../NeuroUI/MNIST Database/train-images.idx3-ubyte", mnistLoader.load("../NeuroUI/MNIST Database/train-images.idx3-ubyte",
"../NeuroUI/MNIST Database/train-labels.idx1-ubyte"); "../NeuroUI/MNIST Database/train-labels.idx1-ubyte");
emit logMessage("done"); emit logMessage("done");
emit progress(0.0);
return; Net digitClassifier({28*28, 256, 1});
Net digitClassifier({32*32, 16*16, 32, 1});
timer.start(); timer.start();
size_t numIterations = 10000; size_t numIterations = 100000;
for (size_t iteration = 0; iteration < numIterations; ++iteration) for (size_t iteration = 0; iteration < numIterations; ++iteration)
{ {
auto trainingSample = mnistLoader.getRandomSample();
QImage trainingImage(trainingSample.data, 28, 28, QImage::Format_Grayscale8);
emit sampleImageLoaded(trainingImage);
std::vector<double> targetValues = std::vector<double> targetValues =
{ {
//trainingSample.first / 10.0 trainingSample.label / 10.0
}; };
//digitClassifier.feedForward(trainingSample.second); std::vector<double> trainingData;
trainingData.reserve(28*28);
for (const uint8_t &val : trainingSample.data)
{
trainingData.push_back(val / 255.0);
}
digitClassifier.feedForward(trainingData);
std::vector<double> outputValues = digitClassifier.getOutput(); std::vector<double> outputValues = digitClassifier.getOutput();