/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.executiongraph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.flink.runtime.executiongraph.EdgeManager;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.util.Preconditions;

public class EdgeManagerBuildUtil {
    static void connectVertexToResult(ExecutionJobVertex vertex, IntermediateResult intermediateResult) {
        DistributionPattern distributionPattern = intermediateResult.getConsumingDistributionPattern();
        JobVertexInputInfo jobVertexInputInfo = vertex.getGraph().getJobVertexInputInfo(vertex.getJobVertexId(), intermediateResult.getId());
        switch (distributionPattern) {
            case POINTWISE: {
                EdgeManagerBuildUtil.connectPointwise(vertex, intermediateResult, jobVertexInputInfo);
                break;
            }
            case ALL_TO_ALL: {
                EdgeManagerBuildUtil.connectAllToAll(vertex, intermediateResult, jobVertexInputInfo);
                break;
            }
            default: {
                throw new IllegalArgumentException("Unrecognized distribution pattern.");
            }
        }
    }

    public static int computeMaxEdgesToTargetExecutionVertex(int targetParallelism, int sourceParallelism, DistributionPattern distributionPattern) {
        switch (distributionPattern) {
            case POINTWISE: {
                return (sourceParallelism + targetParallelism - 1) / targetParallelism;
            }
            case ALL_TO_ALL: {
                return sourceParallelism;
            }
        }
        throw new IllegalArgumentException("Unrecognized distribution pattern.");
    }

    private static void connectAllToAll(ExecutionJobVertex jobVertex, IntermediateResult result, JobVertexInputInfo jobVertexInputInfo) {
        jobVertexInputInfo.getExecutionVertexInputInfos().forEach(executionVertexInputInfo -> {
            IndexRange partitionRange = executionVertexInputInfo.getPartitionIndexRange();
            Preconditions.checkArgument(partitionRange.getStartIndex() == 0);
            Preconditions.checkArgument(partitionRange.getEndIndex() == result.getNumberOfAssignedPartitions() - 1);
        });
        EdgeManagerBuildUtil.connectInternal(Arrays.asList(jobVertex.getTaskVertices()), Arrays.asList(result.getPartitions()), result.getResultType(), jobVertex.getGraph().getEdgeManager());
    }

    private static void connectPointwise(ExecutionJobVertex jobVertex, IntermediateResult result, JobVertexInputInfo jobVertexInputInfo) {
        LinkedHashMap<IndexRange, List> consumersByPartition = new LinkedHashMap<IndexRange, List>();
        for (ExecutionVertexInputInfo executionVertexInputInfo : jobVertexInputInfo.getExecutionVertexInputInfos()) {
            int consumerIndex = executionVertexInputInfo.getSubtaskIndex();
            IndexRange range2 = executionVertexInputInfo.getPartitionIndexRange();
            consumersByPartition.compute(range2, (ignore, consumers) -> {
                if (consumers == null) {
                    consumers = new ArrayList<Integer>();
                }
                consumers.add(consumerIndex);
                return consumers;
            });
        }
        consumersByPartition.forEach((range, subtasks) -> {
            ArrayList<ExecutionVertex> taskVertices = new ArrayList<ExecutionVertex>();
            ArrayList<IntermediateResultPartition> partitions = new ArrayList<IntermediateResultPartition>();
            Iterator iterator = subtasks.iterator();
            while (iterator.hasNext()) {
                int index = (Integer)iterator.next();
                taskVertices.add(jobVertex.getTaskVertices()[index]);
            }
            for (int i = range.getStartIndex(); i <= range.getEndIndex(); ++i) {
                partitions.add(result.getPartitions()[i]);
            }
            EdgeManagerBuildUtil.connectInternal(taskVertices, partitions, result.getResultType(), jobVertex.getGraph().getEdgeManager());
        });
    }

    private static void connectInternal(List<ExecutionVertex> taskVertices, List<IntermediateResultPartition> partitions, ResultPartitionType resultPartitionType, EdgeManager edgeManager) {
        Preconditions.checkState(!taskVertices.isEmpty());
        Preconditions.checkState(!partitions.isEmpty());
        ConsumedPartitionGroup consumedPartitionGroup = EdgeManagerBuildUtil.createAndRegisterConsumedPartitionGroupToEdgeManager(taskVertices.size(), partitions, resultPartitionType, edgeManager);
        for (ExecutionVertex ev : taskVertices) {
            ev.addConsumedPartitionGroup(consumedPartitionGroup);
        }
        List<ExecutionVertexID> consumerVertices = taskVertices.stream().map(ExecutionVertex::getID).collect(Collectors.toList());
        ConsumerVertexGroup consumerVertexGroup = ConsumerVertexGroup.fromMultipleVertices(consumerVertices, resultPartitionType);
        for (IntermediateResultPartition partition : partitions) {
            partition.addConsumers(consumerVertexGroup);
        }
        consumedPartitionGroup.setConsumerVertexGroup(consumerVertexGroup);
        consumerVertexGroup.setConsumedPartitionGroup(consumedPartitionGroup);
    }

    private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager(int numConsumers, List<IntermediateResultPartition> partitions, ResultPartitionType resultPartitionType, EdgeManager edgeManager) {
        List<IntermediateResultPartitionID> partitionIds = partitions.stream().map(IntermediateResultPartition::getPartitionId).collect(Collectors.toList());
        ConsumedPartitionGroup consumedPartitionGroup = ConsumedPartitionGroup.fromMultiplePartitions(numConsumers, partitionIds, resultPartitionType);
        EdgeManagerBuildUtil.finishAllDataProducedPartitions(partitions, consumedPartitionGroup);
        edgeManager.registerConsumedPartitionGroup(consumedPartitionGroup);
        return consumedPartitionGroup;
    }

    private static void finishAllDataProducedPartitions(List<IntermediateResultPartition> partitions, ConsumedPartitionGroup consumedPartitionGroup) {
        for (IntermediateResultPartition partition : partitions) {
            if (!partition.hasDataAllProduced()) continue;
            consumedPartitionGroup.partitionFinished();
        }
    }
}

