diff --git a/gui/NeuroUI/mnistloader.cpp b/gui/NeuroUI/mnistloader.cpp index 67c4c1f..91ffa65 100644 --- a/gui/NeuroUI/mnistloader.cpp +++ b/gui/NeuroUI/mnistloader.cpp @@ -8,6 +8,21 @@ void MnistLoader::load(const std::string &databaseFileName, const std::string &l loadLabels(labelsFileName); } +size_t MnistLoader::getSamleCount() const +{ + return samples.size(); +} + +const MnistLoader::MnistSample &MnistLoader::getSample(size_t index) const +{ + if (index >= samples.size()) + { + throw std::runtime_error("MNIST sample index out of range"); + } + + return *(samples[index].get()); +} + const MnistLoader::MnistSample &MnistLoader::getRandomSample() const { size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX; diff --git a/gui/NeuroUI/mnistloader.h b/gui/NeuroUI/mnistloader.h index d2a1e2b..6b1b57d 100644 --- a/gui/NeuroUI/mnistloader.h +++ b/gui/NeuroUI/mnistloader.h @@ -31,6 +31,8 @@ private: public: void load(const std::string &databaseFileName, const std::string &labelsFileName); + size_t getSamleCount() const; + const MnistSample &getSample(size_t index) const; const MnistSample &getRandomSample() const; private: