Querying of a single MNIST sample

This commit is contained in:
Michael Mandl 2015-11-01 14:49:02 +01:00
parent ab9dcfbd35
commit eecd7a0fe6
2 changed files with 17 additions and 0 deletions

View file

@ -8,6 +8,21 @@ void MnistLoader::load(const std::string &databaseFileName, const std::string &l
loadLabels(labelsFileName); loadLabels(labelsFileName);
} }
size_t MnistLoader::getSamleCount() const
{
return samples.size();
}
const MnistLoader::MnistSample &MnistLoader::getSample(size_t index) const
{
if (index >= samples.size())
{
throw std::runtime_error("MNIST sample index out of range");
}
return *(samples[index].get());
}
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
{ {
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX; size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;

View file

@ -31,6 +31,8 @@ private:
public: public:
void load(const std::string &databaseFileName, const std::string &labelsFileName); void load(const std::string &databaseFileName, const std::string &labelsFileName);
size_t getSamleCount() const;
const MnistSample &getSample(size_t index) const;
const MnistSample &getRandomSample() const; const MnistSample &getRandomSample() const;
private: private: