/*
 * Decompiled with CFR 0.152.
 */
package aima.core.learning.learners;

import aima.core.learning.framework.DataSet;
import aima.core.learning.framework.Example;
import aima.core.learning.framework.Learner;
import aima.core.util.Util;
import aima.core.util.datastructure.Table;
import java.util.Hashtable;
import java.util.List;

public class AdaBoostLearner
implements Learner {
    private List<Learner> learners;
    private DataSet dataSet;
    private double[] exampleWeights;
    private Hashtable<Learner, Double> learnerWeights;

    public AdaBoostLearner(List<Learner> learners, DataSet ds) {
        this.learners = learners;
        this.dataSet = ds;
        this.initializeExampleWeights(ds.examples.size());
        this.initializeHypothesisWeights(learners.size());
    }

    @Override
    public void train(DataSet ds) {
        this.initializeExampleWeights(ds.examples.size());
        for (Learner learner : this.learners) {
            learner.train(ds);
            double error = this.calculateError(ds, learner);
            if (error < 1.0E-4) break;
            this.adjustExampleWeights(ds, learner, error);
            double newHypothesisWeight = this.learnerWeights.get(learner) * Math.log((1.0 - error) / error);
            this.learnerWeights.put(learner, newHypothesisWeight);
        }
    }

    @Override
    public String predict(Example e) {
        return this.weightedMajority(e);
    }

    @Override
    public int[] test(DataSet ds) {
        int[] results = new int[]{0, 0};
        for (Example e : ds.examples) {
            if (e.targetValue().equals(this.predict(e))) {
                results[0] = results[0] + 1;
                continue;
            }
            results[1] = results[1] + 1;
        }
        return results;
    }

    private String weightedMajority(Example e) {
        List<String> targetValues = this.dataSet.getPossibleAttributeValues(this.dataSet.getTargetAttributeName());
        Table<String, Learner, Double> table = this.createTargetValueLearnerTable(targetValues, e);
        return this.getTargetValueWithTheMaximumVotes(targetValues, table);
    }

    private Table<String, Learner, Double> createTargetValueLearnerTable(List<String> targetValues, Example e) {
        Table<String, Learner, Double> table = new Table<String, Learner, Double>(targetValues, this.learners);
        for (Learner l : this.learners) {
            for (String s : targetValues) {
                table.set(s, l, 0.0);
            }
        }
        for (Learner learner : this.learners) {
            String predictedValue = learner.predict(e);
            for (String v : targetValues) {
                if (!predictedValue.equals(v)) continue;
                table.set(v, learner, table.get(v, learner) + this.learnerWeights.get(learner) * 1.0);
            }
        }
        return table;
    }

    private String getTargetValueWithTheMaximumVotes(List<String> targetValues, Table<String, Learner, Double> table) {
        String targetValueWithMaxScore = targetValues.get(0);
        double score = this.scoreOfValue(targetValueWithMaxScore, table, this.learners);
        for (String value : targetValues) {
            double scoreOfValue = this.scoreOfValue(value, table, this.learners);
            if (!(scoreOfValue > score)) continue;
            targetValueWithMaxScore = value;
            score = scoreOfValue;
        }
        return targetValueWithMaxScore;
    }

    private void initializeExampleWeights(int size) {
        if (size == 0) {
            throw new RuntimeException("cannot initialize Ensemble learning with Empty Dataset");
        }
        double value = 1.0 / (1.0 * (double)size);
        this.exampleWeights = new double[size];
        for (int i = 0; i < size; ++i) {
            this.exampleWeights[i] = value;
        }
    }

    private void initializeHypothesisWeights(int size) {
        if (size == 0) {
            throw new RuntimeException("cannot initialize Ensemble learning with Zero Learners");
        }
        this.learnerWeights = new Hashtable();
        for (Learner le : this.learners) {
            this.learnerWeights.put(le, 1.0);
        }
    }

    private double calculateError(DataSet ds, Learner l) {
        double error = 0.0;
        for (int i = 0; i < ds.examples.size(); ++i) {
            Example e = ds.getExample(i);
            if (l.predict(e).equals(e.targetValue())) continue;
            error += this.exampleWeights[i];
        }
        return error;
    }

    private void adjustExampleWeights(DataSet ds, Learner l, double error) {
        double epsilon = error / (1.0 - error);
        for (int j = 0; j < ds.examples.size(); ++j) {
            Example e = ds.getExample(j);
            if (!l.predict(e).equals(e.targetValue())) continue;
            this.exampleWeights[j] = this.exampleWeights[j] * epsilon;
        }
        this.exampleWeights = Util.normalize(this.exampleWeights);
    }

    private double scoreOfValue(String targetValue, Table<String, Learner, Double> table, List<Learner> learners) {
        double score = 0.0;
        for (Learner l : learners) {
            score += table.get(targetValue, l).doubleValue();
        }
        return score;
    }
}

