/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.RescaleMappings;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.util.CollectionUtil;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class TaskStateAssignment {
    private static final Logger LOG = LoggerFactory.getLogger(TaskStateAssignment.class);
    final ExecutionJobVertex executionJobVertex;
    final Map<OperatorID, OperatorState> oldState;
    final boolean hasNonFinishedState;
    final boolean isFullyFinished;
    final int newParallelism;
    final OperatorID inputOperatorID;
    final OperatorID outputOperatorID;
    private final Set<Integer> inputStateGates;
    private final Set<Integer> outputStatePartitions;
    final Map<OperatorInstanceID, List<OperatorStateHandle>> subManagedOperatorState;
    final Map<OperatorInstanceID, List<OperatorStateHandle>> subRawOperatorState;
    final Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState;
    final Map<OperatorInstanceID, List<KeyedStateHandle>> subRawKeyedState;
    final Map<OperatorInstanceID, List<InputChannelStateHandle>> inputChannelStates;
    final Map<OperatorInstanceID, List<ResultSubpartitionStateHandle>> resultSubpartitionStates;
    private final Map<Integer, SubtasksRescaleMapping> outputSubtaskMappings = new HashMap<Integer, SubtasksRescaleMapping>();
    private final Map<Integer, SubtasksRescaleMapping> inputSubtaskMappings = new HashMap<Integer, SubtasksRescaleMapping>();
    @Nullable
    private TaskStateAssignment[] downstreamAssignments;
    @Nullable
    private TaskStateAssignment[] upstreamAssignments;
    @Nullable
    private Boolean hasUpstreamOutputStates;
    @Nullable
    private Boolean hasDownstreamInputStates;
    private final Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment;
    private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments;

    public TaskStateAssignment(ExecutionJobVertex executionJobVertex, Map<OperatorID, OperatorState> oldState, Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment, Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments) {
        this.executionJobVertex = executionJobVertex;
        this.oldState = oldState;
        this.hasNonFinishedState = oldState.values().stream().anyMatch(operatorState -> operatorState.getNumberCollectedStates() > 0);
        this.isFullyFinished = oldState.values().stream().anyMatch(OperatorState::isFullyFinished);
        if (this.isFullyFinished) {
            Preconditions.checkState((boolean)oldState.values().stream().allMatch(OperatorState::isFullyFinished), (Object)"JobVertex could not have mixed finished and unfinished operators");
        }
        this.newParallelism = executionJobVertex.getParallelism();
        this.consumerAssignment = (Map)Preconditions.checkNotNull(consumerAssignment);
        this.vertexAssignments = (Map)Preconditions.checkNotNull(vertexAssignments);
        int expectedNumberOfSubtasks = this.newParallelism * oldState.size();
        this.subManagedOperatorState = CollectionUtil.newHashMapWithExpectedSize((int)expectedNumberOfSubtasks);
        this.subRawOperatorState = CollectionUtil.newHashMapWithExpectedSize((int)expectedNumberOfSubtasks);
        this.inputChannelStates = CollectionUtil.newHashMapWithExpectedSize((int)expectedNumberOfSubtasks);
        this.resultSubpartitionStates = CollectionUtil.newHashMapWithExpectedSize((int)expectedNumberOfSubtasks);
        this.subManagedKeyedState = CollectionUtil.newHashMapWithExpectedSize((int)expectedNumberOfSubtasks);
        this.subRawKeyedState = CollectionUtil.newHashMapWithExpectedSize((int)expectedNumberOfSubtasks);
        List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
        this.outputOperatorID = operatorIDs.get(0).getGeneratedOperatorID();
        this.inputOperatorID = operatorIDs.get(operatorIDs.size() - 1).getGeneratedOperatorID();
        this.inputStateGates = TaskStateAssignment.extractInputStateGates(oldState.get((Object)this.inputOperatorID));
        this.outputStatePartitions = TaskStateAssignment.extractOutputStatePartitions(oldState.get((Object)this.outputOperatorID));
    }

    private static Set<Integer> extractInputStateGates(OperatorState operatorState) {
        return operatorState.getStates().stream().map(OperatorSubtaskState::getInputChannelState).flatMap(Collection::stream).map(handle -> ((InputChannelInfo)handle.getInfo()).getGateIdx()).collect(Collectors.toSet());
    }

    private static Set<Integer> extractOutputStatePartitions(OperatorState operatorState) {
        return operatorState.getStates().stream().map(OperatorSubtaskState::getResultSubpartitionState).flatMap(Collection::stream).map(handle -> ((ResultSubpartitionInfo)handle.getInfo()).getPartitionIdx()).collect(Collectors.toSet());
    }

    public boolean hasInputState() {
        return !this.inputStateGates.isEmpty();
    }

    public boolean hasOutputState() {
        return !this.outputStatePartitions.isEmpty();
    }

    public TaskStateAssignment[] getDownstreamAssignments() {
        if (this.downstreamAssignments == null) {
            this.downstreamAssignments = (TaskStateAssignment[])Arrays.stream(this.executionJobVertex.getProducedDataSets()).map(result -> this.consumerAssignment.get(result.getId())).toArray(TaskStateAssignment[]::new);
        }
        return this.downstreamAssignments;
    }

    private static int getAssignmentIndex(TaskStateAssignment[] assignments, TaskStateAssignment assignment) {
        return Arrays.asList(assignments).indexOf(assignment);
    }

    public TaskStateAssignment[] getUpstreamAssignments() {
        if (this.upstreamAssignments == null) {
            this.upstreamAssignments = (TaskStateAssignment[])this.executionJobVertex.getInputs().stream().map(result -> this.vertexAssignments.get(result.getProducer())).toArray(TaskStateAssignment[]::new);
        }
        return this.upstreamAssignments;
    }

    public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
        Preconditions.checkState((this.subManagedKeyedState.containsKey(instanceID) || !this.subRawKeyedState.containsKey(instanceID) ? 1 : 0) != 0, (Object)"If an operator has no managed key state, it should also not have a raw keyed state.");
        StateObjectCollection<InputChannelStateHandle> inputState = this.getState(instanceID, this.inputChannelStates);
        StateObjectCollection<ResultSubpartitionStateHandle> outputState = this.getState(instanceID, this.resultSubpartitionStates);
        return OperatorSubtaskState.builder().setManagedOperatorState(this.getState(instanceID, this.subManagedOperatorState)).setRawOperatorState(this.getState(instanceID, this.subRawOperatorState)).setManagedKeyedState(this.getState(instanceID, this.subManagedKeyedState)).setRawKeyedState(this.getState(instanceID, this.subRawKeyedState)).setInputChannelState(inputState).setResultSubpartitionState(outputState).setInputRescalingDescriptor(this.createRescalingDescriptor(instanceID, this.inputOperatorID, this.getUpstreamAssignments(), (assignment, recompute) -> {
            int assignmentIndex = TaskStateAssignment.getAssignmentIndex(assignment.getDownstreamAssignments(), this);
            return assignment.getOutputMapping(assignmentIndex, (boolean)recompute);
        }, this.inputSubtaskMappings, this::getInputMapping, true)).setOutputRescalingDescriptor(this.createRescalingDescriptor(instanceID, this.outputOperatorID, this.getDownstreamAssignments(), (assignment, recompute) -> {
            int assignmentIndex = TaskStateAssignment.getAssignmentIndex(assignment.getUpstreamAssignments(), this);
            return assignment.getInputMapping(assignmentIndex, (boolean)recompute);
        }, this.outputSubtaskMappings, this::getOutputMapping, false)).build();
    }

    public boolean hasUpstreamOutputStates() {
        if (this.hasUpstreamOutputStates == null) {
            this.hasUpstreamOutputStates = Arrays.stream(this.getUpstreamAssignments()).anyMatch(TaskStateAssignment::hasOutputState);
        }
        return this.hasUpstreamOutputStates;
    }

    public boolean hasDownstreamInputStates() {
        if (this.hasDownstreamInputStates == null) {
            this.hasDownstreamInputStates = Arrays.stream(this.getDownstreamAssignments()).anyMatch(TaskStateAssignment::hasInputState);
        }
        return this.hasDownstreamInputStates;
    }

    private InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor log(InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor descriptor, int subtask, int partition) {
        LOG.debug("created {} for task={} subtask={} partition={}", new Object[]{descriptor, this.executionJobVertex.getName(), subtask, partition});
        return descriptor;
    }

    private InflightDataRescalingDescriptor log(InflightDataRescalingDescriptor descriptor, int subtask) {
        LOG.debug("created {} for task={} subtask={}", new Object[]{descriptor, this.executionJobVertex.getName(), subtask});
        return descriptor;
    }

    private InflightDataRescalingDescriptor createRescalingDescriptor(OperatorInstanceID instanceID, OperatorID expectedOperatorID, TaskStateAssignment[] connectedAssignments, BiFunction<TaskStateAssignment, Boolean, SubtasksRescaleMapping> mappingRetriever, Map<Integer, SubtasksRescaleMapping> subtaskGateOrPartitionMappings, Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator, boolean isInput) {
        if (!expectedOperatorID.equals((Object)instanceID.getOperatorId())) {
            return InflightDataRescalingDescriptor.NO_RESCALE;
        }
        SubtasksRescaleMapping[] rescaledChannelsMappings = (SubtasksRescaleMapping[])Arrays.stream(connectedAssignments).map(assignment -> (SubtasksRescaleMapping)mappingRetriever.apply((TaskStateAssignment)assignment, false)).toArray(SubtasksRescaleMapping[]::new);
        if (subtaskGateOrPartitionMappings.isEmpty() && Arrays.stream(rescaledChannelsMappings).allMatch(Objects::isNull)) {
            return InflightDataRescalingDescriptor.NO_RESCALE;
        }
        InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor[] gateOrPartitionDescriptors = this.createGateOrPartitionRescalingDescriptors(instanceID, connectedAssignments, assignment -> (SubtasksRescaleMapping)mappingRetriever.apply((TaskStateAssignment)assignment, true), subtaskGateOrPartitionMappings, subtaskMappingCalculator, rescaledChannelsMappings, isInput);
        if (Arrays.stream(gateOrPartitionDescriptors).allMatch(InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor::isIdentity)) {
            return this.log(InflightDataRescalingDescriptor.NO_RESCALE, instanceID.getSubtaskId());
        }
        return this.log(new InflightDataRescalingDescriptor(gateOrPartitionDescriptors), instanceID.getSubtaskId());
    }

    private InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor[] createGateOrPartitionRescalingDescriptors(OperatorInstanceID instanceID, TaskStateAssignment[] connectedAssignments, Function<TaskStateAssignment, SubtasksRescaleMapping> mappingCalculator, Map<Integer, SubtasksRescaleMapping> subtaskGateOrPartitionMappings, Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator, SubtasksRescaleMapping[] rescaledChannelsMappings, boolean isInput) {
        return (InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor[])IntStream.range(0, rescaledChannelsMappings.length).mapToObj(partition -> {
            if (!this.hasInFlightData(isInput, partition)) {
                return InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
            }
            TaskStateAssignment connectedAssignment = connectedAssignments[partition];
            SubtasksRescaleMapping rescaleMapping = Optional.ofNullable(rescaledChannelsMappings[partition]).orElseGet(() -> (SubtasksRescaleMapping)mappingCalculator.apply(connectedAssignment));
            SubtasksRescaleMapping subtaskMapping = Optional.ofNullable(subtaskGateOrPartitionMappings.get(partition)).orElseGet(() -> (SubtasksRescaleMapping)subtaskMappingCalculator.apply(partition));
            return this.getInflightDataGateOrPartitionRescalingDescriptor(instanceID, partition, rescaleMapping, subtaskMapping);
        }).toArray(InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor[]::new);
    }

    private boolean hasInFlightData(boolean isInput, int gateOrPartitionIndex) {
        if (isInput) {
            return this.hasInFlightDataForInputGate(gateOrPartitionIndex);
        }
        return this.hasInFlightDataForResultPartition(gateOrPartitionIndex);
    }

    private InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor getInflightDataGateOrPartitionRescalingDescriptor(OperatorInstanceID instanceID, int partition, SubtasksRescaleMapping rescaleMapping, SubtasksRescaleMapping subtaskMapping) {
        int[] oldSubtaskInstances = subtaskMapping.rescaleMappings.getMappedIndexes(instanceID.getSubtaskId());
        boolean isIdentity = subtaskMapping.rescaleMappings.isIdentity() && rescaleMapping.getRescaleMappings().isIdentity() || oldSubtaskInstances.length == 0;
        Set<Integer> ambiguousSubtasks = subtaskMapping.mayHaveAmbiguousSubtasks ? subtaskMapping.rescaleMappings.getAmbiguousTargets() : Collections.emptySet();
        return this.log(new InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor(oldSubtaskInstances, rescaleMapping.getRescaleMappings(), ambiguousSubtasks, isIdentity ? InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType.IDENTITY : InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType.RESCALING), instanceID.getSubtaskId(), partition);
    }

    private <T extends StateObject> StateObjectCollection<T> getState(OperatorInstanceID instanceID, Map<OperatorInstanceID, List<T>> subManagedOperatorState) {
        List<T> value = subManagedOperatorState.get(instanceID);
        return value != null ? new StateObjectCollection<T>(value) : StateObjectCollection.empty();
    }

    private SubtasksRescaleMapping getOutputMapping(int assignmentIndex, boolean recompute) {
        SubtasksRescaleMapping mapping = this.outputSubtaskMappings.get(assignmentIndex);
        if (recompute && mapping == null) {
            return this.getOutputMapping(assignmentIndex);
        }
        return mapping;
    }

    private SubtasksRescaleMapping getInputMapping(int assignmentIndex, boolean recompute) {
        SubtasksRescaleMapping mapping = this.inputSubtaskMappings.get(assignmentIndex);
        if (recompute && mapping == null) {
            return this.getInputMapping(assignmentIndex);
        }
        return mapping;
    }

    public SubtasksRescaleMapping getOutputMapping(int partitionIndex) {
        TaskStateAssignment downstreamAssignment = this.getDownstreamAssignments()[partitionIndex];
        IntermediateResult output = this.executionJobVertex.getProducedDataSets()[partitionIndex];
        int gateIndex = downstreamAssignment.executionJobVertex.getInputs().indexOf(output);
        SubtaskStateMapper mapper = (SubtaskStateMapper)((Object)Preconditions.checkNotNull((Object)((Object)downstreamAssignment.executionJobVertex.getJobVertex().getInputs().get(gateIndex).getUpstreamSubtaskStateMapper()), (String)"No channel rescaler found during rescaling of channel state"));
        RescaleMappings mapping = mapper.getNewToOldSubtasksMapping(this.oldState.get((Object)this.outputOperatorID).getParallelism(), this.newParallelism);
        return this.outputSubtaskMappings.compute(partitionIndex, (idx, oldMapping) -> TaskStateAssignment.checkSubtaskMapping(oldMapping, mapping, mapper.isAmbiguous()));
    }

    public SubtasksRescaleMapping getInputMapping(int gateIndex) {
        SubtaskStateMapper mapper = (SubtaskStateMapper)((Object)Preconditions.checkNotNull((Object)((Object)this.executionJobVertex.getJobVertex().getInputs().get(gateIndex).getDownstreamSubtaskStateMapper()), (String)"No channel rescaler found during rescaling of channel state"));
        RescaleMappings mapping = mapper.getNewToOldSubtasksMapping(this.oldState.get((Object)this.inputOperatorID).getParallelism(), this.newParallelism);
        return this.inputSubtaskMappings.compute(gateIndex, (idx, oldMapping) -> TaskStateAssignment.checkSubtaskMapping(oldMapping, mapping, mapper.isAmbiguous()));
    }

    public boolean hasInFlightDataForInputGate(int gateIndex) {
        if (this.inputStateGates.contains(gateIndex)) {
            return true;
        }
        TaskStateAssignment upstreamAssignment = this.getUpstreamAssignments()[gateIndex];
        if (upstreamAssignment != null && upstreamAssignment.hasOutputState()) {
            IntermediateResult inputResult = this.executionJobVertex.getInputs().get(gateIndex);
            IntermediateDataSetID resultId = inputResult.getId();
            IntermediateResult[] producedDataSets = inputResult.getProducer().getProducedDataSets();
            for (int i = 0; i < producedDataSets.length; ++i) {
                if (!producedDataSets[i].getId().equals(resultId)) continue;
                return upstreamAssignment.outputStatePartitions.contains(i);
            }
        }
        return false;
    }

    public boolean hasInFlightDataForResultPartition(int partitionIndex) {
        if (this.outputStatePartitions.contains(partitionIndex)) {
            return true;
        }
        TaskStateAssignment downstreamAssignment = this.getDownstreamAssignments()[partitionIndex];
        if (downstreamAssignment != null && downstreamAssignment.hasInputState()) {
            IntermediateResult producedResult = this.executionJobVertex.getProducedDataSets()[partitionIndex];
            IntermediateDataSetID resultId = producedResult.getId();
            List<IntermediateResult> inputs = downstreamAssignment.executionJobVertex.getInputs();
            for (int i = 0; i < inputs.size(); ++i) {
                if (!inputs.get(i).getId().equals(resultId)) continue;
                return downstreamAssignment.inputStateGates.contains(i);
            }
        }
        return false;
    }

    public String toString() {
        return "TaskStateAssignment for " + this.executionJobVertex.getName();
    }

    @Nonnull
    private static SubtasksRescaleMapping checkSubtaskMapping(@Nullable SubtasksRescaleMapping oldMapping, RescaleMappings mapping, boolean mayHaveAmbiguousSubtasks) {
        if (oldMapping == null) {
            return new SubtasksRescaleMapping(mapping, mayHaveAmbiguousSubtasks);
        }
        if (!oldMapping.rescaleMappings.equals(mapping)) {
            throw new IllegalStateException("Incompatible subtask mappings: are multiple operators ingesting/producing intermediate results with varying degrees of parallelism?Found " + oldMapping + " and " + mapping + ".");
        }
        return new SubtasksRescaleMapping(mapping, oldMapping.mayHaveAmbiguousSubtasks || mayHaveAmbiguousSubtasks);
    }

    static class SubtasksRescaleMapping {
        private final RescaleMappings rescaleMappings;
        private final boolean mayHaveAmbiguousSubtasks;

        private SubtasksRescaleMapping(RescaleMappings rescaleMappings, boolean mayHaveAmbiguousSubtasks) {
            this.rescaleMappings = rescaleMappings;
            this.mayHaveAmbiguousSubtasks = mayHaveAmbiguousSubtasks;
        }

        public RescaleMappings getRescaleMappings() {
            return this.rescaleMappings;
        }

        public boolean isMayHaveAmbiguousSubtasks() {
            return this.mayHaveAmbiguousSubtasks;
        }
    }
}

