Compare commits

...

9 Commits
digits ... main

9 changed files with 61 additions and 10 deletions

View File

@ -41,7 +41,7 @@ double Layer::getWeightedSum(size_t outputNeuron) const
sum += neuron.getWeightedOutputValue(outputNeuron); sum += neuron.getWeightedOutputValue(outputNeuron);
} }
return sum; return sum / size();
} }
void Layer::connectTo(const Layer & nextLayer) void Layer::connectTo(const Layer & nextLayer)

View File

@ -28,6 +28,11 @@ void ErrorPlotter::clear()
void ErrorPlotter::addErrorValue(double errorValue) void ErrorPlotter::addErrorValue(double errorValue)
{ {
if (m_errorValues.size() == m_bufferSize)
{
m_errorValues.pop_front();
}
m_errorValues.push_back(errorValue); m_errorValues.push_back(errorValue);
m_maxErrorValue = std::max<double>(m_maxErrorValue, errorValue); m_maxErrorValue = std::max<double>(m_maxErrorValue, errorValue);

View File

@ -11,6 +11,8 @@ private:
std::list<double> m_errorValues; std::list<double> m_errorValues;
double m_maxErrorValue; double m_maxErrorValue;
size_t m_bufferSize = 10000;
public: public:
explicit ErrorPlotter(QWidget *parent = 0); explicit ErrorPlotter(QWidget *parent = 0);

View File

@ -8,6 +8,21 @@ void MnistLoader::load(const std::string &databaseFileName, const std::string &l
loadLabels(labelsFileName); loadLabels(labelsFileName);
} }
size_t MnistLoader::getSamleCount() const
{
return samples.size();
}
const MnistLoader::MnistSample &MnistLoader::getSample(size_t index) const
{
if (index >= samples.size())
{
throw std::runtime_error("MNIST sample index out of range");
}
return *(samples[index].get());
}
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
{ {
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX; size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;

View File

@ -6,6 +6,8 @@
#include <memory> #include <memory>
#include <inttypes.h> #include <inttypes.h>
#include <QImage>
class MnistLoader class MnistLoader
{ {
private: private:
@ -21,6 +23,11 @@ public:
public: public:
uint8_t label; uint8_t label;
uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT]; uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT];
QImage toQImage() const
{
return QImage(data, SAMPLE_WIDTH, SAMPLE_HEIGHT, QImage::Format_Grayscale8);
}
}; };
using MnistSample = Sample<SampleWidth, SampleHeight>; using MnistSample = Sample<SampleWidth, SampleHeight>;
@ -31,6 +38,8 @@ private:
public: public:
void load(const std::string &databaseFileName, const std::string &labelsFileName); void load(const std::string &databaseFileName, const std::string &labelsFileName);
size_t getSamleCount() const;
const MnistSample &getSample(size_t index) const;
const MnistSample &getRandomSample() const; const MnistSample &getRandomSample() const;
private: private:

View File

@ -24,12 +24,12 @@ void NetLearner::run()
timer.start(); timer.start();
size_t numIterations = 100000; size_t numIterations = 100000;
for (size_t iteration = 0; iteration < numIterations; ++iteration) for (size_t iteration = 0; iteration < numIterations && cancel == false; ++iteration)
{ {
auto trainingSample = mnistLoader.getRandomSample(); auto trainingSample = mnistLoader.getRandomSample();
QImage trainingImage(trainingSample.data, 28, 28, QImage::Format_Grayscale8); // emit logMessage(QString("training sample ") + QString::number(trainingSample.label));
emit sampleImageLoaded(trainingImage); emit sampleImageLoaded(trainingSample.toQImage());
std::vector<double> targetValues = std::vector<double> targetValues =
{ {
@ -49,12 +49,7 @@ void NetLearner::run()
double error = outputValues[0] - targetValues[0]; double error = outputValues[0] - targetValues[0];
QString logString; emit logMessage(QString("Error: ") + QString::number(std::abs(error)));
logString.append("Error: ");
logString.append(QString::number(std::abs(error)));
emit logMessage(logString);
emit currentNetError(error); emit currentNetError(error);
emit progress((double)iteration / (double)numIterations); emit progress((double)iteration / (double)numIterations);
@ -76,4 +71,11 @@ void NetLearner::run()
logString.append(ex.what()); logString.append(ex.what());
emit logMessage(logString); emit logMessage(logString);
} }
cancel = false;
}
void NetLearner::cancelLearning()
{
cancel = true;
} }

View File

@ -7,6 +7,9 @@ class NetLearner : public QThread
{ {
Q_OBJECT Q_OBJECT
private:
bool cancel = false;
private: private:
void run() Q_DECL_OVERRIDE; void run() Q_DECL_OVERRIDE;
@ -15,6 +18,9 @@ signals:
void progress(double progress); void progress(double progress);
void currentNetError(double error); void currentNetError(double error);
void sampleImageLoaded(const QImage &image); void sampleImageLoaded(const QImage &image);
public slots:
void cancelLearning();
}; };
#endif // NETLEARNER_H #endif // NETLEARNER_H

View File

@ -12,6 +12,12 @@ NeuroUI::NeuroUI(QWidget *parent) :
NeuroUI::~NeuroUI() NeuroUI::~NeuroUI()
{ {
if (m_netLearner != nullptr)
{
m_netLearner->cancelLearning();
m_netLearner->wait();
}
delete ui; delete ui;
} }
@ -40,6 +46,11 @@ void NeuroUI::on_runButton_clicked()
void NeuroUI::logMessage(const QString &logMessage) void NeuroUI::logMessage(const QString &logMessage)
{ {
if (ui->logView->count() == static_cast<int>(m_logSize))
{
delete ui->logView->item(0);
}
ui->logView->addItem(logMessage); ui->logView->addItem(logMessage);
ui->logView->scrollToBottom(); ui->logView->scrollToBottom();
} }

View File

@ -17,6 +17,7 @@ class NeuroUI : public QMainWindow
private: private:
std::unique_ptr<NetLearner> m_netLearner; std::unique_ptr<NetLearner> m_netLearner;
size_t m_logSize = 128;
public: public:
explicit NeuroUI(QWidget *parent = 0); explicit NeuroUI(QWidget *parent = 0);