package com.github.psambit9791.jdsp.transform;

import com.github.psambit9791.jdsp.misc.Random;
import com.github.psambit9791.jdsp.misc.UtilMethods;
import java.util.Arrays;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.MathArrays;

/* loaded from: input_file:com/github/psambit9791/jdsp/transform/ICA.class */
public class ICA {
    private double[][] signal;
    public double[][] zm_signal;
    private double[][] output;
    private double alpha;
    public double[] gx;
    public double g_x;
    public double[][] w_init;
    private int max_iter;
    private double tol;
    private String whiten;
    private String func;
    private long seed;
    private double[] mean_;
    private int components;
    private int n_iter;
    public double[][] mixingMatrix;
    public double[][] unmixingMatrix;
    public double[][] whiteningMatrix;
    private double[][] componentMatrix;
    private double[][] sources;

    private void logcosh_(double[] dArr) {
        this.gx = new double[dArr.length];
        if (this.alpha < 1.0d || this.alpha > 2.0d) {
            throw new IllegalArgumentException("alpha should be between 1.0 and 2.0");
        }
        double d = 0.0d;
        for (int i = 0; i < this.gx.length; i++) {
            this.gx[i] = FastMath.tanh(dArr[i] * this.alpha);
            d += this.alpha * (1.0d - Math.pow(this.gx[i], 2.0d));
        }
        this.g_x = d / this.gx.length;
    }

    private void exp_(double[] dArr) {
        this.gx = new double[dArr.length];
        double d = 0.0d;
        for (int i = 0; i < this.gx.length; i++) {
            double exp = FastMath.exp((0.0d - Math.pow(dArr[i], 2.0d)) / 2.0d);
            this.gx[i] = dArr[i] * exp;
            d += (1.0d - Math.pow(dArr[i], 2.0d)) * exp;
        }
        this.g_x = d / this.gx.length;
    }

    private void cube_(double[] dArr) {
        this.gx = new double[dArr.length];
        double d = 0.0d;
        for (int i = 0; i < this.gx.length; i++) {
            this.gx[i] = Math.pow(dArr[i], 3.0d);
            d += 3.0d * Math.pow(dArr[i], 2.0d);
        }
        this.g_x = d / this.gx.length;
    }

    public ICA(double[][] dArr, String str, String str2, double[][] dArr2, int i, double d, double d2) {
        this.alpha = 1.0d;
        this.max_iter = 200;
        this.tol = 1.0E-4d;
        this.whiten = "unit-variance";
        this.func = "logcosh";
        this.seed = 42L;
        this.n_iter = -1;
        this.mixingMatrix = null;
        this.unmixingMatrix = null;
        this.whiteningMatrix = null;
        this.componentMatrix = null;
        this.sources = null;
        this.signal = dArr;
        this.components = this.signal[0].length;
        if (dArr2.length != dArr2[0].length || dArr2.length != this.components) {
            throw new IllegalArgumentException("w_init should be a square matrix and the shape should be same as the number of components in signal");
        }
        if (!str.equals("logcosh") && !str.equals("exp") && !str.equals("cube")) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (str.equals("logcosh") && (d2 > 2.0d || d2 < 1.0d)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        if (!str2.equals("unit-variance") && !str2.equals("arbitrary-variance") && !str2.isEmpty()) {
            throw new IllegalArgumentException("whiten must be one of \"unit-variance\", \"arbitrary-variance\" or an empty string. ");
        }
        this.func = str;
        this.whiten = str2;
        this.w_init = dArr2;
        this.max_iter = i;
        this.tol = d;
        this.alpha = d2;
    }

    public ICA(double[][] dArr, String str, String str2, int i, double d, double d2, long j) {
        this.alpha = 1.0d;
        this.max_iter = 200;
        this.tol = 1.0E-4d;
        this.whiten = "unit-variance";
        this.func = "logcosh";
        this.seed = 42L;
        this.n_iter = -1;
        this.mixingMatrix = null;
        this.unmixingMatrix = null;
        this.whiteningMatrix = null;
        this.componentMatrix = null;
        this.sources = null;
        this.signal = dArr;
        this.components = this.signal[0].length;
        if (!str.equals("logcosh") && !str.equals("exp") && !str.equals("cube")) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (str.equals("logcosh") && (d2 > 2.0d || d2 < 1.0d)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        if (!str2.equals("unit-variance") && !str2.equals("arbitrary-variance") && !str2.isEmpty()) {
            throw new IllegalArgumentException("whiten must be one of \"unit-variance\", \"arbitrary-variance\" or an empty string. ");
        }
        this.func = str;
        this.whiten = str2;
        this.seed = j;
        this.max_iter = i;
        this.tol = d;
        this.w_init = new Random(this.seed).randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] dArr, String str, int i, double d, long j) {
        this.alpha = 1.0d;
        this.max_iter = 200;
        this.tol = 1.0E-4d;
        this.whiten = "unit-variance";
        this.func = "logcosh";
        this.seed = 42L;
        this.n_iter = -1;
        this.mixingMatrix = null;
        this.unmixingMatrix = null;
        this.whiteningMatrix = null;
        this.componentMatrix = null;
        this.sources = null;
        this.signal = dArr;
        this.components = this.signal[0].length;
        if (!str.equals("logcosh") && !str.equals("exp") && !str.equals("cube")) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (str.equals("logcosh") && (d > 2.0d || d < 1.0d)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        this.func = str;
        this.seed = j;
        this.max_iter = i;
        this.w_init = new Random(this.seed).randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] dArr, String str, double[][] dArr2, int i, double d, double d2) {
        this.alpha = 1.0d;
        this.max_iter = 200;
        this.tol = 1.0E-4d;
        this.whiten = "unit-variance";
        this.func = "logcosh";
        this.seed = 42L;
        this.n_iter = -1;
        this.mixingMatrix = null;
        this.unmixingMatrix = null;
        this.whiteningMatrix = null;
        this.componentMatrix = null;
        this.sources = null;
        this.signal = dArr;
        this.components = this.signal[0].length;
        if (!str.equals("logcosh") && !str.equals("exp") && !str.equals("cube")) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (str.equals("logcosh") && (d2 > 2.0d || d2 < 1.0d)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        this.w_init = dArr2;
        this.func = str;
        this.max_iter = i;
        this.tol = d;
    }

    public ICA(double[][] dArr, String str, double d, long j) {
        this.alpha = 1.0d;
        this.max_iter = 200;
        this.tol = 1.0E-4d;
        this.whiten = "unit-variance";
        this.func = "logcosh";
        this.seed = 42L;
        this.n_iter = -1;
        this.mixingMatrix = null;
        this.unmixingMatrix = null;
        this.whiteningMatrix = null;
        this.componentMatrix = null;
        this.sources = null;
        this.signal = dArr;
        this.components = this.signal[0].length;
        if (!str.equals("logcosh") && !str.equals("exp") && !str.equals("cube")) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (str.equals("logcosh") && (d > 2.0d || d < 1.0d)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        this.func = str;
        this.alpha = d;
        this.seed = j;
        this.w_init = new Random(this.seed).randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] dArr, String str, double d) {
        this.alpha = 1.0d;
        this.max_iter = 200;
        this.tol = 1.0E-4d;
        this.whiten = "unit-variance";
        this.func = "logcosh";
        this.seed = 42L;
        this.n_iter = -1;
        this.mixingMatrix = null;
        this.unmixingMatrix = null;
        this.whiteningMatrix = null;
        this.componentMatrix = null;
        this.sources = null;
        this.signal = dArr;
        this.components = this.signal[0].length;
        if (!str.equals("logcosh") && !str.equals("exp") && !str.equals("cube")) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (str.equals("logcosh") && (d > 2.0d || d < 1.0d)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        this.func = str;
        this.alpha = d;
        this.w_init = new Random(this.seed).randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] dArr, String str) {
        this.alpha = 1.0d;
        this.max_iter = 200;
        this.tol = 1.0E-4d;
        this.whiten = "unit-variance";
        this.func = "logcosh";
        this.seed = 42L;
        this.n_iter = -1;
        this.mixingMatrix = null;
        this.unmixingMatrix = null;
        this.whiteningMatrix = null;
        this.componentMatrix = null;
        this.sources = null;
        this.signal = dArr;
        this.components = this.signal[0].length;
        if (!str.equals("logcosh") && !str.equals("exp") && !str.equals("cube")) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        this.gx = new double[this.signal.length];
        this.g_x = 0.0d;
        this.func = str;
        this.w_init = new Random(this.seed).randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] dArr, long j) {
        this.alpha = 1.0d;
        this.max_iter = 200;
        this.tol = 1.0E-4d;
        this.whiten = "unit-variance";
        this.func = "logcosh";
        this.seed = 42L;
        this.n_iter = -1;
        this.mixingMatrix = null;
        this.unmixingMatrix = null;
        this.whiteningMatrix = null;
        this.componentMatrix = null;
        this.sources = null;
        this.signal = dArr;
        this.components = this.signal[0].length;
        this.seed = j;
        this.w_init = new Random(this.seed).randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] dArr) {
        this.alpha = 1.0d;
        this.max_iter = 200;
        this.tol = 1.0E-4d;
        this.whiten = "unit-variance";
        this.func = "logcosh";
        this.seed = 42L;
        this.n_iter = -1;
        this.mixingMatrix = null;
        this.unmixingMatrix = null;
        this.whiteningMatrix = null;
        this.componentMatrix = null;
        this.sources = null;
        this.signal = dArr;
        this.components = this.signal[0].length;
        this.w_init = new Random(this.seed).randomNormal2D(new int[]{this.components, this.components});
    }

    private double[] _gs_decorrelation(double[] dArr, double[][] dArr2, int i) {
        double[][] transpose;
        double[][] subarray = UtilMethods.subarray(dArr2, i, dArr2.length);
        if (i == 0) {
            transpose = new double[dArr.length][dArr.length];
            for (double[] dArr3 : transpose) {
                Arrays.fill(dArr3, 0.0d);
            }
        } else {
            transpose = UtilMethods.transpose(UtilMethods.matrixMultiply(UtilMethods.transpose(subarray), subarray));
        }
        double[] dArr4 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr4.length; i2++) {
            dArr4[i2] = StatUtils.sum(MathArrays.ebeMultiply(transpose[i2], dArr));
        }
        return MathArrays.ebeSubtract(dArr, dArr4);
    }

    /* JADX WARN: Type inference failed for: r0v24, types: [double[], double[][]] */
    private double[][] icaDef(double[][] dArr, String str, int i, double[][] dArr2) {
        double[][] dArr3 = new double[this.components][this.components];
        for (double[] dArr4 : dArr3) {
            Arrays.fill(dArr4, 0.0d);
        }
        for (int i2 = 0; i2 < this.components; i2++) {
            double[] dArr5 = dArr2[i2];
            double[] scalarArithmetic = UtilMethods.scalarArithmetic(dArr5, Math.sqrt(StatUtils.sum(UtilMethods.scalarArithmetic(dArr5, 2.0d, "pow"))), "div");
            int i3 = 0;
            while (i3 < i) {
                double[] flattenMatrix = UtilMethods.flattenMatrix(UtilMethods.matrixMultiply(new double[]{scalarArithmetic}, dArr));
                if (str.equals("logcosh")) {
                    logcosh_(flattenMatrix);
                } else if (str.equals("cube")) {
                    cube_(flattenMatrix);
                } else {
                    exp_(flattenMatrix);
                }
                double[] dArr6 = new double[dArr.length];
                for (int i4 = 0; i4 < dArr6.length; i4++) {
                    dArr6[i4] = StatUtils.mean(MathArrays.ebeMultiply(dArr[i4], this.gx));
                }
                double[] _gs_decorrelation = _gs_decorrelation(MathArrays.ebeSubtract(dArr6, UtilMethods.scalarArithmetic(scalarArithmetic, this.g_x, "mul")), dArr3, i2);
                double[] scalarArithmetic2 = UtilMethods.scalarArithmetic(_gs_decorrelation, Math.sqrt(StatUtils.sum(UtilMethods.scalarArithmetic(_gs_decorrelation, 2.0d, "pow"))), "div");
                double abs = Math.abs(Math.abs(StatUtils.sum(MathArrays.ebeMultiply(scalarArithmetic2, scalarArithmetic))) - 1.0d);
                scalarArithmetic = scalarArithmetic2;
                if (abs < this.tol) {
                    break;
                }
                i3++;
            }
            this.n_iter = Math.max(this.n_iter, i3 + 1);
            dArr3[i2] = scalarArithmetic;
        }
        return dArr3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r14v0 */
    /* JADX WARN: Type inference failed for: r14v1 */
    /* JADX WARN: Type inference failed for: r14v3, types: [double[][]] */
    public void fit() {
        double[][] dArr;
        double[][] transpose;
        double length = this.signal.length;
        double[][] transpose2 = UtilMethods.transpose(this.signal);
        double[][] dArr2 = new double[0];
        this.zm_signal = UtilMethods.transpose(this.signal);
        if (this.whiten.isEmpty()) {
            dArr = transpose2;
        } else {
            this.mean_ = new double[transpose2.length];
            for (int i = 0; i < transpose2.length; i++) {
                this.mean_[i] = StatUtils.mean(transpose2[i]);
                this.zm_signal[i] = UtilMethods.zeroCenter(transpose2[i]);
            }
            SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(MatrixUtils.createRealMatrix(this.zm_signal));
            double[][] data = singularValueDecomposition.getU().getData();
            double[][] data2 = singularValueDecomposition.getS().getData();
            double[] sign = UtilMethods.sign(data[0]);
            for (int i2 = 0; i2 < data.length; i2++) {
                data[i2] = MathArrays.ebeMultiply(data[i2], sign);
            }
            for (double[] dArr3 : data2) {
                for (int i3 = 0; i3 < data2.length; i3++) {
                    dArr3[i3] = data2[i3][i3];
                }
            }
            dArr2 = UtilMethods.transpose(UtilMethods.ebeDivide(MatrixUtils.createRealMatrix(data), MatrixUtils.createRealMatrix(data2)).getData());
            dArr = UtilMethods.matrixMultiply(dArr2, this.zm_signal);
            for (int i4 = 0; i4 < dArr.length; i4++) {
                dArr[i4] = UtilMethods.scalarArithmetic(dArr[i4], Math.sqrt(length), "mul");
            }
        }
        double[][] icaDef = icaDef(dArr, this.func, this.max_iter, this.w_init);
        double[][] matrixMultiply = !this.whiten.isEmpty() ? UtilMethods.matrixMultiply(icaDef, UtilMethods.matrixMultiply(dArr2, this.zm_signal)) : UtilMethods.matrixMultiply(icaDef, this.zm_signal);
        if (this.whiten.isEmpty()) {
            transpose = UtilMethods.transpose(matrixMultiply);
            this.componentMatrix = icaDef;
        } else {
            if (this.whiten.equals("unit-variance")) {
                double[] dArr4 = new double[matrixMultiply.length];
                for (int i5 = 0; i5 < matrixMultiply.length; i5++) {
                    dArr4[i5] = 1.0d / Math.sqrt((StatUtils.variance(matrixMultiply[i5]) * (matrixMultiply[i5].length - 1)) / matrixMultiply[i5].length);
                }
                transpose = UtilMethods.transpose(matrixMultiply);
                for (int i6 = 0; i6 < transpose.length; i6++) {
                    transpose[i6] = MathArrays.ebeMultiply(transpose[i6], dArr4);
                }
                double[][] transpose3 = UtilMethods.transpose(icaDef);
                for (int i7 = 0; i7 < transpose3.length; i7++) {
                    transpose3[i7] = MathArrays.ebeMultiply(transpose3[i7], dArr4);
                }
                icaDef = UtilMethods.transpose(transpose3);
            } else {
                transpose = UtilMethods.transpose(matrixMultiply);
            }
            this.whiteningMatrix = dArr2;
            this.componentMatrix = UtilMethods.matrixMultiply(icaDef, dArr2);
        }
        this.mixingMatrix = UtilMethods.pseudoInverse(this.componentMatrix);
        this.unmixingMatrix = icaDef;
        this.sources = transpose;
    }

    public double[][] transform() throws ExceptionInInitializerError {
        if (this.unmixingMatrix == null) {
            throw new ExceptionInInitializerError("Execute fit() before calling this function");
        }
        return this.sources;
    }

    public double[][] transform(double[][] dArr) throws ExceptionInInitializerError, ArithmeticException {
        if (this.unmixingMatrix == null) {
            throw new ExceptionInInitializerError("Execute fit() before calling this function");
        }
        if (dArr[0].length != this.components) {
            throw new ArithmeticException("Number of components has to be same as original signal");
        }
        if (!this.whiten.isEmpty()) {
            double[][] transpose = UtilMethods.transpose(dArr);
            for (int i = 0; i < transpose.length; i++) {
                transpose[i] = UtilMethods.scalarArithmetic(transpose[i], this.mean_[i], "sub");
            }
            dArr = UtilMethods.transpose(transpose);
        }
        return UtilMethods.matrixMultiply(dArr, UtilMethods.transpose(this.componentMatrix));
    }
}
