package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import io.trino.block.BlockAssertions;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.AggregationMetrics;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BlockBuilderStatus;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.testing.assertions.TrinoExceptionAssert;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.OptionalInt;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
import org.apache.commons.math3.util.Precision;
import org.assertj.core.api.AbstractIntegerAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.Fail;
import org.assertj.core.api.ObjectAssert;

/* loaded from: input_file:io/trino/operator/aggregation/AggregationTestUtils.class */
public final class AggregationTestUtils {
    private AggregationTestUtils() {
    }

    public static void assertAggregation(TestingFunctionResolution testingFunctionResolution, String str, List<TypeSignatureProvider> list, Object obj, Block... blockArr) {
        assertAggregation(testingFunctionResolution, str, list, obj, new Page(blockArr));
    }

    public static void assertAggregation(TestingFunctionResolution testingFunctionResolution, String str, List<TypeSignatureProvider> list, Object obj, Page page) {
        assertAggregation(testingFunctionResolution, str, list, makeValidityAssertion(obj), null, page, obj);
    }

    public static TrinoExceptionAssert assertAggregationFails(TestingFunctionResolution testingFunctionResolution, String str, List<TypeSignatureProvider> list, Block... blockArr) {
        return TrinoExceptionAssert.assertTrinoExceptionThrownBy(() -> {
            assertAggregation(testingFunctionResolution, str, (List<TypeSignatureProvider>) list, (Object) null, blockArr);
        });
    }

    public static BiFunction<Object, Object, Boolean> makeValidityAssertion(Object obj) {
        return (!(obj instanceof Double) || obj.equals(Double.valueOf(Double.NaN))) ? (!(obj instanceof Float) || obj.equals(Float.valueOf(Float.NaN))) ? Objects::equals : (obj2, obj3) -> {
            return Boolean.valueOf((obj2 == null || obj3 == null || !Precision.equals(((Float) obj2).floatValue(), ((Float) obj3).floatValue(), 1.0E-10f)) ? false : true);
        } : (obj4, obj5) -> {
            return Boolean.valueOf((obj4 == null || obj5 == null || !Precision.equals(((Double) obj4).doubleValue(), ((Double) obj5).doubleValue(), 1.0E-10d)) ? false : true);
        };
    }

    public static void assertAggregation(TestingFunctionResolution testingFunctionResolution, String str, List<TypeSignatureProvider> list, BiFunction<Object, Object, Boolean> biFunction, String str2, Page page, Object obj) {
        TestingAggregationFunction aggregateFunction = testingFunctionResolution.getAggregateFunction(str, list);
        int positionCount = page.getPositionCount();
        for (int i = 1; i < page.getChannelCount(); i++) {
            ((AbstractIntegerAssert) Assertions.assertThat(positionCount).describedAs("input blocks provided are not equal in position count", new Object[0])).isEqualTo(page.getBlock(i).getPositionCount());
        }
        if (positionCount == 0) {
            assertAggregationInternal(aggregateFunction, biFunction, str2, obj, new Page[0]);
        } else if (positionCount == 1) {
            assertAggregationInternal(aggregateFunction, biFunction, str2, obj, page);
        } else {
            int i2 = positionCount / 2;
            assertAggregationInternal(aggregateFunction, biFunction, str2, obj, page.getRegion(0, i2), page.getRegion(i2, positionCount - i2));
        }
    }

    public static Block getIntermediateBlock(Type type, Aggregator aggregator) {
        BlockBuilder createBlockBuilder = type.createBlockBuilder((BlockBuilderStatus) null, 1000);
        aggregator.evaluate(createBlockBuilder);
        return createBlockBuilder.build();
    }

    public static Block getIntermediateBlock(Type type, GroupedAggregator groupedAggregator) {
        BlockBuilder createBlockBuilder = type.createBlockBuilder((BlockBuilderStatus) null, 1000);
        groupedAggregator.evaluate(0, createBlockBuilder);
        return createBlockBuilder.build();
    }

    public static Block getFinalBlock(Type type, Aggregator aggregator) {
        BlockBuilder createBlockBuilder = type.createBlockBuilder((BlockBuilderStatus) null, 1000);
        aggregator.evaluate(createBlockBuilder);
        return createBlockBuilder.build();
    }

    private static void assertAggregationInternal(TestingAggregationFunction testingAggregationFunction, BiFunction<Object, Object, Boolean> biFunction, String str, Object obj, Page... pageArr) {
        assertFunctionEquals(biFunction, str, aggregation(testingAggregationFunction, pageArr), obj);
        assertFunctionEquals(biFunction, str, partialAggregation(testingAggregationFunction, pageArr), obj);
        if (pageArr.length > 0) {
            assertFunctionEquals(biFunction, str, groupedAggregation(biFunction, testingAggregationFunction, pageArr), obj);
            assertFunctionEquals(biFunction, str, groupedPartialAggregation(biFunction, testingAggregationFunction, pageArr), obj);
            assertFunctionEquals(biFunction, str, distinctAggregation(testingAggregationFunction, pageArr), obj);
        }
    }

    private static void assertFunctionEquals(BiFunction<Object, Object, Boolean> biFunction, String str, Object obj, Object obj2) {
        if (biFunction.apply(obj, obj2).booleanValue()) {
            return;
        }
        StringBuilder sb = new StringBuilder();
        if (str != null) {
            sb.append(String.format("Test: %s, ", str));
        }
        sb.append(String.format("Expected: %s, actual: %s", obj2, obj));
        Fail.fail(sb.toString());
    }

    private static Object distinctAggregation(TestingAggregationFunction testingAggregationFunction, Page... pageArr) {
        int parameterCount = testingAggregationFunction.getParameterCount();
        OptionalInt of = OptionalInt.of(pageArr[0].getChannelCount());
        Object aggregation = aggregation(testingAggregationFunction, createArgs(parameterCount), of, maskPages(true, pageArr));
        Page[] pageArr2 = new Page[pageArr.length * 2];
        System.arraycopy(maskPages(true, pageArr), 0, pageArr2, 0, pageArr.length);
        System.arraycopy(maskPages(false, pageArr), 0, pageArr2, pageArr.length, pageArr.length);
        ((ObjectAssert) Assertions.assertThat(aggregation(testingAggregationFunction, createArgs(parameterCount), of, pageArr2)).describedAs("Inconsistent results with mask", new Object[0])).isEqualTo(aggregation);
        System.arraycopy(maskPagesWithRle(true, pageArr), 0, pageArr2, 0, pageArr.length);
        System.arraycopy(maskPagesWithRle(false, pageArr), 0, pageArr2, pageArr.length, pageArr.length);
        ((ObjectAssert) Assertions.assertThat(aggregation(testingAggregationFunction, createArgs(parameterCount), of, pageArr2)).describedAs("Inconsistent results with RLE mask", new Object[0])).isEqualTo(aggregation);
        return aggregation;
    }

    private static Page[] maskPagesWithRle(boolean z, Page... pageArr) {
        Page[] pageArr2 = new Page[pageArr.length];
        for (int i = 0; i < pageArr.length; i++) {
            Page page = pageArr[i];
            pageArr2[i] = page.appendColumn(RunLengthEncodedBlock.create(BooleanType.createBlockForSingleNonNullValue(z), page.getPositionCount()));
        }
        return pageArr2;
    }

    private static Page[] maskPages(boolean z, Page... pageArr) {
        Page[] pageArr2 = new Page[pageArr.length];
        for (int i = 0; i < pageArr.length; i++) {
            Page page = pageArr[i];
            BlockBuilder createFixedSizeBlockBuilder = BooleanType.BOOLEAN.createFixedSizeBlockBuilder(page.getPositionCount());
            for (int i2 = 0; i2 < page.getPositionCount(); i2++) {
                BooleanType.BOOLEAN.writeBoolean(createFixedSizeBlockBuilder, z);
            }
            pageArr2[i] = page.appendColumn(createFixedSizeBlockBuilder.build());
        }
        return pageArr2;
    }

    public static Object aggregation(TestingAggregationFunction testingAggregationFunction, Page... pageArr) {
        int parameterCount = testingAggregationFunction.getParameterCount();
        Object aggregation = aggregation(testingAggregationFunction, createArgs(parameterCount), OptionalInt.empty(), pageArr);
        if (parameterCount > 1) {
            ((ObjectAssert) Assertions.assertThat(aggregation(testingAggregationFunction, reverseArgs(parameterCount), OptionalInt.empty(), reverseColumns(pageArr))).describedAs("Inconsistent results with reversed channels", new Object[0])).isEqualTo(aggregation);
        }
        ((ObjectAssert) Assertions.assertThat(aggregation(testingAggregationFunction, offsetArgs(parameterCount, 3), OptionalInt.empty(), offsetColumns(pageArr, 3))).describedAs("Inconsistent results with channel offset", new Object[0])).isEqualTo(aggregation);
        return aggregation;
    }

    private static Object aggregation(TestingAggregationFunction testingAggregationFunction, int[] iArr, OptionalInt optionalInt, Page... pageArr) {
        Aggregator createAggregator = testingAggregationFunction.createAggregatorFactory(AggregationNode.Step.SINGLE, Ints.asList(iArr), optionalInt).createAggregator(new AggregationMetrics());
        for (Page page : pageArr) {
            if (page.getPositionCount() > 0) {
                createAggregator.processPage(page);
            }
        }
        return BlockAssertions.getOnlyValue(testingAggregationFunction.getFinalType(), getFinalBlock(testingAggregationFunction.getFinalType(), createAggregator));
    }

    public static Object partialAggregation(TestingAggregationFunction testingAggregationFunction, Page... pageArr) {
        int parameterCount = testingAggregationFunction.getParameterCount();
        Object partialAggregation = partialAggregation(testingAggregationFunction, createArgs(parameterCount), pageArr);
        if (parameterCount > 1) {
            ((ObjectAssert) Assertions.assertThat(partialAggregation(testingAggregationFunction, reverseArgs(parameterCount), reverseColumns(pageArr))).describedAs("Inconsistent results with reversed channels", new Object[0])).isEqualTo(partialAggregation);
        }
        ((ObjectAssert) Assertions.assertThat(partialAggregation(testingAggregationFunction, offsetArgs(parameterCount, 3), offsetColumns(pageArr, 3))).describedAs("Inconsistent results with channel offset", new Object[0])).isEqualTo(partialAggregation);
        return partialAggregation;
    }

    private static Object partialAggregation(TestingAggregationFunction testingAggregationFunction, int[] iArr, Page... pageArr) {
        Aggregator createAggregator = testingAggregationFunction.createAggregatorFactory(AggregationNode.Step.FINAL, Ints.asList(new int[]{0}), OptionalInt.empty()).createAggregator(new AggregationMetrics());
        AggregatorFactory createAggregatorFactory = testingAggregationFunction.createAggregatorFactory(AggregationNode.Step.PARTIAL, Ints.asList(iArr), OptionalInt.empty());
        Block intermediateBlock = getIntermediateBlock(testingAggregationFunction.getIntermediateType(), createAggregatorFactory.createAggregator(new AggregationMetrics()));
        createAggregator.processPage(new Page(new Block[]{intermediateBlock}));
        for (Page page : pageArr) {
            Aggregator createAggregator2 = createAggregatorFactory.createAggregator(new AggregationMetrics());
            if (page.getPositionCount() > 0) {
                createAggregator2.processPage(page);
            }
            createAggregator.processPage(new Page(new Block[]{getIntermediateBlock(testingAggregationFunction.getIntermediateType(), createAggregator2)}));
        }
        createAggregator.processPage(new Page(new Block[]{intermediateBlock}));
        return BlockAssertions.getOnlyValue(testingAggregationFunction.getFinalType(), getFinalBlock(testingAggregationFunction.getFinalType(), createAggregator));
    }

    public static Object groupedAggregation(TestingAggregationFunction testingAggregationFunction, Page page) {
        return groupedAggregation((BiFunction<Object, Object, Boolean>) Objects::equals, testingAggregationFunction, page);
    }

    private static Object groupedAggregation(BiFunction<Object, Object, Boolean> biFunction, TestingAggregationFunction testingAggregationFunction, Page... pageArr) {
        int parameterCount = testingAggregationFunction.getParameterCount();
        Object groupedAggregation = groupedAggregation(testingAggregationFunction, createArgs(parameterCount), pageArr);
        if (parameterCount > 1) {
            assertFunctionEquals(biFunction, "Inconsistent results with reversed channels", groupedAggregation(testingAggregationFunction, reverseArgs(parameterCount), reverseColumns(pageArr)), groupedAggregation);
        }
        assertFunctionEquals(biFunction, "Consistent results with channel offset", groupedAggregation(testingAggregationFunction, offsetArgs(parameterCount, 3), offsetColumns(pageArr, 3)), groupedAggregation);
        return groupedAggregation;
    }

    public static Object groupedAggregation(TestingAggregationFunction testingAggregationFunction, int[] iArr, Page... pageArr) {
        GroupedAggregator createGroupedAggregator = testingAggregationFunction.createAggregatorFactory(AggregationNode.Step.SINGLE, Ints.asList(iArr), OptionalInt.empty()).createGroupedAggregator(new AggregationMetrics());
        for (Page page : pageArr) {
            createGroupedAggregator.processPage(0, createGroupByIdBlock(0, page.getPositionCount()), page);
        }
        Object groupValue = getGroupValue(testingAggregationFunction.getFinalType(), createGroupedAggregator, 0);
        for (Page page2 : pageArr) {
            createGroupedAggregator.processPage(4000, createGroupByIdBlock(4000, page2.getPositionCount()), page2);
        }
        ((ObjectAssert) Assertions.assertThat(getGroupValue(testingAggregationFunction.getFinalType(), createGroupedAggregator, 4000)).describedAs("Inconsistent results with large group id", new Object[0])).isEqualTo(groupValue);
        return groupValue;
    }

    private static Object groupedPartialAggregation(BiFunction<Object, Object, Boolean> biFunction, TestingAggregationFunction testingAggregationFunction, Page... pageArr) {
        int parameterCount = testingAggregationFunction.getParameterCount();
        Object groupedPartialAggregation = groupedPartialAggregation(testingAggregationFunction, createArgs(parameterCount), pageArr);
        if (parameterCount > 1) {
            assertFunctionEquals(biFunction, "Consistent results with reversed channels", groupedPartialAggregation(testingAggregationFunction, reverseArgs(parameterCount), reverseColumns(pageArr)), groupedPartialAggregation);
        }
        assertFunctionEquals(biFunction, "Consistent results with channel offset", groupedPartialAggregation(testingAggregationFunction, offsetArgs(parameterCount, 3), offsetColumns(pageArr, 3)), groupedPartialAggregation);
        return groupedPartialAggregation;
    }

    private static Object groupedPartialAggregation(TestingAggregationFunction testingAggregationFunction, int[] iArr, Page... pageArr) {
        GroupedAggregator createGroupedAggregator = testingAggregationFunction.createAggregatorFactory(AggregationNode.Step.FINAL, ImmutableList.of(0), OptionalInt.empty()).createGroupedAggregator(new AggregationMetrics());
        AggregatorFactory createAggregatorFactory = testingAggregationFunction.createAggregatorFactory(AggregationNode.Step.PARTIAL, Ints.asList(iArr), OptionalInt.empty());
        Block intermediateBlock = getIntermediateBlock(testingAggregationFunction.getIntermediateType(), createAggregatorFactory.createGroupedAggregator(new AggregationMetrics()));
        createGroupedAggregator.processPage(0, createGroupByIdBlock(0, intermediateBlock.getPositionCount()), new Page(new Block[]{intermediateBlock}));
        for (Page page : pageArr) {
            GroupedAggregator createGroupedAggregator2 = createAggregatorFactory.createGroupedAggregator(new AggregationMetrics());
            createGroupedAggregator2.processPage(0, createGroupByIdBlock(0, page.getPositionCount()), page);
            Block intermediateBlock2 = getIntermediateBlock(testingAggregationFunction.getIntermediateType(), createGroupedAggregator2);
            createGroupedAggregator.processPage(0, createGroupByIdBlock(0, intermediateBlock2.getPositionCount()), new Page(new Block[]{intermediateBlock2}));
        }
        createGroupedAggregator.processPage(0, createGroupByIdBlock(0, intermediateBlock.getPositionCount()), new Page(new Block[]{intermediateBlock}));
        return getGroupValue(testingAggregationFunction.getFinalType(), createGroupedAggregator, 0);
    }

    public static int[] createGroupByIdBlock(int i, int i2) {
        int[] iArr = new int[i2];
        Arrays.fill(iArr, i);
        return iArr;
    }

    static int[] createArgs(int i) {
        int[] iArr = new int[i];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = i2;
        }
        return iArr;
    }

    private static int[] reverseArgs(int i) {
        int[] createArgs = createArgs(i);
        Collections.reverse(Ints.asList(createArgs));
        return createArgs;
    }

    private static int[] offsetArgs(int i, int i2) {
        int[] createArgs = createArgs(i);
        for (int i3 = 0; i3 < createArgs.length; i3++) {
            int i4 = i3;
            createArgs[i4] = createArgs[i4] + i2;
        }
        return createArgs;
    }

    private static Page[] reverseColumns(Page[] pageArr) {
        Page[] pageArr2 = new Page[pageArr.length];
        for (int i = 0; i < pageArr.length; i++) {
            Page page = pageArr[i];
            if (page.getPositionCount() == 0) {
                pageArr2[i] = page;
            } else {
                Block[] blockArr = new Block[page.getChannelCount()];
                for (int i2 = 0; i2 < page.getChannelCount(); i2++) {
                    blockArr[i2] = page.getBlock((page.getChannelCount() - i2) - 1);
                }
                pageArr2[i] = new Page(page.getPositionCount(), blockArr);
            }
        }
        return pageArr2;
    }

    public static Page[] offsetColumns(Page[] pageArr, int i) {
        Page[] pageArr2 = new Page[pageArr.length];
        for (int i2 = 0; i2 < pageArr.length; i2++) {
            Page page = pageArr[i2];
            Block[] blockArr = new Block[page.getChannelCount() + i];
            for (int i3 = 0; i3 < i; i3++) {
                blockArr[i3] = createAllNullBlock(page.getPositionCount());
            }
            for (int i4 = 0; i4 < page.getChannelCount(); i4++) {
                blockArr[i4 + i] = page.getBlock(i4);
            }
            pageArr2[i2] = new Page(page.getPositionCount(), blockArr);
        }
        return pageArr2;
    }

    private static Block createAllNullBlock(int i) {
        return RunLengthEncodedBlock.create(BooleanType.BOOLEAN, (Object) null, i);
    }

    public static Object getGroupValue(Type type, GroupedAggregator groupedAggregator, int i) {
        BlockBuilder createBlockBuilder = type.createBlockBuilder((BlockBuilderStatus) null, 1);
        groupedAggregator.evaluate(i, createBlockBuilder);
        return BlockAssertions.getOnlyValue(type, createBlockBuilder.build());
    }

    public static double[] constructDoublePrimitiveArray(int i, int i2) {
        return IntStream.range(i, i + i2).asDoubleStream().toArray();
    }
}
