Compare commits
9 Commits
Author | SHA1 | Date |
---|---|---|
mandlm | 6a23b6a4ae | |
mandlm | c5fbe764a6 | |
mandlm | 0f7c617f9c | |
mandlm | 5e32724b1a | |
mandlm | 44760d0402 | |
mandlm | 1d343a079a | |
mandlm | eecd7a0fe6 | |
mandlm | ab9dcfbd35 | |
mandlm | 1a0d2b9ea7 |
|
@ -41,7 +41,7 @@ double Layer::getWeightedSum(size_t outputNeuron) const
|
||||||
sum += neuron.getWeightedOutputValue(outputNeuron);
|
sum += neuron.getWeightedOutputValue(outputNeuron);
|
||||||
}
|
}
|
||||||
|
|
||||||
return sum;
|
return sum / size();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layer::connectTo(const Layer & nextLayer)
|
void Layer::connectTo(const Layer & nextLayer)
|
||||||
|
|
|
@ -28,6 +28,11 @@ void ErrorPlotter::clear()
|
||||||
|
|
||||||
void ErrorPlotter::addErrorValue(double errorValue)
|
void ErrorPlotter::addErrorValue(double errorValue)
|
||||||
{
|
{
|
||||||
|
if (m_errorValues.size() == m_bufferSize)
|
||||||
|
{
|
||||||
|
m_errorValues.pop_front();
|
||||||
|
}
|
||||||
|
|
||||||
m_errorValues.push_back(errorValue);
|
m_errorValues.push_back(errorValue);
|
||||||
m_maxErrorValue = std::max<double>(m_maxErrorValue, errorValue);
|
m_maxErrorValue = std::max<double>(m_maxErrorValue, errorValue);
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,8 @@ private:
|
||||||
std::list<double> m_errorValues;
|
std::list<double> m_errorValues;
|
||||||
double m_maxErrorValue;
|
double m_maxErrorValue;
|
||||||
|
|
||||||
|
size_t m_bufferSize = 10000;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit ErrorPlotter(QWidget *parent = 0);
|
explicit ErrorPlotter(QWidget *parent = 0);
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,21 @@ void MnistLoader::load(const std::string &databaseFileName, const std::string &l
|
||||||
loadLabels(labelsFileName);
|
loadLabels(labelsFileName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t MnistLoader::getSamleCount() const
|
||||||
|
{
|
||||||
|
return samples.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
const MnistLoader::MnistSample &MnistLoader::getSample(size_t index) const
|
||||||
|
{
|
||||||
|
if (index >= samples.size())
|
||||||
|
{
|
||||||
|
throw std::runtime_error("MNIST sample index out of range");
|
||||||
|
}
|
||||||
|
|
||||||
|
return *(samples[index].get());
|
||||||
|
}
|
||||||
|
|
||||||
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
|
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
|
||||||
{
|
{
|
||||||
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;
|
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;
|
||||||
|
|
|
@ -6,6 +6,8 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <inttypes.h>
|
#include <inttypes.h>
|
||||||
|
|
||||||
|
#include <QImage>
|
||||||
|
|
||||||
class MnistLoader
|
class MnistLoader
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
|
@ -21,6 +23,11 @@ public:
|
||||||
public:
|
public:
|
||||||
uint8_t label;
|
uint8_t label;
|
||||||
uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT];
|
uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT];
|
||||||
|
|
||||||
|
QImage toQImage() const
|
||||||
|
{
|
||||||
|
return QImage(data, SAMPLE_WIDTH, SAMPLE_HEIGHT, QImage::Format_Grayscale8);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
using MnistSample = Sample<SampleWidth, SampleHeight>;
|
using MnistSample = Sample<SampleWidth, SampleHeight>;
|
||||||
|
@ -31,6 +38,8 @@ private:
|
||||||
public:
|
public:
|
||||||
void load(const std::string &databaseFileName, const std::string &labelsFileName);
|
void load(const std::string &databaseFileName, const std::string &labelsFileName);
|
||||||
|
|
||||||
|
size_t getSamleCount() const;
|
||||||
|
const MnistSample &getSample(size_t index) const;
|
||||||
const MnistSample &getRandomSample() const;
|
const MnistSample &getRandomSample() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -24,12 +24,12 @@ void NetLearner::run()
|
||||||
timer.start();
|
timer.start();
|
||||||
|
|
||||||
size_t numIterations = 100000;
|
size_t numIterations = 100000;
|
||||||
for (size_t iteration = 0; iteration < numIterations; ++iteration)
|
for (size_t iteration = 0; iteration < numIterations && cancel == false; ++iteration)
|
||||||
{
|
{
|
||||||
auto trainingSample = mnistLoader.getRandomSample();
|
auto trainingSample = mnistLoader.getRandomSample();
|
||||||
|
|
||||||
QImage trainingImage(trainingSample.data, 28, 28, QImage::Format_Grayscale8);
|
// emit logMessage(QString("training sample ") + QString::number(trainingSample.label));
|
||||||
emit sampleImageLoaded(trainingImage);
|
emit sampleImageLoaded(trainingSample.toQImage());
|
||||||
|
|
||||||
std::vector<double> targetValues =
|
std::vector<double> targetValues =
|
||||||
{
|
{
|
||||||
|
@ -49,12 +49,7 @@ void NetLearner::run()
|
||||||
|
|
||||||
double error = outputValues[0] - targetValues[0];
|
double error = outputValues[0] - targetValues[0];
|
||||||
|
|
||||||
QString logString;
|
emit logMessage(QString("Error: ") + QString::number(std::abs(error)));
|
||||||
|
|
||||||
logString.append("Error: ");
|
|
||||||
logString.append(QString::number(std::abs(error)));
|
|
||||||
|
|
||||||
emit logMessage(logString);
|
|
||||||
emit currentNetError(error);
|
emit currentNetError(error);
|
||||||
emit progress((double)iteration / (double)numIterations);
|
emit progress((double)iteration / (double)numIterations);
|
||||||
|
|
||||||
|
@ -76,4 +71,11 @@ void NetLearner::run()
|
||||||
logString.append(ex.what());
|
logString.append(ex.what());
|
||||||
emit logMessage(logString);
|
emit logMessage(logString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cancel = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NetLearner::cancelLearning()
|
||||||
|
{
|
||||||
|
cancel = true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,9 @@ class NetLearner : public QThread
|
||||||
{
|
{
|
||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool cancel = false;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void run() Q_DECL_OVERRIDE;
|
void run() Q_DECL_OVERRIDE;
|
||||||
|
|
||||||
|
@ -15,6 +18,9 @@ signals:
|
||||||
void progress(double progress);
|
void progress(double progress);
|
||||||
void currentNetError(double error);
|
void currentNetError(double error);
|
||||||
void sampleImageLoaded(const QImage &image);
|
void sampleImageLoaded(const QImage &image);
|
||||||
|
|
||||||
|
public slots:
|
||||||
|
void cancelLearning();
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // NETLEARNER_H
|
#endif // NETLEARNER_H
|
||||||
|
|
|
@ -12,6 +12,12 @@ NeuroUI::NeuroUI(QWidget *parent) :
|
||||||
|
|
||||||
NeuroUI::~NeuroUI()
|
NeuroUI::~NeuroUI()
|
||||||
{
|
{
|
||||||
|
if (m_netLearner != nullptr)
|
||||||
|
{
|
||||||
|
m_netLearner->cancelLearning();
|
||||||
|
m_netLearner->wait();
|
||||||
|
}
|
||||||
|
|
||||||
delete ui;
|
delete ui;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,6 +46,11 @@ void NeuroUI::on_runButton_clicked()
|
||||||
|
|
||||||
void NeuroUI::logMessage(const QString &logMessage)
|
void NeuroUI::logMessage(const QString &logMessage)
|
||||||
{
|
{
|
||||||
|
if (ui->logView->count() == static_cast<int>(m_logSize))
|
||||||
|
{
|
||||||
|
delete ui->logView->item(0);
|
||||||
|
}
|
||||||
|
|
||||||
ui->logView->addItem(logMessage);
|
ui->logView->addItem(logMessage);
|
||||||
ui->logView->scrollToBottom();
|
ui->logView->scrollToBottom();
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ class NeuroUI : public QMainWindow
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<NetLearner> m_netLearner;
|
std::unique_ptr<NetLearner> m_netLearner;
|
||||||
|
size_t m_logSize = 128;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit NeuroUI(QWidget *parent = 0);
|
explicit NeuroUI(QWidget *parent = 0);
|
||||||
|
|
Loading…
Reference in New Issue