/*
 * Decompiled with CFR 0.152.
 */
package aima.learning.reinforcement;

import aima.learning.reinforcement.MDPAgent;
import aima.learning.reinforcement.QTable;
import aima.probability.decision.MDP;
import aima.probability.decision.MDPPerception;
import aima.util.FrequencyCounter;
import aima.util.Pair;
import java.util.Hashtable;
import java.util.List;

public class QLearningAgent<STATE_TYPE, ACTION_TYPE>
extends MDPAgent<STATE_TYPE, ACTION_TYPE> {
    private Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> Q = new Hashtable();
    private FrequencyCounter<Pair<STATE_TYPE, ACTION_TYPE>> stateActionCount;
    private Double previousReward;
    private QTable<STATE_TYPE, ACTION_TYPE> qTable;
    private int actionCounter;

    public QLearningAgent(MDP<STATE_TYPE, ACTION_TYPE> mDP) {
        super(mDP);
        this.qTable = new QTable(mDP.getAllActions());
        this.stateActionCount = new FrequencyCounter();
        this.actionCounter = 0;
    }

    @Override
    public ACTION_TYPE decideAction(MDPPerception<STATE_TYPE> mDPPerception) {
        this.currentState = mDPPerception.getState();
        this.currentReward = mDPPerception.getReward();
        if (this.startingTrial()) {
            ACTION_TYPE ACTION_TYPE = this.selectRandomAction();
            this.updateLearnerState(ACTION_TYPE);
            return (ACTION_TYPE)this.previousAction;
        }
        if (this.mdp.isTerminalState(this.currentState)) {
            this.incrementStateActionCount(this.previousState, this.previousAction);
            this.updateQ(0.8);
            this.previousAction = null;
            this.previousState = null;
            this.previousReward = null;
            return (ACTION_TYPE)this.previousAction;
        }
        this.incrementStateActionCount(this.previousState, this.previousAction);
        ACTION_TYPE ACTION_TYPE = this.updateQ(0.8);
        this.updateLearnerState(ACTION_TYPE);
        return (ACTION_TYPE)this.previousAction;
    }

    private void updateLearnerState(ACTION_TYPE ACTION_TYPE) {
        this.previousAction = ACTION_TYPE;
        this.previousAction = ACTION_TYPE;
        this.previousState = this.currentState;
        this.previousReward = this.currentReward;
    }

    private ACTION_TYPE updateQ(double d) {
        ++this.actionCounter;
        double d2 = this.calculateProbabilityOf(this.previousState, this.previousAction);
        Object object = this.qTable.upDateQ(this.previousState, this.previousAction, this.currentState, d2, this.currentReward, 0.8);
        return (ACTION_TYPE)object;
    }

    private double calculateProbabilityOf(STATE_TYPE STATE_TYPE, ACTION_TYPE ACTION_TYPE) {
        Double d = 0.0;
        Double d2 = 0.0;
        for (Pair<STATE_TYPE, ACTION_TYPE> pair : this.stateActionCount.getStates()) {
            if (!pair.getFirst().equals(STATE_TYPE)) continue;
            d = d + 1.0;
            if (!pair.getSecond().equals(ACTION_TYPE)) continue;
            d2 = d2 + 1.0;
        }
        return d2 / d;
    }

    private ACTION_TYPE actionMaximizingLearningFunction() {
        ACTION_TYPE ACTION_TYPE = null;
        Double d = Double.NEGATIVE_INFINITY;
        for (Object ACTION_TYPE2 : this.mdp.getAllActions()) {
            Double d2 = this.qTable.getQValue(this.currentState, ACTION_TYPE2);
            Double d3 = this.learningFunction(d2);
            if (!(d3 > d)) continue;
            d = d3;
            ACTION_TYPE = ACTION_TYPE2;
        }
        return ACTION_TYPE;
    }

    private Double learningFunction(Double d) {
        if (this.actionCounter > 3) {
            this.actionCounter = 0;
            return 1.0;
        }
        return d;
    }

    private ACTION_TYPE selectRandomAction() {
        List list = this.mdp.getAllActions();
        return list.get(0);
    }

    private boolean startingTrial() {
        return this.previousAction == null && this.previousState == null && this.previousReward == null && this.currentState.equals(this.mdp.getInitialState());
    }

    private void incrementStateActionCount(STATE_TYPE STATE_TYPE, ACTION_TYPE ACTION_TYPE) {
        Pair<STATE_TYPE, ACTION_TYPE> pair = new Pair<STATE_TYPE, ACTION_TYPE>(STATE_TYPE, ACTION_TYPE);
        this.stateActionCount.incrementFor(pair);
    }

    public Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> getQ() {
        return this.Q;
    }

    public QTable<STATE_TYPE, ACTION_TYPE> getQTable() {
        return this.qTable;
    }
}

