commit
1a0d2b9ea7
16 changed files with 225 additions and 55 deletions
|
@ -54,7 +54,7 @@ void Layer::connectTo(const Layer & nextLayer)
|
||||||
|
|
||||||
void Layer::updateInputWeights(Layer & prevLayer)
|
void Layer::updateInputWeights(Layer & prevLayer)
|
||||||
{
|
{
|
||||||
static const double trainingRate = 0.3;
|
static const double trainingRate = 0.2;
|
||||||
|
|
||||||
for (size_t targetLayerIndex = 0; targetLayerIndex < sizeWithoutBiasNeuron(); ++targetLayerIndex)
|
for (size_t targetLayerIndex = 0; targetLayerIndex < sizeWithoutBiasNeuron(); ++targetLayerIndex)
|
||||||
{
|
{
|
||||||
|
|
1
Layer.h
1
Layer.h
|
@ -13,6 +13,7 @@ public:
|
||||||
Layer(size_t numNeurons);
|
Layer(size_t numNeurons);
|
||||||
|
|
||||||
void setOutputValues(const std::vector<double> & outputValues);
|
void setOutputValues(const std::vector<double> & outputValues);
|
||||||
|
|
||||||
void feedForward(const Layer &inputLayer);
|
void feedForward(const Layer &inputLayer);
|
||||||
double getWeightedSum(size_t outputNeuron) const;
|
double getWeightedSum(size_t outputNeuron) const;
|
||||||
void connectTo(const Layer &nextLayer);
|
void connectTo(const Layer &nextLayer);
|
||||||
|
|
BIN
gui/NeuroUI/MNIST Database/t10k-images.idx3-ubyte
Normal file
BIN
gui/NeuroUI/MNIST Database/t10k-images.idx3-ubyte
Normal file
Binary file not shown.
BIN
gui/NeuroUI/MNIST Database/t10k-labels.idx1-ubyte
Normal file
BIN
gui/NeuroUI/MNIST Database/t10k-labels.idx1-ubyte
Normal file
Binary file not shown.
BIN
gui/NeuroUI/MNIST Database/train-images.idx3-ubyte
Normal file
BIN
gui/NeuroUI/MNIST Database/train-images.idx3-ubyte
Normal file
Binary file not shown.
BIN
gui/NeuroUI/MNIST Database/train-labels.idx1-ubyte
Normal file
BIN
gui/NeuroUI/MNIST Database/train-labels.idx1-ubyte
Normal file
Binary file not shown.
|
@ -18,14 +18,16 @@ SOURCES += main.cpp\
|
||||||
../../Net.cpp \
|
../../Net.cpp \
|
||||||
../../Neuron.cpp \
|
../../Neuron.cpp \
|
||||||
netlearner.cpp \
|
netlearner.cpp \
|
||||||
errorplotter.cpp
|
errorplotter.cpp \
|
||||||
|
mnistloader.cpp
|
||||||
|
|
||||||
HEADERS += neuroui.h \
|
HEADERS += neuroui.h \
|
||||||
../../Layer.h \
|
../../Layer.h \
|
||||||
../../Net.h \
|
../../Net.h \
|
||||||
../../Neuron.h \
|
../../Neuron.h \
|
||||||
netlearner.h \
|
netlearner.h \
|
||||||
errorplotter.h
|
errorplotter.h \
|
||||||
|
mnistloader.h
|
||||||
|
|
||||||
FORMS += neuroui.ui
|
FORMS += neuroui.ui
|
||||||
|
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 34 KiB |
97
gui/NeuroUI/mnistloader.cpp
Normal file
97
gui/NeuroUI/mnistloader.cpp
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
#include "mnistloader.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
void MnistLoader::load(const std::string &databaseFileName, const std::string &labelsFileName)
|
||||||
|
{
|
||||||
|
loadDatabase(databaseFileName);
|
||||||
|
loadLabels(labelsFileName);
|
||||||
|
}
|
||||||
|
|
||||||
|
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
|
||||||
|
{
|
||||||
|
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;
|
||||||
|
|
||||||
|
return *(samples[sampleIndex].get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void MnistLoader::loadDatabase(const std::string &fileName)
|
||||||
|
{
|
||||||
|
std::ifstream databaseFile;
|
||||||
|
databaseFile.open(fileName, std::ios::binary);
|
||||||
|
|
||||||
|
if (!databaseFile.is_open())
|
||||||
|
{
|
||||||
|
throw std::runtime_error("unable to open MNIST database file");
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t magicNumber = readInt32(databaseFile);
|
||||||
|
if (magicNumber != DatabaseFileMagicNumber)
|
||||||
|
{
|
||||||
|
throw std::runtime_error("unexpected data reading MNIST database file");
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t sampleCount = readInt32(databaseFile);
|
||||||
|
int32_t sampleWidth = readInt32(databaseFile);
|
||||||
|
int32_t sampleHeight = readInt32(databaseFile);
|
||||||
|
|
||||||
|
if (sampleWidth != SampleWidth || sampleHeight != SampleHeight)
|
||||||
|
{
|
||||||
|
throw std::runtime_error("unexpected sample size loading MNIST database");
|
||||||
|
}
|
||||||
|
|
||||||
|
samples.reserve(samples.size() + sampleCount);
|
||||||
|
|
||||||
|
for (int32_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex)
|
||||||
|
{
|
||||||
|
std::unique_ptr<MnistSample> sample = std::make_unique<MnistSample>();
|
||||||
|
|
||||||
|
databaseFile.read(reinterpret_cast<char *>(sample->data), sampleWidth * sampleHeight);
|
||||||
|
|
||||||
|
samples.push_back(std::move(sample));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MnistLoader::loadLabels(const std::string &fileName)
|
||||||
|
{
|
||||||
|
std::ifstream labelFile;
|
||||||
|
labelFile.open(fileName, std::ios::binary);
|
||||||
|
|
||||||
|
if (!labelFile.is_open())
|
||||||
|
{
|
||||||
|
throw std::runtime_error("unable to open MNIST label file");
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t magicNumber = readInt32(labelFile);
|
||||||
|
if (magicNumber != LabelFileMagicNumber)
|
||||||
|
{
|
||||||
|
throw std::runtime_error("unexpected data reading MNIST label file");
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t labelCount = readInt32(labelFile);
|
||||||
|
if (labelCount != static_cast<int32_t>(samples.size()))
|
||||||
|
{
|
||||||
|
throw std::runtime_error("MNIST database and label files don't match in size");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sampleIt = samples.begin();
|
||||||
|
for (int32_t labelIndex = 0; labelIndex < labelCount; ++labelIndex)
|
||||||
|
{
|
||||||
|
(*sampleIt++)->label = readInt8(labelFile);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int8_t MnistLoader::readInt8(std::ifstream &file)
|
||||||
|
{
|
||||||
|
int8_t buf8;
|
||||||
|
file.read(reinterpret_cast<char *>(&buf8), sizeof(buf8));
|
||||||
|
return buf8;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t MnistLoader::readInt32(std::ifstream &file)
|
||||||
|
{
|
||||||
|
int32_t buf32;
|
||||||
|
file.read(reinterpret_cast<char *>(&buf32), sizeof(buf32));
|
||||||
|
return _byteswap_ulong(buf32);
|
||||||
|
}
|
||||||
|
|
44
gui/NeuroUI/mnistloader.h
Normal file
44
gui/NeuroUI/mnistloader.h
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
#ifndef MNISTLOADER_H
|
||||||
|
#define MNISTLOADER_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <inttypes.h>
|
||||||
|
|
||||||
|
class MnistLoader
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
static const uint32_t DatabaseFileMagicNumber = 2051;
|
||||||
|
static const uint32_t LabelFileMagicNumber = 2049;
|
||||||
|
static const size_t SampleWidth = 28;
|
||||||
|
static const size_t SampleHeight = 28;
|
||||||
|
|
||||||
|
public:
|
||||||
|
template<size_t SAMPLE_WIDTH, size_t SAMPLE_HEIGHT>
|
||||||
|
class Sample
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
uint8_t label;
|
||||||
|
uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT];
|
||||||
|
};
|
||||||
|
|
||||||
|
using MnistSample = Sample<SampleWidth, SampleHeight>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<std::unique_ptr<MnistSample>> samples;
|
||||||
|
|
||||||
|
public:
|
||||||
|
void load(const std::string &databaseFileName, const std::string &labelsFileName);
|
||||||
|
|
||||||
|
const MnistSample &getRandomSample() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void loadDatabase(const std::string &fileName);
|
||||||
|
void loadLabels(const std::string &fileName);
|
||||||
|
|
||||||
|
static int8_t readInt8(std::ifstream &file);
|
||||||
|
static int32_t readInt32(std::ifstream &file);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // MNISTLOADER_H
|
|
@ -1,7 +1,9 @@
|
||||||
#include "netlearner.h"
|
#include "netlearner.h"
|
||||||
#include "../../Net.h"
|
#include "../../Net.h"
|
||||||
|
#include "mnistloader.h"
|
||||||
|
|
||||||
#include <QElapsedTimer>
|
#include <QElapsedTimer>
|
||||||
|
#include <QImage>
|
||||||
|
|
||||||
void NetLearner::run()
|
void NetLearner::run()
|
||||||
{
|
{
|
||||||
|
@ -9,67 +11,54 @@ void NetLearner::run()
|
||||||
{
|
{
|
||||||
QElapsedTimer timer;
|
QElapsedTimer timer;
|
||||||
|
|
||||||
Net myNet;
|
emit logMessage("Loading training data...");
|
||||||
try
|
|
||||||
{
|
|
||||||
myNet.load("mynet.nnet");
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
myNet.initialize({2, 3, 1});
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t batchSize = 5000;
|
MnistLoader mnistLoader;
|
||||||
size_t batchIndex = 0;
|
mnistLoader.load("../NeuroUI/MNIST Database/train-images.idx3-ubyte",
|
||||||
double batchMaxError = 0.0;
|
"../NeuroUI/MNIST Database/train-labels.idx1-ubyte");
|
||||||
double batchMeanError = 0.0;
|
|
||||||
|
emit logMessage("done");
|
||||||
|
|
||||||
|
Net digitClassifier({28*28, 256, 1});
|
||||||
|
|
||||||
timer.start();
|
timer.start();
|
||||||
|
|
||||||
size_t numIterations = 1000000;
|
size_t numIterations = 100000;
|
||||||
for (size_t iteration = 0; iteration < numIterations; ++iteration)
|
for (size_t iteration = 0; iteration < numIterations; ++iteration)
|
||||||
{
|
{
|
||||||
std::vector<double> inputValues =
|
auto trainingSample = mnistLoader.getRandomSample();
|
||||||
{
|
|
||||||
std::rand() / (double)RAND_MAX,
|
QImage trainingImage(trainingSample.data, 28, 28, QImage::Format_Grayscale8);
|
||||||
std::rand() / (double)RAND_MAX
|
emit sampleImageLoaded(trainingImage);
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<double> targetValues =
|
std::vector<double> targetValues =
|
||||||
{
|
{
|
||||||
(inputValues[0] + inputValues[1]) / 2.0
|
trainingSample.label / 10.0
|
||||||
};
|
};
|
||||||
|
|
||||||
myNet.feedForward(inputValues);
|
std::vector<double> trainingData;
|
||||||
|
trainingData.reserve(28*28);
|
||||||
|
for (const uint8_t &val : trainingSample.data)
|
||||||
|
{
|
||||||
|
trainingData.push_back(val / 255.0);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<double> outputValues = myNet.getOutput();
|
digitClassifier.feedForward(trainingData);
|
||||||
|
|
||||||
|
std::vector<double> outputValues = digitClassifier.getOutput();
|
||||||
|
|
||||||
double error = outputValues[0] - targetValues[0];
|
double error = outputValues[0] - targetValues[0];
|
||||||
|
|
||||||
batchMeanError += error;
|
|
||||||
batchMaxError = std::max<double>(batchMaxError, error);
|
|
||||||
|
|
||||||
if (batchIndex++ == batchSize)
|
|
||||||
{
|
|
||||||
QString logString;
|
QString logString;
|
||||||
|
|
||||||
logString.append("Batch error (");
|
logString.append("Error: ");
|
||||||
logString.append(QString::number(batchSize));
|
logString.append(QString::number(std::abs(error)));
|
||||||
logString.append(" iterations, max/mean): ");
|
|
||||||
logString.append(QString::number(std::abs(batchMaxError)));
|
|
||||||
logString.append(" / ");
|
|
||||||
logString.append(QString::number(std::abs(batchMeanError / batchSize)));
|
|
||||||
|
|
||||||
emit logMessage(logString);
|
emit logMessage(logString);
|
||||||
emit currentNetError(batchMaxError);
|
emit currentNetError(error);
|
||||||
emit progress((double)iteration / (double)numIterations);
|
emit progress((double)iteration / (double)numIterations);
|
||||||
|
|
||||||
batchIndex = 0;
|
digitClassifier.backProp(targetValues);
|
||||||
batchMaxError = 0.0;
|
|
||||||
batchMeanError = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
myNet.backProp(targetValues);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
QString timerLogString;
|
QString timerLogString;
|
||||||
|
@ -79,7 +68,7 @@ void NetLearner::run()
|
||||||
|
|
||||||
emit logMessage(timerLogString);
|
emit logMessage(timerLogString);
|
||||||
|
|
||||||
myNet.save("mynet.nnet");
|
digitClassifier.save("DigitClassifier.nnet");
|
||||||
}
|
}
|
||||||
catch (std::exception &ex)
|
catch (std::exception &ex)
|
||||||
{
|
{
|
||||||
|
|
|
@ -14,6 +14,7 @@ signals:
|
||||||
void logMessage(const QString &logMessage);
|
void logMessage(const QString &logMessage);
|
||||||
void progress(double progress);
|
void progress(double progress);
|
||||||
void currentNetError(double error);
|
void currentNetError(double error);
|
||||||
|
void sampleImageLoaded(const QImage &image);
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // NETLEARNER_H
|
#endif // NETLEARNER_H
|
||||||
|
|
|
@ -31,6 +31,8 @@ void NeuroUI::on_runButton_clicked()
|
||||||
connect(m_netLearner.get(), &NetLearner::finished, this, &NeuroUI::netLearnerFinished);
|
connect(m_netLearner.get(), &NetLearner::finished, this, &NeuroUI::netLearnerFinished);
|
||||||
|
|
||||||
connect(m_netLearner.get(), &NetLearner::currentNetError, ui->errorPlotter, &ErrorPlotter::addErrorValue);
|
connect(m_netLearner.get(), &NetLearner::currentNetError, ui->errorPlotter, &ErrorPlotter::addErrorValue);
|
||||||
|
|
||||||
|
connect(m_netLearner.get(), &NetLearner::sampleImageLoaded, this, &NeuroUI::setImage);
|
||||||
}
|
}
|
||||||
|
|
||||||
m_netLearner->start();
|
m_netLearner->start();
|
||||||
|
@ -61,3 +63,10 @@ void NeuroUI::progress(double progress)
|
||||||
|
|
||||||
ui->progressBar->setValue(value);
|
ui->progressBar->setValue(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NeuroUI::setImage(const QImage &image)
|
||||||
|
{
|
||||||
|
QPixmap pixmap;
|
||||||
|
pixmap.convertFromImage(image);
|
||||||
|
ui->label->setPixmap(pixmap);
|
||||||
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ private slots:
|
||||||
void netLearnerStarted();
|
void netLearnerStarted();
|
||||||
void netLearnerFinished();
|
void netLearnerFinished();
|
||||||
void progress(double progress);
|
void progress(double progress);
|
||||||
|
void setImage(const QImage &image);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Ui::NeuroUI *ui;
|
Ui::NeuroUI *ui;
|
||||||
|
|
|
@ -19,6 +19,8 @@
|
||||||
</property>
|
</property>
|
||||||
<widget class="QWidget" name="centralWidget">
|
<widget class="QWidget" name="centralWidget">
|
||||||
<layout class="QVBoxLayout" name="verticalLayout_2">
|
<layout class="QVBoxLayout" name="verticalLayout_2">
|
||||||
|
<item>
|
||||||
|
<layout class="QHBoxLayout" name="horizontalLayout_2">
|
||||||
<item>
|
<item>
|
||||||
<widget class="QListWidget" name="logView">
|
<widget class="QListWidget" name="logView">
|
||||||
<property name="uniformItemSizes">
|
<property name="uniformItemSizes">
|
||||||
|
@ -26,6 +28,30 @@
|
||||||
</property>
|
</property>
|
||||||
</widget>
|
</widget>
|
||||||
</item>
|
</item>
|
||||||
|
<item>
|
||||||
|
<widget class="QLabel" name="label">
|
||||||
|
<property name="sizePolicy">
|
||||||
|
<sizepolicy hsizetype="Fixed" vsizetype="Preferred">
|
||||||
|
<horstretch>0</horstretch>
|
||||||
|
<verstretch>0</verstretch>
|
||||||
|
</sizepolicy>
|
||||||
|
</property>
|
||||||
|
<property name="minimumSize">
|
||||||
|
<size>
|
||||||
|
<width>128</width>
|
||||||
|
<height>0</height>
|
||||||
|
</size>
|
||||||
|
</property>
|
||||||
|
<property name="text">
|
||||||
|
<string/>
|
||||||
|
</property>
|
||||||
|
<property name="alignment">
|
||||||
|
<set>Qt::AlignCenter</set>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
|
</item>
|
||||||
|
</layout>
|
||||||
|
</item>
|
||||||
<item>
|
<item>
|
||||||
<widget class="ErrorPlotter" name="errorPlotter" native="true">
|
<widget class="ErrorPlotter" name="errorPlotter" native="true">
|
||||||
<property name="sizePolicy">
|
<property name="sizePolicy">
|
||||||
|
|
Loading…
Reference in a new issue