/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous.hmc;

import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.TaxonEffectTraitDataModel;
import dr.evomodel.treedatalikelihood.preorder.TipGradientViaFullConditionalDelegate;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.util.StopWatch;
import dr.util.TaskPool;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.List;

public class TaxonEffectGradient
implements GradientWrtParameterProvider,
Reportable {
    private final List<Partition> partitions;
    private Likelihood compoundLikelihood;
    private final Parameter effects;
    private final int dimTrait;
    private final int nTraits;
    private final int nTaxa;
    protected StopWatch[] stopWatches;
    protected static final boolean TIMING = true;

    public TaxonEffectGradient(List<Partition> list, TaskPool taskPool, ThreadUseProvider threadUseProvider) {
        this.partitions = list;
        this.effects = list.get((int)0).model.getEffects();
        this.dimTrait = list.get((int)0).delegate.getTraitDim();
        this.nTraits = list.get((int)0).delegate.getTraitCount();
        this.nTaxa = list.get((int)0).tree.getExternalNodeCount();
        int n = 5;
        this.stopWatches = new StopWatch[n];
        for (int i = 0; i < n; ++i) {
            this.stopWatches[i] = new StopWatch();
        }
    }

    @Override
    public Likelihood getLikelihood() {
        if (this.compoundLikelihood == null) {
            ArrayList<Likelihood> arrayList = new ArrayList<Likelihood>();
            for (Partition partition : this.partitions) {
                arrayList.add(partition.likelihood);
            }
            this.compoundLikelihood = new CompoundLikelihood(arrayList);
        }
        return this.compoundLikelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.effects;
    }

    @Override
    public int getDimension() {
        return this.nTaxa * this.dimTrait * this.nTraits;
    }

    @Override
    public double[] getGradientLogDensity() {
        this.stopWatches[0].start();
        double[] dArray = new double[this.getDimension()];
        for (Partition partition : this.partitions) {
            int n = partition.model.getMap().getSign();
            for (int i = 0; i < this.nTaxa; ++i) {
                Tree tree = partition.tree;
                TreeTrait treeTrait = partition.treeTraitProvider;
                double[] dArray2 = (double[])treeTrait.getTrait(tree, tree.getExternalNode(i));
                int n2 = partition.model.getMap().getEffectIndex(i);
                int n3 = n2 * this.dimTrait * this.nTraits;
                for (int j = 0; j < this.dimTrait; ++j) {
                    int n4 = n3 + j;
                    dArray[n4] = dArray[n4] - (double)n * dArray2[j];
                }
            }
        }
        this.stopWatches[0].stop();
        return dArray;
    }

    @Override
    public String getReport() {
        String string = "";
        string = string + this.timingInfo();
        string = string + GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.005);
        string = string + this.timingInfo();
        return string;
    }

    private String timingInfo() {
        StringBuilder stringBuilder = new StringBuilder("\nTiming in TaxonEffectGradient\n");
        for (StopWatch stopWatch : this.stopWatches) {
            stringBuilder.append("\t").append(stopWatch.toString()).append("\n");
            stopWatch.reset();
        }
        return stringBuilder.toString();
    }

    public static class Partition {
        public final Tree tree;
        public final TreeDataLikelihood likelihood;
        public final ContinuousDataLikelihoodDelegate delegate;
        public final TaxonEffectTraitDataModel model;
        public final TreeTrait treeTraitProvider;

        public Partition(TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, TaxonEffectTraitDataModel taxonEffectTraitDataModel) {
            this.tree = treeDataLikelihood.getTree();
            this.likelihood = treeDataLikelihood;
            this.delegate = continuousDataLikelihoodDelegate;
            this.model = taxonEffectTraitDataModel;
            String string = continuousDataLikelihoodDelegate.getModelName();
            int n = continuousDataLikelihoodDelegate.getTraitDim();
            String string2 = TipGradientViaFullConditionalDelegate.getName(string);
            TreeTrait treeTrait = treeDataLikelihood.getTreeTrait(string2);
            if (treeTrait == null) {
                continuousDataLikelihoodDelegate.addFullConditionalGradientTrait(string, 0, n);
            }
            this.treeTraitProvider = treeDataLikelihood.getTreeTrait(string2);
        }
    }

    public static enum ThreadUseProvider {
        PARALLEL{

            @Override
            boolean usePool() {
                return true;
            }
        }
        ,
        SERIAL{

            @Override
            boolean usePool() {
                return false;
            }
        };


        abstract boolean usePool();
    }
}

