Completed loading of MNIST Databases

This commit is contained in:
mandlm 2015-10-31 12:28:35 +01:00
parent 87e267d65f
commit d98ec63fbd
3 changed files with 43 additions and 17 deletions

View file

@ -29,7 +29,7 @@ void MnistLoader::loadDatabase(const std::string &fileName)
} }
int32_t magicNumber = readInt32(databaseFile); int32_t magicNumber = readInt32(databaseFile);
if (magicNumber != 2051) if (magicNumber != DatabaseFileMagicNumber)
{ {
throw std::runtime_error("unexpected data reading MNIST database file"); throw std::runtime_error("unexpected data reading MNIST database file");
} }
@ -38,18 +38,19 @@ void MnistLoader::loadDatabase(const std::string &fileName)
int32_t sampleWidth = readInt32(databaseFile); int32_t sampleWidth = readInt32(databaseFile);
int32_t sampleHeight = readInt32(databaseFile); int32_t sampleHeight = readInt32(databaseFile);
std::list<std::unique_ptr<int8_t[]>> samples; if (sampleWidth != SampleWidth || sampleHeight != SampleHeight)
for (size_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex)
{ {
std::unique_ptr<int8_t[]> sampleData = std::make_unique<int8_t[]>(sampleWidth * sampleHeight); throw std::runtime_error("unexpected sample size loading MNIST database");
databaseFile.read(reinterpret_cast<char *>(sampleData.get()), sampleWidth * sampleHeight);
samples.push_back(std::move(sampleData));
} }
return; for (int32_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex)
{
std::unique_ptr<MnistSample> sample = std::make_unique<MnistSample>();
databaseFile.read(reinterpret_cast<char *>(sample->data), sampleWidth * sampleHeight);
samples.push_back(std::move(sample));
}
} }
void MnistLoader::loadLabels(const std::string &fileName) void MnistLoader::loadLabels(const std::string &fileName)
@ -63,21 +64,22 @@ void MnistLoader::loadLabels(const std::string &fileName)
} }
int32_t magicNumber = readInt32(labelFile); int32_t magicNumber = readInt32(labelFile);
if (magicNumber != 2049) if (magicNumber != LabelFileMagicNumber)
{ {
throw std::runtime_error("unexpected data reading MNIST label file"); throw std::runtime_error("unexpected data reading MNIST label file");
} }
int32_t labelCount = readInt32(labelFile); int32_t labelCount = readInt32(labelFile);
if (labelCount != static_cast<int32_t>(samples.size()))
std::list<int8_t> labels;
for (size_t labelIndex = 0; labelIndex < labelCount; ++labelIndex)
{ {
labels.push_back(readInt8(labelFile)); throw std::runtime_error("MNIST database and label files don't match in size");
} }
return; auto sampleIt = samples.begin();
for (int32_t labelIndex = 0; labelIndex < labelCount; ++labelIndex)
{
(*sampleIt++)->label = readInt8(labelFile);
}
} }
int8_t MnistLoader::readInt8(std::ifstream &file) int8_t MnistLoader::readInt8(std::ifstream &file)

View file

@ -2,10 +2,32 @@
#define MNISTLOADER_H #define MNISTLOADER_H
#include <string> #include <string>
#include <list>
#include <memory>
#include <inttypes.h> #include <inttypes.h>
class MnistLoader class MnistLoader
{ {
private:
static const uint32_t DatabaseFileMagicNumber = 2051;
static const uint32_t LabelFileMagicNumber = 2049;
static const size_t SampleWidth = 28;
static const size_t SampleHeight = 28;
public:
template<size_t SAMPLE_WIDTH, size_t SAMPLE_HEIGHT>
class Sample
{
public:
uint8_t label;
uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT];
};
using MnistSample = Sample<SampleWidth, SampleHeight>;
private:
std::list<std::unique_ptr<MnistSample>> samples;
public: public:
MnistLoader(); MnistLoader();

View file

@ -21,6 +21,8 @@ void NetLearner::run()
emit logMessage("done"); emit logMessage("done");
emit progress(0.0); emit progress(0.0);
return;
Net digitClassifier({32*32, 16*16, 32, 1}); Net digitClassifier({32*32, 16*16, 32, 1});
timer.start(); timer.start();