Threaded learning and signaling in the Qt UI
This commit is contained in:
parent
c9fb9b9fa8
commit
6943fc0116
8 changed files with 154 additions and 16 deletions
|
@ -24,7 +24,7 @@ void Layer::setOutputValues(const std::vector<double> & outputValues)
|
|||
|
||||
void Layer::feedForward(const Layer &inputLayer)
|
||||
{
|
||||
for (int neuronNumber = 0; neuronNumber < sizeWithoutBiasNeuron(); ++neuronNumber)
|
||||
for (size_t neuronNumber = 0; neuronNumber < sizeWithoutBiasNeuron(); ++neuronNumber)
|
||||
{
|
||||
at(neuronNumber).feedForward(inputLayer.getWeightedSum(neuronNumber));
|
||||
}
|
||||
|
|
|
@ -13,8 +13,19 @@ TEMPLATE = app
|
|||
|
||||
|
||||
SOURCES += main.cpp\
|
||||
neuroui.cpp
|
||||
neuroui.cpp \
|
||||
../../Layer.cpp \
|
||||
../../Net.cpp \
|
||||
../../Neuron.cpp \
|
||||
netlearner.cpp
|
||||
|
||||
HEADERS += neuroui.h
|
||||
HEADERS += neuroui.h \
|
||||
../../Layer.h \
|
||||
../../Net.h \
|
||||
../../Neuron.h \
|
||||
netlearner.h
|
||||
|
||||
FORMS += neuroui.ui
|
||||
|
||||
RESOURCES += \
|
||||
icons.qrc
|
||||
|
|
2
gui/NeuroUI/icons.qrc
Normal file
2
gui/NeuroUI/icons.qrc
Normal file
|
@ -0,0 +1,2 @@
|
|||
<RCC/>
|
||||
|
65
gui/NeuroUI/netlearner.cpp
Normal file
65
gui/NeuroUI/netlearner.cpp
Normal file
|
@ -0,0 +1,65 @@
|
|||
#include "netlearner.h"
|
||||
#include "../../Net.h"
|
||||
|
||||
void NetLearner::run()
|
||||
{
|
||||
try
|
||||
{
|
||||
Net myNet({2, 3, 1});
|
||||
|
||||
size_t batchSize = 5000;
|
||||
size_t batchIndex = 0;
|
||||
double batchMaxError = 0.0;
|
||||
double batchMeanError = 0.0;
|
||||
|
||||
size_t numIterations = 1000000;
|
||||
for (size_t iteration = 0; iteration < numIterations; ++iteration)
|
||||
{
|
||||
std::vector<double> inputValues =
|
||||
{
|
||||
std::rand() / (double)RAND_MAX,
|
||||
std::rand() / (double)RAND_MAX
|
||||
};
|
||||
|
||||
std::vector<double> targetValues =
|
||||
{
|
||||
(inputValues[0] + inputValues[1]) / 2.0
|
||||
};
|
||||
|
||||
myNet.feedForward(inputValues);
|
||||
|
||||
std::vector<double> outputValues = myNet.getOutput();
|
||||
|
||||
double error = outputValues[0] - targetValues[0];
|
||||
|
||||
batchMeanError += error;
|
||||
batchMaxError = std::max<double>(batchMaxError, error);
|
||||
|
||||
if (batchIndex++ == batchSize)
|
||||
{
|
||||
QString logString;
|
||||
|
||||
logString.append("Batch error (");
|
||||
logString.append(QString::number(batchSize));
|
||||
logString.append(" iterations, max/mean): ");
|
||||
logString.append(QString::number(std::abs(batchMaxError)));
|
||||
logString.append(" / ");
|
||||
logString.append(QString::number(std::abs(batchMeanError / batchSize)));
|
||||
|
||||
emit logMessage(logString);
|
||||
|
||||
batchIndex = 0;
|
||||
batchMaxError = 0.0;
|
||||
batchMeanError = 0.0;
|
||||
}
|
||||
|
||||
myNet.backProp(targetValues);
|
||||
}
|
||||
}
|
||||
catch (std::exception &ex)
|
||||
{
|
||||
QString logString("Error: ");
|
||||
logString.append(ex.what());
|
||||
emit logMessage(logString);
|
||||
}
|
||||
}
|
17
gui/NeuroUI/netlearner.h
Normal file
17
gui/NeuroUI/netlearner.h
Normal file
|
@ -0,0 +1,17 @@
|
|||
#ifndef NETLEARNER_H
|
||||
#define NETLEARNER_H
|
||||
|
||||
#include <QThread>
|
||||
|
||||
class NetLearner : public QThread
|
||||
{
|
||||
Q_OBJECT
|
||||
|
||||
private:
|
||||
void run() Q_DECL_OVERRIDE;
|
||||
|
||||
signals:
|
||||
void logMessage(const QString &logMessage);
|
||||
};
|
||||
|
||||
#endif // NETLEARNER_H
|
|
@ -12,3 +12,36 @@ NeuroUI::~NeuroUI()
|
|||
{
|
||||
delete ui;
|
||||
}
|
||||
|
||||
void NeuroUI::on_runButton_clicked()
|
||||
{
|
||||
ui->logView->clear();
|
||||
|
||||
if (m_netLearner == nullptr)
|
||||
{
|
||||
m_netLearner.reset(new NetLearner);
|
||||
}
|
||||
|
||||
connect(m_netLearner.get(), &NetLearner::logMessage, this, &NeuroUI::logMessage);
|
||||
|
||||
connect(m_netLearner.get(), &NetLearner::started, this, &NeuroUI::netLearnerStarted);
|
||||
connect(m_netLearner.get(), &NetLearner::finished, this, &NeuroUI::netLearnerFinished);
|
||||
|
||||
m_netLearner->start();
|
||||
}
|
||||
|
||||
void NeuroUI::logMessage(const QString &logMessage)
|
||||
{
|
||||
ui->logView->addItem(logMessage);
|
||||
ui->logView->scrollToBottom();
|
||||
}
|
||||
|
||||
void NeuroUI::netLearnerStarted()
|
||||
{
|
||||
ui->runButton->setEnabled(false);
|
||||
}
|
||||
|
||||
void NeuroUI::netLearnerFinished()
|
||||
{
|
||||
ui->runButton->setEnabled(true);
|
||||
}
|
||||
|
|
|
@ -3,6 +3,10 @@
|
|||
|
||||
#include <QMainWindow>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "netlearner.h"
|
||||
|
||||
namespace Ui {
|
||||
class NeuroUI;
|
||||
}
|
||||
|
@ -11,10 +15,19 @@ class NeuroUI : public QMainWindow
|
|||
{
|
||||
Q_OBJECT
|
||||
|
||||
private:
|
||||
std::unique_ptr<NetLearner> m_netLearner;
|
||||
|
||||
public:
|
||||
explicit NeuroUI(QWidget *parent = 0);
|
||||
~NeuroUI();
|
||||
|
||||
private slots:
|
||||
void on_runButton_clicked();
|
||||
void logMessage(const QString &logMessage);
|
||||
void netLearnerStarted();
|
||||
void netLearnerFinished();
|
||||
|
||||
private:
|
||||
Ui::NeuroUI *ui;
|
||||
};
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
<rect>
|
||||
<x>0</x>
|
||||
<y>0</y>
|
||||
<width>400</width>
|
||||
<height>300</height>
|
||||
<width>597</width>
|
||||
<height>389</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="windowTitle">
|
||||
|
@ -16,22 +16,19 @@
|
|||
<widget class="QWidget" name="centralWidget">
|
||||
<layout class="QVBoxLayout" name="verticalLayout_2">
|
||||
<item>
|
||||
<widget class="QListView" name="logView"/>
|
||||
<widget class="QListWidget" name="logView"/>
|
||||
</item>
|
||||
<item>
|
||||
<layout class="QHBoxLayout" name="horizontalLayout">
|
||||
<item>
|
||||
<spacer name="horizontalSpacer">
|
||||
<property name="orientation">
|
||||
<enum>Qt::Horizontal</enum>
|
||||
<widget class="QProgressBar" name="progressBar">
|
||||
<property name="value">
|
||||
<number>0</number>
|
||||
</property>
|
||||
<property name="sizeHint" stdset="0">
|
||||
<size>
|
||||
<width>40</width>
|
||||
<height>20</height>
|
||||
</size>
|
||||
<property name="textVisible">
|
||||
<bool>false</bool>
|
||||
</property>
|
||||
</spacer>
|
||||
</widget>
|
||||
</item>
|
||||
<item>
|
||||
<widget class="QPushButton" name="runButton">
|
||||
|
|
Loading…
Reference in a new issue