/*
 * Decompiled with CFR 0.152.
 */
package aima.core.probability.bayes.impl;

import aima.core.probability.CategoricalDistribution;
import aima.core.probability.Factor;
import aima.core.probability.RandomVariable;
import aima.core.probability.bayes.ConditionalProbabilityTable;
import aima.core.probability.domain.FiniteDomain;
import aima.core.probability.proposition.AssignmentProposition;
import aima.core.probability.util.ProbUtil;
import aima.core.probability.util.ProbabilityTable;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class CPT
implements ConditionalProbabilityTable {
    private RandomVariable on = null;
    private LinkedHashSet<RandomVariable> parents = new LinkedHashSet();
    private ProbabilityTable table = null;
    private List<Object> onDomain = new ArrayList<Object>();

    public CPT(RandomVariable on, double[] values, RandomVariable ... conditionedOn) {
        this.on = on;
        if (null == conditionedOn) {
            conditionedOn = new RandomVariable[]{};
        }
        RandomVariable[] tableVars = new RandomVariable[conditionedOn.length + 1];
        for (int i = 0; i < conditionedOn.length; ++i) {
            tableVars[i] = conditionedOn[i];
            this.parents.add(conditionedOn[i]);
        }
        tableVars[conditionedOn.length] = on;
        this.table = new ProbabilityTable(values, tableVars);
        this.onDomain.addAll(((FiniteDomain)on.getDomain()).getPossibleValues());
        this.checkEachRowTotalsOne();
    }

    public double probabilityFor(Object ... values) {
        return this.table.getValue(values);
    }

    @Override
    public RandomVariable getOn() {
        return this.on;
    }

    @Override
    public Set<RandomVariable> getParents() {
        return this.parents;
    }

    @Override
    public Set<RandomVariable> getFor() {
        return this.table.getFor();
    }

    @Override
    public boolean contains(RandomVariable rv) {
        return this.table.contains(rv);
    }

    @Override
    public double getValue(Object ... eventValues) {
        return this.table.getValue(eventValues);
    }

    @Override
    public double getValue(AssignmentProposition ... eventValues) {
        return this.table.getValue(eventValues);
    }

    @Override
    public Object getSample(double probabilityChoice, Object ... parentValues) {
        return ProbUtil.sample(probabilityChoice, this.on, this.getConditioningCase(parentValues).getValues());
    }

    @Override
    public Object getSample(double probabilityChoice, AssignmentProposition ... parentValues) {
        return ProbUtil.sample(probabilityChoice, this.on, this.getConditioningCase(parentValues).getValues());
    }

    @Override
    public CategoricalDistribution getConditioningCase(Object ... parentValues) {
        if (parentValues.length != this.parents.size()) {
            throw new IllegalArgumentException("The number of parent value arguments [" + parentValues.length + "] is not equal to the number of parents [" + this.parents.size() + "] for this CPT.");
        }
        AssignmentProposition[] aps = new AssignmentProposition[parentValues.length];
        int idx = 0;
        for (RandomVariable parentRV : this.parents) {
            aps[idx] = new AssignmentProposition(parentRV, parentValues[idx]);
            ++idx;
        }
        return this.getConditioningCase(aps);
    }

    @Override
    public CategoricalDistribution getConditioningCase(AssignmentProposition ... parentValues) {
        if (parentValues.length != this.parents.size()) {
            throw new IllegalArgumentException("The number of parent value arguments [" + parentValues.length + "] is not equal to the number of parents [" + this.parents.size() + "] for this CPT.");
        }
        final ProbabilityTable cc = new ProbabilityTable(this.getOn());
        ProbabilityTable.Iterator pti = new ProbabilityTable.Iterator(){
            private int idx = 0;

            @Override
            public void iterate(Map<RandomVariable, Object> possibleAssignment, double probability) {
                cc.getValues()[this.idx] = probability;
                ++this.idx;
            }
        };
        this.table.iterateOverTable(pti, parentValues);
        return cc;
    }

    @Override
    public Factor getFactorFor(AssignmentProposition ... evidence) {
        LinkedHashSet<RandomVariable> fofVars = new LinkedHashSet<RandomVariable>(this.table.getFor());
        for (AssignmentProposition ap : evidence) {
            fofVars.remove(ap.getTermVariable());
        }
        final ProbabilityTable fof = new ProbabilityTable(fofVars);
        final Object[] termValues = new Object[fofVars.size()];
        ProbabilityTable.Iterator di = new ProbabilityTable.Iterator(){

            @Override
            public void iterate(Map<RandomVariable, Object> possibleWorld, double probability) {
                if (0 == termValues.length) {
                    double[] dArray = fof.getValues();
                    dArray[0] = dArray[0] + probability;
                } else {
                    int i = 0;
                    for (RandomVariable rv : fof.getFor()) {
                        termValues[i] = possibleWorld.get(rv);
                        ++i;
                    }
                    double[] dArray = fof.getValues();
                    int n = fof.getIndex(termValues);
                    dArray[n] = dArray[n] + probability;
                }
            }
        };
        this.table.iterateOverTable(di, evidence);
        return fof;
    }

    private void checkEachRowTotalsOne() {
        ProbabilityTable.Iterator di = new ProbabilityTable.Iterator(){
            private int rowSize;
            private int iterateCnt;
            private double rowProb;
            {
                this.rowSize = CPT.this.onDomain.size();
                this.iterateCnt = 0;
                this.rowProb = 0.0;
            }

            @Override
            public void iterate(Map<RandomVariable, Object> possibleWorld, double probability) {
                ++this.iterateCnt;
                this.rowProb += probability;
                if (this.iterateCnt % this.rowSize == 0) {
                    if (Math.abs(1.0 - this.rowProb) > 1.0E-8) {
                        throw new IllegalArgumentException("Row " + this.iterateCnt / this.rowSize + " of CPT does not sum to 1.0.");
                    }
                    this.rowProb = 0.0;
                }
            }
        };
        this.table.iterateOverTable(di);
    }
}

