/*
 * Decompiled with CFR 0.152.
 */
package aima.core.learning.neural;

import aima.core.learning.neural.FeedForwardNeuralNetwork;
import aima.core.learning.neural.FunctionApproximator;
import aima.core.learning.neural.Layer;
import aima.core.learning.neural.LayerSensitivity;
import aima.core.learning.neural.NNTrainingScheme;
import aima.core.util.math.Matrix;
import aima.core.util.math.Vector;

public class BackPropLearning
implements NNTrainingScheme {
    private final double learningRate;
    private final double momentum;
    private Layer hiddenLayer;
    private Layer outputLayer;
    private LayerSensitivity hiddenSensitivity;
    private LayerSensitivity outputSensitivity;

    public BackPropLearning(double learningRate, double momentum) {
        this.learningRate = learningRate;
        this.momentum = momentum;
    }

    @Override
    public void setNeuralNetwork(FunctionApproximator fapp) {
        FeedForwardNeuralNetwork ffnn = (FeedForwardNeuralNetwork)fapp;
        this.hiddenLayer = ffnn.getHiddenLayer();
        this.outputLayer = ffnn.getOutputLayer();
        this.hiddenSensitivity = new LayerSensitivity(this.hiddenLayer);
        this.outputSensitivity = new LayerSensitivity(this.outputLayer);
    }

    @Override
    public Vector processInput(FeedForwardNeuralNetwork network, Vector input) {
        this.hiddenLayer.feedForward(input);
        this.outputLayer.feedForward(this.hiddenLayer.getLastActivationValues());
        return this.outputLayer.getLastActivationValues();
    }

    @Override
    public void processError(FeedForwardNeuralNetwork network, Vector error) {
        this.outputSensitivity.sensitivityMatrixFromErrorMatrix(error);
        this.hiddenSensitivity.sensitivityMatrixFromSucceedingLayer(this.outputSensitivity);
        this.calculateWeightUpdates(this.outputSensitivity, this.hiddenLayer.getLastActivationValues(), this.learningRate, this.momentum);
        this.calculateWeightUpdates(this.hiddenSensitivity, this.hiddenLayer.getLastInputValues(), this.learningRate, this.momentum);
        this.calculateBiasUpdates(this.outputSensitivity, this.learningRate, this.momentum);
        this.calculateBiasUpdates(this.hiddenSensitivity, this.learningRate, this.momentum);
        this.outputLayer.updateWeights();
        this.outputLayer.updateBiases();
        this.hiddenLayer.updateWeights();
        this.hiddenLayer.updateBiases();
    }

    public Matrix calculateWeightUpdates(LayerSensitivity layerSensitivity, Vector previousLayerActivationOrInput, double alpha, double momentum) {
        Layer layer = layerSensitivity.getLayer();
        Matrix activationTranspose = previousLayerActivationOrInput.transpose();
        Matrix momentumLessUpdate = layerSensitivity.getSensitivityMatrix().times(activationTranspose).times(alpha).times(-1.0);
        Matrix updateWithMomentum = layer.getLastWeightUpdateMatrix().times(momentum).plus(momentumLessUpdate.times(1.0 - momentum));
        layer.acceptNewWeightUpdate(updateWithMomentum.copy());
        return updateWithMomentum;
    }

    public static Matrix calculateWeightUpdates(LayerSensitivity layerSensitivity, Vector previousLayerActivationOrInput, double alpha) {
        Layer layer = layerSensitivity.getLayer();
        Matrix activationTranspose = previousLayerActivationOrInput.transpose();
        Matrix weightUpdateMatrix = layerSensitivity.getSensitivityMatrix().times(activationTranspose).times(alpha).times(-1.0);
        layer.acceptNewWeightUpdate(weightUpdateMatrix.copy());
        return weightUpdateMatrix;
    }

    public Vector calculateBiasUpdates(LayerSensitivity layerSensitivity, double alpha, double momentum) {
        Layer layer = layerSensitivity.getLayer();
        Matrix biasUpdateMatrixWithoutMomentum = layerSensitivity.getSensitivityMatrix().times(alpha).times(-1.0);
        Matrix biasUpdateMatrixWithMomentum = layer.getLastBiasUpdateVector().times(momentum).plus(biasUpdateMatrixWithoutMomentum.times(1.0 - momentum));
        Vector result = new Vector(biasUpdateMatrixWithMomentum.getRowDimension());
        for (int i = 0; i < biasUpdateMatrixWithMomentum.getRowDimension(); ++i) {
            result.setValue(i, biasUpdateMatrixWithMomentum.get(i, 0));
        }
        layer.acceptNewBiasUpdate(result.copyVector());
        return result;
    }

    public static Vector calculateBiasUpdates(LayerSensitivity layerSensitivity, double alpha) {
        Layer layer = layerSensitivity.getLayer();
        Matrix biasUpdateMatrix = layerSensitivity.getSensitivityMatrix().times(alpha).times(-1.0);
        Vector result = new Vector(biasUpdateMatrix.getRowDimension());
        for (int i = 0; i < biasUpdateMatrix.getRowDimension(); ++i) {
            result.setValue(i, biasUpdateMatrix.get(i, 0));
        }
        layer.acceptNewBiasUpdate(result.copyVector());
        return result;
    }
}

