/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptive.allocator;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
import org.apache.flink.runtime.instance.SlotSharingGroupId;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
import org.apache.flink.runtime.jobmaster.LogicalSlot;
import org.apache.flink.runtime.jobmaster.SlotInfo;
import org.apache.flink.runtime.jobmaster.SlotRequestId;
import org.apache.flink.runtime.jobmaster.slotpool.PhysicalSlot;
import org.apache.flink.runtime.scheduler.adaptive.JobSchedulingPlan;
import org.apache.flink.runtime.scheduler.adaptive.allocator.DefaultSlotAssigner;
import org.apache.flink.runtime.scheduler.adaptive.allocator.FreeSlotFunction;
import org.apache.flink.runtime.scheduler.adaptive.allocator.IsSlotAvailableAndFreeFunction;
import org.apache.flink.runtime.scheduler.adaptive.allocator.JobAllocationsInformation;
import org.apache.flink.runtime.scheduler.adaptive.allocator.JobInformation;
import org.apache.flink.runtime.scheduler.adaptive.allocator.ReserveSlotFunction;
import org.apache.flink.runtime.scheduler.adaptive.allocator.ReservedSlots;
import org.apache.flink.runtime.scheduler.adaptive.allocator.SharedSlot;
import org.apache.flink.runtime.scheduler.adaptive.allocator.SlotAllocator;
import org.apache.flink.runtime.scheduler.adaptive.allocator.SlotAssigner;
import org.apache.flink.runtime.scheduler.adaptive.allocator.StateLocalitySlotAssigner;
import org.apache.flink.runtime.scheduler.adaptive.allocator.VertexParallelism;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.util.ResourceCounter;

public class SlotSharingSlotAllocator
implements SlotAllocator {
    private final ReserveSlotFunction reserveSlotFunction;
    private final FreeSlotFunction freeSlotFunction;
    private final IsSlotAvailableAndFreeFunction isSlotAvailableAndFreeFunction;

    private SlotSharingSlotAllocator(ReserveSlotFunction reserveSlot, FreeSlotFunction freeSlotFunction, IsSlotAvailableAndFreeFunction isSlotAvailableAndFreeFunction) {
        this.reserveSlotFunction = reserveSlot;
        this.freeSlotFunction = freeSlotFunction;
        this.isSlotAvailableAndFreeFunction = isSlotAvailableAndFreeFunction;
    }

    public static SlotSharingSlotAllocator createSlotSharingSlotAllocator(ReserveSlotFunction reserveSlot, FreeSlotFunction freeSlotFunction, IsSlotAvailableAndFreeFunction isSlotAvailableAndFreeFunction) {
        return new SlotSharingSlotAllocator(reserveSlot, freeSlotFunction, isSlotAvailableAndFreeFunction);
    }

    @Override
    public ResourceCounter calculateRequiredSlots(Iterable<JobInformation.VertexInformation> vertices) {
        int numTotalRequiredSlots = 0;
        for (SlotSharingGroupMetaInfo slotSharingGroupMetaInfo : SlotSharingGroupMetaInfo.from(vertices).values()) {
            numTotalRequiredSlots += slotSharingGroupMetaInfo.getMaxUpperBound();
        }
        return ResourceCounter.withResource(ResourceProfile.UNKNOWN, numTotalRequiredSlots);
    }

    @Override
    public Optional<VertexParallelism> determineParallelism(JobInformation jobInformation, Collection<? extends SlotInfo> freeSlots) {
        Map<SlotSharingGroupId, SlotSharingGroupMetaInfo> slotSharingGroupMetaInfo = SlotSharingGroupMetaInfo.from(jobInformation.getVertices());
        int minimumRequiredSlots = slotSharingGroupMetaInfo.values().stream().map(SlotSharingGroupMetaInfo::getMinLowerBound).reduce(0, Integer::sum);
        if (minimumRequiredSlots > freeSlots.size()) {
            return Optional.empty();
        }
        Map<SlotSharingGroupId, Integer> slotSharingGroupParallelism = SlotSharingSlotAllocator.determineSlotsPerSharingGroup(jobInformation, freeSlots.size(), minimumRequiredSlots, slotSharingGroupMetaInfo);
        HashMap<JobVertexID, Integer> allVertexParallelism = new HashMap<JobVertexID, Integer>();
        for (SlotSharingGroup slotSharingGroup : jobInformation.getSlotSharingGroups()) {
            List<JobInformation.VertexInformation> containedJobVertices = slotSharingGroup.getJobVertexIds().stream().map(jobInformation::getVertexInformation).collect(Collectors.toList());
            Map<JobVertexID, Integer> vertexParallelism = SlotSharingSlotAllocator.determineVertexParallelism(containedJobVertices, slotSharingGroupParallelism.get(slotSharingGroup.getSlotSharingGroupId()));
            allVertexParallelism.putAll(vertexParallelism);
        }
        return Optional.of(new VertexParallelism(allVertexParallelism));
    }

    @Override
    public Optional<JobSchedulingPlan> determineParallelismAndCalculateAssignment(JobInformation jobInformation, Collection<? extends SlotInfo> slots, JobAllocationsInformation jobAllocationsInformation) {
        return this.determineParallelism(jobInformation, slots).map(parallelism -> {
            SlotAssigner slotAssigner = jobAllocationsInformation.isEmpty() ? new DefaultSlotAssigner() : new StateLocalitySlotAssigner();
            return new JobSchedulingPlan((VertexParallelism)parallelism, slotAssigner.assignSlots(jobInformation, slots, (VertexParallelism)parallelism, jobAllocationsInformation));
        });
    }

    private static Map<SlotSharingGroupId, Integer> determineSlotsPerSharingGroup(JobInformation jobInformation, int freeSlots, int minRequiredSlots, Map<SlotSharingGroupId, SlotSharingGroupMetaInfo> slotSharingGroupMetaInfo) {
        int numUnassignedSlots = freeSlots;
        int numUnassignedSlotSharingGroups = jobInformation.getSlotSharingGroups().size();
        int numMinSlotsRequiredByRemainingGroups = minRequiredSlots;
        HashMap<SlotSharingGroupId, Integer> slotSharingGroupParallelism = new HashMap<SlotSharingGroupId, Integer>();
        for (SlotSharingGroupId slotSharingGroup : SlotSharingSlotAllocator.sortSlotSharingGroupsByHighestParallelismRange(slotSharingGroupMetaInfo)) {
            int minParallelism = slotSharingGroupMetaInfo.get(slotSharingGroup).getMinLowerBound();
            int maxOptionalSlots = slotSharingGroupMetaInfo.get(slotSharingGroup).getMaxUpperBound() - minParallelism;
            int freeOptionalSlots = numUnassignedSlots - numMinSlotsRequiredByRemainingGroups;
            int optionalSlotShare = freeOptionalSlots / numUnassignedSlotSharingGroups;
            int groupParallelism = minParallelism + Math.min(maxOptionalSlots, optionalSlotShare);
            slotSharingGroupParallelism.put(slotSharingGroup, groupParallelism);
            numMinSlotsRequiredByRemainingGroups -= minParallelism;
            numUnassignedSlots -= groupParallelism;
            --numUnassignedSlotSharingGroups;
        }
        return slotSharingGroupParallelism;
    }

    private static List<SlotSharingGroupId> sortSlotSharingGroupsByHighestParallelismRange(Map<SlotSharingGroupId, SlotSharingGroupMetaInfo> slotSharingGroupMetaInfo) {
        return slotSharingGroupMetaInfo.entrySet().stream().sorted(Comparator.comparingInt(entry -> ((SlotSharingGroupMetaInfo)entry.getValue()).getMaxLowerUpperBoundRange())).map(Map.Entry::getKey).collect(Collectors.toList());
    }

    private static Map<JobVertexID, Integer> determineVertexParallelism(Collection<JobInformation.VertexInformation> containedJobVertices, int availableSlots) {
        HashMap<JobVertexID, Integer> vertexParallelism = new HashMap<JobVertexID, Integer>();
        for (JobInformation.VertexInformation jobVertex : containedJobVertices) {
            int parallelism = Math.min(jobVertex.getParallelism(), availableSlots);
            vertexParallelism.put(jobVertex.getJobVertexID(), parallelism);
        }
        return vertexParallelism;
    }

    @Override
    public Optional<ReservedSlots> tryReserveResources(JobSchedulingPlan jobSchedulingPlan) {
        Collection<AllocationID> expectedSlots = this.calculateExpectedSlots(jobSchedulingPlan.getSlotAssignments());
        if (this.areAllExpectedSlotsAvailableAndFree(expectedSlots)) {
            HashMap<ExecutionVertexID, LogicalSlot> assignedSlots = new HashMap<ExecutionVertexID, LogicalSlot>();
            for (JobSchedulingPlan.SlotAssignment assignment : jobSchedulingPlan.getSlotAssignments()) {
                SharedSlot sharedSlot = this.reserveSharedSlot(assignment.getSlotInfo());
                for (ExecutionVertexID executionVertexId : assignment.getTargetAs(ExecutionSlotSharingGroup.class).getContainedExecutionVertices()) {
                    assignedSlots.put(executionVertexId, sharedSlot.allocateLogicalSlot());
                }
            }
            return Optional.of(ReservedSlots.create(assignedSlots));
        }
        return Optional.empty();
    }

    @Nonnull
    private Collection<AllocationID> calculateExpectedSlots(Iterable<JobSchedulingPlan.SlotAssignment> assignments) {
        ArrayList<AllocationID> requiredSlots = new ArrayList<AllocationID>();
        for (JobSchedulingPlan.SlotAssignment assignment : assignments) {
            requiredSlots.add(assignment.getSlotInfo().getAllocationId());
        }
        return requiredSlots;
    }

    private boolean areAllExpectedSlotsAvailableAndFree(Iterable<? extends AllocationID> requiredSlots) {
        for (AllocationID allocationID : requiredSlots) {
            if (this.isSlotAvailableAndFreeFunction.isSlotAvailableAndFree(allocationID)) continue;
            return false;
        }
        return true;
    }

    private SharedSlot reserveSharedSlot(SlotInfo slotInfo) {
        PhysicalSlot physicalSlot = this.reserveSlotFunction.reserveSlot(slotInfo.getAllocationId(), ResourceProfile.UNKNOWN);
        return new SharedSlot(new SlotRequestId(), physicalSlot, slotInfo.willBeOccupiedIndefinitely(), () -> this.freeSlotFunction.freeSlot(slotInfo.getAllocationId(), null, System.currentTimeMillis()));
    }

    private static class SlotSharingGroupMetaInfo {
        private final int minLowerBound;
        private final int maxUpperBound;
        private final int maxLowerUpperBoundRange;

        private SlotSharingGroupMetaInfo(int minLowerBound, int maxUpperBound, int maxLowerUpperBoundRange) {
            this.minLowerBound = minLowerBound;
            this.maxUpperBound = maxUpperBound;
            this.maxLowerUpperBoundRange = maxLowerUpperBoundRange;
        }

        public int getMinLowerBound() {
            return this.minLowerBound;
        }

        public int getMaxUpperBound() {
            return this.maxUpperBound;
        }

        public int getMaxLowerUpperBoundRange() {
            return this.maxLowerUpperBoundRange;
        }

        public static Map<SlotSharingGroupId, SlotSharingGroupMetaInfo> from(Iterable<JobInformation.VertexInformation> vertices) {
            return SlotSharingGroupMetaInfo.getPerSlotSharingGroups(vertices, vertexInformation -> new SlotSharingGroupMetaInfo(vertexInformation.getMinParallelism(), vertexInformation.getParallelism(), vertexInformation.getParallelism() - vertexInformation.getMinParallelism()), (metaInfo1, metaInfo2) -> new SlotSharingGroupMetaInfo(Math.min(metaInfo1.getMinLowerBound(), metaInfo2.minLowerBound), Math.max(metaInfo1.getMaxUpperBound(), metaInfo2.getMaxUpperBound()), Math.max(metaInfo1.getMaxLowerUpperBoundRange(), metaInfo2.getMaxLowerUpperBoundRange())));
        }

        private static <T> Map<SlotSharingGroupId, T> getPerSlotSharingGroups(Iterable<JobInformation.VertexInformation> vertices, Function<JobInformation.VertexInformation, T> mapper, BiFunction<T, T, T> reducer) {
            HashMap<SlotSharingGroupId, Object> extractedPerSlotSharingGroups = new HashMap<SlotSharingGroupId, Object>();
            for (JobInformation.VertexInformation vertex : vertices) {
                extractedPerSlotSharingGroups.compute(vertex.getSlotSharingGroup().getSlotSharingGroupId(), (slotSharingGroupId, currentData) -> currentData == null ? mapper.apply(vertex) : reducer.apply(currentData, mapper.apply(vertex)));
            }
            return extractedPerSlotSharingGroups;
        }
    }

    static class ExecutionSlotSharingGroup {
        private final String id;
        private final Set<ExecutionVertexID> containedExecutionVertices;

        public ExecutionSlotSharingGroup(Set<ExecutionVertexID> containedExecutionVertices) {
            this(containedExecutionVertices, UUID.randomUUID().toString());
        }

        public ExecutionSlotSharingGroup(Set<ExecutionVertexID> containedExecutionVertices, String id) {
            this.containedExecutionVertices = containedExecutionVertices;
            this.id = id;
        }

        public String getId() {
            return this.id;
        }

        public Collection<ExecutionVertexID> getContainedExecutionVertices() {
            return this.containedExecutionVertices;
        }
    }
}

