diff --git a/Net.cpp b/Net.cpp index c1447a6..2a540b3 100644 --- a/Net.cpp +++ b/Net.cpp @@ -4,27 +4,14 @@ #include #include +Net::Net() +{ + +} + Net::Net(std::initializer_list layerSizes) { - if (layerSizes.size() < 2) - { - throw std::exception("A net needs at least 2 layers"); - } - - for (size_t numNeurons : layerSizes) - { - push_back(Layer(numNeurons)); - } - - for (auto layerIt = begin(); layerIt != end() - 1; ++layerIt) - { - Layer ¤tLayer = *layerIt; - const Layer &nextLayer = *(layerIt + 1); - - currentLayer.addBiasNeuron(); - - currentLayer.connectTo(nextLayer); - } + initialize(layerSizes); } Net::Net(const std::string &filename) @@ -32,6 +19,31 @@ Net::Net(const std::string &filename) load(filename); } +void Net::initialize(std::initializer_list layerSizes) +{ + clear(); + + if (layerSizes.size() < 2) + { + throw std::exception("A net needs at least 2 layers"); + } + + for (size_t numNeurons : layerSizes) + { + push_back(Layer(numNeurons)); + } + + for (auto layerIt = begin(); layerIt != end() - 1; ++layerIt) + { + Layer ¤tLayer = *layerIt; + const Layer &nextLayer = *(layerIt + 1); + + currentLayer.addBiasNeuron(); + + currentLayer.connectTo(nextLayer); + } +} + void Net::feedForward(const std::vector &inputValues) { Layer &inputLayer = front(); diff --git a/Net.h b/Net.h index d85b5f8..ad2e53c 100644 --- a/Net.h +++ b/Net.h @@ -7,9 +7,12 @@ class Net : public std::vector < Layer > { public: + Net(); Net(std::initializer_list layerSizes); Net(const std::string &filename); + void initialize(std::initializer_list layerSizes); + void feedForward(const std::vector &inputValues); std::vector getOutput(); void backProp(const std::vector &targetValues); diff --git a/gui/NeuroUI/netlearner.cpp b/gui/NeuroUI/netlearner.cpp index 4db9939..af849db 100644 --- a/gui/NeuroUI/netlearner.cpp +++ b/gui/NeuroUI/netlearner.cpp @@ -9,7 +9,15 @@ void NetLearner::run() { QElapsedTimer timer; - Net myNet({2, 3, 1}); + Net myNet; + try + { + myNet.load("mynet.nnet"); + } + catch (...) + { + myNet.initialize({2, 3, 1}); + } size_t batchSize = 5000; size_t batchIndex = 0; @@ -54,6 +62,7 @@ void NetLearner::run() emit logMessage(logString); emit currentNetError(batchMaxError); + emit progress((double)iteration / (double)numIterations); batchIndex = 0; batchMaxError = 0.0; @@ -61,8 +70,6 @@ void NetLearner::run() } myNet.backProp(targetValues); - - emit progress((double)iteration / (double)numIterations); } QString timerLogString;