/*
 * Decompiled with CFR 0.152.
 */
package aima.core.probability.mdp.search;

import aima.core.agent.Action;
import aima.core.probability.mdp.MarkovDecisionProcess;
import aima.core.probability.mdp.Policy;
import aima.core.probability.mdp.PolicyEvaluation;
import aima.core.probability.mdp.impl.LookupPolicy;
import aima.core.util.Util;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.Map;

public class PolicyIteration<S, A extends Action> {
    private PolicyEvaluation<S, A> policyEvaluation = null;

    public PolicyIteration(PolicyEvaluation<S, A> policyEvaluation) {
        this.policyEvaluation = policyEvaluation;
    }

    public Policy<S, A> policyIteration(MarkovDecisionProcess<S, A> mdp) {
        boolean unchanged;
        Map<S, Double> U = Util.create(mdp.states(), new Double(0.0));
        Map<S, A> pi = PolicyIteration.initialPolicyVector(mdp);
        do {
            U = this.policyEvaluation.evaluate(pi, U, mdp);
            unchanged = true;
            for (S s : mdp.states()) {
                double aMax = Double.NEGATIVE_INFINITY;
                double piVal = 0.0;
                Action aArgmax = (Action)pi.get(s);
                for (Action a : mdp.actions(s)) {
                    double aSum = 0.0;
                    for (S sDelta : mdp.states()) {
                        aSum += mdp.transitionProbability(sDelta, s, a) * U.get(sDelta);
                    }
                    if (aSum > aMax) {
                        aMax = aSum;
                        aArgmax = a;
                    }
                    if (!a.equals(pi.get(s))) continue;
                    piVal = aSum;
                }
                if (!(aMax > piVal)) continue;
                pi.put(s, aArgmax);
                unchanged = false;
            }
        } while (!unchanged);
        return new LookupPolicy<S, A>(pi);
    }

    public static <S, A extends Action> Map<S, A> initialPolicyVector(MarkovDecisionProcess<S, A> mdp) {
        LinkedHashMap pi = new LinkedHashMap();
        ArrayList<A> actions = new ArrayList<A>();
        for (S s : mdp.states()) {
            actions.clear();
            actions.addAll(mdp.actions(s));
            if (actions.size() <= 0) continue;
            pi.put(s, Util.selectRandomlyFromList(actions));
        }
        return pi;
    }
}

