package org.jpmml.evaluator.support_vector_machine;

import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Floats;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.Array;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasValue;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.RealSparseArray;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.support_vector_machine.Coefficient;
import org.dmg.pmml.support_vector_machine.Coefficients;
import org.dmg.pmml.support_vector_machine.Kernel;
import org.dmg.pmml.support_vector_machine.PMMLAttributes;
import org.dmg.pmml.support_vector_machine.SupportVector;
import org.dmg.pmml.support_vector_machine.SupportVectorMachine;
import org.dmg.pmml.support_vector_machine.SupportVectorMachineModel;
import org.dmg.pmml.support_vector_machine.VectorDictionary;
import org.dmg.pmml.support_vector_machine.VectorInstance;
import org.jpmml.evaluator.ArrayUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.MissingFieldValueException;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.Numbers;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.SparseArrayUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.model.InvalidAttributeException;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.InvalidElementListException;
import org.jpmml.model.MisplacedElementException;
import org.jpmml.model.UnsupportedAttributeException;

/* loaded from: input_file:org/jpmml/evaluator/support_vector_machine/SupportVectorMachineModelEvaluator.class */
public class SupportVectorMachineModelEvaluator extends ModelEvaluator<SupportVectorMachineModel> {
    private Map<String, Object> vectorMap;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.evaluator.support_vector_machine.SupportVectorMachineModelEvaluator$2, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/evaluator/support_vector_machine/SupportVectorMachineModelEvaluator$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MathContext = new int[MathContext.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MathContext[MathContext.FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MathContext[MathContext.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod = new int[SupportVectorMachineModel.ClassificationMethod.values().length];
            try {
                $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ALL.ordinal()] = 1;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE.ordinal()] = 2;
            } catch (NoSuchFieldError e4) {
            }
            $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation = new int[SupportVectorMachineModel.Representation.values().length];
            try {
                $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation[SupportVectorMachineModel.Representation.SUPPORT_VECTORS.ordinal()] = 1;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    private SupportVectorMachineModelEvaluator() {
        this.vectorMap = Collections.emptyMap();
    }

    public SupportVectorMachineModelEvaluator(PMML pmml) {
        this(pmml, PMMLUtil.findModel(pmml, SupportVectorMachineModel.class));
    }

    public SupportVectorMachineModelEvaluator(PMML pmml, SupportVectorMachineModel supportVectorMachineModel) {
        super(pmml, supportVectorMachineModel);
        this.vectorMap = Collections.emptyMap();
        boolean isMaxWins = supportVectorMachineModel.isMaxWins();
        if (isMaxWins) {
            throw new UnsupportedAttributeException(supportVectorMachineModel, PMMLAttributes.SUPPORTVECTORMACHINEMODEL_MAXWINS, Boolean.valueOf(isMaxWins));
        }
        SupportVectorMachineModel.Representation representation = supportVectorMachineModel.getRepresentation();
        switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation[representation.ordinal()]) {
            case 1:
                supportVectorMachineModel.requireSupportVectorMachines();
                this.vectorMap = ImmutableMap.copyOf(parseVectorDictionary(supportVectorMachineModel));
                return;
            default:
                throw new UnsupportedAttributeException(supportVectorMachineModel, representation);
        }
    }

    @Override // org.jpmml.evaluator.Evaluator
    public String getSummary() {
        return "Support vector machine";
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    protected <V extends Number> Map<String, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        List requireSupportVectorMachines = ((SupportVectorMachineModel) getModel()).requireSupportVectorMachines();
        if (requireSupportVectorMachines.size() != 1) {
            throw new InvalidElementListException(requireSupportVectorMachines);
        }
        return TargetUtil.evaluateRegression(getTargetField(), evaluateSupportVectorMachine(valueFactory, (SupportVectorMachine) requireSupportVectorMachines.get(0), createInput(evaluationContext)));
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    protected <V extends Number> Map<String, ? extends Classification<?, V>> evaluateClassification(final ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        ValueMap valueMap;
        Classification voteProbabilityDistribution;
        Object requireTargetCategory;
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel) getModel();
        Object alternateBinaryTargetCategory = supportVectorMachineModel.getAlternateBinaryTargetCategory();
        List<SupportVectorMachine> requireSupportVectorMachines = supportVectorMachineModel.requireSupportVectorMachines();
        SupportVectorMachineModel.ClassificationMethod classificationMethod = supportVectorMachineModel.getClassificationMethod();
        switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[classificationMethod.ordinal()]) {
            case 1:
                valueMap = new ValueMap(2 * requireSupportVectorMachines.size());
                break;
            case 2:
                valueMap = new VoteMap<Object, V>(2 * requireSupportVectorMachines.size()) { // from class: org.jpmml.evaluator.support_vector_machine.SupportVectorMachineModelEvaluator.1
                    @Override // org.jpmml.evaluator.ValueMap
                    public ValueFactory<V> getValueFactory() {
                        return valueFactory;
                    }
                };
                break;
            default:
                throw new UnsupportedAttributeException(supportVectorMachineModel, classificationMethod);
        }
        Object createInput = createInput(evaluationContext);
        for (SupportVectorMachine supportVectorMachine : requireSupportVectorMachines) {
            Value<V> evaluateSupportVectorMachine = evaluateSupportVectorMachine(valueFactory, supportVectorMachine, createInput);
            switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[classificationMethod.ordinal()]) {
                case 1:
                    valueMap.put(supportVectorMachine.requireTargetCategory(), evaluateSupportVectorMachine);
                    break;
                case 2:
                    if (alternateBinaryTargetCategory != null) {
                        evaluateSupportVectorMachine.round2();
                        if (evaluateSupportVectorMachine.isZero()) {
                            requireTargetCategory = alternateBinaryTargetCategory;
                        } else {
                            if (!evaluateSupportVectorMachine.isOne()) {
                                throw new EvaluationException("Expected " + EvaluationException.formatValue(Numbers.DOUBLE_ZERO) + " or " + EvaluationException.formatValue(Numbers.DOUBLE_ONE) + ", got " + EvaluationException.formatValue(evaluateSupportVectorMachine.getValue()));
                            }
                            requireTargetCategory = supportVectorMachine.requireTargetCategory();
                        }
                    } else {
                        Number threshold = supportVectorMachine.getThreshold();
                        if (threshold == null) {
                            threshold = supportVectorMachineModel.getThreshold();
                        }
                        requireTargetCategory = evaluateSupportVectorMachine.compareTo(threshold) < 0 ? supportVectorMachine.requireTargetCategory() : supportVectorMachine.requireAlternateTargetCategory();
                    }
                    ((VoteMap) valueMap).increment(requireTargetCategory);
                    break;
            }
        }
        switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[classificationMethod.ordinal()]) {
            case 1:
                voteProbabilityDistribution = new DistanceDistribution(valueMap);
                break;
            case 2:
                voteProbabilityDistribution = new VoteProbabilityDistribution(valueMap);
                break;
            default:
                throw new UnsupportedAttributeException(supportVectorMachineModel, classificationMethod);
        }
        return TargetUtil.evaluateClassification(getTargetField(), voteProbabilityDistribution);
    }

    private <V extends Number> Value<V> evaluateSupportVectorMachine(ValueFactory<V> valueFactory, SupportVectorMachine supportVectorMachine, Object obj) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel) getModel();
        Value<V> newValue = valueFactory.newValue();
        Kernel requireKernel = supportVectorMachineModel.requireKernel();
        Coefficients coefficients = supportVectorMachine.getCoefficients();
        Iterator it = coefficients.iterator();
        Iterator it2 = supportVectorMachine.getSupportVectors().iterator();
        Map<String, Object> vectorMap = getVectorMap();
        while (it.hasNext() && it2.hasNext()) {
            Coefficient coefficient = (Coefficient) it.next();
            SupportVector supportVector = (SupportVector) it2.next();
            String requireVectorId = supportVector.requireVectorId();
            Object obj2 = vectorMap.get(requireVectorId);
            if (obj2 == null) {
                throw new InvalidAttributeException(supportVector, PMMLAttributes.SUPPORTVECTOR_VECTORID, requireVectorId);
            }
            newValue.add2(coefficient.getValue(), KernelUtil.evaluate(requireKernel, valueFactory, obj, obj2).getValue());
        }
        if (it.hasNext() || it2.hasNext()) {
            throw new InvalidElementException(supportVectorMachine);
        }
        newValue.add2(coefficients.getAbsoluteValue());
        return newValue;
    }

    private Object createInput(EvaluationContext evaluationContext) {
        SupportVectorMachineModel supportVectorMachineModel = (SupportVectorMachineModel) getModel();
        List content = supportVectorMachineModel.requireVectorDictionary().requireVectorFields().getContent();
        ArrayList arrayList = new ArrayList(content.size());
        int size = content.size();
        for (int i = 0; i < size; i++) {
            FieldRef fieldRef = (PMMLObject) content.get(i);
            if (fieldRef instanceof FieldRef) {
                FieldRef fieldRef2 = fieldRef;
                FieldValue evaluate = ExpressionUtil.evaluate((Expression) fieldRef2, evaluationContext);
                if (FieldValueUtil.isMissing(evaluate)) {
                    throw new MissingFieldValueException(fieldRef2);
                }
                arrayList.add(evaluate.asNumber());
            } else {
                if (!(fieldRef instanceof CategoricalPredictor)) {
                    throw new MisplacedElementException(fieldRef);
                }
                CategoricalPredictor categoricalPredictor = (CategoricalPredictor) fieldRef;
                FieldValue evaluate2 = evaluationContext.evaluate(categoricalPredictor.requireField());
                if (FieldValueUtil.isMissing(evaluate2)) {
                    throw new MissingFieldValueException(categoricalPredictor);
                }
                Number coefficient = categoricalPredictor.getCoefficient();
                if (coefficient != null && coefficient.doubleValue() != 1.0d) {
                    throw new InvalidAttributeException(categoricalPredictor, org.dmg.pmml.regression.PMMLAttributes.CATEGORICALPREDICTOR_COEFFICIENT, coefficient);
                }
                arrayList.add(evaluate2.equals((HasValue<?>) categoricalPredictor) ? Numbers.DOUBLE_ONE : Numbers.DOUBLE_ZERO);
            }
        }
        return toArray(supportVectorMachineModel, arrayList);
    }

    private Map<String, Object> getVectorMap() {
        return this.vectorMap;
    }

    private static Map<String, Object> parseVectorDictionary(SupportVectorMachineModel supportVectorMachineModel) {
        List<? extends Number> asNumberList;
        VectorDictionary requireVectorDictionary = supportVectorMachineModel.requireVectorDictionary();
        List content = requireVectorDictionary.requireVectorFields().getContent();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (VectorInstance vectorInstance : requireVectorDictionary.getVectorInstances()) {
            String requireId = vectorInstance.requireId();
            Array array = vectorInstance.getArray();
            RealSparseArray realSparseArray = vectorInstance.getRealSparseArray();
            if (array != null && realSparseArray == null) {
                asNumberList = ArrayUtil.asNumberList(array);
            } else {
                if (array != null || realSparseArray == null) {
                    throw new InvalidElementException(vectorInstance);
                }
                asNumberList = SparseArrayUtil.asNumberList(realSparseArray);
            }
            if (content.size() != asNumberList.size()) {
                throw new InvalidElementException(vectorInstance);
            }
            linkedHashMap.put(requireId, toArray(supportVectorMachineModel, asNumberList));
        }
        return linkedHashMap;
    }

    private static Object toArray(SupportVectorMachineModel supportVectorMachineModel, List<? extends Number> list) {
        MathContext mathContext = supportVectorMachineModel.getMathContext();
        switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$MathContext[mathContext.ordinal()]) {
            case 1:
                return Floats.toArray(list);
            case 2:
                return Doubles.toArray(list);
            default:
                throw new UnsupportedAttributeException(supportVectorMachineModel, mathContext);
        }
    }
}
