package org.jpmml.model.visitors;

import java.util.Comparator;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.HasDerivedFields;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.TransformationDictionary;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.mining.MiningModel;

/* loaded from: input_file:org/jpmml/model/visitors/DerivedFieldRelocator.class */
public class DerivedFieldRelocator extends DeepFieldResolver {
    private Map<DerivedField, Set<Model>> derivedFieldModels = new IdentityHashMap();

    @Override // org.jpmml.model.visitors.DeepFieldResolver, org.jpmml.model.visitors.FieldResolver, org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
    public void applyTo(Visitable visitable) {
        this.derivedFieldModels.clear();
        super.applyTo(visitable);
    }

    @Override // org.jpmml.model.visitors.FieldResolver, org.dmg.pmml.VisitContext
    public PMMLObject popParent() {
        PMMLObject popParent = super.popParent();
        if (popParent instanceof Model) {
            processModel((Model) popParent);
        } else if (popParent instanceof PMML) {
            processPMML((PMML) popParent);
        }
        return popParent;
    }

    private void processModel(Model model) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        linkedHashSet.add(model);
        for (PMMLObject pMMLObject : getParents()) {
            if (pMMLObject instanceof Model) {
                linkedHashSet.add((Model) pMMLObject);
            }
        }
        for (DerivedField derivedField : getActiveDerivedFields(model)) {
            Set<Model> set = this.derivedFieldModels.get(derivedField);
            if (set == null) {
                this.derivedFieldModels.put(derivedField, new LinkedHashSet(linkedHashSet));
            } else {
                set.retainAll(linkedHashSet);
            }
        }
    }

    private void processPMML(PMML pmml) {
        final IdentityHashMap identityHashMap = new IdentityHashMap();
        for (Map.Entry<DerivedField, Set<Model>> entry : this.derivedFieldModels.entrySet()) {
            DerivedField key = entry.getKey();
            Set<Model> value = entry.getValue();
            if (value.size() > 0) {
                identityHashMap.put(key, value.iterator().next());
            }
        }
        final IdentityHashMap identityHashMap2 = new IdentityHashMap();
        AbstractVisitor abstractVisitor = new AbstractVisitor() { // from class: org.jpmml.model.visitors.DerivedFieldRelocator.1
            @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
            public VisitorAction visit(DerivedField derivedField) {
                identityHashMap2.put(derivedField, Integer.valueOf(identityHashMap2.size()));
                return super.visit(derivedField);
            }

            @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
            public VisitorAction visit(LocalTransformations localTransformations) {
                Model model = (Model) getParent();
                if (localTransformations.hasDerivedFields()) {
                    for (DerivedField derivedField : localTransformations.getDerivedFields()) {
                        Model model2 = (Model) identityHashMap.get(derivedField);
                        if (model2 != null && Objects.equals(model2, model)) {
                            identityHashMap.remove(derivedField);
                        }
                    }
                }
                return super.visit(localTransformations);
            }
        };
        AbstractVisitor abstractVisitor2 = new AbstractVisitor() { // from class: org.jpmml.model.visitors.DerivedFieldRelocator.2
            @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
            public VisitorAction visit(LocalTransformations localTransformations) {
                Model model = (Model) getParent();
                if (localTransformations.hasDerivedFields()) {
                    Iterator<DerivedField> it = localTransformations.getDerivedFields().iterator();
                    while (it.hasNext()) {
                        Model model2 = (Model) identityHashMap.get(it.next());
                        if (model2 != null && !Objects.equals(model2, model)) {
                            it.remove();
                        }
                    }
                }
                return super.visit(localTransformations);
            }

            @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
            public VisitorAction visit(Model model) {
                LocalTransformations localTransformations = model.getLocalTransformations();
                for (Map.Entry entry2 : identityHashMap.entrySet()) {
                    DerivedField derivedField = (DerivedField) entry2.getKey();
                    if (Objects.equals((Model) entry2.getValue(), model)) {
                        if (localTransformations == null) {
                            localTransformations = new LocalTransformations();
                            model.setLocalTransformations(localTransformations);
                        }
                        localTransformations.addDerivedFields(derivedField);
                    }
                }
                return super.visit(model);
            }

            @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
            public VisitorAction visit(TransformationDictionary transformationDictionary) {
                if (transformationDictionary.hasDerivedFields()) {
                    Iterator<DerivedField> it = transformationDictionary.getDerivedFields().iterator();
                    while (it.hasNext()) {
                        if (((Model) identityHashMap.get(it.next())) != null) {
                            it.remove();
                        }
                    }
                }
                return super.visit(transformationDictionary);
            }
        };
        AbstractVisitor abstractVisitor3 = new AbstractVisitor() { // from class: org.jpmml.model.visitors.DerivedFieldRelocator.3
            private Comparator<DerivedField> comparator = new Comparator<DerivedField>() { // from class: org.jpmml.model.visitors.DerivedFieldRelocator.3.1
                @Override // java.util.Comparator
                public int compare(DerivedField derivedField, DerivedField derivedField2) {
                    return ((Integer) identityHashMap2.get(derivedField)).compareTo((Integer) identityHashMap2.get(derivedField2));
                }
            };

            @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
            public VisitorAction visit(LocalTransformations localTransformations) {
                sort(localTransformations);
                return super.visit(localTransformations);
            }

            @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
            public VisitorAction visit(TransformationDictionary transformationDictionary) {
                sort(transformationDictionary);
                return super.visit(transformationDictionary);
            }

            private void sort(HasDerivedFields<?> hasDerivedFields) {
                if (hasDerivedFields.hasDerivedFields()) {
                    hasDerivedFields.getDerivedFields().sort(this.comparator);
                }
            }
        };
        abstractVisitor.applyTo(pmml);
        abstractVisitor2.applyTo(pmml);
        abstractVisitor3.applyTo(pmml);
    }

    private Set<DerivedField> getActiveDerivedFields(Model model) {
        FieldDependencyResolver fieldDependencyResolver = getFieldDependencyResolver();
        Set<Field<?>> activeFields = model instanceof MiningModel ? DeepFieldResolverUtil.getActiveFields((DeepFieldResolver) this, (MiningModel) model) : DeepFieldResolverUtil.getActiveFields(this, model);
        HashSet hashSet = new HashSet();
        hashSet.addAll(fieldDependencyResolver.expand(activeFields, fieldDependencyResolver.getLocalDerivedFields()));
        hashSet.addAll(fieldDependencyResolver.expand(activeFields, fieldDependencyResolver.getGlobalDerivedFields()));
        return hashSet;
    }
}
