Added a error-plotter widget to NeuroUI

This commit is contained in:
Michael Mandl 2015-10-26 08:41:59 +01:00
parent b899c6f55e
commit 24b6969ab7
7 changed files with 128 additions and 3 deletions

View file

@ -17,13 +17,15 @@ SOURCES += main.cpp\
../../Layer.cpp \ ../../Layer.cpp \
../../Net.cpp \ ../../Net.cpp \
../../Neuron.cpp \ ../../Neuron.cpp \
netlearner.cpp netlearner.cpp \
errorplotter.cpp
HEADERS += neuroui.h \ HEADERS += neuroui.h \
../../Layer.h \ ../../Layer.h \
../../Net.h \ ../../Net.h \
../../Neuron.h \ ../../Neuron.h \
netlearner.h netlearner.h \
errorplotter.h
FORMS += neuroui.ui FORMS += neuroui.ui

View file

@ -0,0 +1,69 @@
#include "errorplotter.h"
#include <QPainter>
ErrorPlotter::ErrorPlotter(QWidget *parent)
: QWidget(parent)
, m_maxErrorValue(0.0)
{
}
QSize ErrorPlotter::minimumSizeHint() const
{
return QSize(100, 100);
}
QSize ErrorPlotter::sizeHint() const
{
return QSize(400, 200);
}
void ErrorPlotter::clear()
{
m_errorValues.clear();
m_maxErrorValue = 0.0;
update();
}
void ErrorPlotter::addErrorValue(double errorValue)
{
m_errorValues.push_back(errorValue);
m_maxErrorValue = std::max<double>(m_maxErrorValue, errorValue);
update();
}
void ErrorPlotter::paintEvent(QPaintEvent *)
{
if (m_errorValues.empty() || m_maxErrorValue == 0)
{
return;
}
QPainter painter(this);
painter.setRenderHint(QPainter::Antialiasing, true);
painter.translate(0.0, height());
painter.scale(1.0, -1.0);
double errorValueScale = height() / m_maxErrorValue;
auto errorValueIt = m_errorValues.crbegin();
double prevErrorValue = *errorValueIt;
errorValueIt++;
int xPos = width() - 2;
while(errorValueIt != m_errorValues.crend() && xPos >= 0)
{
double currentErrorValue = *errorValueIt;
painter.drawLine(xPos + 1, prevErrorValue * errorValueScale,
xPos, currentErrorValue * errorValueScale);
prevErrorValue = currentErrorValue;
errorValueIt++;
xPos--;
}
}

View file

@ -0,0 +1,31 @@
#ifndef ERRORPLOTTER_H
#define ERRORPLOTTER_H
#include <QWidget>
class ErrorPlotter : public QWidget
{
Q_OBJECT
private:
std::list<double> m_errorValues;
double m_maxErrorValue;
public:
explicit ErrorPlotter(QWidget *parent = 0);
QSize minimumSizeHint() const Q_DECL_OVERRIDE;
QSize sizeHint() const Q_DECL_OVERRIDE;
void clear();
protected:
void paintEvent(QPaintEvent *) Q_DECL_OVERRIDE;
signals:
public slots:
void addErrorValue(double errorValue);
};
#endif // ERRORPLOTTER_H

View file

@ -12,7 +12,7 @@ void NetLearner::run()
double batchMaxError = 0.0; double batchMaxError = 0.0;
double batchMeanError = 0.0; double batchMeanError = 0.0;
size_t numIterations = 100000; size_t numIterations = 1000000;
for (size_t iteration = 0; iteration < numIterations; ++iteration) for (size_t iteration = 0; iteration < numIterations; ++iteration)
{ {
std::vector<double> inputValues = std::vector<double> inputValues =
@ -47,6 +47,7 @@ void NetLearner::run()
logString.append(QString::number(std::abs(batchMeanError / batchSize))); logString.append(QString::number(std::abs(batchMeanError / batchSize)));
emit logMessage(logString); emit logMessage(logString);
emit currentNetError(batchMaxError);
batchIndex = 0; batchIndex = 0;
batchMaxError = 0.0; batchMaxError = 0.0;

View file

@ -13,6 +13,7 @@ private:
signals: signals:
void logMessage(const QString &logMessage); void logMessage(const QString &logMessage);
void progress(double progress); void progress(double progress);
void currentNetError(double error);
}; };
#endif // NETLEARNER_H #endif // NETLEARNER_H

View file

@ -18,6 +18,7 @@ NeuroUI::~NeuroUI()
void NeuroUI::on_runButton_clicked() void NeuroUI::on_runButton_clicked()
{ {
ui->logView->clear(); ui->logView->clear();
ui->errorPlotter->clear();
if (m_netLearner == nullptr) if (m_netLearner == nullptr)
{ {
@ -30,6 +31,8 @@ void NeuroUI::on_runButton_clicked()
connect(m_netLearner.get(), &NetLearner::started, this, &NeuroUI::netLearnerStarted); connect(m_netLearner.get(), &NetLearner::started, this, &NeuroUI::netLearnerStarted);
connect(m_netLearner.get(), &NetLearner::finished, this, &NeuroUI::netLearnerFinished); connect(m_netLearner.get(), &NetLearner::finished, this, &NeuroUI::netLearnerFinished);
connect(m_netLearner.get(), &NetLearner::currentNetError, ui->errorPlotter, &ErrorPlotter::addErrorValue);
m_netLearner->start(); m_netLearner->start();
} }

View file

@ -22,6 +22,16 @@
<item> <item>
<widget class="QListWidget" name="logView"/> <widget class="QListWidget" name="logView"/>
</item> </item>
<item>
<widget class="ErrorPlotter" name="errorPlotter" native="true">
<property name="sizePolicy">
<sizepolicy hsizetype="Expanding" vsizetype="Expanding">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
</widget>
</item>
<item> <item>
<layout class="QHBoxLayout" name="horizontalLayout"> <layout class="QHBoxLayout" name="horizontalLayout">
<item> <item>
@ -50,6 +60,14 @@
</widget> </widget>
</widget> </widget>
<layoutdefault spacing="6" margin="11"/> <layoutdefault spacing="6" margin="11"/>
<customwidgets>
<customwidget>
<class>ErrorPlotter</class>
<extends>QWidget</extends>
<header>errorplotter.h</header>
<container>1</container>
</customwidget>
</customwidgets>
<resources> <resources>
<include location="icons.qrc"/> <include location="icons.qrc"/>
</resources> </resources>