package com.facebook.presto.operator.aggregation.noisyaggregation;

import com.facebook.presto.block.BlockAssertions;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.BenchmarkHashAndSegmentedAggregationOperators;
import com.facebook.presto.operator.aggregation.AggregationTestUtils;
import com.facebook.presto.operator.scalar.AbstractTestFunctions;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.testing.LocalQueryRunner;
import com.facebook.presto.testing.MaterializedRow;
import java.util.Arrays;
import java.util.List;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisyCountIfGaussianAggregation.class */
public class TestNoisyCountIfGaussianAggregation extends AbstractTestFunctions {
    private static final String FUNCTION_NAME = "noisy_count_if_gaussian";
    private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();

    @Test
    public void testNoisyCountIfGaussianDefinitions() {
        getFunction(BooleanType.BOOLEAN, DoubleType.DOUBLE);
        getFunction(BooleanType.BOOLEAN, DoubleType.DOUBLE, BigintType.BIGINT);
    }

    @Test(expectedExceptions = {PrestoException.class})
    public void testNoisyCountIfGaussianInvalidNoiseScale() {
        JavaAggregationFunctionImplementation function = getFunction(BooleanType.BOOLEAN, DoubleType.DOUBLE);
        List createTestValues = TestNoisyAggregationUtils.createTestValues(10, false, true, false);
        AggregationTestUtils.assertAggregation(function, TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_count_if_gaussian(boolean, noiseScale) with noiseScale < 0 which means errors", new Page(new Block[]{BlockAssertions.createBooleansBlock(createTestValues), BlockAssertions.createRLEBlock(-123.0d, 10)}), Double.valueOf(TestNoisyAggregationUtils.countTrue(createTestValues)));
    }

    @Test
    public void testNoisyCountIfGaussianZeroNoiseScale() {
        JavaAggregationFunctionImplementation function = getFunction(BooleanType.BOOLEAN, DoubleType.DOUBLE);
        List createTestValues = TestNoisyAggregationUtils.createTestValues(10, false, true, false);
        AggregationTestUtils.assertAggregation(function, TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_count_if_gaussian(boolean, noiseScale) with noiseScale=0 which means no noise", new Page(new Block[]{BlockAssertions.createBooleansBlock(createTestValues), BlockAssertions.createRLEBlock(0.0d, 10)}), Double.valueOf(TestNoisyAggregationUtils.countTrue(createTestValues)));
    }

    @Test
    public void testNoisyCountIfGaussianZeroNoiseScaleWithNull() {
        JavaAggregationFunctionImplementation function = getFunction(BooleanType.BOOLEAN, DoubleType.DOUBLE);
        List createTestValues = TestNoisyAggregationUtils.createTestValues(10, true, true, false);
        AggregationTestUtils.assertAggregation(function, TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_count_if_gaussian(boolean, noiseScale) with noiseScale=0 and 1 null row which means no noise", new Page(new Block[]{BlockAssertions.createBooleansBlock(createTestValues), BlockAssertions.createRLEBlock(0.0d, 10)}), Double.valueOf(TestNoisyAggregationUtils.countTrue(createTestValues)));
    }

    @Test
    public void testNoisyCountIfGaussianSomeNoiseScaleWithinSomeStd() {
        JavaAggregationFunctionImplementation function = getFunction(BooleanType.BOOLEAN, DoubleType.DOUBLE);
        List createTestValues = TestNoisyAggregationUtils.createTestValues(BenchmarkHashAndSegmentedAggregationOperators.Context.ROWS_PER_PAGE, true, true, false);
        AggregationTestUtils.assertAggregation(function, TestNoisyAggregationUtils.withinSomeStdAssertion, "Test noisy_count_if_gaussian(boolean, noiseScale) within some std from mean", new Page(new Block[]{BlockAssertions.createBooleansBlock(createTestValues), BlockAssertions.createRLEBlock(1.0d, BenchmarkHashAndSegmentedAggregationOperators.Context.ROWS_PER_PAGE)}), Double.valueOf(TestNoisyAggregationUtils.countTrue(createTestValues)));
    }

    @Test
    public void testNoisyCountIfGaussianNoiseScaleVsNormalCountIf() {
        String buildData = TestNoisyAggregationUtils.buildData(10, true, Arrays.asList("boolean", "double"));
        String buildColumnName = TestNoisyAggregationUtils.buildColumnName("boolean");
        String format = String.format("SELECT COUNT_IF(%s) FROM %s", buildColumnName, buildData);
        Assert.assertEquals(Double.parseDouble(runQuery(String.format("SELECT %s(%s, %f) FROM %s", FUNCTION_NAME, buildColumnName, Double.valueOf(0.0d), buildData)).get(0).getField(0).toString()), Double.parseDouble(runQuery(format).get(0).getField(0).toString()));
    }

    @Test
    public void testNoisyCountIfGaussianZeroNoiseScaleZeroRandomSeed() {
        JavaAggregationFunctionImplementation function = getFunction(BooleanType.BOOLEAN, DoubleType.DOUBLE, BigintType.BIGINT);
        List createTestValues = TestNoisyAggregationUtils.createTestValues(10, true, true, false);
        AggregationTestUtils.assertAggregation(function, TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_count_if_gaussian(boolean, noiseScale, randomSeed) with noiseScale=0 which means no noise", new Page(new Block[]{BlockAssertions.createBooleansBlock(createTestValues), BlockAssertions.createRLEBlock(0.0d, 10), BlockAssertions.createRLEBlock(0L, 10)}), Double.valueOf(TestNoisyAggregationUtils.countTrue(createTestValues)));
    }

    @Test
    public void testNoisyCountIfGaussianSomeNoiseScaleFixedRandomSeed() {
        JavaAggregationFunctionImplementation function = getFunction(BooleanType.BOOLEAN, DoubleType.DOUBLE, BigintType.BIGINT);
        List createTestValues = TestNoisyAggregationUtils.createTestValues(10, true, true, false);
        AggregationTestUtils.assertAggregation(function, TestNoisyAggregationUtils.equalDoubleAssertion, "Test noisy_count_if_gaussian(boolean, noiseScale, randomSeed) with noiseScale=0 which means no noise", new Page(new Block[]{BlockAssertions.createBooleansBlock(createTestValues), BlockAssertions.createRLEBlock(12.0d, 10), BlockAssertions.createRLEBlock(10L, 10)}), Double.valueOf(TestNoisyAggregationUtils.countTrue(createTestValues) + 10.0d));
    }

    @Test
    public void testNoisyCountIfGaussianNoInputRowsWithoutGroupBy() {
        List<MaterializedRow> runQuery = runQuery("SELECT noisy_count_if_gaussian(" + TestNoisyAggregationUtils.buildColumnName("boolean") + ", 0) + 1 FROM " + TestNoisyAggregationUtils.buildData(100, true, Arrays.asList("boolean", "double", "real", "decimal")) + " WHERE false");
        Assert.assertEquals(runQuery.size(), 1);
        Assert.assertNull(runQuery.get(0).getField(0));
    }

    @Test
    public void testNoisyCountIfGaussianNoInputRowsWithGroupBy() {
        String buildData = TestNoisyAggregationUtils.buildData(100, true, Arrays.asList("boolean", "double", "real", "decimal"));
        String buildColumnName = TestNoisyAggregationUtils.buildColumnName("boolean");
        Assert.assertEquals(runQuery("SELECT noisy_count_if_gaussian(" + buildColumnName + ", 0) + 1 FROM " + buildData + " WHERE false GROUP BY " + buildColumnName).size(), 0);
    }

    private List<MaterializedRow> runQuery(String str) {
        return new LocalQueryRunner(this.session).execute(str).toTestTypes().getMaterializedRows();
    }

    private JavaAggregationFunctionImplementation getFunction(Type... typeArr) {
        return FUNCTION_AND_TYPE_MANAGER.getJavaAggregateFunctionImplementation(FUNCTION_AND_TYPE_MANAGER.lookupFunction(FUNCTION_NAME, TypeSignatureProvider.fromTypes(typeArr)));
    }
}
