/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.hashtable;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
import org.apache.flink.runtime.memory.MemoryAllocationException;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.memory.MemoryManagerBuilder;
import org.apache.flink.runtime.operators.testutils.UnionIterator;
import org.apache.flink.table.api.config.ExecutionConfigOptions;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.runtime.hashtable.BinaryHashTableTest;
import org.apache.flink.table.runtime.hashtable.LongHashPartition;
import org.apache.flink.table.runtime.hashtable.LongHybridHashTable;
import org.apache.flink.table.runtime.hashtable.ProbeIterator;
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
import org.apache.flink.table.runtime.util.RowIterator;
import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator;
import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
import org.apache.flink.util.MutableObjectIterator;
import org.assertj.core.api.AbstractIntegerAssert;
import org.assertj.core.api.AbstractLongAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.Fail;
import org.assertj.core.api.MapAssert;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;

@ExtendWith(value={ParameterizedTestExtension.class})
public class LongHashTableTest {
    private static final int PAGE_SIZE = 32768;
    private IOManager ioManager;
    private BinaryRowDataSerializer buildSideSerializer;
    private BinaryRowDataSerializer probeSideSerializer;
    private MemoryManager memManager = MemoryManagerBuilder.newBuilder().setMemorySize(0x1C00000L).build();
    private boolean useCompress;
    private Configuration conf;

    public LongHashTableTest(boolean useCompress) {
        this.useCompress = useCompress;
    }

    @Parameters(name="useCompress-{0}")
    public static List<Boolean> getVarSeg() {
        return Arrays.asList(true, false);
    }

    @BeforeEach
    public void init() {
        TypeInformation[] types = new TypeInformation[]{Types.INT, Types.INT};
        this.buildSideSerializer = new BinaryRowDataSerializer(types.length);
        this.probeSideSerializer = new BinaryRowDataSerializer(types.length);
        this.ioManager = new IOManagerAsync();
        this.conf = new Configuration();
        this.conf.setBoolean(ExecutionConfigOptions.TABLE_EXEC_SPILL_COMPRESSION_ENABLED, this.useCompress);
    }

    @TestTemplate
    void testInMemory() throws IOException {
        int numKeys = 100000;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(100000, 3, false);
        UniformBinaryRowGenerator probeInput = new UniformBinaryRowGenerator(100000, 10, true);
        MyHashTable table = new MyHashTable(0xFA0000L);
        int numRecordsInJoinResult = this.join(table, buildInput, probeInput);
        ((AbstractIntegerAssert)Assertions.assertThat((int)numRecordsInJoinResult).as("Wrong number of records in join result.", new Object[0])).isEqualTo(3000000);
        table.close();
        table.free();
    }

    @TestTemplate
    void testSpillingHashJoinOneRecursion() throws IOException {
        int numKeys = 100000;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(100000, 3, false);
        UniformBinaryRowGenerator probeInput = new UniformBinaryRowGenerator(100000, 10, true);
        MyHashTable table = new MyHashTable(0x960000L);
        int numRecordsInJoinResult = this.join(table, buildInput, probeInput);
        ((AbstractIntegerAssert)Assertions.assertThat((int)numRecordsInJoinResult).as("Wrong number of records in join result.", new Object[0])).isEqualTo(3000000);
        table.close();
        table.free();
    }

    @TestTemplate
    void testSpillingHashJoinOneRecursionPerformance() throws IOException {
        int numKeys = 1000000;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(1000000, 3, false);
        UniformBinaryRowGenerator probeInput = new UniformBinaryRowGenerator(1000000, 10, true);
        MyHashTable table = new MyHashTable(0x320000L);
        int numRecordsInJoinResult = this.join(table, buildInput, probeInput);
        ((AbstractIntegerAssert)Assertions.assertThat((int)numRecordsInJoinResult).as("Wrong number of records in join result.", new Object[0])).isEqualTo(30000000);
        table.close();
        table.free();
    }

    @TestTemplate
    void testSpillingHashJoinOneRecursionValidity() throws IOException {
        int numKeys = 1000000;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(1000000, 3, false);
        UniformBinaryRowGenerator probeInput = new UniformBinaryRowGenerator(1000000, 10, true);
        HashMap<Integer, Long> map = new HashMap<Integer, Long>(1000000);
        MyHashTable table = new MyHashTable(0x320000L);
        BinaryRowData buildRow = this.buildSideSerializer.createInstance();
        while ((buildRow = (BinaryRowData)buildInput.next(buildRow)) != null) {
            table.putBuildRow(buildRow);
        }
        table.endBuild();
        BinaryRowData probeRow = this.probeSideSerializer.createInstance();
        while ((probeRow = (BinaryRowData)probeInput.next(probeRow)) != null) {
            if (!table.tryProbe((RowData)probeRow)) continue;
            this.testJoin(table, map);
        }
        while (table.nextMatching()) {
            this.testJoin(table, map);
        }
        table.close();
        ((MapAssert)Assertions.assertThat(map).as("Wrong number of keys", new Object[0])).hasSize(1000000);
        for (Map.Entry<Integer, Long> entry : map.entrySet()) {
            long val = entry.getValue();
            int key = entry.getKey();
            ((AbstractLongAssert)Assertions.assertThat((long)val).as("Wrong number of values in per-key cross product for key " + key, new Object[0])).isEqualTo(30L);
        }
        table.free();
    }

    @TestTemplate
    void testSpillingHashJoinWithMassiveCollisions() throws IOException {
        int repeatedValue1 = 40559;
        int repeatedValue2 = 92882;
        int repeatedValueCountBuild = 200000;
        int repeatedValueCountProbe = 5;
        int numKeys = 1000000;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        UniformBinaryRowGenerator build1 = new UniformBinaryRowGenerator(1000000, 3, false);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator build2 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(40559, 17, 200000);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator build3 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(92882, 23, 200000);
        ArrayList<Object> builds = new ArrayList<Object>();
        builds.add(build1);
        builds.add(build2);
        builds.add(build3);
        UnionIterator buildInput = new UnionIterator(builds);
        UniformBinaryRowGenerator probe1 = new UniformBinaryRowGenerator(1000000, 10, true);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator probe2 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(40559, 17, 5);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator probe3 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(92882, 23, 5);
        ArrayList<Object> probes = new ArrayList<Object>();
        probes.add(probe1);
        probes.add(probe2);
        probes.add(probe3);
        UnionIterator probeInput = new UnionIterator(probes);
        HashMap<Integer, Long> map = new HashMap<Integer, Long>(1000000);
        MyHashTable table = new MyHashTable(0x1C00000L);
        BinaryRowData buildRow = this.buildSideSerializer.createInstance();
        while ((buildRow = (BinaryRowData)buildInput.next((Object)buildRow)) != null) {
            table.putBuildRow(buildRow);
        }
        table.endBuild();
        BinaryRowData probeRow = this.probeSideSerializer.createInstance();
        while ((probeRow = (BinaryRowData)probeInput.next((Object)probeRow)) != null) {
            if (!table.tryProbe((RowData)probeRow)) continue;
            this.testJoin(table, map);
        }
        while (table.nextMatching()) {
            this.testJoin(table, map);
        }
        table.close();
        ((MapAssert)Assertions.assertThat(map).as("Wrong number of keys", new Object[0])).hasSize(1000000);
        for (Map.Entry<Integer, Long> entry : map.entrySet()) {
            long val = entry.getValue();
            int key = entry.getKey();
            ((AbstractLongAssert)Assertions.assertThat((long)val).as("Wrong number of values in per-key cross product for key " + key, new Object[0])).isEqualTo(key == 40559 || key == 92882 ? 3000045L : 30L);
        }
        table.free();
    }

    @TestTemplate
    void testSpillingHashJoinWithTwoRecursions() throws IOException {
        int repeatedValue1 = 40559;
        int repeatedValue2 = 92882;
        int repeatedValueCountBuild = 200000;
        int repeatedValueCountProbe = 5;
        int numKeys = 1000000;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        UniformBinaryRowGenerator build1 = new UniformBinaryRowGenerator(1000000, 3, false);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator build2 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(40559, 17, 200000);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator build3 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(92882, 23, 200000);
        ArrayList<Object> builds = new ArrayList<Object>();
        builds.add(build1);
        builds.add(build2);
        builds.add(build3);
        UnionIterator buildInput = new UnionIterator(builds);
        UniformBinaryRowGenerator probe1 = new UniformBinaryRowGenerator(1000000, 10, true);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator probe2 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(40559, 17, 5);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator probe3 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(92882, 23, 5);
        ArrayList<Object> probes = new ArrayList<Object>();
        probes.add(probe1);
        probes.add(probe2);
        probes.add(probe3);
        UnionIterator probeInput = new UnionIterator(probes);
        HashMap<Integer, Long> map = new HashMap<Integer, Long>(1000000);
        MyHashTable table = new MyHashTable(0x1C00000L);
        BinaryRowData buildRow = this.buildSideSerializer.createInstance();
        while ((buildRow = (BinaryRowData)buildInput.next((Object)buildRow)) != null) {
            table.putBuildRow(buildRow);
        }
        table.endBuild();
        BinaryRowData probeRow = this.probeSideSerializer.createInstance();
        while ((probeRow = (BinaryRowData)probeInput.next((Object)probeRow)) != null) {
            if (!table.tryProbe((RowData)probeRow)) continue;
            this.testJoin(table, map);
        }
        while (table.nextMatching()) {
            this.testJoin(table, map);
        }
        table.close();
        ((MapAssert)Assertions.assertThat(map).as("Wrong number of keys", new Object[0])).hasSize(1000000);
        for (Map.Entry<Integer, Long> entry : map.entrySet()) {
            long val = entry.getValue();
            int key = entry.getKey();
            ((AbstractLongAssert)Assertions.assertThat((long)val).as("Wrong number of values in per-key cross product for key " + key, new Object[0])).isEqualTo(key == 40559 || key == 92882 ? 3000045L : 30L);
        }
        table.free();
    }

    @TestTemplate
    void testSpillingHashJoinWithTooManyRecursions() throws IOException {
        int repeatedValue1 = 40559;
        int repeatedValue2 = 92882;
        int repeatedValueCount = 3000000;
        int numKeys = 1000000;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        UniformBinaryRowGenerator build1 = new UniformBinaryRowGenerator(1000000, 3, false);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator build2 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(40559, 17, 3000000);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator build3 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(92882, 23, 3000000);
        ArrayList<Object> builds = new ArrayList<Object>();
        builds.add(build1);
        builds.add(build2);
        builds.add(build3);
        UnionIterator buildInput = new UnionIterator(builds);
        UniformBinaryRowGenerator probe1 = new UniformBinaryRowGenerator(1000000, 10, true);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator probe2 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(40559, 17, 3000000);
        BinaryHashTableTest.ConstantsKeyValuePairsIterator probe3 = new BinaryHashTableTest.ConstantsKeyValuePairsIterator(92882, 23, 3000000);
        ArrayList<Object> probes = new ArrayList<Object>();
        probes.add(probe1);
        probes.add(probe2);
        probes.add(probe3);
        UnionIterator probeInput = new UnionIterator(probes);
        MyHashTable table = new MyHashTable(0x1C00000L);
        HashMap<Integer, Long> map = new HashMap<Integer, Long>(1000000);
        BinaryRowData buildRow = this.buildSideSerializer.createInstance();
        while ((buildRow = (BinaryRowData)buildInput.next((Object)buildRow)) != null) {
            table.putBuildRow(buildRow);
        }
        table.endBuild();
        BinaryRowData probeRow = this.probeSideSerializer.createInstance();
        while ((probeRow = (BinaryRowData)probeInput.next((Object)probeRow)) != null) {
            if (!table.tryProbe((RowData)probeRow)) continue;
            this.testJoin(table, map);
        }
        while (table.nextMatching()) {
            this.testJoin(table, map);
        }
        ((AbstractIntegerAssert)Assertions.assertThat((int)map.size()).as("Wrong number of records in join result.", new Object[0])).isLessThan(1000000);
        ((AbstractIntegerAssert)Assertions.assertThat((int)table.getPartitionsPendingForSMJ().size()).as("Wrong number of spilled partition.", new Object[0])).isEqualTo(2);
        HashMap<Integer, Integer> spilledPartitionBuildSideKeys = new HashMap<Integer, Integer>();
        HashMap<Integer, Integer> spilledPartitionProbeSideKeys = new HashMap<Integer, Integer>();
        for (LongHashPartition p : table.getPartitionsPendingForSMJ()) {
            BinaryRowData rowData;
            RowIterator buildIter = table.getSpilledPartitionBuildSideIter(p);
            while (buildIter.advanceNext()) {
                Integer key = ((BinaryRowData)buildIter.getRow()).getInt(0);
                spilledPartitionBuildSideKeys.put(key, spilledPartitionBuildSideKeys.getOrDefault(key, 0) + 1);
            }
            ProbeIterator probeIter = table.getSpilledPartitionProbeSideIter(p);
            while ((rowData = probeIter.next()) != null) {
                Integer key = rowData.getInt(0);
                spilledPartitionProbeSideKeys.put(key, spilledPartitionProbeSideKeys.getOrDefault(key, 0) + 1);
            }
        }
        Integer buildKeyCnt = 3000003;
        Assertions.assertThat(spilledPartitionBuildSideKeys).containsEntry((Object)40559, (Object)buildKeyCnt);
        Assertions.assertThat(spilledPartitionBuildSideKeys).containsEntry((Object)92882, (Object)buildKeyCnt);
        Integer probeKeyCnt = 3000010;
        Assertions.assertThat(spilledPartitionProbeSideKeys).containsEntry((Object)40559, (Object)probeKeyCnt);
        Assertions.assertThat(spilledPartitionProbeSideKeys).containsEntry((Object)92882, (Object)probeKeyCnt);
        table.close();
        table.free();
    }

    @TestTemplate
    void testSparseProbeSpilling() throws IOException, MemoryAllocationException {
        int numBuildKeys = 1000000;
        boolean numBuildVals = true;
        int numProbeKeys = 20;
        boolean numProbeVals = true;
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(1000000, 1, false);
        MyHashTable table = new MyHashTable(0x320000L);
        int expectedNumResults = Math.min(20, 1000000) * 1 * 1;
        int numRecordsInJoinResult = this.join(table, buildInput, new UniformBinaryRowGenerator(20, 1, true));
        ((AbstractIntegerAssert)Assertions.assertThat((int)numRecordsInJoinResult).as("Wrong number of records in join result.", new Object[0])).isEqualTo(expectedNumResults);
        table.close();
        table.free();
    }

    @TestTemplate
    void validateSpillingDuringInsertion() throws IOException, MemoryAllocationException {
        int numBuildKeys = 500000;
        boolean numBuildVals = true;
        int numProbeKeys = 10;
        boolean numProbeVals = true;
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(500000, 1, false);
        MyHashTable table = new MyHashTable(2785280L);
        int expectedNumResults = Math.min(10, 500000) * 1 * 1;
        int numRecordsInJoinResult = this.join(table, buildInput, new UniformBinaryRowGenerator(10, 1, true));
        ((AbstractIntegerAssert)Assertions.assertThat((int)numRecordsInJoinResult).as("Wrong number of records in join result.", new Object[0])).isEqualTo(expectedNumResults);
        table.close();
        table.free();
    }

    @TestTemplate
    void testBucketsNotFulfillSegment() throws Exception {
        int numKeys = 10000;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(10000, 3, false);
        UniformBinaryRowGenerator probeInput = new UniformBinaryRowGenerator(10000, 10, true);
        MyHashTable table = new MyHashTable(0x118000L);
        int numRecordsInJoinResult = this.join(table, buildInput, probeInput);
        ((AbstractIntegerAssert)Assertions.assertThat((int)numRecordsInJoinResult).as("Wrong number of records in join result.", new Object[0])).isEqualTo(300000);
        table.close();
        table.free();
    }

    private void testJoin(MyHashTable table, HashMap<Integer, Long> map) throws IOException {
        BinaryRowData record;
        int numBuildValues = 0;
        RowData probeRec = table.getCurrentProbeRow();
        int key = probeRec.getInt(0);
        LongHashPartition.MatchIterator buildSide = table.getBuildSideIterator();
        if (buildSide.advanceNext()) {
            numBuildValues = 1;
            record = (BinaryRowData)buildSide.getRow();
            ((AbstractIntegerAssert)Assertions.assertThat((int)record.getInt(0)).as("Probe-side key was different than build-side key.", new Object[0])).isEqualTo(key);
        } else {
            Fail.fail((String)"No build side values found for a probe key.");
        }
        while (buildSide.advanceNext()) {
            ++numBuildValues;
            record = (BinaryRowData)buildSide.getRow();
            ((AbstractIntegerAssert)Assertions.assertThat((int)record.getInt(0)).as("Probe-side key was different than build-side key.", new Object[0])).isEqualTo(key);
        }
        Long contained = map.get(key);
        contained = contained == null ? Long.valueOf(numBuildValues) : Long.valueOf(contained + (long)numBuildValues);
        map.put(key, contained);
    }

    private int join(MyHashTable table, MutableObjectIterator<BinaryRowData> buildInput, MutableObjectIterator<BinaryRowData> probeInput) throws IOException {
        BinaryRowData buildRow;
        int count = 0;
        BinaryRowData reuseBuildSizeRow = this.buildSideSerializer.createInstance();
        while ((buildRow = (BinaryRowData)buildInput.next((Object)reuseBuildSizeRow)) != null) {
            table.putBuildRow(buildRow);
        }
        table.endBuild();
        BinaryRowData probeRow = this.probeSideSerializer.createInstance();
        while ((probeRow = (BinaryRowData)probeInput.next((Object)probeRow)) != null) {
            if (!table.tryProbe((RowData)probeRow)) continue;
            count += this.joinWithNextKey(table);
        }
        while (table.nextMatching()) {
            count += this.joinWithNextKey(table);
        }
        return count;
    }

    private int joinWithNextKey(MyHashTable table) throws IOException {
        BinaryRowData buildRow;
        int count = 0;
        LongHashPartition.MatchIterator buildIterator = table.getBuildSideIterator();
        RowData probeRow = table.getCurrentProbeRow();
        BinaryRowData binaryRowData = buildRow = buildIterator.advanceNext() ? (BinaryRowData)buildIterator.getRow() : null;
        if (probeRow != null && buildRow != null) {
            ++count;
            while (buildIterator.advanceNext()) {
                ++count;
            }
        }
        return count;
    }

    private class MyHashTable
    extends LongHybridHashTable {
        public MyHashTable(long memorySize) {
            super(LongHashTableTest.this.conf, (Object)LongHashTableTest.this, LongHashTableTest.this.buildSideSerializer, LongHashTableTest.this.probeSideSerializer, LongHashTableTest.this.memManager, memorySize, LongHashTableTest.this.ioManager, 24, 200000L);
        }

        public long getBuildLongKey(RowData row) {
            return row.getInt(0);
        }

        public long getProbeLongKey(RowData row) {
            return row.getInt(0);
        }

        public BinaryRowData probeToBinary(RowData row) {
            return (BinaryRowData)row;
        }
    }
}

