/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.utils;

import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasLossUtils {
    private static final Logger log = LoggerFactory.getLogger(KerasLossUtils.class);

    public static LossFunctions.LossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf) throws UnsupportedKerasConfigurationException {
        LossFunctions.LossFunction dl4jLoss;
        if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) || kerasLoss.equals(conf.getKERAS_LOSS_MSE())) {
            dl4jLoss = LossFunctions.LossFunction.SQUARED_LOSS;
        } else if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_ABSOLUTE_ERROR()) || kerasLoss.equals(conf.getKERAS_LOSS_MAE())) {
            dl4jLoss = LossFunctions.LossFunction.MEAN_ABSOLUTE_ERROR;
        } else if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_ABSOLUTE_PERCENTAGE_ERROR()) || kerasLoss.equals(conf.getKERAS_LOSS_MAPE())) {
            dl4jLoss = LossFunctions.LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR;
        } else if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_LOGARITHMIC_ERROR()) || kerasLoss.equals(conf.getKERAS_LOSS_MSLE())) {
            dl4jLoss = LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR;
        } else if (kerasLoss.equals(conf.getKERAS_LOSS_SQUARED_HINGE())) {
            dl4jLoss = LossFunctions.LossFunction.SQUARED_HINGE;
        } else if (kerasLoss.equals(conf.getKERAS_LOSS_HINGE())) {
            dl4jLoss = LossFunctions.LossFunction.HINGE;
        } else {
            if (kerasLoss.equals(conf.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY())) {
                throw new UnsupportedKerasConfigurationException("Loss function " + kerasLoss + " not supported yet.");
            }
            if (kerasLoss.equals(conf.getKERAS_LOSS_BINARY_CROSSENTROPY())) {
                dl4jLoss = LossFunctions.LossFunction.XENT;
            } else if (kerasLoss.equals(conf.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) {
                dl4jLoss = LossFunctions.LossFunction.MCXENT;
            } else if (kerasLoss.equals(conf.getKERAS_LOSS_KULLBACK_LEIBLER_DIVERGENCE()) || kerasLoss.equals(conf.getKERAS_LOSS_KLD())) {
                dl4jLoss = LossFunctions.LossFunction.KL_DIVERGENCE;
            } else if (kerasLoss.equals(conf.getKERAS_LOSS_POISSON())) {
                dl4jLoss = LossFunctions.LossFunction.POISSON;
            } else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) {
                dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY;
            } else {
                throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
            }
        }
        return dl4jLoss;
    }
}

