/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.translator.tree;

import com.sun.codemodel.JArray;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JType;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.tree.Node;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.ValueUtil;
import org.jpmml.translator.ArrayManager;
import org.jpmml.translator.TranslationContext;
import org.jpmml.translator.tree.Scorer;
import org.jpmml.translator.tree.ScorerUtil;

public abstract class NodeScoreDistributionManager<V extends Number>
extends ArrayManager<List<Number>>
implements Scorer<List<Number>> {
    private Object[] categories = null;
    public static final JExpression RESULT_MISSING = JExpr.lit((int)-1);

    public NodeScoreDistributionManager(JType componentType, String name, Object[] categories) {
        super(componentType, name);
        this.setCategories(categories);
    }

    public abstract ValueFactory<V> getValueFactory();

    @Override
    public List<Number> prepare(Node node) {
        Value value;
        Object category;
        ValueFactory<V> valueFactory = this.getValueFactory();
        if (!node.hasScoreDistributions()) {
            return null;
        }
        ValueMap probabilityMap = new ValueMap();
        List scoreDistributions = node.getScoreDistributions();
        for (ScoreDistribution scoreDistribution : scoreDistributions) {
            Number recordCount = scoreDistribution.requireRecordCount();
            category = scoreDistribution.requireValue();
            value = valueFactory.newValue(recordCount);
            probabilityMap.put(category, (Object)value);
        }
        ValueUtil.normalizeSimpleMax((Iterable)probabilityMap.values());
        ArrayList<Number> result = new ArrayList<Number>();
        Object[] categories = this.getCategories();
        for (int i = 0; i < categories.length; ++i) {
            category = categories[i];
            value = (Value)probabilityMap.get(category);
            if (value == null) {
                value = valueFactory.newValue((Number)0.0);
            }
            result.add(value.getValue());
        }
        return result;
    }

    @Override
    public void yield(List<Number> probabilities, TranslationContext context) {
        context._return(this.createIndexExpression(probabilities));
    }

    @Override
    public void yieldIf(JExpression expression, List<Number> probabilities, TranslationContext context) {
        context._returnIf(expression, this.createIndexExpression(probabilities));
    }

    @Override
    public JExpression createExpression(List<Number> probabilities) {
        JType componentType = this.getComponentType();
        JArray array = JExpr.newArray((JType)componentType.elementType());
        for (Number probability : probabilities) {
            JExpression elementExpr = ScorerUtil.format(probability);
            array = array.add(elementExpr);
        }
        return array;
    }

    public JExpression createIndexExpression(List<Number> probabilities) {
        if (probabilities == null) {
            return RESULT_MISSING;
        }
        return JExpr.lit((int)this.getOrInsert(probabilities));
    }

    public Number[][] getValues() {
        Collection elements = this.getElements();
        Number[][] result = (Number[][])elements.stream().map(element -> element.toArray(new Number[element.size()])).toArray(x$0 -> new Number[x$0][]);
        return result;
    }

    public Object[] getCategories() {
        return this.categories;
    }

    private void setCategories(Object[] categories) {
        this.categories = Objects.requireNonNull(categories);
    }
}

