diff --git a/Layer.cpp b/Layer.cpp index f4b7319..5c7b660 100644 --- a/Layer.cpp +++ b/Layer.cpp @@ -41,7 +41,7 @@ double Layer::getWeightedSum(size_t outputNeuron) const sum += neuron.getWeightedOutputValue(outputNeuron); } - return sum; + return sum / size(); } void Layer::connectTo(const Layer & nextLayer) diff --git a/gui/NeuroUI/errorplotter.cpp b/gui/NeuroUI/errorplotter.cpp index 4e6f0b3..0bd05ed 100644 --- a/gui/NeuroUI/errorplotter.cpp +++ b/gui/NeuroUI/errorplotter.cpp @@ -28,6 +28,11 @@ void ErrorPlotter::clear() void ErrorPlotter::addErrorValue(double errorValue) { + if (m_errorValues.size() == m_bufferSize) + { + m_errorValues.pop_front(); + } + m_errorValues.push_back(errorValue); m_maxErrorValue = std::max(m_maxErrorValue, errorValue); diff --git a/gui/NeuroUI/errorplotter.h b/gui/NeuroUI/errorplotter.h index 47c5093..ac4488f 100644 --- a/gui/NeuroUI/errorplotter.h +++ b/gui/NeuroUI/errorplotter.h @@ -11,6 +11,8 @@ private: std::list m_errorValues; double m_maxErrorValue; + size_t m_bufferSize = 10000; + public: explicit ErrorPlotter(QWidget *parent = 0); diff --git a/gui/NeuroUI/mnistloader.cpp b/gui/NeuroUI/mnistloader.cpp index 67c4c1f..91ffa65 100644 --- a/gui/NeuroUI/mnistloader.cpp +++ b/gui/NeuroUI/mnistloader.cpp @@ -8,6 +8,21 @@ void MnistLoader::load(const std::string &databaseFileName, const std::string &l loadLabels(labelsFileName); } +size_t MnistLoader::getSamleCount() const +{ + return samples.size(); +} + +const MnistLoader::MnistSample &MnistLoader::getSample(size_t index) const +{ + if (index >= samples.size()) + { + throw std::runtime_error("MNIST sample index out of range"); + } + + return *(samples[index].get()); +} + const MnistLoader::MnistSample &MnistLoader::getRandomSample() const { size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX; diff --git a/gui/NeuroUI/mnistloader.h b/gui/NeuroUI/mnistloader.h index d2a1e2b..3e0f46f 100644 --- a/gui/NeuroUI/mnistloader.h +++ b/gui/NeuroUI/mnistloader.h @@ -6,6 +6,8 @@ #include #include +#include + class MnistLoader { private: @@ -21,6 +23,11 @@ public: public: uint8_t label; uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT]; + + QImage toQImage() const + { + return QImage(data, SAMPLE_WIDTH, SAMPLE_HEIGHT, QImage::Format_Grayscale8); + } }; using MnistSample = Sample; @@ -31,6 +38,8 @@ private: public: void load(const std::string &databaseFileName, const std::string &labelsFileName); + size_t getSamleCount() const; + const MnistSample &getSample(size_t index) const; const MnistSample &getRandomSample() const; private: diff --git a/gui/NeuroUI/netlearner.cpp b/gui/NeuroUI/netlearner.cpp index b0956e1..65acf0c 100644 --- a/gui/NeuroUI/netlearner.cpp +++ b/gui/NeuroUI/netlearner.cpp @@ -24,12 +24,12 @@ void NetLearner::run() timer.start(); size_t numIterations = 100000; - for (size_t iteration = 0; iteration < numIterations; ++iteration) + for (size_t iteration = 0; iteration < numIterations && cancel == false; ++iteration) { auto trainingSample = mnistLoader.getRandomSample(); - QImage trainingImage(trainingSample.data, 28, 28, QImage::Format_Grayscale8); - emit sampleImageLoaded(trainingImage); +// emit logMessage(QString("training sample ") + QString::number(trainingSample.label)); + emit sampleImageLoaded(trainingSample.toQImage()); std::vector targetValues = { @@ -49,12 +49,7 @@ void NetLearner::run() double error = outputValues[0] - targetValues[0]; - QString logString; - - logString.append("Error: "); - logString.append(QString::number(std::abs(error))); - - emit logMessage(logString); + emit logMessage(QString("Error: ") + QString::number(std::abs(error))); emit currentNetError(error); emit progress((double)iteration / (double)numIterations); @@ -76,4 +71,11 @@ void NetLearner::run() logString.append(ex.what()); emit logMessage(logString); } + + cancel = false; +} + +void NetLearner::cancelLearning() +{ + cancel = true; } diff --git a/gui/NeuroUI/netlearner.h b/gui/NeuroUI/netlearner.h index 70378a4..018c3bf 100644 --- a/gui/NeuroUI/netlearner.h +++ b/gui/NeuroUI/netlearner.h @@ -7,6 +7,9 @@ class NetLearner : public QThread { Q_OBJECT +private: + bool cancel = false; + private: void run() Q_DECL_OVERRIDE; @@ -15,6 +18,9 @@ signals: void progress(double progress); void currentNetError(double error); void sampleImageLoaded(const QImage &image); + +public slots: + void cancelLearning(); }; #endif // NETLEARNER_H diff --git a/gui/NeuroUI/neuroui.cpp b/gui/NeuroUI/neuroui.cpp index c057984..54b00ea 100644 --- a/gui/NeuroUI/neuroui.cpp +++ b/gui/NeuroUI/neuroui.cpp @@ -12,6 +12,12 @@ NeuroUI::NeuroUI(QWidget *parent) : NeuroUI::~NeuroUI() { + if (m_netLearner != nullptr) + { + m_netLearner->cancelLearning(); + m_netLearner->wait(); + } + delete ui; } @@ -40,6 +46,11 @@ void NeuroUI::on_runButton_clicked() void NeuroUI::logMessage(const QString &logMessage) { + if (ui->logView->count() == static_cast(m_logSize)) + { + delete ui->logView->item(0); + } + ui->logView->addItem(logMessage); ui->logView->scrollToBottom(); } diff --git a/gui/NeuroUI/neuroui.h b/gui/NeuroUI/neuroui.h index 0799c53..92861ac 100644 --- a/gui/NeuroUI/neuroui.h +++ b/gui/NeuroUI/neuroui.h @@ -17,6 +17,7 @@ class NeuroUI : public QMainWindow private: std::unique_ptr m_netLearner; + size_t m_logSize = 128; public: explicit NeuroUI(QWidget *parent = 0);