/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.multilabel.meta;

import moa.classifiers.Classifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.classifiers.meta.OzaBagAdwin;
import moa.core.InstancesHeader;
import moa.core.MiscUtils;
import weka.core.Instance;

public class MLOzaBagAdwin
extends OzaBagAdwin {
    protected int m_L = -1;

    public void setModelContext(InstancesHeader raw_header) {
        this.modelContext = raw_header;
        this.m_L = raw_header.classIndex() + 1;
        this.resetLearningImpl();
        for (int i = 0; i < this.ensemble.length; ++i) {
            this.ensemble[i].setModelContext(raw_header);
            this.ensemble[i].resetLearning();
        }
    }

    public void trainOnInstanceImpl(Instance inst) {
        boolean Change = false;
        for (int i = 0; i < this.ensemble.length; ++i) {
            int k = MiscUtils.poisson(1.0, this.classifierRandom);
            if (k > 0) {
                Instance weightedInst = (Instance)inst.copy();
                weightedInst.setWeight(inst.weight() * (double)k);
                this.ensemble[i].trainOnInstance(weightedInst);
            }
            double[] prediction = this.ensemble[i].getVotesForInstance(inst);
            double[] actual = new double[prediction.length];
            for (int j = 0; j < prediction.length; j = (int)((short)(j + 1))) {
                actual[j] = inst.value(j);
            }
            int p_sum = 0;
            boolean r_sum = false;
            int set_union = 0;
            int set_inter = 0;
            double t = 0.01;
            for (int j = 0; j < prediction.length; ++j) {
                boolean p = prediction[j] >= t;
                int R = (int)actual[j];
                if (p) {
                    ++p_sum;
                    if (R == 1) {
                        ++set_inter;
                        ++set_union;
                        continue;
                    }
                    ++set_union;
                    continue;
                }
                if (R != 1) continue;
                ++set_union;
            }
            double accuracy = 0.0;
            if (set_union > 0) {
                accuracy = (double)set_inter / (double)set_union;
            }
            double ErrEstim = this.ADError[i].getEstimation();
            if (!this.ADError[i].setInput(1.0 - accuracy) || !(this.ADError[i].getEstimation() > ErrEstim)) continue;
            Change = true;
        }
        if (Change) {
            System.err.println("change!");
            double max = 0.0;
            int imax = -1;
            for (int i = 0; i < this.ensemble.length; ++i) {
                if (!(max < this.ADError[i].getEstimation())) continue;
                max = this.ADError[i].getEstimation();
                imax = i;
            }
            if (imax != -1) {
                this.ensemble[imax] = null;
                this.ensemble[imax] = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
                this.ensemble[imax].setModelContext(this.modelContext);
                this.ensemble[imax].trainOnInstance(inst);
                this.ADError[imax] = new ADWIN();
            }
        }
    }

    public double[] getVotesForInstance(Instance x) {
        int L = x.classIndex() + 1;
        if (this.m_L != L) {
            this.m_L = L;
        }
        double[] y = new double[this.m_L];
        for (int i = 0; i < this.ensemble.length; ++i) {
            double[] w = this.ensemble[i].getVotesForInstance(x);
            for (int j = 0; j < w.length; ++j) {
                int n = j;
                y[n] = y[n] + w[j];
            }
        }
        return y;
    }
}

