/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.translator;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.sun.codemodel.JAssignmentTarget;
import com.sun.codemodel.JBlock;
import com.sun.codemodel.JClass;
import com.sun.codemodel.JDefinedClass;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JFieldVar;
import com.sun.codemodel.JInvocation;
import com.sun.codemodel.JMethod;
import com.sun.codemodel.JType;
import com.sun.codemodel.JTypeVar;
import com.sun.codemodel.JVar;
import java.lang.reflect.Field;
import java.lang.reflect.TypeVariable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLAttributes;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Target;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.VisitorAction;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.IndexableUtil;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelManager;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueFactoryFactory;
import org.jpmml.evaluator.java.JavaModel;
import org.jpmml.model.visitors.ActiveFieldFinder;
import org.jpmml.model.visitors.FieldResolver;
import org.jpmml.translator.ArrayInfo;
import org.jpmml.translator.ArrayInfoMap;
import org.jpmml.translator.ClassificationBuilder;
import org.jpmml.translator.FieldInfo;
import org.jpmml.translator.FieldInfoMap;
import org.jpmml.translator.FunctionInvocation;
import org.jpmml.translator.FunctionInvocationContext;
import org.jpmml.translator.FunctionInvocationUtil;
import org.jpmml.translator.IdentifierUtil;
import org.jpmml.translator.JWrappedExpression;
import org.jpmml.translator.MethodScope;
import org.jpmml.translator.PMMLObjectUtil;
import org.jpmml.translator.TranslatedModel;
import org.jpmml.translator.TranslationContext;
import org.jpmml.translator.ValueBuilder;

public abstract class ModelTranslator<M extends Model>
extends ModelManager<M> {
    public ModelTranslator(PMML pmml, M model) {
        super(pmml, model);
        MathContext mathContext = model.getMathContext();
        switch (mathContext) {
            case FLOAT: 
            case DOUBLE: {
                break;
            }
            default: {
                throw new UnsupportedAttributeException(model, (Enum)mathContext);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public JExpression translate(TranslationContext context) {
        Model model = this.getModel();
        JDefinedClass javaModelClazz = PMMLObjectUtil.createMemberClass(25, IdentifierUtil.create(JavaModel.class.getSimpleName(), (PMMLObject)model), context);
        javaModelClazz._extends(JavaModel.class);
        Set<String> activeFieldNames = context.getActiveFieldNames();
        activeFieldNames.clear();
        try {
            context.pushOwner(javaModelClazz);
            this.createEvaluateMethod(context);
        }
        finally {
            context.popOwner();
        }
        JWrappedExpression expression = new JWrappedExpression((JExpression)context._new((JClass)javaModelClazz, new Object[0]));
        TranslatedModel translatedModel = new TranslatedModel(model).setExpression(expression).setActiveFields(new LinkedHashSet<String>(activeFieldNames));
        context.addTranslation(model, translatedModel);
        return expression;
    }

    public void createEvaluateMethod(TranslationContext context) {
        Model model = this.getModel();
        MiningFunction miningFunction = model.requireMiningFunction();
        switch (miningFunction) {
            case REGRESSION: {
                JMethod regressorMethod = this.translateRegressor(context);
                this.createEvaluateRegressionMethod(regressorMethod, context);
                break;
            }
            case CLASSIFICATION: {
                JMethod classifierMethod = this.translateClassifier(context);
                this.createEvaluateClassificationMethod(classifierMethod, context);
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)model, (Enum)miningFunction);
            }
        }
    }

    public JMethod translateRegressor(TranslationContext context) {
        throw new UnsupportedOperationException();
    }

    public JMethod translateClassifier(TranslationContext context) {
        throw new UnsupportedOperationException();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public JMethod createEvaluateRegressionMethod(JMethod evaluateMethod, TranslationContext context) {
        Model model = this.getModel();
        TargetField targetField = this.getTargetField();
        JMethod evaluateRegressionMethod = ModelTranslator.createEvaluatorMethod("evaluateRegression", context);
        try {
            context.pushScope(new MethodScope(evaluateRegressionMethod));
            JInvocation methodInvocation = ModelTranslator.createEvaluatorMethodInvocation(evaluateMethod, context);
            JClass valueClazz = context.ref(Value.class);
            if (!evaluateMethod.type().erasure().equals(valueClazz)) {
                methodInvocation = context.getValueFactoryVariable().newValue((JExpression)methodInvocation);
            }
            ValueBuilder valueBuilder = new ValueBuilder(context).declare("value", methodInvocation);
            Target target = targetField.getTarget();
            if (target != null) {
                ModelTranslator.translateRegressorTarget(target, valueBuilder);
                model.setTargets(null);
            }
            JVar valueVar = valueBuilder.getVariable();
            context._return((JExpression)context.staticInvoke(Collections.class, "singletonMap", targetField.getName(), valueVar.invoke("getValue")));
        }
        finally {
            context.popScope();
        }
        return evaluateRegressionMethod;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public JMethod createEvaluateClassificationMethod(JMethod evaluateMethod, TranslationContext context) {
        Model model = this.getModel();
        TargetField targetField = this.getTargetField();
        JMethod evaluateClassificationMethod = ModelTranslator.createEvaluatorMethod("evaluateClassification", context);
        try {
            context.pushScope(new MethodScope(evaluateClassificationMethod));
            ClassificationBuilder classificationBuilder = new ClassificationBuilder(context).declare("classification", (JExpression)ModelTranslator.createEvaluatorMethodInvocation(evaluateMethod, context)).computeResult(targetField.getDataType());
            context._return((JExpression)context.staticInvoke(Collections.class, "singletonMap", context.constantFieldName(targetField.getName()), classificationBuilder));
        }
        finally {
            context.popScope();
        }
        return evaluateClassificationMethod;
    }

    public FieldInfoMap getFieldInfos(final Set<? extends PMMLObject> bodyObjects) {
        PMML pmml = this.getPMML();
        Model model = this.getModel();
        MiningSchema miningSchema = model.requireMiningSchema();
        final HashMap bodyFields = new HashMap();
        FieldResolver fieldResolver = new FieldResolver(){

            public VisitorAction visit(PMMLObject object) {
                if (bodyObjects.contains(object)) {
                    Model parent = (Model)this.getParent();
                    Collection fields = this.getFields();
                    for (org.dmg.pmml.Field field : fields) {
                        String name = field.requireName();
                        org.dmg.pmml.Field previousField = bodyFields.put(name, field);
                        if (previousField == null || previousField == field) continue;
                        throw new IllegalArgumentException(name);
                    }
                    return VisitorAction.SKIP;
                }
                return super.visit(object);
            }
        };
        fieldResolver.applyTo((Visitable)pmml);
        FieldInfoMap result = new FieldInfoMap();
        Set names = ActiveFieldFinder.getFieldNames((PMMLObject[])bodyObjects.toArray(new PMMLObject[bodyObjects.size()]));
        for (String name : names) {
            org.dmg.pmml.Field field = (org.dmg.pmml.Field)bodyFields.get(name);
            result.create(field);
        }
        FunctionInvocationContext context = new FunctionInvocationContext(){

            @Override
            public DefineFunction getDefineFunction(String name) {
                return ModelTranslator.this.getDefineFunction(name);
            }
        };
        ArrayList fieldInfos = new ArrayList(result.values());
        for (FieldInfo fieldInfo : fieldInfos) {
            ModelTranslator.enhanceFieldInfo(fieldInfo, miningSchema, bodyFields, result, context);
        }
        return result;
    }

    public ArrayInfoMap getArrayInfos() {
        ArrayInfoMap result = new ArrayInfoMap();
        Pattern pattern = Pattern.compile("^(.+)\\_(\\d+)");
        Matcher matcher = null;
        List inputFields = this.getInputFields();
        for (InputField inputField : inputFields) {
            String name = inputField.getFieldName();
            if (matcher == null) {
                matcher = pattern.matcher(name);
            } else {
                matcher.reset(name);
            }
            if (!matcher.matches()) continue;
            String arrayName = matcher.group(1);
            Integer arrayIndex = Integer.parseInt(matcher.group(2));
            DataField dataField = this.getDataField(name);
            ArrayInfo arrayInfo = (ArrayInfo)result.get(arrayName);
            if (arrayInfo == null) {
                arrayInfo = result.create(arrayName);
            }
            arrayInfo.setElement(arrayIndex, dataField);
        }
        return result;
    }

    public Object[] getTargetCategories() {
        TargetField targetField = this.getTargetField();
        List categories = targetField.getCategories();
        return categories.toArray(new Object[categories.size()]);
    }

    protected void declareArrayFields(Collection<ArrayInfo> arrayInfos) {
        PMML pmml = this.getPMML();
        DataDictionary dataDictionary = pmml.requireDataDictionary();
        for (ArrayInfo arrayInfo : arrayInfos) {
            DataField dataField = new DataField(arrayInfo.getName(), arrayInfo.getOpType(), arrayInfo.getDataType());
            dataDictionary.addDataFields(new DataField[]{dataField});
        }
        try {
            Field dataFieldsField = ModelManager.class.getDeclaredField("dataFields");
            if (!dataFieldsField.isAccessible()) {
                dataFieldsField.setAccessible(true);
            }
            dataFieldsField.set((Object)this, ImmutableMap.copyOf((Map)IndexableUtil.buildMap((List)dataDictionary.getDataFields(), (Field)PMMLAttributes.DATAFIELD_NAME)));
        }
        catch (ReflectiveOperationException roe) {
            throw new RuntimeException(roe);
        }
    }

    public static <V extends Number> ValueFactory<V> getValueFactory(Model model) {
        MathContext mathContext = model.getMathContext();
        switch (mathContext) {
            case FLOAT: 
            case DOUBLE: {
                ValueFactoryFactory valueFactoryFactory = ValueFactoryFactory.newInstance();
                return valueFactoryFactory.newValueFactory(mathContext);
            }
        }
        throw new UnsupportedAttributeException((PMMLObject)model, (Enum)mathContext);
    }

    private static void translateRegressorTarget(Target target, ValueBuilder valueBuilder) {
        Target.CastInteger castInteger;
        Number rescaleConstant;
        Number rescaleFactor = target.getRescaleFactor();
        if (rescaleFactor != null && rescaleFactor.doubleValue() != 1.0) {
            valueBuilder.update("multiply", rescaleFactor);
        }
        if ((rescaleConstant = target.getRescaleConstant()) != null && rescaleConstant.doubleValue() != 0.0) {
            valueBuilder.update("add", rescaleConstant);
        }
        if ((castInteger = target.getCastInteger()) != null) {
            throw new UnsupportedAttributeException((PMMLObject)target, (Enum)castInteger);
        }
    }

    public static JMethod createEvaluatorMethod(String name, TranslationContext context) {
        JDefinedClass owner = context.getOwner();
        JMethod method = owner.method(9, (JType)context.ref(Map.class).narrow(Arrays.asList(context.ref(String.class), context.ref(Object.class).wildcard())), name);
        method.annotate(Override.class);
        JTypeVar numberTypeVar = method.generify("V", Number.class);
        method.param((JType)context.ref(ValueFactory.class).narrow((JClass)numberTypeVar), "valueFactory");
        method.param(EvaluationContext.class, "context");
        return method;
    }

    public static JMethod createEvaluatorMethod(Class<?> type, PMMLObject object, boolean withValueFactory, TranslationContext context) {
        return ModelTranslator.createEvaluatorMethod(type, IdentifierUtil.create("evaluate" + object.getClass().getSimpleName(), object), withValueFactory, context);
    }

    public static JMethod createEvaluatorMethod(Class<?> type, List<? extends PMMLObject> objects, boolean withValueFactory, TranslationContext context) {
        PMMLObject object = (PMMLObject)Iterables.getFirst(objects, null);
        return ModelTranslator.createEvaluatorMethod(type, IdentifierUtil.create("evaluate" + object.getClass().getSimpleName() + "List", object), withValueFactory, context);
    }

    private static JMethod createEvaluatorMethod(Class<?> type, String name, boolean withValueFactory, TranslationContext context) {
        JDefinedClass owner = context.getOwner();
        JMethod method = owner.method(28, type, name);
        if (withValueFactory) {
            JTypeVar numberTypeVar = method.generify("V", Number.class);
            TypeVariable<Class<?>>[] typeVariables = type.getTypeParameters();
            if (typeVariables.length == 1) {
                method.type((JType)context.ref(type).narrow((JClass)numberTypeVar));
            } else if (typeVariables.length == 2) {
                method.type((JType)context.ref(type).narrow(new JClass[]{context.ref(Object.class), numberTypeVar}));
            }
            method.param((JType)context.ref(ValueFactory.class).narrow((JClass)numberTypeVar), "valueFactory");
        }
        method.param((JType)ModelTranslator.ensureArgumentsType(context), "arguments");
        return method;
    }

    public static JInvocation createEvaluatorMethodInvocation(JMethod method, TranslationContext context) {
        JInvocation invocation = JExpr.invoke((JMethod)method);
        List params = method.params();
        for (JVar param : params) {
            JVar arg;
            String name;
            switch (name = param.name()) {
                case "arguments": {
                    try {
                        arg = context.getArgumentsVariable().getExpression();
                    }
                    catch (IllegalArgumentException iae) {
                        arg = context._new((JClass)ModelTranslator.ensureArgumentsType(context), context.getContextVariable().getExpression());
                    }
                    break;
                }
                case "context": {
                    arg = context.getContextVariable().getExpression();
                    break;
                }
                case "valueFactory": {
                    arg = context.getValueFactoryVariable().getExpression();
                    break;
                }
                default: {
                    throw new IllegalArgumentException(name);
                }
            }
            invocation = invocation.arg((JExpression)arg);
        }
        return invocation;
    }

    public static JDefinedClass ensureArgumentsType(TranslationContext context) {
        JDefinedClass owner = context.getOwner(JavaModel.class);
        Iterator it = owner.classes();
        while (it.hasNext()) {
            JDefinedClass clazz = (JDefinedClass)it.next();
            if (!"Arguments".equals(clazz.name())) continue;
            return clazz;
        }
        JDefinedClass argumentsClazz = PMMLObjectUtil.createMemberClass(25, "Arguments", context);
        JFieldVar contextVar = argumentsClazz.field(4, EvaluationContext.class, "context");
        JMethod constructor = argumentsClazz.constructor(1);
        JVar contextParam = constructor.param(EvaluationContext.class, "context");
        JBlock block = constructor.body();
        block.assign((JAssignmentTarget)JExpr.refthis((String)contextVar.name()), (JExpression)contextParam);
        return argumentsClazz;
    }

    private static void enhanceFieldInfo(FieldInfo fieldInfo, MiningSchema miningSchema, Map<String, org.dmg.pmml.Field<?>> bodyFields, FieldInfoMap fieldInfos, FunctionInvocationContext context) {
        org.dmg.pmml.Field<?> field = fieldInfo.getField();
        if (field instanceof DerivedField) {
            DerivedField derivedField = (DerivedField)field;
            Expression expression = derivedField.requireExpression();
            FunctionInvocation functionInvocation = FunctionInvocationUtil.match(expression, context);
            if (functionInvocation instanceof FunctionInvocation.Ref) {
                FunctionInvocation.Ref ref = (FunctionInvocation.Ref)functionInvocation;
                String fieldName = ref.getField();
                FieldInfo refFieldInfo = (FieldInfo)fieldInfos.get(fieldName);
                if (refFieldInfo == null) {
                    org.dmg.pmml.Field<?> refField = bodyFields.get(fieldName);
                    refFieldInfo = fieldInfos.create(refField);
                    ModelTranslator.enhanceFieldInfo(refFieldInfo, miningSchema, bodyFields, fieldInfos, context);
                }
                fieldInfo.setRef(refFieldInfo);
                functionInvocation = null;
            }
            fieldInfo.setFunctionInvocation(functionInvocation);
        }
    }
}

