diff --git a/gui/NeuroUI/mnistloader.cpp b/gui/NeuroUI/mnistloader.cpp index 853c979..6c2c7dd 100644 --- a/gui/NeuroUI/mnistloader.cpp +++ b/gui/NeuroUI/mnistloader.cpp @@ -1,5 +1,12 @@ #include "mnistloader.h" +#include +#include +#include +#include + +#include + MnistLoader::MnistLoader() { @@ -7,6 +14,83 @@ MnistLoader::MnistLoader() void MnistLoader::load(const std::string &databaseFileName, const std::string &labelsFileName) { - + loadDatabase(databaseFileName); + loadLabels(labelsFileName); +} + +void MnistLoader::loadDatabase(const std::string &fileName) +{ + std::ifstream databaseFile; + databaseFile.open(fileName, std::ios::binary); + + if (!databaseFile.is_open()) + { + throw std::runtime_error("unable to open MNIST database file"); + } + + int32_t magicNumber = readInt32(databaseFile); + if (magicNumber != 2051) + { + throw std::runtime_error("unexpected data reading MNIST database file"); + } + + int32_t sampleCount = readInt32(databaseFile); + int32_t sampleWidth = readInt32(databaseFile); + int32_t sampleHeight = readInt32(databaseFile); + + std::list> samples; + + for (size_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex) + { + std::unique_ptr sampleData = std::make_unique(sampleWidth * sampleHeight); + + databaseFile.read(reinterpret_cast(sampleData.get()), sampleWidth * sampleHeight); + + samples.push_back(std::move(sampleData)); + } + + return; +} + +void MnistLoader::loadLabels(const std::string &fileName) +{ + std::ifstream labelFile; + labelFile.open(fileName, std::ios::binary); + + if (!labelFile.is_open()) + { + throw std::runtime_error("unable to open MNIST label file"); + } + + int32_t magicNumber = readInt32(labelFile); + if (magicNumber != 2049) + { + throw std::runtime_error("unexpected data reading MNIST label file"); + } + + int32_t labelCount = readInt32(labelFile); + + std::list labels; + + for (size_t labelIndex = 0; labelIndex < labelCount; ++labelIndex) + { + labels.push_back(readInt8(labelFile)); + } + + return; +} + +int8_t MnistLoader::readInt8(std::ifstream &file) +{ + int8_t buf8; + file.read(reinterpret_cast(&buf8), sizeof(buf8)); + return buf8; +} + +int32_t MnistLoader::readInt32(std::ifstream &file) +{ + int32_t buf32; + file.read(reinterpret_cast(&buf32), sizeof(buf32)); + return _byteswap_ulong(buf32); } diff --git a/gui/NeuroUI/mnistloader.h b/gui/NeuroUI/mnistloader.h index f277eba..e15081c 100644 --- a/gui/NeuroUI/mnistloader.h +++ b/gui/NeuroUI/mnistloader.h @@ -2,6 +2,7 @@ #define MNISTLOADER_H #include +#include class MnistLoader { @@ -9,6 +10,13 @@ public: MnistLoader(); void load(const std::string &databaseFileName, const std::string &labelsFileName); + +private: + void loadDatabase(const std::string &fileName); + void loadLabels(const std::string &fileName); + + static int8_t readInt8(std::ifstream &file); + static int32_t readInt32(std::ifstream &file); }; #endif // MNISTLOADER_H diff --git a/gui/NeuroUI/netlearner.cpp b/gui/NeuroUI/netlearner.cpp index 59147c9..51b5632 100644 --- a/gui/NeuroUI/netlearner.cpp +++ b/gui/NeuroUI/netlearner.cpp @@ -15,8 +15,8 @@ void NetLearner::run() emit progress(0.0); MnistLoader mnistLoader; - mnistLoader.load("../NeuroUI/MNIST Aatabase/train-images.idx3-ubyte", - "../NeuroUI/MNIST Aatabase/train-labels.idx1-ubyte"); + mnistLoader.load("../NeuroUI/MNIST Database/train-images.idx3-ubyte", + "../NeuroUI/MNIST Database/train-labels.idx1-ubyte"); emit logMessage("done"); emit progress(0.0);