package com.facebook.presto.operator.aggregation.noisyaggregation;

import com.facebook.presto.block.BlockAssertions;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.RealType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.BenchmarkHashAndSegmentedAggregationOperators;
import com.facebook.presto.operator.aggregation.AggregationTestUtils;
import com.facebook.presto.operator.scalar.AbstractTestFunctions;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.testing.LocalQueryRunner;
import com.facebook.presto.testing.MaterializedRow;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisySumGaussianRealAggregation.class */
public class TestNoisySumGaussianRealAggregation extends AbstractTestFunctions {
    private static final String FUNCTION_NAME = "noisy_sum_gaussian";
    private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();
    private static final double DEFAULT_TEST_STANDARD_DEVIATION = 1.0d;

    @Test
    public void testNoisySumGaussianRealDefinitions() {
        getFunction(RealType.REAL, DoubleType.DOUBLE);
        getFunction(RealType.REAL, DoubleType.DOUBLE, BigintType.BIGINT);
        getFunction(RealType.REAL, DoubleType.DOUBLE, DoubleType.DOUBLE, DoubleType.DOUBLE);
        getFunction(RealType.REAL, DoubleType.DOUBLE, DoubleType.DOUBLE, DoubleType.DOUBLE, BigintType.BIGINT);
    }

    @Test(expectedExceptions = {PrestoException.class})
    public void testNoisySumGaussianRealInvalidNoiseScale() {
        JavaAggregationFunctionImplementation function = getFunction(RealType.REAL, DoubleType.DOUBLE);
        List<Double> createTestValues = TestNoisyAggregationUtils.createTestValues(10, false, Double.valueOf(1.0d), true);
        AggregationTestUtils.assertAggregation(function, TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_sum_gaussian(real, noiseScale) with noiseScale < 0 which means errors", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(createTestValues)), BlockAssertions.createRLEBlock(-123.0d, 10)}), Double.valueOf(TestNoisyAggregationUtils.sum(createTestValues)));
    }

    @Test
    public void testNoisySumGaussianRealZeroNoiseScale() {
        JavaAggregationFunctionImplementation function = getFunction(RealType.REAL, DoubleType.DOUBLE);
        List<Double> createTestValues = TestNoisyAggregationUtils.createTestValues(10, false, Double.valueOf(1.0d), true);
        AggregationTestUtils.assertAggregation(function, TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_sum_gaussian(real, noiseScale) with noiseScale=0 which means no noise", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(createTestValues)), BlockAssertions.createRLEBlock(0.0d, 10)}), Double.valueOf(TestNoisyAggregationUtils.sum(createTestValues)));
    }

    @Test
    public void testNoisySumGaussianRealZeroNoiseScaleWithNull() {
        JavaAggregationFunctionImplementation function = getFunction(RealType.REAL, DoubleType.DOUBLE);
        List<Double> createTestValues = TestNoisyAggregationUtils.createTestValues(10, true, Double.valueOf(1.0d), true);
        AggregationTestUtils.assertAggregation(function, TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_sum_gaussian(real, noiseScale) with noiseScale=0 and 1 null row which means no noise", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(createTestValues)), BlockAssertions.createRLEBlock(0.0d, 10)}), Double.valueOf(TestNoisyAggregationUtils.sum(createTestValues)));
    }

    @Test
    public void testNoisySumGaussianRealSomeNoiseScale() {
        JavaAggregationFunctionImplementation function = getFunction(RealType.REAL, DoubleType.DOUBLE);
        List<Double> createTestValues = TestNoisyAggregationUtils.createTestValues(10, false, Double.valueOf(1.0d), true);
        AggregationTestUtils.assertAggregation(function, TestNoisyAggregationUtils.notEqualDoubleAssertion, "Test noisy_sum_gaussian(real, noiseScale) with noiseScale > 0 which means some noise", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(createTestValues)), BlockAssertions.createRLEBlock(1.0d, 10)}), Double.valueOf(TestNoisyAggregationUtils.sum(createTestValues)));
    }

    @Test
    public void testNoisySumGaussianRealSomeNoiseScaleWithinSomeStd() {
        JavaAggregationFunctionImplementation function = getFunction(RealType.REAL, DoubleType.DOUBLE);
        BiFunction biFunction = (obj, obj2) -> {
            double doubleValue = new Double(obj.toString()).doubleValue();
            double doubleValue2 = new Double(obj2.toString()).doubleValue();
            return Boolean.valueOf(doubleValue2 - 50.0d <= doubleValue && doubleValue <= doubleValue2 + 50.0d);
        };
        List<Double> createTestValues = TestNoisyAggregationUtils.createTestValues(BenchmarkHashAndSegmentedAggregationOperators.Context.ROWS_PER_PAGE, false, Double.valueOf(1.0d), true);
        AggregationTestUtils.assertAggregation(function, biFunction, "Test noisy_sum_gaussian(real, noiseScale) within some std from mean", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(createTestValues)), BlockAssertions.createRLEBlock(1.0d, BenchmarkHashAndSegmentedAggregationOperators.Context.ROWS_PER_PAGE)}), Double.valueOf(TestNoisyAggregationUtils.sum(createTestValues)));
    }

    @Test
    public void testNoisySumGaussianRealNoiseScaleVsNormalSum() {
        String buildData = TestNoisyAggregationUtils.buildData(10, true, Arrays.asList("bigint", "double", "real", "decimal"));
        String buildColumnName = TestNoisyAggregationUtils.buildColumnName("real");
        String format = String.format("SELECT SUM(%s) FROM %s", buildColumnName, buildData);
        Assert.assertEquals(Double.parseDouble(runQuery(String.format("SELECT %s(%s, %f) FROM %s", FUNCTION_NAME, buildColumnName, Double.valueOf(0.0d), buildData)).get(0).getField(0).toString()), Double.parseDouble(runQuery(format).get(0).getField(0).toString()));
    }

    @Test
    public void testNoisySumGaussianRealClippingZeroNoiseScale() {
        AggregationTestUtils.assertAggregation(getFunction(RealType.REAL, DoubleType.DOUBLE, DoubleType.DOUBLE, DoubleType.DOUBLE), TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_sum_gaussian(real, noiseScale, lower, upper) with noiseScale=0 which means no noise, and clipping", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(TestNoisyAggregationUtils.createTestValues(10, false, Double.valueOf(1.0d), false))), BlockAssertions.createRLEBlock(0.0d, 10), BlockAssertions.createRLEBlock(2.0d, 10), BlockAssertions.createRLEBlock(8.0d, 10)}), Double.valueOf(47.0d));
    }

    @Test(expectedExceptions = {PrestoException.class})
    public void testNoisySumGaussianRealClippingInvalidBound() {
        AggregationTestUtils.assertAggregation(getFunction(RealType.REAL, DoubleType.DOUBLE, DoubleType.DOUBLE, DoubleType.DOUBLE), TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_sum_gaussian(real, noiseScale, lower, upper) with clipping lower > upper ", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(TestNoisyAggregationUtils.createTestValues(10, false, Double.valueOf(1.0d), false))), BlockAssertions.createRLEBlock(0.0d, 10), BlockAssertions.createRLEBlock(2.0d, 10), BlockAssertions.createRLEBlock(-8.0d, 10)}), Double.valueOf(45.0d));
    }

    @Test
    public void testNoisySumGaussianRealClippingZeroNoiseScaleWithNull() {
        AggregationTestUtils.assertAggregation(getFunction(RealType.REAL, DoubleType.DOUBLE, DoubleType.DOUBLE, DoubleType.DOUBLE), TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_sum_gaussian(real, noiseScale, lower, upper) with noiseScale=0 which means no noise, and clipping, with null values", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(TestNoisyAggregationUtils.createTestValues(10, true, Double.valueOf(1.0d), false))), BlockAssertions.createRLEBlock(0.0d, 10), BlockAssertions.createRLEBlock(2.0d, 10), BlockAssertions.createRLEBlock(8.0d, 10)}), Double.valueOf(45.0d));
    }

    @Test
    public void testNoisySumGaussianRealClippingSomeNoiseScale() {
        AggregationTestUtils.assertAggregation(getFunction(RealType.REAL, DoubleType.DOUBLE, DoubleType.DOUBLE, DoubleType.DOUBLE), TestNoisyAggregationUtils.notEqualDoubleAssertion, "Test noisy_sum_gaussian(real, noiseScale, lower, upper) with noiseScale > 0 which means some noise", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(TestNoisyAggregationUtils.createTestValues(10, true, Double.valueOf(1.0d), false))), BlockAssertions.createRLEBlock(1.0d, 10), BlockAssertions.createRLEBlock(2.0d, 10), BlockAssertions.createRLEBlock(8.0d, 10)}), Double.valueOf(45.0d));
    }

    @Test
    public void testNoisySumGaussianRealClippingSomeNoiseScaleWithinSomeStd() {
        AggregationTestUtils.assertAggregation(getFunction(RealType.REAL, DoubleType.DOUBLE, DoubleType.DOUBLE, DoubleType.DOUBLE), (obj, obj2) -> {
            double doubleValue = new Double(obj.toString()).doubleValue();
            double doubleValue2 = new Double(obj2.toString()).doubleValue();
            return Boolean.valueOf(doubleValue2 - 5.0d <= doubleValue && doubleValue <= doubleValue2 + 5.0d);
        }, "Test noisy_sum_gaussian(real, noiseScale, lower, upper) within some std from mean", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(TestNoisyAggregationUtils.createTestValues(10, true, Double.valueOf(1.0d), false))), BlockAssertions.createRLEBlock(1.0d, 10), BlockAssertions.createRLEBlock(2.0d, 10), BlockAssertions.createRLEBlock(8.0d, 10)}), Double.valueOf(45.0d));
    }

    @Test
    public void testNoisySumGaussianRealClippingRandomSeed() {
        AggregationTestUtils.assertAggregation(getFunction(RealType.REAL, DoubleType.DOUBLE, DoubleType.DOUBLE, DoubleType.DOUBLE, BigintType.BIGINT), TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_sum_gaussian(real, noiseScale, lower, upper, randomSeed)", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(TestNoisyAggregationUtils.createTestValues(10, false, Double.valueOf(1.0d), false))), BlockAssertions.createRLEBlock(12.0d, 10), BlockAssertions.createRLEBlock(2.0d, 10), BlockAssertions.createRLEBlock(5.0d, 10), BlockAssertions.createRLEBlock(10L, 10)}), Double.valueOf(48.4961467597545d));
    }

    @Test
    public void testNoisySumGaussianRealZeroNoiseScaleZeroRandomSeed() {
        JavaAggregationFunctionImplementation function = getFunction(RealType.REAL, DoubleType.DOUBLE, BigintType.BIGINT);
        List<Double> createTestValues = TestNoisyAggregationUtils.createTestValues(10, true, Double.valueOf(1.0d), false);
        AggregationTestUtils.assertAggregation(function, TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_sum_gaussian(double, noiseScale, randomSeed) with noiseScale=0 which means no noise", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(createTestValues)), BlockAssertions.createRLEBlock(0.0d, 10), BlockAssertions.createRLEBlock(0L, 10)}), Double.valueOf(TestNoisyAggregationUtils.sum(createTestValues)));
    }

    @Test
    public void testNoisySumGaussianRealSomeNoiseScaleFixedRandomSeed() {
        AggregationTestUtils.assertAggregation(getFunction(RealType.REAL, DoubleType.DOUBLE, BigintType.BIGINT), TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_sum_gaussian(real, noiseScale, randomSeed) with noiseScale=0 which means no noise", new Page(new Block[]{BlockAssertions.createBlockOfReals(doubleListToFloatList(TestNoisyAggregationUtils.createTestValues(10, true, Double.valueOf(1.0d), false))), BlockAssertions.createRLEBlock(12.0d, 10), BlockAssertions.createRLEBlock(10L, 10)}), Double.valueOf(55.496146759754d));
    }

    @Test
    public void testNoisySumGaussianRealNoInputRowsWithoutGroupBy() {
        List<MaterializedRow> runQuery = runQuery("SELECT noisy_sum_gaussian(" + TestNoisyAggregationUtils.buildColumnName("real") + ", 0) + 1 FROM " + TestNoisyAggregationUtils.buildData(100, true, Arrays.asList("bigint", "double", "real", "decimal")) + " WHERE false");
        Assert.assertEquals(runQuery.size(), 1);
        Assert.assertNull(runQuery.get(0).getField(0));
    }

    @Test
    public void testNoisySumGaussianRealNoInputRowsWithGroupBy() {
        String buildData = TestNoisyAggregationUtils.buildData(100, true, Arrays.asList("bigint", "double", "real", "decimal"));
        String buildColumnName = TestNoisyAggregationUtils.buildColumnName("real");
        Assert.assertEquals(runQuery("SELECT noisy_sum_gaussian(" + buildColumnName + ", 0) + 1 FROM " + buildData + " WHERE false GROUP BY " + buildColumnName).size(), 0);
    }

    private List<MaterializedRow> runQuery(String str) {
        return new LocalQueryRunner(this.session).execute(str).toTestTypes().getMaterializedRows();
    }

    private List<Float> doubleListToFloatList(List<Double> list) {
        return (List) list.stream().map(d -> {
            if (d == null) {
                return null;
            }
            return Float.valueOf(d.floatValue());
        }).collect(Collectors.toList());
    }

    private JavaAggregationFunctionImplementation getFunction(Type... typeArr) {
        return FUNCTION_AND_TYPE_MANAGER.getJavaAggregateFunctionImplementation(FUNCTION_AND_TYPE_MANAGER.lookupFunction(FUNCTION_NAME, TypeSignatureProvider.fromTypes(typeArr)));
    }
}
