Neuro/gui/NeuroUI/mnistloader.h

45 lines
1 KiB
C
Raw Normal View History

#ifndef MNISTLOADER_H
#define MNISTLOADER_H
#include <string>
#include <vector>
2015-10-31 11:28:35 +00:00
#include <memory>
2015-10-29 15:00:58 +00:00
#include <inttypes.h>
class MnistLoader
{
2015-10-31 11:28:35 +00:00
private:
static const uint32_t DatabaseFileMagicNumber = 2051;
static const uint32_t LabelFileMagicNumber = 2049;
static const size_t SampleWidth = 28;
static const size_t SampleHeight = 28;
public:
template<size_t SAMPLE_WIDTH, size_t SAMPLE_HEIGHT>
class Sample
{
public:
uint8_t label;
uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT];
};
using MnistSample = Sample<SampleWidth, SampleHeight>;
private:
std::vector<std::unique_ptr<MnistSample>> samples;
2015-10-31 11:28:35 +00:00
public:
void load(const std::string &databaseFileName, const std::string &labelsFileName);
2015-10-29 15:00:58 +00:00
const MnistSample &getRandomSample() const;
2015-10-29 15:00:58 +00:00
private:
void loadDatabase(const std::string &fileName);
void loadLabels(const std::string &fileName);
static int8_t readInt8(std::ifstream &file);
static int32_t readInt32(std::ifstream &file);
};
#endif // MNISTLOADER_H