/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.tree;

import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.dmg.pmml.CompoundPredicate;
import org.dmg.pmml.EmbeddedModel;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.PMMLAttributes;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.evaluator.EntityUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.PredicateUtil;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UndefinedResultException;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.UnsupportedElementException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.tree.HasNodeRegistry;
import org.jpmml.evaluator.tree.NodeScore;
import org.jpmml.evaluator.tree.NodeScoreDistribution;
import org.jpmml.evaluator.tree.NodeVote;
import org.jpmml.evaluator.tree.PathFinder;
import org.jpmml.evaluator.tree.TreeModelEvaluator;
import org.jpmml.model.InvalidAttributeException;
import org.jpmml.model.MisplacedAttributeException;
import org.jpmml.model.MissingAttributeException;
import org.jpmml.model.visitors.AbstractVisitor;

public class ComplexTreeModelEvaluator
extends TreeModelEvaluator
implements HasNodeRegistry {
    private BiMap<String, Node> entityRegistry = ImmutableBiMap.of();

    private ComplexTreeModelEvaluator() {
    }

    public ComplexTreeModelEvaluator(PMML pmml) {
        this(pmml, PMMLUtil.findModel(pmml, TreeModel.class));
    }

    public ComplexTreeModelEvaluator(PMML pmml, TreeModel treeModel) {
        super(pmml, treeModel);
        List<Node> nodes = ComplexTreeModelEvaluator.collectNodes(treeModel);
        this.entityRegistry = ImmutableBiMap.copyOf(EntityUtil.buildBiMap(nodes));
    }

    @Override
    public BiMap<String, Node> getEntityRegistry() {
        return this.entityRegistry;
    }

    @Override
    public List<Node> getPath(String id) {
        return this.getPath(this.resolveNode(id));
    }

    @Override
    public List<Node> getPathBetween(String parentId, String childId) {
        return this.getPathBetween(this.resolveNode(parentId), this.resolveNode(childId));
    }

    @Override
    protected <V extends Number> Map<String, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext context) {
        TargetField targetField = this.getTargetField();
        Trail trail = new Trail();
        Node node = this.evaluateTree(trail, context);
        if (node == null) {
            return TargetUtil.evaluateRegressionDefault(valueFactory, targetField);
        }
        NodeScore<V> result = this.createNodeScore(valueFactory, targetField, node);
        return TargetUtil.evaluateRegression(targetField, result);
    }

    @Override
    protected <V extends Number> Map<String, ?> evaluateClassification(ValueFactory<V> valueFactory, EvaluationContext context) {
        TreeModel treeModel = (TreeModel)this.getModel();
        TargetField targetField = this.getTargetField();
        Trail trail = new Trail();
        Node node = this.evaluateTree(trail, context);
        if (node == null) {
            return TargetUtil.evaluateClassificationDefault(valueFactory, targetField);
        }
        if (!node.hasScoreDistributions()) {
            NodeVote result = this.createNodeVote(node);
            return TargetUtil.evaluateVote(targetField, result);
        }
        double missingValuePenalty = 1.0;
        int missingLevels = trail.getMissingLevels();
        if (missingLevels > 0) {
            missingValuePenalty = treeModel.getMissingValuePenalty().doubleValue();
            if (missingLevels > 1) {
                missingValuePenalty = Math.pow(missingValuePenalty, missingLevels);
            }
        }
        NodeScoreDistribution<V> result = this.createNodeScoreDistribution(valueFactory, node, missingValuePenalty);
        return TargetUtil.evaluateClassification(targetField, result);
    }

    private Node evaluateTree(Trail trail, EvaluationContext context) {
        TreeModel treeModel = (TreeModel)this.getModel();
        Node root = treeModel.requireNode();
        Boolean status = this.evaluateNode(trail, root, context);
        if (status != null && status.booleanValue()) {
            Node node = (trail = this.handleTrue(trail, root, context)).getResult();
            if (node != null && !node.hasScore()) {
                throw new MissingAttributeException((PMMLObject)node, PMMLAttributes.COMPLEXNODE_SCORE);
            }
            return node;
        }
        return null;
    }

    private Boolean evaluateNode(Trail trail, Node node, EvaluationContext context) {
        EmbeddedModel embeddedModel = node.getEmbeddedModel();
        if (embeddedModel != null) {
            throw new UnsupportedElementException((PMMLObject)embeddedModel);
        }
        Predicate predicate = node.requirePredicate();
        if (predicate instanceof CompoundPredicate) {
            CompoundPredicate compoundPredicate = (CompoundPredicate)predicate;
            PredicateUtil.CompoundPredicateResult result = PredicateUtil.evaluateCompoundPredicateInternal(compoundPredicate, context);
            if (result.isAlternative()) {
                trail.addMissingLevel();
            }
            return result.getResult();
        }
        return PredicateUtil.evaluate(predicate, context);
    }

    private Trail handleTrue(Trail trail, Node node, EvaluationContext context) {
        if (!node.hasNodes()) {
            return trail.selectNode(node);
        }
        trail.push(node);
        List children = node.getNodes();
        int max = children.size();
        for (int i = 0; i < max; ++i) {
            Node child = (Node)children.get(i);
            Boolean status = this.evaluateNode(trail, child, context);
            if (status == null) {
                Trail destination = this.handleMissingValue(trail, node, child, context);
                if (destination == null) continue;
                return destination;
            }
            if (!status.booleanValue()) continue;
            return this.handleTrue(trail, child, context);
        }
        return this.handleNoTrueChild(trail);
    }

    private Trail handleDefaultChild(Trail trail, Node node, EvaluationContext context) {
        Node defaultChild = ComplexTreeModelEvaluator.findDefaultChild(node);
        trail.addMissingLevel();
        return this.handleTrue(trail, defaultChild, context);
    }

    private Trail handleNoTrueChild(Trail trail) {
        TreeModel treeModel = (TreeModel)this.getModel();
        TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
        switch (noTrueChildStrategy) {
            case RETURN_NULL_PREDICTION: {
                return trail.selectNull();
            }
            case RETURN_LAST_PREDICTION: {
                Node lastPrediction = trail.getLastPrediction();
                if (lastPrediction.hasScore()) {
                    return trail.selectLastPrediction();
                }
                return trail.selectNull();
            }
        }
        throw new UnsupportedAttributeException((PMMLObject)treeModel, (Enum<?>)noTrueChildStrategy);
    }

    private Trail handleMissingValue(Trail trail, Node parent, Node node, EvaluationContext context) {
        TreeModel treeModel = (TreeModel)this.getModel();
        TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy();
        switch (missingValueStrategy) {
            case NULL_PREDICTION: {
                return trail.selectNull();
            }
            case LAST_PREDICTION: {
                return trail.selectLastPrediction();
            }
            case DEFAULT_CHILD: {
                return this.handleDefaultChild(trail, parent, context);
            }
            case NONE: {
                return null;
            }
        }
        throw new UnsupportedAttributeException((PMMLObject)treeModel, (Enum<?>)missingValueStrategy);
    }

    private <V extends Number> NodeScore<V> createNodeScore(ValueFactory<V> valueFactory, TargetField targetField, Node node) {
        Object score = node.getScore();
        Value<V> value = score instanceof Number ? valueFactory.newValue((Number)score) : valueFactory.newValue((String)score);
        value = TargetUtil.evaluateRegressionInternal(targetField, value);
        NodeScore result = new NodeScore<V>(value, node){

            @Override
            public BiMap<String, Node> getEntityRegistry() {
                return ComplexTreeModelEvaluator.this.getEntityRegistry();
            }

            @Override
            public List<Node> getDecisionPath() {
                return ComplexTreeModelEvaluator.this.getPath(this.getNode());
            }
        };
        return result;
    }

    private NodeVote createNodeVote(Node node) {
        NodeVote result = new NodeVote(node){

            @Override
            public BiMap<String, Node> getEntityRegistry() {
                return ComplexTreeModelEvaluator.this.getEntityRegistry();
            }

            @Override
            public List<Node> getDecisionPath() {
                return ComplexTreeModelEvaluator.this.getPath(this.getNode());
            }
        };
        return result;
    }

    private <V extends Number> NodeScoreDistribution<V> createNodeScoreDistribution(ValueFactory<V> valueFactory, Node node, double missingValuePenalty) {
        List scoreDistributions = node.getScoreDistributions();
        NodeScoreDistribution result = new NodeScoreDistribution<V>(new ValueMap(2 * scoreDistributions.size()), node){

            @Override
            public BiMap<String, Node> getEntityRegistry() {
                return ComplexTreeModelEvaluator.this.getEntityRegistry();
            }

            @Override
            public List<Node> getDecisionPath() {
                return ComplexTreeModelEvaluator.this.getPath(this.getNode());
            }
        };
        Value<V> sum = valueFactory.newValue();
        boolean hasProbabilities = false;
        int max = scoreDistributions.size();
        for (int i = 0; i < max; ++i) {
            Value<V> value;
            ScoreDistribution scoreDistribution = (ScoreDistribution)scoreDistributions.get(i);
            Number probability = scoreDistribution.getProbability();
            if (i == 0) {
                boolean bl = hasProbabilities = probability != null;
            }
            if (hasProbabilities) {
                if (probability == null) {
                    throw new MissingAttributeException((PMMLObject)scoreDistribution, org.dmg.pmml.PMMLAttributes.SCOREDISTRIBUTION_PROBABILITY);
                }
                if (probability.doubleValue() < 0.0 || probability.doubleValue() > 1.0) {
                    throw new InvalidAttributeException((PMMLObject)scoreDistribution, org.dmg.pmml.PMMLAttributes.SCOREDISTRIBUTION_PROBABILITY, (Object)probability);
                }
                sum.add(probability);
                value = valueFactory.newValue(probability);
            } else {
                if (probability != null) {
                    throw new MisplacedAttributeException((PMMLObject)scoreDistribution, org.dmg.pmml.PMMLAttributes.SCOREDISTRIBUTION_PROBABILITY, (Object)probability);
                }
                Number recordCount = scoreDistribution.requireRecordCount();
                sum.add(recordCount);
                value = valueFactory.newValue(recordCount);
            }
            Object targetCategory = scoreDistribution.requireValue();
            result.put(targetCategory, value);
            Number confidence = scoreDistribution.getConfidence();
            if (confidence == null) continue;
            value = valueFactory.newValue(confidence).multiply(missingValuePenalty);
            result.putConfidence(targetCategory, value);
        }
        if (!sum.isOne()) {
            ValueMap values = result.getValues();
            if (sum.isZero()) {
                throw new UndefinedResultException();
            }
            for (Value value : values) {
                value.divide(sum);
            }
        }
        return result;
    }

    private List<Node> getPath(Node node) {
        TreeModel treeModel = (TreeModel)this.getModel();
        Node root = treeModel.requireNode();
        return this.getPathBetween(root, node);
    }

    private List<Node> getPathBetween(Node parent, final Node child) {
        PathFinder pathFinder = new PathFinder(){

            @Override
            public boolean test(Node node) {
                return Objects.equals(child, node);
            }
        };
        pathFinder.applyTo((Visitable)parent);
        return pathFinder.getPath();
    }

    private Node resolveNode(String id) {
        BiMap<String, Node> entityRegistry = this.getEntityRegistry();
        Node node = (Node)entityRegistry.get((Object)id);
        if (node == null) {
            throw new IllegalArgumentException(id);
        }
        return node;
    }

    private static List<Node> collectNodes(TreeModel treeModel) {
        final ArrayList<Node> result = new ArrayList<Node>();
        AbstractVisitor visitor = new AbstractVisitor(){

            public VisitorAction visit(Node node) {
                result.add(node);
                return super.visit(node);
            }
        };
        visitor.applyTo((Visitable)treeModel);
        return result;
    }

    private static class Trail {
        private Node lastPrediction = null;
        private Node result = null;
        private int missingLevels = 0;

        public void push(Node node) {
            this.setLastPrediction(node);
        }

        public Trail selectNull() {
            this.setResult(null);
            return this;
        }

        public Trail selectNode(Node node) {
            this.setResult(node);
            return this;
        }

        public Trail selectLastPrediction() {
            this.setResult(this.getLastPrediction());
            return this;
        }

        public Node getResult() {
            return this.result;
        }

        private void setResult(Node result) {
            this.result = result;
        }

        public Node getLastPrediction() {
            if (this.lastPrediction == null) {
                throw new EvaluationException("Empty trail");
            }
            return this.lastPrediction;
        }

        private void setLastPrediction(Node lastPrediction) {
            this.lastPrediction = lastPrediction;
        }

        public void addMissingLevel() {
            this.setMissingLevels(this.getMissingLevels() + 1);
        }

        public int getMissingLevels() {
            return this.missingLevels;
        }

        private void setMissingLevels(int missingLevels) {
            this.missingLevels = missingLevels;
        }
    }
}

