Threaded learning and signaling in the Qt UI

main
mandlm 2015-10-24 18:03:07 +02:00
parent c9fb9b9fa8
commit 6943fc0116
8 changed files with 154 additions and 16 deletions

View File

@ -24,7 +24,7 @@ void Layer::setOutputValues(const std::vector<double> & outputValues)
void Layer::feedForward(const Layer &inputLayer)
{
for (int neuronNumber = 0; neuronNumber < sizeWithoutBiasNeuron(); ++neuronNumber)
for (size_t neuronNumber = 0; neuronNumber < sizeWithoutBiasNeuron(); ++neuronNumber)
{
at(neuronNumber).feedForward(inputLayer.getWeightedSum(neuronNumber));
}

View File

@ -13,8 +13,19 @@ TEMPLATE = app
SOURCES += main.cpp\
neuroui.cpp
neuroui.cpp \
../../Layer.cpp \
../../Net.cpp \
../../Neuron.cpp \
netlearner.cpp
HEADERS += neuroui.h
HEADERS += neuroui.h \
../../Layer.h \
../../Net.h \
../../Neuron.h \
netlearner.h
FORMS += neuroui.ui
RESOURCES += \
icons.qrc

2
gui/NeuroUI/icons.qrc Normal file
View File

@ -0,0 +1,2 @@
<RCC/>

View File

@ -0,0 +1,65 @@
#include "netlearner.h"
#include "../../Net.h"
void NetLearner::run()
{
try
{
Net myNet({2, 3, 1});
size_t batchSize = 5000;
size_t batchIndex = 0;
double batchMaxError = 0.0;
double batchMeanError = 0.0;
size_t numIterations = 1000000;
for (size_t iteration = 0; iteration < numIterations; ++iteration)
{
std::vector<double> inputValues =
{
std::rand() / (double)RAND_MAX,
std::rand() / (double)RAND_MAX
};
std::vector<double> targetValues =
{
(inputValues[0] + inputValues[1]) / 2.0
};
myNet.feedForward(inputValues);
std::vector<double> outputValues = myNet.getOutput();
double error = outputValues[0] - targetValues[0];
batchMeanError += error;
batchMaxError = std::max<double>(batchMaxError, error);
if (batchIndex++ == batchSize)
{
QString logString;
logString.append("Batch error (");
logString.append(QString::number(batchSize));
logString.append(" iterations, max/mean): ");
logString.append(QString::number(std::abs(batchMaxError)));
logString.append(" / ");
logString.append(QString::number(std::abs(batchMeanError / batchSize)));
emit logMessage(logString);
batchIndex = 0;
batchMaxError = 0.0;
batchMeanError = 0.0;
}
myNet.backProp(targetValues);
}
}
catch (std::exception &ex)
{
QString logString("Error: ");
logString.append(ex.what());
emit logMessage(logString);
}
}

17
gui/NeuroUI/netlearner.h Normal file
View File

@ -0,0 +1,17 @@
#ifndef NETLEARNER_H
#define NETLEARNER_H
#include <QThread>
class NetLearner : public QThread
{
Q_OBJECT
private:
void run() Q_DECL_OVERRIDE;
signals:
void logMessage(const QString &logMessage);
};
#endif // NETLEARNER_H

View File

@ -12,3 +12,36 @@ NeuroUI::~NeuroUI()
{
delete ui;
}
void NeuroUI::on_runButton_clicked()
{
ui->logView->clear();
if (m_netLearner == nullptr)
{
m_netLearner.reset(new NetLearner);
}
connect(m_netLearner.get(), &NetLearner::logMessage, this, &NeuroUI::logMessage);
connect(m_netLearner.get(), &NetLearner::started, this, &NeuroUI::netLearnerStarted);
connect(m_netLearner.get(), &NetLearner::finished, this, &NeuroUI::netLearnerFinished);
m_netLearner->start();
}
void NeuroUI::logMessage(const QString &logMessage)
{
ui->logView->addItem(logMessage);
ui->logView->scrollToBottom();
}
void NeuroUI::netLearnerStarted()
{
ui->runButton->setEnabled(false);
}
void NeuroUI::netLearnerFinished()
{
ui->runButton->setEnabled(true);
}

View File

@ -3,18 +3,31 @@
#include <QMainWindow>
#include <memory>
#include "netlearner.h"
namespace Ui {
class NeuroUI;
class NeuroUI;
}
class NeuroUI : public QMainWindow
{
Q_OBJECT
private:
std::unique_ptr<NetLearner> m_netLearner;
public:
explicit NeuroUI(QWidget *parent = 0);
~NeuroUI();
private slots:
void on_runButton_clicked();
void logMessage(const QString &logMessage);
void netLearnerStarted();
void netLearnerFinished();
private:
Ui::NeuroUI *ui;
};

View File

@ -6,8 +6,8 @@
<rect>
<x>0</x>
<y>0</y>
<width>400</width>
<height>300</height>
<width>597</width>
<height>389</height>
</rect>
</property>
<property name="windowTitle">
@ -16,22 +16,19 @@
<widget class="QWidget" name="centralWidget">
<layout class="QVBoxLayout" name="verticalLayout_2">
<item>
<widget class="QListView" name="logView"/>
<widget class="QListWidget" name="logView"/>
</item>
<item>
<layout class="QHBoxLayout" name="horizontalLayout">
<item>
<spacer name="horizontalSpacer">
<property name="orientation">
<enum>Qt::Horizontal</enum>
<widget class="QProgressBar" name="progressBar">
<property name="value">
<number>0</number>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>40</width>
<height>20</height>
</size>
<property name="textVisible">
<bool>false</bool>
</property>
</spacer>
</widget>
</item>
<item>
<widget class="QPushButton" name="runButton">