Reduced the progress update messages to take load from the UI,

implemented load-or-create in test funtion
main
mandlm 2015-10-26 19:50:01 +01:00
parent 9bb927d2d2
commit 1e716979a9
3 changed files with 44 additions and 22 deletions

50
Net.cpp
View File

@ -4,27 +4,14 @@
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
Net::Net()
{
}
Net::Net(std::initializer_list<size_t> layerSizes) Net::Net(std::initializer_list<size_t> layerSizes)
{ {
if (layerSizes.size() < 2) initialize(layerSizes);
{
throw std::exception("A net needs at least 2 layers");
}
for (size_t numNeurons : layerSizes)
{
push_back(Layer(numNeurons));
}
for (auto layerIt = begin(); layerIt != end() - 1; ++layerIt)
{
Layer &currentLayer = *layerIt;
const Layer &nextLayer = *(layerIt + 1);
currentLayer.addBiasNeuron();
currentLayer.connectTo(nextLayer);
}
} }
Net::Net(const std::string &filename) Net::Net(const std::string &filename)
@ -32,6 +19,31 @@ Net::Net(const std::string &filename)
load(filename); load(filename);
} }
void Net::initialize(std::initializer_list<size_t> layerSizes)
{
clear();
if (layerSizes.size() < 2)
{
throw std::exception("A net needs at least 2 layers");
}
for (size_t numNeurons : layerSizes)
{
push_back(Layer(numNeurons));
}
for (auto layerIt = begin(); layerIt != end() - 1; ++layerIt)
{
Layer &currentLayer = *layerIt;
const Layer &nextLayer = *(layerIt + 1);
currentLayer.addBiasNeuron();
currentLayer.connectTo(nextLayer);
}
}
void Net::feedForward(const std::vector<double> &inputValues) void Net::feedForward(const std::vector<double> &inputValues)
{ {
Layer &inputLayer = front(); Layer &inputLayer = front();

3
Net.h
View File

@ -7,9 +7,12 @@
class Net : public std::vector < Layer > class Net : public std::vector < Layer >
{ {
public: public:
Net();
Net(std::initializer_list<size_t> layerSizes); Net(std::initializer_list<size_t> layerSizes);
Net(const std::string &filename); Net(const std::string &filename);
void initialize(std::initializer_list<size_t> layerSizes);
void feedForward(const std::vector<double> &inputValues); void feedForward(const std::vector<double> &inputValues);
std::vector<double> getOutput(); std::vector<double> getOutput();
void backProp(const std::vector<double> &targetValues); void backProp(const std::vector<double> &targetValues);

View File

@ -9,7 +9,15 @@ void NetLearner::run()
{ {
QElapsedTimer timer; QElapsedTimer timer;
Net myNet({2, 3, 1}); Net myNet;
try
{
myNet.load("mynet.nnet");
}
catch (...)
{
myNet.initialize({2, 3, 1});
}
size_t batchSize = 5000; size_t batchSize = 5000;
size_t batchIndex = 0; size_t batchIndex = 0;
@ -54,6 +62,7 @@ void NetLearner::run()
emit logMessage(logString); emit logMessage(logString);
emit currentNetError(batchMaxError); emit currentNetError(batchMaxError);
emit progress((double)iteration / (double)numIterations);
batchIndex = 0; batchIndex = 0;
batchMaxError = 0.0; batchMaxError = 0.0;
@ -61,8 +70,6 @@ void NetLearner::run()
} }
myNet.backProp(targetValues); myNet.backProp(targetValues);
emit progress((double)iteration / (double)numIterations);
} }
QString timerLogString; QString timerLogString;