diff --git a/gui/NeuroUI/mnistloader.cpp b/gui/NeuroUI/mnistloader.cpp index 6c2c7dd..4eac8bb 100644 --- a/gui/NeuroUI/mnistloader.cpp +++ b/gui/NeuroUI/mnistloader.cpp @@ -29,7 +29,7 @@ void MnistLoader::loadDatabase(const std::string &fileName) } int32_t magicNumber = readInt32(databaseFile); - if (magicNumber != 2051) + if (magicNumber != DatabaseFileMagicNumber) { throw std::runtime_error("unexpected data reading MNIST database file"); } @@ -38,18 +38,19 @@ void MnistLoader::loadDatabase(const std::string &fileName) int32_t sampleWidth = readInt32(databaseFile); int32_t sampleHeight = readInt32(databaseFile); - std::list> samples; - - for (size_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex) + if (sampleWidth != SampleWidth || sampleHeight != SampleHeight) { - std::unique_ptr sampleData = std::make_unique(sampleWidth * sampleHeight); - - databaseFile.read(reinterpret_cast(sampleData.get()), sampleWidth * sampleHeight); - - samples.push_back(std::move(sampleData)); + throw std::runtime_error("unexpected sample size loading MNIST database"); } - return; + for (int32_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex) + { + std::unique_ptr sample = std::make_unique(); + + databaseFile.read(reinterpret_cast(sample->data), sampleWidth * sampleHeight); + + samples.push_back(std::move(sample)); + } } void MnistLoader::loadLabels(const std::string &fileName) @@ -63,21 +64,22 @@ void MnistLoader::loadLabels(const std::string &fileName) } int32_t magicNumber = readInt32(labelFile); - if (magicNumber != 2049) + if (magicNumber != LabelFileMagicNumber) { throw std::runtime_error("unexpected data reading MNIST label file"); } int32_t labelCount = readInt32(labelFile); - - std::list labels; - - for (size_t labelIndex = 0; labelIndex < labelCount; ++labelIndex) + if (labelCount != static_cast(samples.size())) { - labels.push_back(readInt8(labelFile)); + throw std::runtime_error("MNIST database and label files don't match in size"); } - return; + auto sampleIt = samples.begin(); + for (int32_t labelIndex = 0; labelIndex < labelCount; ++labelIndex) + { + (*sampleIt++)->label = readInt8(labelFile); + } } int8_t MnistLoader::readInt8(std::ifstream &file) diff --git a/gui/NeuroUI/mnistloader.h b/gui/NeuroUI/mnistloader.h index e15081c..482931c 100644 --- a/gui/NeuroUI/mnistloader.h +++ b/gui/NeuroUI/mnistloader.h @@ -2,10 +2,32 @@ #define MNISTLOADER_H #include +#include +#include #include class MnistLoader { +private: + static const uint32_t DatabaseFileMagicNumber = 2051; + static const uint32_t LabelFileMagicNumber = 2049; + static const size_t SampleWidth = 28; + static const size_t SampleHeight = 28; + +public: + template + class Sample + { + public: + uint8_t label; + uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT]; + }; + + using MnistSample = Sample; + +private: + std::list> samples; + public: MnistLoader(); diff --git a/gui/NeuroUI/netlearner.cpp b/gui/NeuroUI/netlearner.cpp index 51b5632..0b96baf 100644 --- a/gui/NeuroUI/netlearner.cpp +++ b/gui/NeuroUI/netlearner.cpp @@ -21,6 +21,8 @@ void NetLearner::run() emit logMessage("done"); emit progress(0.0); + return; + Net digitClassifier({32*32, 16*16, 32, 1}); timer.start();