Finished the max-value net (2/3/1 neurons) with 10k learning iterations. No good.

This commit is contained in:
Michael Mandl 2015-10-22 22:09:35 +02:00
parent d4a22ecae7
commit 6ed30e56c4
3 changed files with 27 additions and 21 deletions

View file

@ -54,16 +54,16 @@ void Layer::updateInputWeights(Layer & prevLayer)
{
static const double trainingRate = 0.2;
for (size_t currentLayerIndex = 0; currentLayerIndex < sizeWithoutBiasNeuron(); ++currentLayerIndex)
for (size_t targetLayerIndex = 0; targetLayerIndex < sizeWithoutBiasNeuron(); ++targetLayerIndex)
{
Neuron &targetNeuron = at(currentLayerIndex);
const Neuron &targetNeuron = at(targetLayerIndex);
for (size_t prevLayerIndex = 0; prevLayerIndex < prevLayer.size(); ++prevLayerIndex)
for (size_t sourceLayerIndex = 0; sourceLayerIndex < prevLayer.size(); ++sourceLayerIndex)
{
Neuron &sourceNeuron = prevLayer.at(prevLayerIndex);
Neuron &sourceNeuron = prevLayer.at(sourceLayerIndex);
sourceNeuron.setOutputWeight(currentLayerIndex,
sourceNeuron.getOutputWeight(currentLayerIndex) +
sourceNeuron.setOutputWeight(targetLayerIndex,
sourceNeuron.getOutputWeight(targetLayerIndex) +
sourceNeuron.getOutputValue() * targetNeuron.getGradient() * trainingRate);
}
}

View file

@ -73,11 +73,10 @@ void Net::backProp(const std::vector<double> &targetValues)
for (unsigned int i = 0; i < numResultValues; ++i)
{
double delta = resultValues[i] - targetValues[i];
rmsError += delta * delta;
rmsError += std::pow(resultValues[i] - targetValues[i], 2);
}
rmsError = sqrt(rmsError / numResultValues);
rmsError = std::sqrt(rmsError / numResultValues);
// calculate output neuron gradients
for (unsigned int i = 0; i < numResultValues; ++i)

View file

@ -12,13 +12,19 @@ int main()
Net myNet({ 2, 3, 1 });
size_t numIterations = 10000;
for (size_t iteration = 0; iteration < numIterations; ++iteration)
{
std::vector<double> inputValues =
{
0.1,
0.7,
std::rand() / (double)RAND_MAX,
std::rand() / (double)RAND_MAX
};
std::vector<double> targetValues = { 0.7 };
std::vector<double> targetValues =
{
*std::max_element(inputValues.begin(), inputValues.end())
};
myNet.feedForward(inputValues);
@ -32,6 +38,7 @@ int main()
myNet.backProp(targetValues);
}
}
catch (std::exception &ex)
{
std::cerr << "Error: " << ex.what() << std::endl;