/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.executiongraph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.ScheduledExecutorService;
import org.apache.flink.runtime.JobException;
import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
import org.apache.flink.runtime.scheduler.SchedulerBase;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.util.JobVertexConnectionUtils;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

class PointwisePatternTest {
    @RegisterExtension
    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension();

    PointwisePatternTest() {
    }

    @Test
    void testNToN() throws Exception {
        int N = 23;
        ExecutionJobVertex target = this.setUpExecutionGraphAndGetDownstreamVertex(23, 23);
        for (ExecutionVertex ev : target.getTaskVertices()) {
            Assertions.assertThat((int)ev.getNumberOfInputs()).isOne();
            ConsumedPartitionGroup consumedPartitionGroup = ev.getConsumedPartitionGroup(0);
            Assertions.assertThat((Iterable)consumedPartitionGroup).hasSize(1);
            Assertions.assertThat((int)ev.getParallelSubtaskIndex()).isEqualTo(consumedPartitionGroup.getFirst().getPartitionNumber());
        }
    }

    @Test
    void test2NToN() throws Exception {
        int N = 17;
        ExecutionJobVertex target = this.setUpExecutionGraphAndGetDownstreamVertex(34, 17);
        for (ExecutionVertex ev : target.getTaskVertices()) {
            Assertions.assertThat((int)ev.getNumberOfInputs()).isOne();
            ConsumedPartitionGroup consumedPartitionGroup = ev.getConsumedPartitionGroup(0);
            Assertions.assertThat((Iterable)consumedPartitionGroup).hasSize(2);
            int idx = 0;
            for (IntermediateResultPartitionID partitionId : consumedPartitionGroup) {
                Assertions.assertThat((long)((long)ev.getParallelSubtaskIndex() * 2L + (long)idx++)).isEqualTo((long)partitionId.getPartitionNumber());
            }
        }
    }

    @Test
    void test3NToN() throws Exception {
        int N = 17;
        ExecutionJobVertex target = this.setUpExecutionGraphAndGetDownstreamVertex(51, 17);
        for (ExecutionVertex ev : target.getTaskVertices()) {
            Assertions.assertThat((int)ev.getNumberOfInputs()).isOne();
            ConsumedPartitionGroup consumedPartitionGroup = ev.getConsumedPartitionGroup(0);
            Assertions.assertThat((Iterable)consumedPartitionGroup).hasSize(3);
            int idx = 0;
            for (IntermediateResultPartitionID partitionId : consumedPartitionGroup) {
                Assertions.assertThat((long)((long)ev.getParallelSubtaskIndex() * 3L + (long)idx++)).isEqualTo((long)partitionId.getPartitionNumber());
            }
        }
    }

    @Test
    void testNTo2N() throws Exception {
        int N = 41;
        ExecutionJobVertex target = this.setUpExecutionGraphAndGetDownstreamVertex(41, 82);
        for (ExecutionVertex ev : target.getTaskVertices()) {
            Assertions.assertThat((int)ev.getNumberOfInputs()).isOne();
            ConsumedPartitionGroup consumedPartitionGroup = ev.getConsumedPartitionGroup(0);
            Assertions.assertThat((Iterable)consumedPartitionGroup).hasSize(1);
            Assertions.assertThat((int)(ev.getParallelSubtaskIndex() / 2)).isEqualTo(consumedPartitionGroup.getFirst().getPartitionNumber());
        }
    }

    @Test
    void testNTo7N() throws Exception {
        int N = 11;
        ExecutionJobVertex target = this.setUpExecutionGraphAndGetDownstreamVertex(11, 77);
        for (ExecutionVertex ev : target.getTaskVertices()) {
            Assertions.assertThat((int)ev.getNumberOfInputs()).isOne();
            ConsumedPartitionGroup consumedPartitionGroup = ev.getConsumedPartitionGroup(0);
            Assertions.assertThat((Iterable)consumedPartitionGroup).hasSize(1);
            Assertions.assertThat((int)(ev.getParallelSubtaskIndex() / 7)).isEqualTo(consumedPartitionGroup.getFirst().getPartitionNumber());
        }
    }

    @Test
    void testLowHighIrregular() throws Exception {
        this.testLowToHigh(3, 16);
        this.testLowToHigh(19, 21);
        this.testLowToHigh(15, 20);
        this.testLowToHigh(11, 31);
        this.testLowToHigh(11, 29);
    }

    @Test
    void testHighLowIrregular() throws Exception {
        this.testHighToLow(16, 3);
        this.testHighToLow(21, 19);
        this.testHighToLow(20, 15);
        this.testHighToLow(31, 11);
    }

    @Test
    void testPointwiseConnectionSequence() throws Exception {
        this.testConnections(3, 5, new int[][]{{0}, {0}, {1}, {1}, {2}});
        this.testConnections(3, 10, new int[][]{{0}, {0}, {0}, {0}, {1}, {1}, {1}, {2}, {2}, {2}});
        this.testConnections(4, 6, new int[][]{{0}, {0}, {1}, {2}, {2}, {3}});
        this.testConnections(6, 10, new int[][]{{0}, {0}, {1}, {1}, {2}, {3}, {3}, {4}, {4}, {5}});
        this.testConnections(5, 3, new int[][]{{0}, {1, 2}, {3, 4}});
        this.testConnections(10, 3, new int[][]{{0, 1, 2}, {3, 4, 5}, {6, 7, 8, 9}});
        this.testConnections(6, 4, new int[][]{{0}, {1, 2}, {3}, {4, 5}});
        this.testConnections(10, 6, new int[][]{{0}, {1, 2}, {3, 4}, {5}, {6, 7}, {8, 9}});
    }

    private void testLowToHigh(int lowDop, int highDop) throws Exception {
        if (highDop < lowDop) {
            throw new IllegalArgumentException();
        }
        int factor = highDop / lowDop;
        int delta = highDop % lowDop == 0 ? 0 : 1;
        ExecutionJobVertex target = this.setUpExecutionGraphAndGetDownstreamVertex(lowDop, highDop);
        int[] timesUsed = new int[lowDop];
        for (ExecutionVertex ev : target.getTaskVertices()) {
            Assertions.assertThat((int)ev.getNumberOfInputs()).isOne();
            ConsumedPartitionGroup consumedPartitionGroup = ev.getConsumedPartitionGroup(0);
            Assertions.assertThat((Iterable)consumedPartitionGroup).hasSize(1);
            int n = consumedPartitionGroup.getFirst().getPartitionNumber();
            timesUsed[n] = timesUsed[n] + 1;
        }
        for (int used : timesUsed) {
            Assertions.assertThat((used >= factor && used <= factor + delta ? 1 : 0) != 0).isTrue();
        }
    }

    private void testHighToLow(int highDop, int lowDop) throws Exception {
        if (highDop < lowDop) {
            throw new IllegalArgumentException();
        }
        int factor = highDop / lowDop;
        int delta = highDop % lowDop == 0 ? 0 : 1;
        ExecutionJobVertex target = this.setUpExecutionGraphAndGetDownstreamVertex(highDop, lowDop);
        int[] timesUsed = new int[highDop];
        for (ExecutionVertex ev : target.getTaskVertices()) {
            Assertions.assertThat((int)ev.getNumberOfInputs()).isOne();
            ArrayList<IntermediateResultPartitionID> consumedPartitions = new ArrayList<IntermediateResultPartitionID>();
            for (ConsumedPartitionGroup partitionGroup : ev.getAllConsumedPartitionGroups()) {
                for (IntermediateResultPartitionID partitionId : partitionGroup) {
                    consumedPartitions.add(partitionId);
                }
            }
            Assertions.assertThat((consumedPartitions.size() >= factor && consumedPartitions.size() <= factor + delta ? 1 : 0) != 0).isTrue();
            for (IntermediateResultPartitionID consumedPartition : consumedPartitions) {
                int n = consumedPartition.getPartitionNumber();
                timesUsed[n] = timesUsed[n] + 1;
            }
        }
        for (int used : timesUsed) {
            Assertions.assertThat((int)used).isOne();
        }
    }

    private ExecutionJobVertex setUpExecutionGraphAndGetDownstreamVertex(int upstream, int downstream) throws Exception {
        JobVertex v1 = new JobVertex("vertex1");
        JobVertex v2 = new JobVertex("vertex2");
        v1.setParallelism(upstream);
        v2.setParallelism(downstream);
        v1.setInvokableClass(AbstractInvokable.class);
        v2.setInvokableClass(AbstractInvokable.class);
        JobVertexConnectionUtils.connectNewDataSetAsInput(v2, v1, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
        ArrayList<JobVertex> ordered = new ArrayList<JobVertex>(Arrays.asList(v1, v2));
        DefaultExecutionGraph eg = TestingDefaultExecutionGraphBuilder.newBuilder().setVertexParallelismStore(SchedulerBase.computeVertexParallelismStore(ordered)).build((ScheduledExecutorService)EXECUTOR_RESOURCE.getExecutor());
        try {
            eg.attachJobGraph(ordered, UnregisteredMetricGroups.createUnregisteredJobManagerJobMetricGroup());
        }
        catch (JobException e) {
            e.printStackTrace();
            Assertions.fail((String)("Job failed with exception: " + e.getMessage()));
        }
        return (ExecutionJobVertex)eg.getAllVertices().get(v2.getID());
    }

    private void testConnections(int sourceParallelism, int targetParallelism, int[][] expectedConsumedPartitionNumber) throws Exception {
        ExecutionJobVertex target = this.setUpExecutionGraphAndGetDownstreamVertex(sourceParallelism, targetParallelism);
        for (int vertexIndex = 0; vertexIndex < target.getTaskVertices().length; ++vertexIndex) {
            ExecutionVertex ev = target.getTaskVertices()[vertexIndex];
            ConsumedPartitionGroup consumedPartitionGroup = ev.getConsumedPartitionGroup(0);
            Assertions.assertThat((int)expectedConsumedPartitionNumber[vertexIndex].length).isEqualTo(consumedPartitionGroup.size());
            int partitionIndex = 0;
            for (IntermediateResultPartitionID partitionId : consumedPartitionGroup) {
                Assertions.assertThat((int)expectedConsumedPartitionNumber[vertexIndex][partitionIndex++]).isEqualTo(partitionId.getPartitionNumber());
            }
        }
    }
}

