/*
 * Decompiled with CFR 0.152.
 */
package aima.core.learning.reinforcement.agent;

import aima.core.agent.Action;
import aima.core.learning.reinforcement.PerceptStateReward;
import aima.core.learning.reinforcement.agent.ReinforcementAgent;
import aima.core.probability.mdp.ActionsFunction;
import aima.core.util.FrequencyCounter;
import aima.core.util.datastructure.Pair;
import java.util.HashMap;
import java.util.Map;

public class QLearningAgent<S, A extends Action>
extends ReinforcementAgent<S, A> {
    Map<Pair<S, A>, Double> Q = new HashMap<Pair<S, A>, Double>();
    private FrequencyCounter<Pair<S, A>> Nsa = new FrequencyCounter();
    private S s = null;
    private A a = null;
    private Double r = null;
    private ActionsFunction<S, A> actionsFunction = null;
    private A noneAction = null;
    private double alpha = 0.0;
    private double gamma = 0.0;
    private int Ne = 0;
    private double Rplus = 0.0;

    public QLearningAgent(ActionsFunction<S, A> actionsFunction, A noneAction, double alpha, double gamma, int Ne, double Rplus) {
        this.actionsFunction = actionsFunction;
        this.noneAction = noneAction;
        this.alpha = alpha;
        this.gamma = gamma;
        this.Ne = Ne;
        this.Rplus = Rplus;
    }

    @Override
    public A execute(PerceptStateReward<S> percept) {
        S sPrime = percept.state();
        double rPrime = percept.reward();
        if (this.isTerminal(sPrime)) {
            this.Q.put(new Pair<S, A>(sPrime, this.noneAction), rPrime);
        }
        if (null != this.s) {
            Pair<S, A> sa = new Pair<S, A>(this.s, this.a);
            this.Nsa.incrementFor(sa);
            Double Q_sa = this.Q.get(sa);
            if (null == Q_sa) {
                Q_sa = 0.0;
            }
            this.Q.put(sa, Q_sa + this.alpha(this.Nsa, this.s, this.a) * (this.r + this.gamma * this.maxAPrime(sPrime) - Q_sa));
        }
        if (this.isTerminal(sPrime)) {
            this.s = null;
            this.a = null;
            this.r = null;
        } else {
            this.s = sPrime;
            this.a = this.argmaxAPrime(sPrime);
            this.r = rPrime;
        }
        return this.a;
    }

    @Override
    public void reset() {
        this.Q.clear();
        this.Nsa.clear();
        this.s = null;
        this.a = null;
        this.r = null;
    }

    @Override
    public Map<S, Double> getUtility() {
        HashMap<S, Double> U = new HashMap<S, Double>();
        for (Pair<S, A> sa : this.Q.keySet()) {
            Double q = this.Q.get(sa);
            Double u = (Double)U.get(sa.getFirst());
            if (null != u && !(u < q)) continue;
            U.put(sa.getFirst(), q);
        }
        return U;
    }

    protected double alpha(FrequencyCounter<Pair<S, A>> Nsa, S s, A a) {
        return this.alpha;
    }

    protected double f(Double u, int n) {
        if (null == u || n < this.Ne) {
            return this.Rplus;
        }
        return u;
    }

    private boolean isTerminal(S s) {
        boolean terminal = false;
        if (null != s && this.actionsFunction.actions(s).size() == 0) {
            terminal = true;
        }
        return terminal;
    }

    private double maxAPrime(S sPrime) {
        double max = Double.NEGATIVE_INFINITY;
        if (this.actionsFunction.actions(sPrime).size() == 0) {
            max = this.Q.get(new Pair<S, A>(sPrime, this.noneAction));
        } else {
            for (Action aPrime : this.actionsFunction.actions(sPrime)) {
                Double Q_sPrimeAPrime = this.Q.get(new Pair<S, Action>(sPrime, aPrime));
                if (null == Q_sPrimeAPrime || !(Q_sPrimeAPrime > max)) continue;
                max = Q_sPrimeAPrime;
            }
        }
        if (max == Double.NEGATIVE_INFINITY) {
            max = 0.0;
        }
        return max;
    }

    private A argmaxAPrime(S sPrime) {
        Action a = null;
        double max = Double.NEGATIVE_INFINITY;
        for (Action aPrime : this.actionsFunction.actions(sPrime)) {
            Pair<S, Action> sPrimeAPrime = new Pair<S, Action>(sPrime, aPrime);
            double explorationValue = this.f(this.Q.get(sPrimeAPrime), this.Nsa.getCount(sPrimeAPrime));
            if (!(explorationValue > max)) continue;
            max = explorationValue;
            a = aPrime;
        }
        return (A)a;
    }
}

