commit
1a0d2b9ea7
16 changed files with 225 additions and 55 deletions
|
@ -21,7 +21,7 @@ void Layer::setOutputValues(const std::vector<double> & outputValues)
|
|||
for (const double &value : outputValues)
|
||||
{
|
||||
(neuronIt++)->setOutputValue(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Layer::feedForward(const Layer &inputLayer)
|
||||
|
@ -54,7 +54,7 @@ void Layer::connectTo(const Layer & nextLayer)
|
|||
|
||||
void Layer::updateInputWeights(Layer & prevLayer)
|
||||
{
|
||||
static const double trainingRate = 0.3;
|
||||
static const double trainingRate = 0.2;
|
||||
|
||||
for (size_t targetLayerIndex = 0; targetLayerIndex < sizeWithoutBiasNeuron(); ++targetLayerIndex)
|
||||
{
|
||||
|
|
3
Layer.h
3
Layer.h
|
@ -13,9 +13,10 @@ public:
|
|||
Layer(size_t numNeurons);
|
||||
|
||||
void setOutputValues(const std::vector<double> & outputValues);
|
||||
|
||||
void feedForward(const Layer &inputLayer);
|
||||
double getWeightedSum(size_t outputNeuron) const;
|
||||
void connectTo(const Layer & nextLayer);
|
||||
void connectTo(const Layer &nextLayer);
|
||||
|
||||
void updateInputWeights(Layer &prevLayer);
|
||||
|
||||
|
|
2
Net.cpp
2
Net.cpp
|
@ -63,7 +63,7 @@ void Net::feedForward(const std::vector<double> &inputValues)
|
|||
Layer &nextLayer = *(layerIt + 1);
|
||||
|
||||
nextLayer.feedForward(currentLayer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<double> Net::getOutput()
|
||||
|
|
BIN
gui/NeuroUI/MNIST Database/t10k-images.idx3-ubyte
Normal file
BIN
gui/NeuroUI/MNIST Database/t10k-images.idx3-ubyte
Normal file
Binary file not shown.
BIN
gui/NeuroUI/MNIST Database/t10k-labels.idx1-ubyte
Normal file
BIN
gui/NeuroUI/MNIST Database/t10k-labels.idx1-ubyte
Normal file
Binary file not shown.
BIN
gui/NeuroUI/MNIST Database/train-images.idx3-ubyte
Normal file
BIN
gui/NeuroUI/MNIST Database/train-images.idx3-ubyte
Normal file
Binary file not shown.
BIN
gui/NeuroUI/MNIST Database/train-labels.idx1-ubyte
Normal file
BIN
gui/NeuroUI/MNIST Database/train-labels.idx1-ubyte
Normal file
Binary file not shown.
|
@ -18,14 +18,16 @@ SOURCES += main.cpp\
|
|||
../../Net.cpp \
|
||||
../../Neuron.cpp \
|
||||
netlearner.cpp \
|
||||
errorplotter.cpp
|
||||
errorplotter.cpp \
|
||||
mnistloader.cpp
|
||||
|
||||
HEADERS += neuroui.h \
|
||||
../../Layer.h \
|
||||
../../Net.h \
|
||||
../../Neuron.h \
|
||||
netlearner.h \
|
||||
errorplotter.h
|
||||
errorplotter.h \
|
||||
mnistloader.h
|
||||
|
||||
FORMS += neuroui.ui
|
||||
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 34 KiB |
97
gui/NeuroUI/mnistloader.cpp
Normal file
97
gui/NeuroUI/mnistloader.cpp
Normal file
|
@ -0,0 +1,97 @@
|
|||
#include "mnistloader.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
void MnistLoader::load(const std::string &databaseFileName, const std::string &labelsFileName)
|
||||
{
|
||||
loadDatabase(databaseFileName);
|
||||
loadLabels(labelsFileName);
|
||||
}
|
||||
|
||||
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
|
||||
{
|
||||
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;
|
||||
|
||||
return *(samples[sampleIndex].get());
|
||||
}
|
||||
|
||||
void MnistLoader::loadDatabase(const std::string &fileName)
|
||||
{
|
||||
std::ifstream databaseFile;
|
||||
databaseFile.open(fileName, std::ios::binary);
|
||||
|
||||
if (!databaseFile.is_open())
|
||||
{
|
||||
throw std::runtime_error("unable to open MNIST database file");
|
||||
}
|
||||
|
||||
int32_t magicNumber = readInt32(databaseFile);
|
||||
if (magicNumber != DatabaseFileMagicNumber)
|
||||
{
|
||||
throw std::runtime_error("unexpected data reading MNIST database file");
|
||||
}
|
||||
|
||||
int32_t sampleCount = readInt32(databaseFile);
|
||||
int32_t sampleWidth = readInt32(databaseFile);
|
||||
int32_t sampleHeight = readInt32(databaseFile);
|
||||
|
||||
if (sampleWidth != SampleWidth || sampleHeight != SampleHeight)
|
||||
{
|
||||
throw std::runtime_error("unexpected sample size loading MNIST database");
|
||||
}
|
||||
|
||||
samples.reserve(samples.size() + sampleCount);
|
||||
|
||||
for (int32_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex)
|
||||
{
|
||||
std::unique_ptr<MnistSample> sample = std::make_unique<MnistSample>();
|
||||
|
||||
databaseFile.read(reinterpret_cast<char *>(sample->data), sampleWidth * sampleHeight);
|
||||
|
||||
samples.push_back(std::move(sample));
|
||||
}
|
||||
}
|
||||
|
||||
void MnistLoader::loadLabels(const std::string &fileName)
|
||||
{
|
||||
std::ifstream labelFile;
|
||||
labelFile.open(fileName, std::ios::binary);
|
||||
|
||||
if (!labelFile.is_open())
|
||||
{
|
||||
throw std::runtime_error("unable to open MNIST label file");
|
||||
}
|
||||
|
||||
int32_t magicNumber = readInt32(labelFile);
|
||||
if (magicNumber != LabelFileMagicNumber)
|
||||
{
|
||||
throw std::runtime_error("unexpected data reading MNIST label file");
|
||||
}
|
||||
|
||||
int32_t labelCount = readInt32(labelFile);
|
||||
if (labelCount != static_cast<int32_t>(samples.size()))
|
||||
{
|
||||
throw std::runtime_error("MNIST database and label files don't match in size");
|
||||
}
|
||||
|
||||
auto sampleIt = samples.begin();
|
||||
for (int32_t labelIndex = 0; labelIndex < labelCount; ++labelIndex)
|
||||
{
|
||||
(*sampleIt++)->label = readInt8(labelFile);
|
||||
}
|
||||
}
|
||||
|
||||
int8_t MnistLoader::readInt8(std::ifstream &file)
|
||||
{
|
||||
int8_t buf8;
|
||||
file.read(reinterpret_cast<char *>(&buf8), sizeof(buf8));
|
||||
return buf8;
|
||||
}
|
||||
|
||||
int32_t MnistLoader::readInt32(std::ifstream &file)
|
||||
{
|
||||
int32_t buf32;
|
||||
file.read(reinterpret_cast<char *>(&buf32), sizeof(buf32));
|
||||
return _byteswap_ulong(buf32);
|
||||
}
|
||||
|
44
gui/NeuroUI/mnistloader.h
Normal file
44
gui/NeuroUI/mnistloader.h
Normal file
|
@ -0,0 +1,44 @@
|
|||
#ifndef MNISTLOADER_H
|
||||
#define MNISTLOADER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <inttypes.h>
|
||||
|
||||
class MnistLoader
|
||||
{
|
||||
private:
|
||||
static const uint32_t DatabaseFileMagicNumber = 2051;
|
||||
static const uint32_t LabelFileMagicNumber = 2049;
|
||||
static const size_t SampleWidth = 28;
|
||||
static const size_t SampleHeight = 28;
|
||||
|
||||
public:
|
||||
template<size_t SAMPLE_WIDTH, size_t SAMPLE_HEIGHT>
|
||||
class Sample
|
||||
{
|
||||
public:
|
||||
uint8_t label;
|
||||
uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT];
|
||||
};
|
||||
|
||||
using MnistSample = Sample<SampleWidth, SampleHeight>;
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<MnistSample>> samples;
|
||||
|
||||
public:
|
||||
void load(const std::string &databaseFileName, const std::string &labelsFileName);
|
||||
|
||||
const MnistSample &getRandomSample() const;
|
||||
|
||||
private:
|
||||
void loadDatabase(const std::string &fileName);
|
||||
void loadLabels(const std::string &fileName);
|
||||
|
||||
static int8_t readInt8(std::ifstream &file);
|
||||
static int32_t readInt32(std::ifstream &file);
|
||||
};
|
||||
|
||||
#endif // MNISTLOADER_H
|
|
@ -1,7 +1,9 @@
|
|||
#include "netlearner.h"
|
||||
#include "../../Net.h"
|
||||
#include "mnistloader.h"
|
||||
|
||||
#include <QElapsedTimer>
|
||||
#include <QImage>
|
||||
|
||||
void NetLearner::run()
|
||||
{
|
||||
|
@ -9,67 +11,54 @@ void NetLearner::run()
|
|||
{
|
||||
QElapsedTimer timer;
|
||||
|
||||
Net myNet;
|
||||
try
|
||||
{
|
||||
myNet.load("mynet.nnet");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
myNet.initialize({2, 3, 1});
|
||||
}
|
||||
emit logMessage("Loading training data...");
|
||||
|
||||
size_t batchSize = 5000;
|
||||
size_t batchIndex = 0;
|
||||
double batchMaxError = 0.0;
|
||||
double batchMeanError = 0.0;
|
||||
MnistLoader mnistLoader;
|
||||
mnistLoader.load("../NeuroUI/MNIST Database/train-images.idx3-ubyte",
|
||||
"../NeuroUI/MNIST Database/train-labels.idx1-ubyte");
|
||||
|
||||
emit logMessage("done");
|
||||
|
||||
Net digitClassifier({28*28, 256, 1});
|
||||
|
||||
timer.start();
|
||||
|
||||
size_t numIterations = 1000000;
|
||||
size_t numIterations = 100000;
|
||||
for (size_t iteration = 0; iteration < numIterations; ++iteration)
|
||||
{
|
||||
std::vector<double> inputValues =
|
||||
{
|
||||
std::rand() / (double)RAND_MAX,
|
||||
std::rand() / (double)RAND_MAX
|
||||
};
|
||||
auto trainingSample = mnistLoader.getRandomSample();
|
||||
|
||||
QImage trainingImage(trainingSample.data, 28, 28, QImage::Format_Grayscale8);
|
||||
emit sampleImageLoaded(trainingImage);
|
||||
|
||||
std::vector<double> targetValues =
|
||||
{
|
||||
(inputValues[0] + inputValues[1]) / 2.0
|
||||
trainingSample.label / 10.0
|
||||
};
|
||||
|
||||
myNet.feedForward(inputValues);
|
||||
std::vector<double> trainingData;
|
||||
trainingData.reserve(28*28);
|
||||
for (const uint8_t &val : trainingSample.data)
|
||||
{
|
||||
trainingData.push_back(val / 255.0);
|
||||
}
|
||||
|
||||
std::vector<double> outputValues = myNet.getOutput();
|
||||
digitClassifier.feedForward(trainingData);
|
||||
|
||||
std::vector<double> outputValues = digitClassifier.getOutput();
|
||||
|
||||
double error = outputValues[0] - targetValues[0];
|
||||
|
||||
batchMeanError += error;
|
||||
batchMaxError = std::max<double>(batchMaxError, error);
|
||||
QString logString;
|
||||
|
||||
if (batchIndex++ == batchSize)
|
||||
{
|
||||
QString logString;
|
||||
logString.append("Error: ");
|
||||
logString.append(QString::number(std::abs(error)));
|
||||
|
||||
logString.append("Batch error (");
|
||||
logString.append(QString::number(batchSize));
|
||||
logString.append(" iterations, max/mean): ");
|
||||
logString.append(QString::number(std::abs(batchMaxError)));
|
||||
logString.append(" / ");
|
||||
logString.append(QString::number(std::abs(batchMeanError / batchSize)));
|
||||
emit logMessage(logString);
|
||||
emit currentNetError(error);
|
||||
emit progress((double)iteration / (double)numIterations);
|
||||
|
||||
emit logMessage(logString);
|
||||
emit currentNetError(batchMaxError);
|
||||
emit progress((double)iteration / (double)numIterations);
|
||||
|
||||
batchIndex = 0;
|
||||
batchMaxError = 0.0;
|
||||
batchMeanError = 0.0;
|
||||
}
|
||||
|
||||
myNet.backProp(targetValues);
|
||||
digitClassifier.backProp(targetValues);
|
||||
}
|
||||
|
||||
QString timerLogString;
|
||||
|
@ -79,7 +68,7 @@ void NetLearner::run()
|
|||
|
||||
emit logMessage(timerLogString);
|
||||
|
||||
myNet.save("mynet.nnet");
|
||||
digitClassifier.save("DigitClassifier.nnet");
|
||||
}
|
||||
catch (std::exception &ex)
|
||||
{
|
||||
|
|
|
@ -14,6 +14,7 @@ signals:
|
|||
void logMessage(const QString &logMessage);
|
||||
void progress(double progress);
|
||||
void currentNetError(double error);
|
||||
void sampleImageLoaded(const QImage &image);
|
||||
};
|
||||
|
||||
#endif // NETLEARNER_H
|
||||
|
|
|
@ -31,6 +31,8 @@ void NeuroUI::on_runButton_clicked()
|
|||
connect(m_netLearner.get(), &NetLearner::finished, this, &NeuroUI::netLearnerFinished);
|
||||
|
||||
connect(m_netLearner.get(), &NetLearner::currentNetError, ui->errorPlotter, &ErrorPlotter::addErrorValue);
|
||||
|
||||
connect(m_netLearner.get(), &NetLearner::sampleImageLoaded, this, &NeuroUI::setImage);
|
||||
}
|
||||
|
||||
m_netLearner->start();
|
||||
|
@ -61,3 +63,10 @@ void NeuroUI::progress(double progress)
|
|||
|
||||
ui->progressBar->setValue(value);
|
||||
}
|
||||
|
||||
void NeuroUI::setImage(const QImage &image)
|
||||
{
|
||||
QPixmap pixmap;
|
||||
pixmap.convertFromImage(image);
|
||||
ui->label->setPixmap(pixmap);
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ private slots:
|
|||
void netLearnerStarted();
|
||||
void netLearnerFinished();
|
||||
void progress(double progress);
|
||||
void setImage(const QImage &image);
|
||||
|
||||
private:
|
||||
Ui::NeuroUI *ui;
|
||||
|
|
|
@ -20,11 +20,37 @@
|
|||
<widget class="QWidget" name="centralWidget">
|
||||
<layout class="QVBoxLayout" name="verticalLayout_2">
|
||||
<item>
|
||||
<widget class="QListWidget" name="logView">
|
||||
<property name="uniformItemSizes">
|
||||
<bool>true</bool>
|
||||
</property>
|
||||
</widget>
|
||||
<layout class="QHBoxLayout" name="horizontalLayout_2">
|
||||
<item>
|
||||
<widget class="QListWidget" name="logView">
|
||||
<property name="uniformItemSizes">
|
||||
<bool>true</bool>
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
<item>
|
||||
<widget class="QLabel" name="label">
|
||||
<property name="sizePolicy">
|
||||
<sizepolicy hsizetype="Fixed" vsizetype="Preferred">
|
||||
<horstretch>0</horstretch>
|
||||
<verstretch>0</verstretch>
|
||||
</sizepolicy>
|
||||
</property>
|
||||
<property name="minimumSize">
|
||||
<size>
|
||||
<width>128</width>
|
||||
<height>0</height>
|
||||
</size>
|
||||
</property>
|
||||
<property name="text">
|
||||
<string/>
|
||||
</property>
|
||||
<property name="alignment">
|
||||
<set>Qt::AlignCenter</set>
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
</layout>
|
||||
</item>
|
||||
<item>
|
||||
<widget class="ErrorPlotter" name="errorPlotter" native="true">
|
||||
|
|
Loading…
Reference in a new issue