Neuro/Layer.cpp

89 lines
1.8 KiB
C++
Raw Normal View History

#include "Layer.h"
2015-10-15 20:37:13 +00:00
Layer::Layer(size_t numNeurons)
{
for (unsigned int i = 0; i < numNeurons; ++i)
{
push_back(Neuron());
}
}
void Layer::setOutputValues(const std::vector<double> & outputValues)
{
2015-03-24 12:45:38 +00:00
if (size() - 1 != outputValues.size())
{
throw std::exception("The number of output values has to match the layer size");
}
2015-03-24 12:45:38 +00:00
auto neuronIt = begin();
for (const double &value : outputValues)
{
(neuronIt++)->setOutputValue(value);
}
}
void Layer::feedForward(const Layer &inputLayer)
2015-10-22 14:02:27 +00:00
{
for (int neuronNumber = 0; neuronNumber < sizeWithoutBiasNeuron(); ++neuronNumber)
{
2015-10-22 14:02:27 +00:00
at(neuronNumber).feedForward(inputLayer.getWeightedSum(neuronNumber));
}
}
double Layer::getWeightedSum(int outputNeuron) const
{
double sum = 0.0;
for (const Neuron &neuron : *this)
{
sum += neuron.getWeightedOutputValue(outputNeuron);
}
return sum;
}
void Layer::connectTo(const Layer & nextLayer)
{
for (Neuron &neuron : *this)
{
2015-10-22 14:02:27 +00:00
neuron.createOutputWeights(nextLayer.sizeWithoutBiasNeuron(), 0.5);
}
}
void Layer::updateInputWeights(Layer & prevLayer)
{
2015-10-18 20:05:18 +00:00
static const double trainingRate = 0.2;
2015-10-18 19:20:37 +00:00
for (size_t currentLayerIndex = 0; currentLayerIndex < sizeWithoutBiasNeuron(); ++currentLayerIndex)
{
Neuron &targetNeuron = at(currentLayerIndex);
for (size_t prevLayerIndex = 0; prevLayerIndex < prevLayer.size(); ++prevLayerIndex)
{
Neuron &sourceNeuron = prevLayer.at(prevLayerIndex);
2015-10-18 19:20:37 +00:00
sourceNeuron.setOutputWeight(currentLayerIndex,
sourceNeuron.getOutputWeight(currentLayerIndex) +
sourceNeuron.getOutputValue() * targetNeuron.getGradient() * trainingRate);
}
}
}
2015-10-18 19:20:37 +00:00
void Layer::addBiasNeuron()
{
push_back(Neuron(1.0));
hasBiasNeuron = true;
}
size_t Layer::sizeWithoutBiasNeuron() const
{
if (hasBiasNeuron)
{
return size() - 1;
}
else
{
return size();
}
}