Threaded learning and signaling in the Qt UI
This commit is contained in:
parent
c9fb9b9fa8
commit
6943fc0116
8 changed files with 154 additions and 16 deletions
|
@ -24,7 +24,7 @@ void Layer::setOutputValues(const std::vector<double> & outputValues)
|
||||||
|
|
||||||
void Layer::feedForward(const Layer &inputLayer)
|
void Layer::feedForward(const Layer &inputLayer)
|
||||||
{
|
{
|
||||||
for (int neuronNumber = 0; neuronNumber < sizeWithoutBiasNeuron(); ++neuronNumber)
|
for (size_t neuronNumber = 0; neuronNumber < sizeWithoutBiasNeuron(); ++neuronNumber)
|
||||||
{
|
{
|
||||||
at(neuronNumber).feedForward(inputLayer.getWeightedSum(neuronNumber));
|
at(neuronNumber).feedForward(inputLayer.getWeightedSum(neuronNumber));
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,8 +13,19 @@ TEMPLATE = app
|
||||||
|
|
||||||
|
|
||||||
SOURCES += main.cpp\
|
SOURCES += main.cpp\
|
||||||
neuroui.cpp
|
neuroui.cpp \
|
||||||
|
../../Layer.cpp \
|
||||||
|
../../Net.cpp \
|
||||||
|
../../Neuron.cpp \
|
||||||
|
netlearner.cpp
|
||||||
|
|
||||||
HEADERS += neuroui.h
|
HEADERS += neuroui.h \
|
||||||
|
../../Layer.h \
|
||||||
|
../../Net.h \
|
||||||
|
../../Neuron.h \
|
||||||
|
netlearner.h
|
||||||
|
|
||||||
FORMS += neuroui.ui
|
FORMS += neuroui.ui
|
||||||
|
|
||||||
|
RESOURCES += \
|
||||||
|
icons.qrc
|
||||||
|
|
2
gui/NeuroUI/icons.qrc
Normal file
2
gui/NeuroUI/icons.qrc
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
<RCC/>
|
||||||
|
|
65
gui/NeuroUI/netlearner.cpp
Normal file
65
gui/NeuroUI/netlearner.cpp
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
#include "netlearner.h"
|
||||||
|
#include "../../Net.h"
|
||||||
|
|
||||||
|
void NetLearner::run()
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
Net myNet({2, 3, 1});
|
||||||
|
|
||||||
|
size_t batchSize = 5000;
|
||||||
|
size_t batchIndex = 0;
|
||||||
|
double batchMaxError = 0.0;
|
||||||
|
double batchMeanError = 0.0;
|
||||||
|
|
||||||
|
size_t numIterations = 1000000;
|
||||||
|
for (size_t iteration = 0; iteration < numIterations; ++iteration)
|
||||||
|
{
|
||||||
|
std::vector<double> inputValues =
|
||||||
|
{
|
||||||
|
std::rand() / (double)RAND_MAX,
|
||||||
|
std::rand() / (double)RAND_MAX
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<double> targetValues =
|
||||||
|
{
|
||||||
|
(inputValues[0] + inputValues[1]) / 2.0
|
||||||
|
};
|
||||||
|
|
||||||
|
myNet.feedForward(inputValues);
|
||||||
|
|
||||||
|
std::vector<double> outputValues = myNet.getOutput();
|
||||||
|
|
||||||
|
double error = outputValues[0] - targetValues[0];
|
||||||
|
|
||||||
|
batchMeanError += error;
|
||||||
|
batchMaxError = std::max<double>(batchMaxError, error);
|
||||||
|
|
||||||
|
if (batchIndex++ == batchSize)
|
||||||
|
{
|
||||||
|
QString logString;
|
||||||
|
|
||||||
|
logString.append("Batch error (");
|
||||||
|
logString.append(QString::number(batchSize));
|
||||||
|
logString.append(" iterations, max/mean): ");
|
||||||
|
logString.append(QString::number(std::abs(batchMaxError)));
|
||||||
|
logString.append(" / ");
|
||||||
|
logString.append(QString::number(std::abs(batchMeanError / batchSize)));
|
||||||
|
|
||||||
|
emit logMessage(logString);
|
||||||
|
|
||||||
|
batchIndex = 0;
|
||||||
|
batchMaxError = 0.0;
|
||||||
|
batchMeanError = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
myNet.backProp(targetValues);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (std::exception &ex)
|
||||||
|
{
|
||||||
|
QString logString("Error: ");
|
||||||
|
logString.append(ex.what());
|
||||||
|
emit logMessage(logString);
|
||||||
|
}
|
||||||
|
}
|
17
gui/NeuroUI/netlearner.h
Normal file
17
gui/NeuroUI/netlearner.h
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
#ifndef NETLEARNER_H
|
||||||
|
#define NETLEARNER_H
|
||||||
|
|
||||||
|
#include <QThread>
|
||||||
|
|
||||||
|
class NetLearner : public QThread
|
||||||
|
{
|
||||||
|
Q_OBJECT
|
||||||
|
|
||||||
|
private:
|
||||||
|
void run() Q_DECL_OVERRIDE;
|
||||||
|
|
||||||
|
signals:
|
||||||
|
void logMessage(const QString &logMessage);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // NETLEARNER_H
|
|
@ -12,3 +12,36 @@ NeuroUI::~NeuroUI()
|
||||||
{
|
{
|
||||||
delete ui;
|
delete ui;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NeuroUI::on_runButton_clicked()
|
||||||
|
{
|
||||||
|
ui->logView->clear();
|
||||||
|
|
||||||
|
if (m_netLearner == nullptr)
|
||||||
|
{
|
||||||
|
m_netLearner.reset(new NetLearner);
|
||||||
|
}
|
||||||
|
|
||||||
|
connect(m_netLearner.get(), &NetLearner::logMessage, this, &NeuroUI::logMessage);
|
||||||
|
|
||||||
|
connect(m_netLearner.get(), &NetLearner::started, this, &NeuroUI::netLearnerStarted);
|
||||||
|
connect(m_netLearner.get(), &NetLearner::finished, this, &NeuroUI::netLearnerFinished);
|
||||||
|
|
||||||
|
m_netLearner->start();
|
||||||
|
}
|
||||||
|
|
||||||
|
void NeuroUI::logMessage(const QString &logMessage)
|
||||||
|
{
|
||||||
|
ui->logView->addItem(logMessage);
|
||||||
|
ui->logView->scrollToBottom();
|
||||||
|
}
|
||||||
|
|
||||||
|
void NeuroUI::netLearnerStarted()
|
||||||
|
{
|
||||||
|
ui->runButton->setEnabled(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NeuroUI::netLearnerFinished()
|
||||||
|
{
|
||||||
|
ui->runButton->setEnabled(true);
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,10 @@
|
||||||
|
|
||||||
#include <QMainWindow>
|
#include <QMainWindow>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "netlearner.h"
|
||||||
|
|
||||||
namespace Ui {
|
namespace Ui {
|
||||||
class NeuroUI;
|
class NeuroUI;
|
||||||
}
|
}
|
||||||
|
@ -11,10 +15,19 @@ class NeuroUI : public QMainWindow
|
||||||
{
|
{
|
||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<NetLearner> m_netLearner;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit NeuroUI(QWidget *parent = 0);
|
explicit NeuroUI(QWidget *parent = 0);
|
||||||
~NeuroUI();
|
~NeuroUI();
|
||||||
|
|
||||||
|
private slots:
|
||||||
|
void on_runButton_clicked();
|
||||||
|
void logMessage(const QString &logMessage);
|
||||||
|
void netLearnerStarted();
|
||||||
|
void netLearnerFinished();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Ui::NeuroUI *ui;
|
Ui::NeuroUI *ui;
|
||||||
};
|
};
|
||||||
|
|
|
@ -6,8 +6,8 @@
|
||||||
<rect>
|
<rect>
|
||||||
<x>0</x>
|
<x>0</x>
|
||||||
<y>0</y>
|
<y>0</y>
|
||||||
<width>400</width>
|
<width>597</width>
|
||||||
<height>300</height>
|
<height>389</height>
|
||||||
</rect>
|
</rect>
|
||||||
</property>
|
</property>
|
||||||
<property name="windowTitle">
|
<property name="windowTitle">
|
||||||
|
@ -16,22 +16,19 @@
|
||||||
<widget class="QWidget" name="centralWidget">
|
<widget class="QWidget" name="centralWidget">
|
||||||
<layout class="QVBoxLayout" name="verticalLayout_2">
|
<layout class="QVBoxLayout" name="verticalLayout_2">
|
||||||
<item>
|
<item>
|
||||||
<widget class="QListView" name="logView"/>
|
<widget class="QListWidget" name="logView"/>
|
||||||
</item>
|
</item>
|
||||||
<item>
|
<item>
|
||||||
<layout class="QHBoxLayout" name="horizontalLayout">
|
<layout class="QHBoxLayout" name="horizontalLayout">
|
||||||
<item>
|
<item>
|
||||||
<spacer name="horizontalSpacer">
|
<widget class="QProgressBar" name="progressBar">
|
||||||
<property name="orientation">
|
<property name="value">
|
||||||
<enum>Qt::Horizontal</enum>
|
<number>0</number>
|
||||||
</property>
|
</property>
|
||||||
<property name="sizeHint" stdset="0">
|
<property name="textVisible">
|
||||||
<size>
|
<bool>false</bool>
|
||||||
<width>40</width>
|
|
||||||
<height>20</height>
|
|
||||||
</size>
|
|
||||||
</property>
|
</property>
|
||||||
</spacer>
|
</widget>
|
||||||
</item>
|
</item>
|
||||||
<item>
|
<item>
|
||||||
<widget class="QPushButton" name="runButton">
|
<widget class="QPushButton" name="runButton">
|
||||||
|
|
Loading…
Reference in a new issue