package com.github.psambit9791.jdsp.signal;

import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.apache.commons.math3.stat.regression.SimpleRegression;

/* loaded from: input_file:com/github/psambit9791/jdsp/signal/Detrend.class */
public class Detrend {
    private String mode;
    private int power;
    private double[] originalSignal;
    private double[] detrendedSignal;
    private double[] trendLine;

    public Detrend(double[] dArr, String str) {
        this.originalSignal = dArr;
        this.mode = str;
        if (this.mode.equals("poly")) {
            this.power = 2;
        }
        this.trendLine = new double[dArr.length];
    }

    public Detrend(double[] dArr) {
        this.originalSignal = dArr;
        this.mode = "linear";
        this.trendLine = new double[dArr.length];
    }

    public Detrend(double[] dArr, int i) {
        this.originalSignal = dArr;
        this.mode = "poly";
        this.power = i;
        this.trendLine = new double[dArr.length];
    }

    public double[] detrendSignal() throws IllegalArgumentException {
        if (this.mode.equals("constant")) {
            this.detrendedSignal = constantDetrend(this.originalSignal);
            return this.detrendedSignal;
        }
        if (this.mode.equals("poly")) {
            this.detrendedSignal = polyDetrend(this.originalSignal, this.power);
            return this.detrendedSignal;
        }
        if (!this.mode.equals("linear")) {
            throw new IllegalArgumentException("Mode can only be linear, constant or poly.");
        }
        this.detrendedSignal = linearDetrend(this.originalSignal);
        return this.detrendedSignal;
    }

    public double[] getTrendLine() {
        return this.trendLine;
    }

    private double[] linearDetrend(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        double[] generateX = generateX(dArr);
        SimpleRegression simpleRegression = new SimpleRegression();
        for (int i = 0; i < dArr.length; i++) {
            simpleRegression.addData(generateX[i], dArr[i]);
        }
        double slope = simpleRegression.getSlope();
        double intercept = simpleRegression.getIntercept();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            this.trendLine[i2] = (generateX[i2] * slope) + intercept;
            dArr2[i2] = dArr[i2] - this.trendLine[i2];
        }
        return dArr2;
    }

    private double[] polyDetrend(double[] dArr, int i) {
        double[] dArr2 = new double[dArr.length];
        double[][] generateX = generateX(dArr, i);
        OLSMultipleLinearRegression oLSMultipleLinearRegression = new OLSMultipleLinearRegression();
        oLSMultipleLinearRegression.setNoIntercept(true);
        oLSMultipleLinearRegression.newSampleData(dArr, generateX);
        double[] estimateRegressionParameters = oLSMultipleLinearRegression.estimateRegressionParameters();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            for (int i3 = 0; i3 <= i; i3++) {
                double[] dArr3 = this.trendLine;
                int i4 = i2;
                dArr3[i4] = dArr3[i4] + (generateX[i2][i3] * estimateRegressionParameters[i3]);
            }
            dArr2[i2] = dArr[i2] - this.trendLine[i2];
        }
        return dArr2;
    }

    private double[] constantDetrend(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        double findMean = findMean(dArr);
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = dArr[i] - findMean;
        }
        return dArr2;
    }

    private double findMean(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return d / dArr.length;
    }

    private double[] generateX(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        double length = dArr.length;
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = (i + 1) / length;
        }
        return dArr2;
    }

    private double[][] generateX(double[] dArr, int i) {
        double[][] dArr2 = new double[dArr.length][i + 1];
        double length = dArr.length;
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            for (int i3 = 0; i3 <= i; i3++) {
                if (i3 > 1) {
                    dArr2[i2][i3] = Math.pow(dArr2[i2][1], i3);
                } else if (i3 == 1) {
                    dArr2[i2][i3] = (i2 + 1) / length;
                } else {
                    dArr2[i2][i3] = 1.0d;
                }
            }
        }
        return dArr2;
    }
}
