Completed loading of MNIST Databases
This commit is contained in:
parent
87e267d65f
commit
d98ec63fbd
3 changed files with 43 additions and 17 deletions
|
@ -29,7 +29,7 @@ void MnistLoader::loadDatabase(const std::string &fileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t magicNumber = readInt32(databaseFile);
|
int32_t magicNumber = readInt32(databaseFile);
|
||||||
if (magicNumber != 2051)
|
if (magicNumber != DatabaseFileMagicNumber)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("unexpected data reading MNIST database file");
|
throw std::runtime_error("unexpected data reading MNIST database file");
|
||||||
}
|
}
|
||||||
|
@ -38,18 +38,19 @@ void MnistLoader::loadDatabase(const std::string &fileName)
|
||||||
int32_t sampleWidth = readInt32(databaseFile);
|
int32_t sampleWidth = readInt32(databaseFile);
|
||||||
int32_t sampleHeight = readInt32(databaseFile);
|
int32_t sampleHeight = readInt32(databaseFile);
|
||||||
|
|
||||||
std::list<std::unique_ptr<int8_t[]>> samples;
|
if (sampleWidth != SampleWidth || sampleHeight != SampleHeight)
|
||||||
|
|
||||||
for (size_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex)
|
|
||||||
{
|
{
|
||||||
std::unique_ptr<int8_t[]> sampleData = std::make_unique<int8_t[]>(sampleWidth * sampleHeight);
|
throw std::runtime_error("unexpected sample size loading MNIST database");
|
||||||
|
|
||||||
databaseFile.read(reinterpret_cast<char *>(sampleData.get()), sampleWidth * sampleHeight);
|
|
||||||
|
|
||||||
samples.push_back(std::move(sampleData));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
for (int32_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex)
|
||||||
|
{
|
||||||
|
std::unique_ptr<MnistSample> sample = std::make_unique<MnistSample>();
|
||||||
|
|
||||||
|
databaseFile.read(reinterpret_cast<char *>(sample->data), sampleWidth * sampleHeight);
|
||||||
|
|
||||||
|
samples.push_back(std::move(sample));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MnistLoader::loadLabels(const std::string &fileName)
|
void MnistLoader::loadLabels(const std::string &fileName)
|
||||||
|
@ -63,21 +64,22 @@ void MnistLoader::loadLabels(const std::string &fileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t magicNumber = readInt32(labelFile);
|
int32_t magicNumber = readInt32(labelFile);
|
||||||
if (magicNumber != 2049)
|
if (magicNumber != LabelFileMagicNumber)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("unexpected data reading MNIST label file");
|
throw std::runtime_error("unexpected data reading MNIST label file");
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t labelCount = readInt32(labelFile);
|
int32_t labelCount = readInt32(labelFile);
|
||||||
|
if (labelCount != static_cast<int32_t>(samples.size()))
|
||||||
std::list<int8_t> labels;
|
|
||||||
|
|
||||||
for (size_t labelIndex = 0; labelIndex < labelCount; ++labelIndex)
|
|
||||||
{
|
{
|
||||||
labels.push_back(readInt8(labelFile));
|
throw std::runtime_error("MNIST database and label files don't match in size");
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
auto sampleIt = samples.begin();
|
||||||
|
for (int32_t labelIndex = 0; labelIndex < labelCount; ++labelIndex)
|
||||||
|
{
|
||||||
|
(*sampleIt++)->label = readInt8(labelFile);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int8_t MnistLoader::readInt8(std::ifstream &file)
|
int8_t MnistLoader::readInt8(std::ifstream &file)
|
||||||
|
|
|
@ -2,10 +2,32 @@
|
||||||
#define MNISTLOADER_H
|
#define MNISTLOADER_H
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <list>
|
||||||
|
#include <memory>
|
||||||
#include <inttypes.h>
|
#include <inttypes.h>
|
||||||
|
|
||||||
class MnistLoader
|
class MnistLoader
|
||||||
{
|
{
|
||||||
|
private:
|
||||||
|
static const uint32_t DatabaseFileMagicNumber = 2051;
|
||||||
|
static const uint32_t LabelFileMagicNumber = 2049;
|
||||||
|
static const size_t SampleWidth = 28;
|
||||||
|
static const size_t SampleHeight = 28;
|
||||||
|
|
||||||
|
public:
|
||||||
|
template<size_t SAMPLE_WIDTH, size_t SAMPLE_HEIGHT>
|
||||||
|
class Sample
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
uint8_t label;
|
||||||
|
uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT];
|
||||||
|
};
|
||||||
|
|
||||||
|
using MnistSample = Sample<SampleWidth, SampleHeight>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::list<std::unique_ptr<MnistSample>> samples;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MnistLoader();
|
MnistLoader();
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,8 @@ void NetLearner::run()
|
||||||
emit logMessage("done");
|
emit logMessage("done");
|
||||||
emit progress(0.0);
|
emit progress(0.0);
|
||||||
|
|
||||||
|
return;
|
||||||
|
|
||||||
Net digitClassifier({32*32, 16*16, 32, 1});
|
Net digitClassifier({32*32, 16*16, 32, 1});
|
||||||
|
|
||||||
timer.start();
|
timer.start();
|
||||||
|
|
Loading…
Reference in a new issue