/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.reduce.bp;

import java.util.Collections;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;

public abstract class BaseReductionBp
extends DynamicCustomOp {
    protected boolean keepDims;
    protected int[] dimensions;

    public BaseReductionBp(SameDiff sameDiff, SDVariable origInput, SDVariable gradAtOutput, boolean keepDims, int ... dimensions) {
        super(null, sameDiff, new SDVariable[]{origInput, gradAtOutput}, false);
        this.keepDims = keepDims;
        this.dimensions = dimensions;
        this.addArgs();
    }

    public BaseReductionBp(SameDiff sameDiff, SDVariable origInput1, SDVariable origInput2, SDVariable gradAtOutput, boolean keepDims, int ... dimensions) {
        super(null, sameDiff, new SDVariable[]{origInput1, origInput2, gradAtOutput}, false);
        this.keepDims = keepDims;
        this.dimensions = dimensions;
        this.addArgs();
    }

    public BaseReductionBp(INDArray origInput, INDArray gradAtOutput, INDArray output, boolean keepDims, int ... dimensions) {
        INDArray[] iNDArrayArray;
        INDArray[] iNDArrayArray2 = new INDArray[]{origInput, gradAtOutput};
        if (output == null) {
            iNDArrayArray = null;
        } else {
            INDArray[] iNDArrayArray3 = new INDArray[1];
            iNDArrayArray = iNDArrayArray3;
            iNDArrayArray3[0] = output;
        }
        super(null, iNDArrayArray2, iNDArrayArray);
        this.keepDims = keepDims;
        this.dimensions = dimensions;
        this.addArgs();
    }

    public BaseReductionBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput, INDArray output, boolean keepDims, int ... dimensions) {
        INDArray[] iNDArrayArray;
        INDArray[] iNDArrayArray2 = new INDArray[]{origInput1, origInput2, gradAtOutput};
        if (output == null) {
            iNDArrayArray = null;
        } else {
            INDArray[] iNDArrayArray3 = new INDArray[1];
            iNDArrayArray = iNDArrayArray3;
            iNDArrayArray3[0] = output;
        }
        super(null, iNDArrayArray2, iNDArrayArray);
        this.keepDims = keepDims;
        this.dimensions = dimensions;
        this.addArgs();
    }

    public BaseReductionBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput, INDArray output1, INDArray output2, boolean keepDims, int ... dimensions) {
        super(null, new INDArray[]{origInput1, origInput2, gradAtOutput}, new INDArray[]{output1, output2});
        this.keepDims = keepDims;
        this.dimensions = dimensions;
        this.addArgs();
    }

    protected void addArgs() {
        this.addTArgument(this.keepDims ? 1.0 : 0.0);
        if (this.dimensions != null && this.dimensions.length > 0 && (this.dimensions.length != 1 || this.dimensions[0] != Integer.MAX_VALUE)) {
            this.addIArgument(this.dimensions);
        }
    }

    @Override
    public abstract String opName();

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatype for %s, got input %s", this.getClass(), dataTypes);
        Preconditions.checkState(dataTypes.get(0).isFPType(), "First input must be a floating point type, got %s", (Object)dataTypes.get(0));
        Preconditions.checkState(dataTypes.get(1).isFPType(), "Second input (gradient at reduction output) must be a floating point type, got %s", (Object)dataTypes.get(1));
        return Collections.singletonList(dataTypes.get(0));
    }

    public BaseReductionBp() {
    }
}

