/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ImagePreProcessingScaler
implements DataNormalization {
    private static final Logger log = LoggerFactory.getLogger(ImagePreProcessingScaler.class);
    private double minRange;
    private double maxRange;
    private double maxPixelVal;
    private int maxBits;
    private boolean fitLabels = false;

    public ImagePreProcessingScaler() {
        this(0.0, 1.0, 8);
    }

    public ImagePreProcessingScaler(double a, double b) {
        this(a, b, 8);
    }

    public ImagePreProcessingScaler(double a, double b, int maxBits) {
        this.maxPixelVal = Math.pow(2.0, maxBits) - 1.0;
        this.minRange = a;
        this.maxRange = b;
    }

    @Override
    public void fit(DataSet dataSet) {
    }

    @Override
    public void fit(DataSetIterator iterator) {
    }

    @Override
    public void preProcess(DataSet toPreProcess) {
        INDArray features = toPreProcess.getFeatures();
        this.preProcess(features);
        if (this.fitLabels && toPreProcess.getLabels() != null) {
            this.preProcess(toPreProcess.getLabels());
        }
    }

    public void preProcess(INDArray features) {
        features.divi(this.maxPixelVal);
        if (this.maxRange - this.minRange != 1.0) {
            features.muli(this.maxRange - this.minRange);
        }
        if (this.minRange != 0.0) {
            features.addi(this.minRange);
        }
    }

    @Override
    public void transform(DataSet toPreProcess) {
        this.preProcess(toPreProcess);
    }

    @Override
    public void transform(INDArray features) {
        this.preProcess(features);
    }

    @Override
    public void transform(INDArray features, INDArray featuresMask) {
        this.transform(features);
    }

    @Override
    public void transformLabel(INDArray label) {
        Preconditions.checkState(label != null && label.rank() == 4, "Labels can only be transformed for segmentation use cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", (Object)label);
        this.transform(label);
    }

    @Override
    public void transformLabel(INDArray labels, INDArray labelsMask) {
        this.transformLabel(labels);
    }

    @Override
    public void revert(DataSet toRevert) {
        this.revertFeatures(toRevert.getFeatures());
        this.revertLabels(toRevert.getLabels());
    }

    @Override
    public NormalizerType getType() {
        return NormalizerType.IMAGE_MIN_MAX;
    }

    @Override
    public void revertFeatures(INDArray features) {
        if (this.minRange != 0.0) {
            features.subi(this.minRange);
        }
        if (this.maxRange - this.minRange != 1.0) {
            features.divi(this.maxRange - this.minRange);
        }
        features.muli(this.maxPixelVal);
    }

    @Override
    public void revertFeatures(INDArray features, INDArray featuresMask) {
        this.revertFeatures(features);
    }

    @Override
    public void revertLabels(INDArray labels) {
        Preconditions.checkState(labels != null && labels.rank() == 4, "Labels can only be transformed for segmentation use cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", (Object)labels);
        this.revertFeatures(labels);
    }

    @Override
    public void revertLabels(INDArray labels, INDArray labelsMask) {
        this.revertLabels(labels);
    }

    @Override
    public void fitLabel(boolean fitLabels) {
        this.fitLabels = fitLabels;
    }

    @Override
    public boolean isFitLabel() {
        return this.fitLabels;
    }

    public double getMinRange() {
        return this.minRange;
    }

    public double getMaxRange() {
        return this.maxRange;
    }

    public double getMaxPixelVal() {
        return this.maxPixelVal;
    }

    public int getMaxBits() {
        return this.maxBits;
    }

    public boolean isFitLabels() {
        return this.fitLabels;
    }

    public void setMinRange(double minRange) {
        this.minRange = minRange;
    }

    public void setMaxRange(double maxRange) {
        this.maxRange = maxRange;
    }

    public void setMaxPixelVal(double maxPixelVal) {
        this.maxPixelVal = maxPixelVal;
    }

    public void setMaxBits(int maxBits) {
        this.maxBits = maxBits;
    }

    public void setFitLabels(boolean fitLabels) {
        this.fitLabels = fitLabels;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ImagePreProcessingScaler)) {
            return false;
        }
        ImagePreProcessingScaler other = (ImagePreProcessingScaler)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (Double.compare(this.getMinRange(), other.getMinRange()) != 0) {
            return false;
        }
        if (Double.compare(this.getMaxRange(), other.getMaxRange()) != 0) {
            return false;
        }
        if (Double.compare(this.getMaxPixelVal(), other.getMaxPixelVal()) != 0) {
            return false;
        }
        if (this.getMaxBits() != other.getMaxBits()) {
            return false;
        }
        return this.isFitLabels() == other.isFitLabels();
    }

    protected boolean canEqual(Object other) {
        return other instanceof ImagePreProcessingScaler;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $minRange = Double.doubleToLongBits(this.getMinRange());
        result = result * 59 + (int)($minRange >>> 32 ^ $minRange);
        long $maxRange = Double.doubleToLongBits(this.getMaxRange());
        result = result * 59 + (int)($maxRange >>> 32 ^ $maxRange);
        long $maxPixelVal = Double.doubleToLongBits(this.getMaxPixelVal());
        result = result * 59 + (int)($maxPixelVal >>> 32 ^ $maxPixelVal);
        result = result * 59 + this.getMaxBits();
        result = result * 59 + (this.isFitLabels() ? 79 : 97);
        return result;
    }
}

