/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.scheduler.adaptivebatch.AbstractBlockingResultInfo;
import org.apache.flink.util.Preconditions;

public class AllToAllBlockingResultInfo
extends AbstractBlockingResultInfo {
    private final boolean singleSubpartitionContainsAllData;
    private boolean isBroadcast;
    @Nullable
    protected List<Long> aggregatedSubpartitionBytes;

    @VisibleForTesting
    AllToAllBlockingResultInfo(IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions, boolean isBroadcast, boolean singleSubpartitionContainsAllData) {
        this(resultId, numOfPartitions, numOfSubpartitions, singleSubpartitionContainsAllData, new HashMap<Integer, long[]>());
        this.isBroadcast = isBroadcast;
    }

    AllToAllBlockingResultInfo(IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions, boolean singleSubpartitionContainsAllData, Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
        super(resultId, numOfPartitions, numOfSubpartitions, subpartitionBytesByPartitionIndex);
        this.singleSubpartitionContainsAllData = singleSubpartitionContainsAllData;
    }

    @Override
    public boolean isBroadcast() {
        return this.isBroadcast;
    }

    @Override
    public boolean isSingleSubpartitionContainsAllData() {
        return this.singleSubpartitionContainsAllData;
    }

    void setBroadcast(boolean isBroadcast) {
        this.isBroadcast = isBroadcast;
    }

    @Override
    public boolean isPointwise() {
        return false;
    }

    @Override
    public int getNumPartitions() {
        return this.numOfPartitions;
    }

    @Override
    public int getNumSubpartitions(int partitionIndex) {
        return this.numOfSubpartitions;
    }

    @Override
    public long getNumBytesProduced() {
        Preconditions.checkState((this.aggregatedSubpartitionBytes != null || this.subpartitionBytesByPartitionIndex.size() == this.numOfPartitions ? 1 : 0) != 0, (Object)"Not all partition infos are ready");
        List<Long> bytes = Optional.ofNullable(this.aggregatedSubpartitionBytes).orElse(this.getAggregatedSubpartitionBytesInternal());
        if (this.singleSubpartitionContainsAllData) {
            return bytes.get(0);
        }
        return bytes.stream().reduce(0L, Long::sum);
    }

    @Override
    public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes) {
        if (this.aggregatedSubpartitionBytes == null) {
            super.recordPartitionInfo(partitionIndex, partitionBytes);
        }
    }

    protected void onFineGrainedSubpartitionBytesNotNeeded() {
        if (this.subpartitionBytesByPartitionIndex.size() == this.numOfPartitions) {
            if (this.aggregatedSubpartitionBytes == null) {
                this.aggregatedSubpartitionBytes = this.getAggregatedSubpartitionBytesInternal();
            }
            this.subpartitionBytesByPartitionIndex.clear();
        }
    }

    private List<Long> getAggregatedSubpartitionBytesInternal() {
        long[] aggregatedBytes = new long[this.numOfSubpartitions];
        this.subpartitionBytesByPartitionIndex.values().forEach(subpartitionBytes -> {
            Preconditions.checkState((((long[])subpartitionBytes).length == this.numOfSubpartitions ? 1 : 0) != 0);
            for (int i = 0; i < ((long[])subpartitionBytes).length; ++i) {
                int n = i;
                aggregatedBytes[n] = aggregatedBytes[n] + subpartitionBytes[i];
            }
        });
        return Arrays.stream(aggregatedBytes).boxed().collect(Collectors.toList());
    }

    @Override
    public void resetPartitionInfo(int partitionIndex) {
        if (this.aggregatedSubpartitionBytes == null) {
            super.resetPartitionInfo(partitionIndex);
        }
    }

    @Override
    public long getNumBytesProduced(IndexRange partitionIndexRange, IndexRange subpartitionIndexRange) {
        if (partitionIndexRange.getStartIndex() == 0 && partitionIndexRange.getEndIndex() == this.getNumPartitions() - 1) {
            return IntStream.rangeClosed(subpartitionIndexRange.getStartIndex(), subpartitionIndexRange.getEndIndex()).mapToLong(i -> this.getAggregatedSubpartitionBytes().get(i)).sum();
        }
        return super.getNumBytesProduced(partitionIndexRange, subpartitionIndexRange);
    }

    public List<Long> getAggregatedSubpartitionBytes() {
        Preconditions.checkState((this.aggregatedSubpartitionBytes != null || this.subpartitionBytesByPartitionIndex.size() == this.numOfPartitions ? 1 : 0) != 0, (Object)"Not all partition infos are ready");
        if (this.aggregatedSubpartitionBytes == null) {
            this.aggregatedSubpartitionBytes = this.getAggregatedSubpartitionBytesInternal();
        }
        return Collections.unmodifiableList(this.aggregatedSubpartitionBytes);
    }
}

