package io.trino.operator.aggregation;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.trino.block.BlockAssertions;
import io.trino.spi.block.Block;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import java.util.List;
import java.util.stream.DoubleStream;
import org.apache.commons.math3.stat.correlation.PearsonsCorrelation;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/operator/aggregation/TestDoubleCorrelationAggregation.class */
public class TestDoubleCorrelationAggregation extends AbstractTestAggregationFunction {
    @Override // io.trino.operator.aggregation.AbstractTestAggregationFunction
    protected Block[] getSequenceBlocks(int i, int i2) {
        return new Block[]{BlockAssertions.createDoubleSequenceBlock(i, i + i2), BlockAssertions.createDoubleSequenceBlock(i + 2, i + 2 + i2)};
    }

    @Override // io.trino.operator.aggregation.AbstractTestAggregationFunction
    protected String getFunctionName() {
        return "corr";
    }

    @Override // io.trino.operator.aggregation.AbstractTestAggregationFunction
    protected List<Type> getFunctionParameterTypes() {
        return ImmutableList.of(DoubleType.DOUBLE, DoubleType.DOUBLE);
    }

    @Override // io.trino.operator.aggregation.AbstractTestAggregationFunction
    protected Object getExpectedValue(int i, int i2) {
        if (i2 <= 1) {
            return null;
        }
        return Double.valueOf(new PearsonsCorrelation().correlation(AggregationTestUtils.constructDoublePrimitiveArray(i + 2, i2), AggregationTestUtils.constructDoublePrimitiveArray(i, i2)));
    }

    @Test
    public void testDivisionByZero() {
        testAggregation(null, BlockAssertions.createDoublesBlock(Double.valueOf(2.0d), Double.valueOf(2.0d), Double.valueOf(2.0d), Double.valueOf(2.0d), Double.valueOf(2.0d)), BlockAssertions.createDoublesBlock(Double.valueOf(1.0d), Double.valueOf(4.0d), Double.valueOf(9.0d), Double.valueOf(16.0d), Double.valueOf(25.0d)));
        testAggregation(null, BlockAssertions.createDoublesBlock(Double.valueOf(1.0d), Double.valueOf(4.0d), Double.valueOf(9.0d), Double.valueOf(16.0d), Double.valueOf(25.0d)), BlockAssertions.createDoublesBlock(Double.valueOf(2.0d), Double.valueOf(2.0d), Double.valueOf(2.0d), Double.valueOf(2.0d), Double.valueOf(2.0d)));
    }

    @Test
    public void testNonTrivialResult() {
        testNonTrivialAggregation(new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}, new double[]{1.0d, 4.0d, 9.0d, 16.0d, 25.0d});
    }

    @Test
    public void testInverseCorrelation() {
        testNonTrivialAggregation(new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}, new double[]{5.0d, 4.0d, 3.0d, 2.0d, 1.0d});
    }

    private void testNonTrivialAggregation(double[] dArr, double[] dArr2) {
        double correlation = new PearsonsCorrelation().correlation(dArr2, dArr);
        Preconditions.checkArgument((!Double.isFinite(correlation) || correlation == 0.0d || correlation == 1.0d) ? false : true, "Expected result is trivial");
        testAggregation(Double.valueOf(correlation), BlockAssertions.createDoublesBlock(box(dArr)), BlockAssertions.createDoublesBlock(box(dArr2)));
    }

    private Double[] box(double[] dArr) {
        return (Double[]) DoubleStream.of(dArr).boxed().toArray(i -> {
            return new Double[i];
        });
    }
}
