package com.facebook.presto.type.khyperloglog;

import com.facebook.presto.block.BlockAssertions;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.SqlVarbinary;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.aggregation.AggregationTestUtils;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.slice.XxHash64;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/type/khyperloglog/TestKHyperLogLogAggregationFunction.class */
public class TestKHyperLogLogAggregationFunction {
    private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();
    private static final String NAME = KHyperLogLogWithLimitAggregationFunction.getFunctionName();

    @Test
    public void testSimpleKHyperLogLog() {
        List<Long> generateLongs = generateLongs(100);
        List<Slice> generateStringSlices = generateStringSlices(100);
        List<Double> generateDoubles = generateDoubles(100);
        testAggregation(BigintType.BIGINT, generateLongs, BigintType.BIGINT, generateLongs);
        testAggregation(BigintType.BIGINT, generateLongs, VarcharType.VARCHAR, generateStringSlices);
        testAggregation(VarcharType.VARCHAR, generateStringSlices, BigintType.BIGINT, generateLongs);
        testAggregation(VarcharType.VARCHAR, generateStringSlices, VarcharType.VARCHAR, generateStringSlices);
        testAggregation(DoubleType.DOUBLE, generateDoubles, BigintType.BIGINT, generateLongs);
        testAggregation(DoubleType.DOUBLE, generateDoubles, VarcharType.VARCHAR, generateStringSlices);
    }

    @Test
    public void testBigKHyperLogLog() {
        List<Long> generateLongs = generateLongs(100000);
        List<Slice> generateStringSlices = generateStringSlices(100000);
        List<Double> generateDoubles = generateDoubles(100000);
        testAggregation(BigintType.BIGINT, generateLongs, BigintType.BIGINT, generateLongs);
        testAggregation(BigintType.BIGINT, generateLongs, VarcharType.VARCHAR, generateStringSlices);
        testAggregation(VarcharType.VARCHAR, generateStringSlices, BigintType.BIGINT, generateLongs);
        testAggregation(VarcharType.VARCHAR, generateStringSlices, VarcharType.VARCHAR, generateStringSlices);
        testAggregation(DoubleType.DOUBLE, generateDoubles, BigintType.BIGINT, generateLongs);
        testAggregation(DoubleType.DOUBLE, generateDoubles, VarcharType.VARCHAR, generateStringSlices);
    }

    @Test
    public void testKHyperLogLogWithSomeNulls() {
        List<Long> generateLongs = generateLongs(3);
        List<Slice> generateStringSlices = generateStringSlices(3);
        List<Double> generateDoubles = generateDoubles(3);
        includeNulls(generateLongs);
        includeNulls(generateStringSlices);
        includeNulls(generateDoubles);
        testAggregation(BigintType.BIGINT, generateLongs, BigintType.BIGINT, generateLongs);
        testAggregation(BigintType.BIGINT, generateLongs, VarcharType.VARCHAR, generateStringSlices);
        testAggregation(VarcharType.VARCHAR, generateStringSlices, BigintType.BIGINT, generateLongs);
        testAggregation(VarcharType.VARCHAR, generateStringSlices, VarcharType.VARCHAR, generateStringSlices);
        testAggregation(DoubleType.DOUBLE, generateDoubles, BigintType.BIGINT, generateLongs);
        testAggregation(DoubleType.DOUBLE, generateDoubles, VarcharType.VARCHAR, generateStringSlices);
    }

    @Test
    public void testKHyperLogLogWithNullColumn() {
        List<Long> generateLongs = generateLongs(3);
        List<Slice> generateStringSlices = generateStringSlices(3);
        List<Double> generateDoubles = generateDoubles(3);
        List<Object> generateNulls = generateNulls(3);
        testAggregation(BigintType.BIGINT, generateNulls, BigintType.BIGINT, generateLongs);
        testAggregation(BigintType.BIGINT, generateLongs, BigintType.BIGINT, generateNulls);
        testAggregation(BigintType.BIGINT, generateNulls, VarcharType.VARCHAR, generateStringSlices);
        testAggregation(BigintType.BIGINT, generateLongs, VarcharType.VARCHAR, generateNulls);
        testAggregation(VarcharType.VARCHAR, generateNulls, BigintType.BIGINT, generateLongs);
        testAggregation(VarcharType.VARCHAR, generateStringSlices, BigintType.BIGINT, generateNulls);
        testAggregation(VarcharType.VARCHAR, generateNulls, VarcharType.VARCHAR, generateStringSlices);
        testAggregation(VarcharType.VARCHAR, generateStringSlices, VarcharType.VARCHAR, generateNulls);
        testAggregation(DoubleType.DOUBLE, generateNulls, BigintType.BIGINT, generateLongs);
        testAggregation(DoubleType.DOUBLE, generateDoubles, BigintType.BIGINT, generateNulls);
        testAggregation(DoubleType.DOUBLE, generateNulls, VarcharType.VARCHAR, generateStringSlices);
        testAggregation(DoubleType.DOUBLE, generateDoubles, VarcharType.VARCHAR, generateNulls);
    }

    private void testAggregation(Type type, List<?> list, Type type2, List<?> list2) {
        JavaAggregationFunctionImplementation aggregation = getAggregation(type, type2);
        KHyperLogLog kHyperLogLog = null;
        for (int i = 0; i < list.size(); i++) {
            if (list.get(i) != null && list2.get(i) != null) {
                if (kHyperLogLog == null) {
                    kHyperLogLog = new KHyperLogLog();
                }
                long j = toLong(list.get(i), type);
                long j2 = toLong(list2.get(i), type2);
                if (type == VarcharType.VARCHAR) {
                    kHyperLogLog.add((Slice) list.get(i), j2);
                } else {
                    kHyperLogLog.add(j, j2);
                }
            }
        }
        AggregationTestUtils.assertAggregation(aggregation, kHyperLogLog == null ? null : new SqlVarbinary(kHyperLogLog.serialize().getBytes()), buildBlock(list, type), buildBlock(list2, type2));
    }

    private long toLong(Object obj, Type type) {
        return type == DoubleType.DOUBLE ? Double.doubleToLongBits(((Double) obj).doubleValue()) : type == VarcharType.VARCHAR ? XxHash64.hash((Slice) obj) : ((Long) obj).longValue();
    }

    private Block buildBlock(List<?> list, Type type) {
        return type == DoubleType.DOUBLE ? BlockAssertions.createDoublesBlock((Iterable<Double>) list.stream().map(obj -> {
            return (Double) obj;
        }).collect(Collectors.toList())) : type == VarcharType.VARCHAR ? BlockAssertions.createSlicesBlock((Iterable<Slice>) list.stream().map(obj2 -> {
            return (Slice) obj2;
        }).collect(Collectors.toList())) : BlockAssertions.createLongsBlock((Iterable<Long>) list.stream().map(obj3 -> {
            return (Long) obj3;
        }).collect(Collectors.toList()));
    }

    private List<Slice> buildStringSliceList(List<String> list) {
        return (List) list.stream().map(this::stringToSlice).collect(Collectors.toList());
    }

    private Slice stringToSlice(String str) {
        if (str == null) {
            return null;
        }
        return Slices.utf8Slice(str);
    }

    private static JavaAggregationFunctionImplementation getAggregation(Type... typeArr) {
        return FUNCTION_AND_TYPE_MANAGER.getJavaAggregateFunctionImplementation(FUNCTION_AND_TYPE_MANAGER.lookupFunction(NAME, TypeSignatureProvider.fromTypes(typeArr)));
    }

    private List<Long> generateLongs(int i) {
        return (List) new Random(13L).longs(i).boxed().collect(Collectors.toList());
    }

    private List<Slice> generateStringSlices(int i) {
        return buildStringSliceList((List) new Random(123L).longs(i).boxed().map((v0) -> {
            return Long.toHexString(v0);
        }).collect(Collectors.toList()));
    }

    private List<Double> generateDoubles(int i) {
        return (List) new Random(123L).doubles(i).boxed().collect(Collectors.toList());
    }

    private List<Object> generateNulls(int i) {
        return Arrays.asList(new Object[i]);
    }

    private <K> List<K> includeNulls(List<K> list) {
        Random random = new Random(123L);
        for (int i = 0; i < list.size(); i++) {
            if (random.nextDouble() < 0.2d) {
                list.set(i, null);
            }
        }
        return list;
    }
}
