/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.scoring.util;

import org.deeplearning4j.arbiter.scoring.RegressionValue;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIteratorFactory;

public class ScoreUtil {
    public static MultiDataSetIterator getMultiIterator(Object o) {
        if (o instanceof MultiDataSetIterator) {
            return (MultiDataSetIterator)o;
        }
        if (o instanceof MultiDataSetIteratorFactory) {
            MultiDataSetIteratorFactory factory = (MultiDataSetIteratorFactory)o;
            return factory.create();
        }
        if (o instanceof DataSetIterator) {
            return new MultiDataSetIteratorAdapter((DataSetIterator)o);
        }
        if (o instanceof DataSetIteratorFactory) {
            return new MultiDataSetIteratorAdapter(((DataSetIteratorFactory)o).create());
        }
        throw new IllegalArgumentException("Type must either be DataSetIterator or DataSetIteratorFactory");
    }

    public static DataSetIterator getIterator(Object o) {
        if (o instanceof DataSetIterator) {
            return (DataSetIterator)o;
        }
        if (o instanceof DataSetIteratorFactory) {
            DataSetIteratorFactory factory = (DataSetIteratorFactory)o;
            return factory.create();
        }
        throw new IllegalArgumentException("Type must either be DataSetIterator or DataSetIteratorFactory");
    }

    public static Evaluation getEvaluation(MultiLayerNetwork model, DataSetIterator testData) {
        return (Evaluation)model.evaluate(testData);
    }

    public static Evaluation getEvaluation(ComputationGraph model, MultiDataSetIterator testData) {
        if (model.getNumOutputArrays() != 1) {
            throw new IllegalStateException("GraphSetSetAccuracyScoreFunction cannot be applied to ComputationGraphs with more than one output. NumOutputs = " + model.getNumOutputArrays());
        }
        return (Evaluation)model.evaluate(testData);
    }

    public static Evaluation getEvaluation(ComputationGraph model, DataSetIterator testData) {
        if (model.getNumOutputArrays() != 1) {
            throw new IllegalStateException("GraphSetSetAccuracyScoreFunctionDataSet cannot be applied to ComputationGraphs with more than one output. NumOutputs = " + model.getNumOutputArrays());
        }
        return (Evaluation)model.evaluate(testData);
    }

    public static double score(ComputationGraph model, MultiDataSetIterator testData, boolean average) {
        double sumScore = 0.0;
        int totalExamples = 0;
        while (testData.hasNext()) {
            MultiDataSet ds = (MultiDataSet)testData.next();
            long numExamples = ds.getFeatures(0).size(0);
            sumScore += (double)numExamples * model.score(ds);
            totalExamples = (int)((long)totalExamples + numExamples);
        }
        if (!average) {
            return sumScore;
        }
        return sumScore / (double)totalExamples;
    }

    public static double score(ComputationGraph model, DataSetIterator testData, boolean average) {
        double sumScore = 0.0;
        int totalExamples = 0;
        while (testData.hasNext()) {
            org.nd4j.linalg.dataset.DataSet ds = (org.nd4j.linalg.dataset.DataSet)testData.next();
            int numExamples = ds.numExamples();
            sumScore += (double)numExamples * model.score((DataSet)ds);
            totalExamples += numExamples;
        }
        if (!average) {
            return sumScore;
        }
        return sumScore / (double)totalExamples;
    }

    public static double score(ComputationGraph model, MultiDataSetIterator testSet, RegressionValue regressionValue) {
        int nOutputs = model.getNumOutputArrays();
        RegressionEvaluation[] evaluations = new RegressionEvaluation[nOutputs];
        for (int i = 0; i < evaluations.length; ++i) {
            evaluations[i] = new RegressionEvaluation();
        }
        while (testSet.hasNext()) {
            MultiDataSet next = (MultiDataSet)testSet.next();
            INDArray[] labels = next.getLabels();
            if (next.hasMaskArrays()) {
                INDArray[] fMasks = next.getFeaturesMaskArrays();
                INDArray[] lMasks = next.getLabelsMaskArrays();
                model.setLayerMaskArrays(fMasks, lMasks);
                INDArray[] outputs = model.output(false, next.getFeatures());
                for (int i = 0; i < evaluations.length; ++i) {
                    if (lMasks != null && lMasks[i] != null) {
                        evaluations[i].evalTimeSeries(labels[i], outputs[i], lMasks[i]);
                        continue;
                    }
                    evaluations[i].evalTimeSeries(labels[i], outputs[i]);
                }
                model.clearLayerMaskArrays();
                continue;
            }
            INDArray[] outputs = model.output(false, next.getFeatures());
            for (int i = 0; i < evaluations.length; ++i) {
                if (labels[i].rank() == 3) {
                    evaluations[i].evalTimeSeries(labels[i], outputs[i]);
                    continue;
                }
                evaluations[i].eval(labels[i], outputs[i]);
            }
        }
        double sum = 0.0;
        int totalColumns = 0;
        for (int i = 0; i < evaluations.length; ++i) {
            int nColumns = evaluations[i].numColumns();
            totalColumns += nColumns;
            sum += ScoreUtil.getScoreFromRegressionEval(evaluations[i], regressionValue);
        }
        if (regressionValue == RegressionValue.CorrCoeff) {
            sum /= (double)totalColumns;
        }
        return sum;
    }

    public static double score(ComputationGraph model, DataSetIterator testSet, RegressionValue regressionValue) {
        RegressionEvaluation evaluation = (RegressionEvaluation)model.evaluateRegression(testSet);
        return ScoreUtil.getScoreFromRegressionEval(evaluation, regressionValue);
    }

    public static double score(MultiLayerNetwork model, DataSetIterator testData, boolean average) {
        double sumScore = 0.0;
        int totalExamples = 0;
        while (testData.hasNext()) {
            org.nd4j.linalg.dataset.DataSet ds = (org.nd4j.linalg.dataset.DataSet)testData.next();
            int numExamples = ds.numExamples();
            sumScore += (double)numExamples * model.score(ds);
            totalExamples += numExamples;
        }
        if (!average) {
            return sumScore;
        }
        return sumScore / (double)totalExamples;
    }

    public static double score(MultiLayerNetwork model, DataSetIterator testSet, RegressionValue regressionValue) {
        RegressionEvaluation eval = (RegressionEvaluation)model.evaluateRegression(testSet);
        return ScoreUtil.getScoreFromRegressionEval(eval, regressionValue);
    }

    @Deprecated
    public static double getScoreFromRegressionEval(RegressionEvaluation eval, RegressionValue regressionValue) {
        double sum = 0.0;
        int nColumns = eval.numColumns();
        switch (regressionValue) {
            case MSE: {
                for (int i = 0; i < nColumns; ++i) {
                    sum += eval.meanSquaredError(i);
                }
                break;
            }
            case MAE: {
                for (int i = 0; i < nColumns; ++i) {
                    sum += eval.meanAbsoluteError(i);
                }
                break;
            }
            case RMSE: {
                for (int i = 0; i < nColumns; ++i) {
                    sum += eval.rootMeanSquaredError(i);
                }
                break;
            }
            case RSE: {
                for (int i = 0; i < nColumns; ++i) {
                    sum += eval.relativeSquaredError(i);
                }
                break;
            }
            case CorrCoeff: {
                for (int i = 0; i < nColumns; ++i) {
                    sum += eval.correlationR2(i);
                }
                sum /= (double)nColumns;
            }
        }
        return sum;
    }
}

