/*
 * Decompiled with CFR 0.152.
 */
package sklearn.tree;

import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import net.razorvine.pickle.objects.ClassDict;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.HasExtensions;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.NodeTransformer;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.SimplifyingNodeTransformer;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ScoreDistributionManager;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.ThresholdFeatureUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.visitors.AbstractExtender;
import org.jpmml.model.UnsupportedElementException;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.python.ClassDictUtil;
import sklearn.Estimator;
import sklearn.HasEstimatorEnsemble;
import sklearn.tree.HasTree;
import sklearn.tree.Tree;
import sklearn.tree.visitors.TreeModelCompactor;
import sklearn.tree.visitors.TreeModelFlattener;

public class TreeUtil {
    private TreeUtil() {
    }

    public static <E extends Estimator, M extends Model> M transform(E estimator, M model) {
        Boolean winnerId = (Boolean)estimator.getOption("winner_id", Boolean.FALSE);
        Map nodeExtensions = (Map)estimator.getOption("node_extensions", null);
        Boolean nodeId = (Boolean)estimator.getOption("node_id", winnerId);
        Boolean nodeScore = (Boolean)estimator.getOption("node_score", winnerId != false ? Boolean.TRUE : null);
        boolean fixed = nodeExtensions != null || nodeId != false || nodeScore != null && nodeScore != false;
        Boolean compact = (Boolean)estimator.getOption("compact", fixed ? Boolean.FALSE : Boolean.TRUE);
        Boolean flat = (Boolean)estimator.getOption("flat", Boolean.FALSE);
        if (compact.booleanValue() || flat.booleanValue()) {
            if (fixed) {
                throw new IllegalArgumentException("Conflicting tree model options");
            }
            nodeExtensions = null;
            nodeId = null;
            nodeScore = null;
        }
        if (Boolean.TRUE.equals(winnerId)) {
            TreeUtil.encodeNodeId(model);
        }
        ArrayList<Object> visitors = new ArrayList<Object>();
        if (Boolean.TRUE.equals(compact)) {
            visitors.add((Object)new TreeModelCompactor());
        }
        if (Boolean.TRUE.equals(flat)) {
            visitors.add((Object)new TreeModelFlattener());
        }
        if (nodeExtensions != null) {
            Set entries = nodeExtensions.entrySet();
            for (Map.Entry entry : entries) {
                String name = (String)entry.getKey();
                final Map values = (Map)entry.getValue();
                AbstractExtender nodeExtender = new AbstractExtender(name){
                    private NodeTransformer nodeTransformer;
                    {
                        super(x0);
                        this.nodeTransformer = SimplifyingNodeTransformer.INSTANCE;
                    }

                    public VisitorAction visit(TreeModel treeModel) {
                        treeModel.setNode(this.ensureExtensibility(treeModel.getNode()));
                        return super.visit(treeModel);
                    }

                    public VisitorAction visit(Node node) {
                        Object value;
                        if (node.hasNodes()) {
                            List children = node.getNodes();
                            ListIterator<Node> childIt = children.listIterator();
                            while (childIt.hasNext()) {
                                childIt.set(this.ensureExtensibility((Node)childIt.next()));
                            }
                        }
                        if ((value = this.getValue(node)) != null) {
                            value = ScalarUtil.decode((Object)value);
                            this.addExtension((PMMLObject)((Node)((HasExtensions)node)), ValueUtil.asString((Object)value));
                        }
                        return super.visit(node);
                    }

                    private Node ensureExtensibility(Node node) {
                        if (node instanceof HasExtensions) {
                            return node;
                        }
                        Object value = this.getValue(node);
                        if (value != null) {
                            return this.nodeTransformer.toComplexNode(node);
                        }
                        return node;
                    }

                    private Object getValue(Node node) {
                        Integer id = ValueUtil.asInteger((Number)((Number)node.getId()));
                        return values.get(id);
                    }
                };
                visitors.add(nodeExtender);
            }
        }
        if (Boolean.FALSE.equals(nodeId)) {
            AbstractVisitor nodeIdCleaner = new AbstractVisitor(){

                public VisitorAction visit(Node node) {
                    node.setId(null);
                    return super.visit(node);
                }
            };
            visitors.add(nodeIdCleaner);
        }
        if (Boolean.FALSE.equals(nodeScore)) {
            AbstractVisitor nodeScoreCleaner = new AbstractVisitor(){

                public VisitorAction visit(Node node) {
                    if (node.hasNodes()) {
                        node.setScore(null);
                        if (node.hasScoreDistributions()) {
                            List scoreDistributions = node.getScoreDistributions();
                            scoreDistributions.clear();
                        }
                    }
                    return super.visit(node);
                }
            };
            visitors.add(nodeScoreCleaner);
        }
        for (Visitor visitor : visitors) {
            visitor.applyTo(model);
        }
        return model;
    }

    public static <E extends Estimator, T extends Estimator> List<TreeModel> encodeTreeModelEnsemble(E estimator, MiningFunction miningFunction, Schema schema) {
        Boolean numeric = (Boolean)estimator.getOption("numeric", Boolean.TRUE);
        PredicateManager predicateManager = new PredicateManager();
        ScoreDistributionManager scoreDistributionManager = new ScoreDistributionManager();
        return TreeUtil.encodeTreeModelEnsemble(estimator, miningFunction, numeric, predicateManager, scoreDistributionManager, schema);
    }

    public static <E extends Estimator, T extends Estimator> List<TreeModel> encodeTreeModelEnsemble(E estimator, final MiningFunction miningFunction, final Boolean numeric, final PredicateManager predicateManager, final ScoreDistributionManager scoreDistributionManager, Schema schema) {
        List estimators = ((HasEstimatorEnsemble)((Object)estimator)).getEstimators();
        final Schema segmentSchema = schema.toAnonymousSchema();
        Function function = new Function<T, TreeModel>(){

            @Override
            public TreeModel apply(T estimator) {
                Schema treeModelSchema = TreeUtil.toTreeModelSchema(((Estimator)estimator).getDataType(), numeric, segmentSchema);
                TreeModel treeModel = TreeUtil.encodeTreeModel(estimator, miningFunction, numeric, predicateManager, scoreDistributionManager, treeModelSchema);
                if (((Estimator)estimator).hasFeatureImportances()) {
                    Schema featureImportanceSchema = TreeUtil.toTreeModelFeatureImportanceSchema(numeric, treeModelSchema);
                    ((Estimator)estimator).addFeatureImportances((Model)treeModel, featureImportanceSchema);
                }
                return treeModel;
            }
        };
        return estimators.stream().map(function).collect(Collectors.toList());
    }

    public static <E extends Estimator> TreeModel encodeTreeModel(E estimator, MiningFunction miningFunction, Schema schema) {
        Boolean numeric = (Boolean)estimator.getOption("numeric", Boolean.TRUE);
        PredicateManager predicateManager = new PredicateManager();
        ScoreDistributionManager scoreDistributionManager = new ScoreDistributionManager();
        return TreeUtil.encodeTreeModel(estimator, miningFunction, numeric, predicateManager, scoreDistributionManager, schema);
    }

    public static <E extends Estimator> TreeModel encodeTreeModel(E estimator, MiningFunction miningFunction, Boolean numeric, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
        Tree tree = ((HasTree)((Object)estimator)).getTree();
        int[] leftChildren = tree.getChildrenLeft();
        int[] rightChildren = tree.getChildrenRight();
        int[] features = tree.getFeature();
        double[] thresholds = tree.getThreshold();
        double[] values = tree.getValues();
        Node root = TreeUtil.encodeNode(0, (Predicate)True.INSTANCE, miningFunction, numeric, leftChildren, rightChildren, features, thresholds, values, new CategoryManager(), predicateManager, scoreDistributionManager, schema);
        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        ClassDictUtil.clearContent((ClassDict)tree);
        return treeModel;
    }

    private static Node encodeNode(int index, Predicate predicate, MiningFunction miningFunction, boolean numeric, int[] leftChildren, int[] rightChildren, int[] features, double[] thresholds, double[] values, CategoryManager categoryManager, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
        SimpleNode result;
        Integer id = index;
        int featureIndex = features[index];
        if (featureIndex >= 0) {
            ClassifierNode result2;
            Predicate rightPredicate;
            Predicate leftPredicate;
            Object value2;
            Feature feature = schema.getFeature(featureIndex);
            double threshold = thresholds[index];
            CategoryManager leftCategoryManager = categoryManager;
            CategoryManager rightCategoryManager = categoryManager;
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                if (threshold < 0.0 || threshold > 1.0) {
                    throw new IllegalArgumentException();
                }
                value2 = binaryFeature.getValue();
                leftPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value2);
                rightPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.EQUAL, value2);
            } else if (feature instanceof ThresholdFeature && !numeric) {
                ThresholdFeature thresholdFeature = (ThresholdFeature)feature;
                String name = thresholdFeature.getName();
                Object missingValue = thresholdFeature.getMissingValue();
                java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
                if (!ValueUtil.isNaN((Object)missingValue)) {
                    valueFilter = valueFilter.and(value -> !ValueUtil.isNaN((Object)value));
                }
                List leftValues = thresholdFeature.getValues(value -> TreeUtil.toSplitValue(value) <= threshold).stream().filter(valueFilter).collect(Collectors.toList());
                List rightValues = thresholdFeature.getValues(value -> TreeUtil.toSplitValue(value) > threshold).stream().filter(valueFilter).collect(Collectors.toList());
                leftCategoryManager = leftCategoryManager.fork(name, leftValues);
                rightCategoryManager = rightCategoryManager.fork(name, rightValues);
                leftPredicate = ThresholdFeatureUtil.createPredicate((ThresholdFeature)thresholdFeature, leftValues, (Object)missingValue, (PredicateManager)predicateManager);
                rightPredicate = ThresholdFeatureUtil.createPredicate((ThresholdFeature)thresholdFeature, rightValues, (Object)missingValue, (PredicateManager)predicateManager);
            } else {
                ContinuousFeature continuousFeature = TreeUtil.toContinuousFeature(feature);
                value2 = threshold;
                leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value2);
                rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, value2);
            }
            int leftIndex = leftChildren[index];
            int rightIndex = rightChildren[index];
            Node leftChild = TreeUtil.encodeNode(leftIndex, leftPredicate, miningFunction, numeric, leftChildren, rightChildren, features, thresholds, values, leftCategoryManager, predicateManager, scoreDistributionManager, schema);
            Node rightChild = TreeUtil.encodeNode(rightIndex, rightPredicate, miningFunction, numeric, leftChildren, rightChildren, features, thresholds, values, rightCategoryManager, predicateManager, scoreDistributionManager, schema);
            if (miningFunction == MiningFunction.CLASSIFICATION) {
                result2 = new ClassifierNode(null, predicate);
            } else if (miningFunction == MiningFunction.REGRESSION) {
                double value3 = values[index];
                result2 = new BranchNode((Object)value3, predicate);
            } else {
                throw new IllegalArgumentException();
            }
            result2.setId((Object)id).addNodes(leftChild, rightChild);
            return result2;
        }
        if (miningFunction == MiningFunction.CLASSIFICATION) {
            CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
            double[] recordCounts = TreeUtil.getRow(values, leftChildren.length, categoricalLabel.size(), index);
            double totalRecordCount = 0.0;
            Object score = null;
            double scoreRecordCount = -1.7976931348623157E308;
            for (int i = 0; i < recordCounts.length; ++i) {
                double recordCount = recordCounts[i];
                totalRecordCount += recordCount;
                if (!(recordCount > scoreRecordCount)) continue;
                score = categoricalLabel.getValue(i);
                scoreRecordCount = recordCount;
            }
            result = new ClassifierNode(score, predicate).setId((Object)id).setRecordCount(ValueUtil.narrow((double)totalRecordCount));
            scoreDistributionManager.addScoreDistributions((PMMLObject)result, categoricalLabel.getValues(), recordCounts);
        } else if (miningFunction == MiningFunction.REGRESSION) {
            double value4 = values[index];
            result = new LeafNode((Object)value4, predicate).setId((Object)id);
        } else {
            throw new IllegalArgumentException();
        }
        return result;
    }

    private static void encodeNodeId(Model model) {
        Output output = ModelUtil.ensureOutput((Model)model);
        if (model instanceof MiningModel) {
            MiningModel miningModel = (MiningModel)model;
            Segmentation segmentation = miningModel.requireSegmentation();
            List segments = segmentation.getSegments();
            for (Segment segment : segments) {
                TreeModel treeModel = (TreeModel)segment.requireModel(TreeModel.class);
                String segmentId = segment.getId();
                if (segmentId == null) {
                    throw new UnsupportedElementException((PMMLObject)segment);
                }
                TreeUtil.encodeNodeId(output, treeModel, segmentId);
            }
        } else {
            TreeModel treeModel = (TreeModel)model;
            TreeUtil.encodeNodeId(output, treeModel, null);
        }
    }

    private static void encodeNodeId(Output output, TreeModel treeModel, String segmentId) {
        final ArrayList values = new ArrayList();
        AbstractVisitor nodeIdCollector = new AbstractVisitor(){

            public VisitorAction visit(Node node) {
                if (!node.hasNodes()) {
                    values.add((Integer)node.getId());
                }
                return super.visit(node);
            }
        };
        nodeIdCollector.applyTo((Visitable)treeModel);
        OutputField nodeIdField = ModelUtil.createEntityIdField((String)"nodeId", (DataType)DataType.INTEGER, values);
        if (segmentId != null) {
            nodeIdField.setName(FieldNameUtil.create((String)nodeIdField.requireName(), (Object[])new Object[]{segmentId})).setSegmentId(segmentId);
        }
        output.addOutputFields(new OutputField[]{nodeIdField});
    }

    private static Schema toTreeModelSchema(final DataType dataType, final boolean numeric, Schema schema) {
        Function<Feature, Feature> function = new Function<Feature, Feature>(){

            @Override
            public Feature apply(Feature feature) {
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature)feature;
                    return binaryFeature;
                }
                if (feature instanceof ThresholdFeature && !numeric) {
                    ThresholdFeature thresholdFeature = (ThresholdFeature)feature;
                    return thresholdFeature;
                }
                ContinuousFeature continuousFeature = feature.toContinuousFeature(dataType);
                return continuousFeature;
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    private static Schema toTreeModelFeatureImportanceSchema(final boolean numeric, Schema schema) {
        Function<Feature, Feature> function = new Function<Feature, Feature>(){

            @Override
            public Feature apply(Feature feature) {
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature)feature;
                    return binaryFeature;
                }
                if (feature instanceof ThresholdFeature && !numeric) {
                    ThresholdFeature thresholdFeature = (ThresholdFeature)feature;
                    return thresholdFeature;
                }
                ContinuousFeature continuousFeature = TreeUtil.toContinuousFeature(feature);
                return continuousFeature;
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    private static ContinuousFeature toContinuousFeature(Feature feature) {
        return feature.toContinuousFeature(DataType.FLOAT).toContinuousFeature(DataType.DOUBLE);
    }

    private static double toSplitValue(Number value) {
        return value.floatValue();
    }

    private static double[] getRow(double[] values, int rows, int columns, int row) {
        if (values.length != rows * columns) {
            throw new IllegalArgumentException("Expected " + rows * columns + " element(s), got " + values.length + " element(s)");
        }
        double[] result = new double[columns];
        System.arraycopy(values, row * columns, result, 0, columns);
        return result;
    }
}

