From eecd7a0fe6dce7e2374074345c54d827824ba955 Mon Sep 17 00:00:00 2001 From: Michael Mandl Date: Sun, 1 Nov 2015 14:49:02 +0100 Subject: [PATCH] Querying of a single MNIST sample --- gui/NeuroUI/mnistloader.cpp | 15 +++++++++++++++ gui/NeuroUI/mnistloader.h | 2 ++ 2 files changed, 17 insertions(+) diff --git a/gui/NeuroUI/mnistloader.cpp b/gui/NeuroUI/mnistloader.cpp index 67c4c1f..91ffa65 100644 --- a/gui/NeuroUI/mnistloader.cpp +++ b/gui/NeuroUI/mnistloader.cpp @@ -8,6 +8,21 @@ void MnistLoader::load(const std::string &databaseFileName, const std::string &l loadLabels(labelsFileName); } +size_t MnistLoader::getSamleCount() const +{ + return samples.size(); +} + +const MnistLoader::MnistSample &MnistLoader::getSample(size_t index) const +{ + if (index >= samples.size()) + { + throw std::runtime_error("MNIST sample index out of range"); + } + + return *(samples[index].get()); +} + const MnistLoader::MnistSample &MnistLoader::getRandomSample() const { size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX; diff --git a/gui/NeuroUI/mnistloader.h b/gui/NeuroUI/mnistloader.h index d2a1e2b..6b1b57d 100644 --- a/gui/NeuroUI/mnistloader.h +++ b/gui/NeuroUI/mnistloader.h @@ -31,6 +31,8 @@ private: public: void load(const std::string &databaseFileName, const std::string &labelsFileName); + size_t getSamleCount() const; + const MnistSample &getSample(size_t index) const; const MnistSample &getRandomSample() const; private: