Reduced the progress update messages to take load from the UI,
implemented load-or-create in test funtion
This commit is contained in:
parent
9bb927d2d2
commit
1e716979a9
3 changed files with 44 additions and 22 deletions
50
Net.cpp
50
Net.cpp
|
@ -4,27 +4,14 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
|
Net::Net()
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
Net::Net(std::initializer_list<size_t> layerSizes)
|
Net::Net(std::initializer_list<size_t> layerSizes)
|
||||||
{
|
{
|
||||||
if (layerSizes.size() < 2)
|
initialize(layerSizes);
|
||||||
{
|
|
||||||
throw std::exception("A net needs at least 2 layers");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t numNeurons : layerSizes)
|
|
||||||
{
|
|
||||||
push_back(Layer(numNeurons));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto layerIt = begin(); layerIt != end() - 1; ++layerIt)
|
|
||||||
{
|
|
||||||
Layer ¤tLayer = *layerIt;
|
|
||||||
const Layer &nextLayer = *(layerIt + 1);
|
|
||||||
|
|
||||||
currentLayer.addBiasNeuron();
|
|
||||||
|
|
||||||
currentLayer.connectTo(nextLayer);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Net::Net(const std::string &filename)
|
Net::Net(const std::string &filename)
|
||||||
|
@ -32,6 +19,31 @@ Net::Net(const std::string &filename)
|
||||||
load(filename);
|
load(filename);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Net::initialize(std::initializer_list<size_t> layerSizes)
|
||||||
|
{
|
||||||
|
clear();
|
||||||
|
|
||||||
|
if (layerSizes.size() < 2)
|
||||||
|
{
|
||||||
|
throw std::exception("A net needs at least 2 layers");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t numNeurons : layerSizes)
|
||||||
|
{
|
||||||
|
push_back(Layer(numNeurons));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto layerIt = begin(); layerIt != end() - 1; ++layerIt)
|
||||||
|
{
|
||||||
|
Layer ¤tLayer = *layerIt;
|
||||||
|
const Layer &nextLayer = *(layerIt + 1);
|
||||||
|
|
||||||
|
currentLayer.addBiasNeuron();
|
||||||
|
|
||||||
|
currentLayer.connectTo(nextLayer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Net::feedForward(const std::vector<double> &inputValues)
|
void Net::feedForward(const std::vector<double> &inputValues)
|
||||||
{
|
{
|
||||||
Layer &inputLayer = front();
|
Layer &inputLayer = front();
|
||||||
|
|
3
Net.h
3
Net.h
|
@ -7,9 +7,12 @@
|
||||||
class Net : public std::vector < Layer >
|
class Net : public std::vector < Layer >
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
Net();
|
||||||
Net(std::initializer_list<size_t> layerSizes);
|
Net(std::initializer_list<size_t> layerSizes);
|
||||||
Net(const std::string &filename);
|
Net(const std::string &filename);
|
||||||
|
|
||||||
|
void initialize(std::initializer_list<size_t> layerSizes);
|
||||||
|
|
||||||
void feedForward(const std::vector<double> &inputValues);
|
void feedForward(const std::vector<double> &inputValues);
|
||||||
std::vector<double> getOutput();
|
std::vector<double> getOutput();
|
||||||
void backProp(const std::vector<double> &targetValues);
|
void backProp(const std::vector<double> &targetValues);
|
||||||
|
|
|
@ -9,7 +9,15 @@ void NetLearner::run()
|
||||||
{
|
{
|
||||||
QElapsedTimer timer;
|
QElapsedTimer timer;
|
||||||
|
|
||||||
Net myNet({2, 3, 1});
|
Net myNet;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
myNet.load("mynet.nnet");
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
myNet.initialize({2, 3, 1});
|
||||||
|
}
|
||||||
|
|
||||||
size_t batchSize = 5000;
|
size_t batchSize = 5000;
|
||||||
size_t batchIndex = 0;
|
size_t batchIndex = 0;
|
||||||
|
@ -54,6 +62,7 @@ void NetLearner::run()
|
||||||
|
|
||||||
emit logMessage(logString);
|
emit logMessage(logString);
|
||||||
emit currentNetError(batchMaxError);
|
emit currentNetError(batchMaxError);
|
||||||
|
emit progress((double)iteration / (double)numIterations);
|
||||||
|
|
||||||
batchIndex = 0;
|
batchIndex = 0;
|
||||||
batchMaxError = 0.0;
|
batchMaxError = 0.0;
|
||||||
|
@ -61,8 +70,6 @@ void NetLearner::run()
|
||||||
}
|
}
|
||||||
|
|
||||||
myNet.backProp(targetValues);
|
myNet.backProp(targetValues);
|
||||||
|
|
||||||
emit progress((double)iteration / (double)numIterations);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
QString timerLogString;
|
QString timerLogString;
|
||||||
|
|
Loading…
Reference in a new issue