/*
 * Decompiled with CFR 0.152.
 */
package sklearn.neural_network;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.neural_network.NeuralEntity;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.neural_network.NeuralNetworkUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.HasArray;

public class MultilayerPerceptronUtil {
    private MultilayerPerceptronUtil() {
    }

    public static int getNumberOfFeatures(List<? extends HasArray> coefs) {
        HasArray input = coefs.get(0);
        int[] shape = input.getArrayShape();
        if (shape.length != 2) {
            throw new IllegalArgumentException();
        }
        return shape[0];
    }

    public static NeuralNetwork encodeNeuralNetwork(MiningFunction miningFunction, String activation, List<? extends HasArray> coefs, List<? extends HasArray> intercepts, Schema schema) {
        NeuralNetwork.ActivationFunction activationFunction = MultilayerPerceptronUtil.parseActivationFunction(activation);
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs((List)features, (DataType)DataType.DOUBLE);
        List<NeuralLayer> neuralLayers = MultilayerPerceptronUtil.encodeNeuralLayers(neuralInputs, coefs, intercepts);
        NeuralOutputs neuralOutputs = MultilayerPerceptronUtil.encodeNeuralOutputs(miningFunction, neuralLayers, label);
        NeuralNetwork neuralNetwork = new NeuralNetwork(miningFunction, activationFunction, ModelUtil.createMiningSchema((Label)label), neuralInputs, neuralLayers).setNeuralOutputs(neuralOutputs);
        return neuralNetwork;
    }

    public static List<NeuralLayer> encodeNeuralLayers(NeuralInputs neuralInputs, List<? extends HasArray> coefs, List<? extends HasArray> intercepts) {
        return MultilayerPerceptronUtil.encodeNeuralLayers(neuralInputs, coefs.size(), coefs, intercepts);
    }

    public static List<NeuralLayer> encodeNeuralLayers(NeuralInputs neuralInputs, int numberOfLayers, List<? extends HasArray> coefs, List<? extends HasArray> intercepts) {
        ClassDictUtil.checkSize((Collection[])new Collection[]{coefs, intercepts});
        List entities = neuralInputs.getNeuralInputs();
        ArrayList<NeuralLayer> result = new ArrayList<NeuralLayer>();
        for (int layer = 0; layer < numberOfLayers; ++layer) {
            HasArray coef = coefs.get(layer);
            HasArray intercept = intercepts.get(layer);
            int[] shape = coef.getArrayShape();
            int rows = shape[0];
            int columns = shape[1];
            NeuralLayer neuralLayer = new NeuralLayer();
            List coefMatrix = coef.getArrayContent();
            List interceptVector = intercept.getArrayContent();
            for (int column = 0; column < columns; ++column) {
                List weights = CMatrixUtil.getColumn((List)coefMatrix, (int)rows, (int)columns, (int)column);
                Number bias = (Number)interceptVector.get(column);
                Neuron neuron = NeuralNetworkUtil.createNeuron((List)entities, (List)weights, (Number)bias).setId(String.valueOf(layer + 1) + "/" + String.valueOf(column + 1));
                neuralLayer.addNeurons(new Neuron[]{neuron});
            }
            result.add(neuralLayer);
            entities = neuralLayer.getNeurons();
        }
        return result;
    }

    public static NeuralOutputs encodeNeuralOutputs(MiningFunction miningFunction, List<NeuralLayer> neuralLayers, Label label) {
        NeuralLayer neuralLayer = (NeuralLayer)Iterables.getLast(neuralLayers);
        neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY);
        List entities = neuralLayer.getNeurons();
        switch (miningFunction) {
            case CLASSIFICATION: {
                CategoricalLabel categoricalLabel = (CategoricalLabel)label;
                if (categoricalLabel.size() == 2) {
                    List transformationNeuralLayers = NeuralNetworkUtil.createBinaryLogisticTransformation((NeuralEntity)((NeuralEntity)Iterables.getOnlyElement((Iterable)entities)));
                    neuralLayers.addAll(transformationNeuralLayers);
                    neuralLayer = (NeuralLayer)Iterables.getLast((Iterable)transformationNeuralLayers);
                    entities = neuralLayer.getNeurons();
                } else if (categoricalLabel.size() > 2) {
                    neuralLayer.setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
                } else {
                    throw new IllegalArgumentException();
                }
                return NeuralNetworkUtil.createClassificationNeuralOutputs((List)entities, (CategoricalLabel)categoricalLabel);
            }
            case REGRESSION: {
                return NeuralNetworkUtil.createRegressionNeuralOutputs((List)entities, (Label)label);
            }
        }
        throw new IllegalArgumentException();
    }

    public static NeuralNetwork.ActivationFunction parseActivationFunction(String activation) {
        switch (activation) {
            case "identity": {
                return NeuralNetwork.ActivationFunction.IDENTITY;
            }
            case "logistic": {
                return NeuralNetwork.ActivationFunction.LOGISTIC;
            }
            case "relu": {
                return NeuralNetwork.ActivationFunction.RECTIFIER;
            }
            case "tanh": {
                return NeuralNetwork.ActivationFunction.TANH;
            }
        }
        throw new IllegalArgumentException(activation);
    }
}

