package tpot.builtins;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Estimator;
import sklearn.HasClasses;
import sklearn.HasEstimator;
import sklearn.Transformer;

/* loaded from: input_file:tpot/builtins/StackingEstimator.class */
public class StackingEstimator extends Transformer implements HasEstimator<Estimator> {

    /* renamed from: tpot.builtins.StackingEstimator$1, reason: invalid class name */
    /* loaded from: input_file:tpot/builtins/StackingEstimator$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public StackingEstimator(String str, String str2) {
        super(str, str2);
    }

    public int getNumberOfFeatures() {
        return getEstimator().getNumberOfFeatures();
    }

    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        HasClasses estimator = getEstimator();
        ScalarLabel encodeLabel = estimator.encodeLabel(Collections.singletonList(null), skLearnEncoder);
        Model encode = estimator.encode(new Schema(skLearnEncoder, encodeLabel, list));
        skLearnEncoder.addTransformer(encode);
        String createFieldName = createFieldName("stack", list);
        ArrayList arrayList = new ArrayList();
        arrayList.add(skLearnEncoder.exportPrediction(encode, createFieldName, encodeLabel));
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[estimator.getMiningFunction().ordinal()]) {
            case 1:
                HasClasses hasClasses = estimator;
                if (hasClasses.hasProbabilityDistribution()) {
                    for (Object obj : hasClasses.getClasses()) {
                        arrayList.add(skLearnEncoder.exportProbability(encode, FieldNameUtil.create("probability", new Object[]{createFieldName, obj}), obj));
                    }
                    break;
                }
                break;
            case 2:
                break;
            default:
                throw new IllegalArgumentException();
        }
        arrayList.addAll(list);
        return arrayList;
    }

    public Estimator getEstimator() {
        return (Estimator) get("estimator", Estimator.class);
    }
}
