/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.scoring.impl;

import lombok.NonNull;
import org.deeplearning4j.arbiter.scoring.impl.BaseNetScoreFunction;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class RegressionScoreFunction
extends BaseNetScoreFunction {
    protected RegressionEvaluation.Metric metric;

    public RegressionScoreFunction(@NonNull RegressionEvaluation.Metric metric) {
        this(metric.toNd4j());
        if (metric == null) {
            throw new NullPointerException("metric is marked non-null but is null");
        }
    }

    public RegressionScoreFunction(@NonNull RegressionEvaluation.Metric metric) {
        if (metric == null) {
            throw new NullPointerException("metric is marked non-null but is null");
        }
        this.metric = metric;
    }

    public boolean minimize() {
        switch (this.metric) {
            case MSE: 
            case MAE: 
            case RMSE: 
            case RSE: {
                return true;
            }
            case PC: 
            case R2: {
                return false;
            }
        }
        throw new IllegalStateException("Unknown metric: " + this.metric);
    }

    public String toString() {
        return "RegressionScoreFunction(metric=" + this.metric + ")";
    }

    @Override
    public double score(MultiLayerNetwork net, DataSetIterator iterator) {
        RegressionEvaluation e = net.evaluateRegression(iterator);
        return e.scoreForMetric(this.metric);
    }

    @Override
    public double score(MultiLayerNetwork net, MultiDataSetIterator iterator) {
        return this.score(net, (DataSetIterator)new MultiDataSetWrapperIterator(iterator));
    }

    @Override
    public double score(ComputationGraph graph, DataSetIterator iterator) {
        RegressionEvaluation e = graph.evaluateRegression(iterator);
        return e.scoreForMetric(this.metric);
    }

    @Override
    public double score(ComputationGraph graph, MultiDataSetIterator iterator) {
        RegressionEvaluation e = graph.evaluateRegression(iterator);
        return e.scoreForMetric(this.metric);
    }

    public RegressionEvaluation.Metric getMetric() {
        return this.metric;
    }

    public void setMetric(RegressionEvaluation.Metric metric) {
        this.metric = metric;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof RegressionScoreFunction)) {
            return false;
        }
        RegressionScoreFunction other = (RegressionScoreFunction)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        RegressionEvaluation.Metric this$metric = this.getMetric();
        RegressionEvaluation.Metric other$metric = other.getMetric();
        return !(this$metric == null ? other$metric != null : !this$metric.equals(other$metric));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof RegressionScoreFunction;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        RegressionEvaluation.Metric $metric = this.getMetric();
        result = result * 59 + ($metric == null ? 43 : $metric.hashCode());
        return result;
    }

    protected RegressionScoreFunction() {
    }
}

