/*
 * Decompiled with CFR 0.152.
 */
package aima.core.learning.neural;

import aima.core.learning.neural.ActivationFunction;
import aima.core.util.Util;
import aima.core.util.math.Matrix;
import aima.core.util.math.Vector;

public class Layer {
    private final Matrix weightMatrix;
    Vector biasVector;
    Vector lastBiasUpdateVector;
    private final ActivationFunction activationFunction;
    private Vector lastActivationValues;
    private Vector lastInducedField;
    private Matrix lastWeightUpdateMatrix;
    private Matrix penultimateWeightUpdateMatrix;
    private Vector penultimateBiasUpdateVector;
    private Vector lastInput;

    public Layer(Matrix weightMatrix, Vector biasVector, ActivationFunction af) {
        this.activationFunction = af;
        this.weightMatrix = weightMatrix;
        this.lastWeightUpdateMatrix = new Matrix(weightMatrix.getRowDimension(), weightMatrix.getColumnDimension());
        this.penultimateWeightUpdateMatrix = new Matrix(weightMatrix.getRowDimension(), weightMatrix.getColumnDimension());
        this.biasVector = biasVector;
        this.lastBiasUpdateVector = new Vector(biasVector.getRowDimension());
        this.penultimateBiasUpdateVector = new Vector(biasVector.getRowDimension());
    }

    public Layer(int numberOfNeurons, int numberOfInputs, double lowerLimitForWeights, double upperLimitForWeights, ActivationFunction af) {
        this.activationFunction = af;
        this.weightMatrix = new Matrix(numberOfNeurons, numberOfInputs);
        this.lastWeightUpdateMatrix = new Matrix(this.weightMatrix.getRowDimension(), this.weightMatrix.getColumnDimension());
        this.penultimateWeightUpdateMatrix = new Matrix(this.weightMatrix.getRowDimension(), this.weightMatrix.getColumnDimension());
        this.biasVector = new Vector(numberOfNeurons);
        this.lastBiasUpdateVector = new Vector(this.biasVector.getRowDimension());
        this.penultimateBiasUpdateVector = new Vector(this.biasVector.getRowDimension());
        Layer.initializeMatrix(this.weightMatrix, lowerLimitForWeights, upperLimitForWeights);
        Layer.initializeVector(this.biasVector, lowerLimitForWeights, upperLimitForWeights);
    }

    public Vector feedForward(Vector inputVector) {
        this.lastInput = inputVector;
        Matrix inducedField = this.weightMatrix.times(inputVector).plus(this.biasVector);
        Vector inducedFieldVector = new Vector(this.numberOfNeurons());
        for (int i = 0; i < this.numberOfNeurons(); ++i) {
            inducedFieldVector.setValue(i, inducedField.get(i, 0));
        }
        this.lastInducedField = inducedFieldVector.copyVector();
        Vector resultVector = new Vector(this.numberOfNeurons());
        for (int i = 0; i < this.numberOfNeurons(); ++i) {
            resultVector.setValue(i, this.activationFunction.activation(inducedFieldVector.getValue(i)));
        }
        this.lastActivationValues = resultVector.copyVector();
        return resultVector;
    }

    public Matrix getWeightMatrix() {
        return this.weightMatrix;
    }

    public Vector getBiasVector() {
        return this.biasVector;
    }

    public int numberOfNeurons() {
        return this.weightMatrix.getRowDimension();
    }

    public int numberOfInputs() {
        return this.weightMatrix.getColumnDimension();
    }

    public Vector getLastActivationValues() {
        return this.lastActivationValues;
    }

    public Vector getLastInducedField() {
        return this.lastInducedField;
    }

    public Matrix getLastWeightUpdateMatrix() {
        return this.lastWeightUpdateMatrix;
    }

    public void setLastWeightUpdateMatrix(Matrix m) {
        this.lastWeightUpdateMatrix = m;
    }

    public Matrix getPenultimateWeightUpdateMatrix() {
        return this.penultimateWeightUpdateMatrix;
    }

    public void setPenultimateWeightUpdateMatrix(Matrix m) {
        this.penultimateWeightUpdateMatrix = m;
    }

    public Vector getLastBiasUpdateVector() {
        return this.lastBiasUpdateVector;
    }

    public void setLastBiasUpdateVector(Vector v) {
        this.lastBiasUpdateVector = v;
    }

    public Vector getPenultimateBiasUpdateVector() {
        return this.penultimateBiasUpdateVector;
    }

    public void setPenultimateBiasUpdateVector(Vector v) {
        this.penultimateBiasUpdateVector = v;
    }

    public void updateWeights() {
        this.weightMatrix.plusEquals(this.lastWeightUpdateMatrix);
    }

    public void updateBiases() {
        Matrix biasMatrix = this.biasVector.plusEquals(this.lastBiasUpdateVector);
        Vector result = new Vector(biasMatrix.getRowDimension());
        for (int i = 0; i < biasMatrix.getRowDimension(); ++i) {
            result.setValue(i, biasMatrix.get(i, 0));
        }
        this.biasVector = result;
    }

    public Vector getLastInputValues() {
        return this.lastInput;
    }

    public ActivationFunction getActivationFunction() {
        return this.activationFunction;
    }

    public void acceptNewWeightUpdate(Matrix weightUpdate) {
        this.setPenultimateWeightUpdateMatrix(this.getLastWeightUpdateMatrix());
        this.setLastWeightUpdateMatrix(weightUpdate);
    }

    public void acceptNewBiasUpdate(Vector biasUpdate) {
        this.setPenultimateBiasUpdateVector(this.getLastBiasUpdateVector());
        this.setLastBiasUpdateVector(biasUpdate);
    }

    public Vector errorVectorFrom(Vector target) {
        return target.minus(this.getLastActivationValues());
    }

    private static void initializeMatrix(Matrix aMatrix, double lowerLimit, double upperLimit) {
        for (int i = 0; i < aMatrix.getRowDimension(); ++i) {
            for (int j = 0; j < aMatrix.getColumnDimension(); ++j) {
                double random = Util.generateRandomDoubleBetween(lowerLimit, upperLimit);
                aMatrix.set(i, j, random);
            }
        }
    }

    private static void initializeVector(Vector aVector, double lowerLimit, double upperLimit) {
        for (int i = 0; i < aVector.size(); ++i) {
            double random = Util.generateRandomDoubleBetween(lowerLimit, upperLimit);
            aVector.setValue(i, random);
        }
    }
}

