#include "mnistloader.h" #include #include #include #include #include MnistLoader::MnistLoader() { } void MnistLoader::load(const std::string &databaseFileName, const std::string &labelsFileName) { loadDatabase(databaseFileName); loadLabels(labelsFileName); } void MnistLoader::loadDatabase(const std::string &fileName) { std::ifstream databaseFile; databaseFile.open(fileName, std::ios::binary); if (!databaseFile.is_open()) { throw std::runtime_error("unable to open MNIST database file"); } int32_t magicNumber = readInt32(databaseFile); if (magicNumber != 2051) { throw std::runtime_error("unexpected data reading MNIST database file"); } int32_t sampleCount = readInt32(databaseFile); int32_t sampleWidth = readInt32(databaseFile); int32_t sampleHeight = readInt32(databaseFile); std::list> samples; for (size_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex) { std::unique_ptr sampleData = std::make_unique(sampleWidth * sampleHeight); databaseFile.read(reinterpret_cast(sampleData.get()), sampleWidth * sampleHeight); samples.push_back(std::move(sampleData)); } return; } void MnistLoader::loadLabels(const std::string &fileName) { std::ifstream labelFile; labelFile.open(fileName, std::ios::binary); if (!labelFile.is_open()) { throw std::runtime_error("unable to open MNIST label file"); } int32_t magicNumber = readInt32(labelFile); if (magicNumber != 2049) { throw std::runtime_error("unexpected data reading MNIST label file"); } int32_t labelCount = readInt32(labelFile); std::list labels; for (size_t labelIndex = 0; labelIndex < labelCount; ++labelIndex) { labels.push_back(readInt8(labelFile)); } return; } int8_t MnistLoader::readInt8(std::ifstream &file) { int8_t buf8; file.read(reinterpret_cast(&buf8), sizeof(buf8)); return buf8; } int32_t MnistLoader::readInt32(std::ifstream &file) { int32_t buf32; file.read(reinterpret_cast(&buf32), sizeof(buf32)); return _byteswap_ulong(buf32); }