/*
 * Decompiled with CFR 0.152.
 */
package edu.iu.dsc.tws.tsched.batch.batchscheduler;

import edu.iu.dsc.tws.api.compute.exceptions.TaskSchedulerException;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.Edge;
import edu.iu.dsc.tws.api.compute.graph.Vertex;
import edu.iu.dsc.tws.api.compute.modifiers.Collector;
import edu.iu.dsc.tws.api.compute.modifiers.Receptor;
import edu.iu.dsc.tws.api.compute.nodes.INode;
import edu.iu.dsc.tws.api.compute.schedule.ITaskScheduler;
import edu.iu.dsc.tws.api.compute.schedule.elements.Resource;
import edu.iu.dsc.tws.api.compute.schedule.elements.TaskInstanceId;
import edu.iu.dsc.tws.api.compute.schedule.elements.TaskInstancePlan;
import edu.iu.dsc.tws.api.compute.schedule.elements.TaskSchedulePlan;
import edu.iu.dsc.tws.api.compute.schedule.elements.Worker;
import edu.iu.dsc.tws.api.compute.schedule.elements.WorkerPlan;
import edu.iu.dsc.tws.api.compute.schedule.elements.WorkerSchedulePlan;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.exceptions.Twister2RuntimeException;
import edu.iu.dsc.tws.tsched.spi.common.TaskSchedulerContext;
import edu.iu.dsc.tws.tsched.spi.taskschedule.TaskInstanceMapCalculation;
import edu.iu.dsc.tws.tsched.utils.TaskAttributes;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.logging.Logger;
import java.util.stream.IntStream;

public class BatchTaskScheduler
implements ITaskScheduler {
    private static final Logger LOG = Logger.getLogger(BatchTaskScheduler.class.getName());
    private Double instanceRAM;
    private Double instanceDisk;
    private Double instanceCPU;
    private Config config;
    private TaskAttributes taskAttributes;
    private int workerId;
    private int index;
    private boolean dependentGraphs = false;
    private List<Integer> workerIdList = new ArrayList<Integer>();
    private Map<Integer, List<TaskInstanceId>> batchTaskAllocation;
    private Map<String, TaskSchedulePlan> taskSchedulePlanMap = new LinkedHashMap<String, TaskSchedulePlan>();
    private static Map<String, Integer> receivableNameMap = new LinkedHashMap<String, Integer>();
    private static Map<String, Integer> collectibleNameMap = new LinkedHashMap<String, Integer>();

    public void initialize(Config cfg) {
        this.config = cfg;
        this.instanceRAM = TaskSchedulerContext.taskInstanceRam(this.config);
        this.instanceDisk = TaskSchedulerContext.taskInstanceDisk(this.config);
        this.instanceCPU = TaskSchedulerContext.taskInstanceCpu(this.config);
        this.batchTaskAllocation = new LinkedHashMap<Integer, List<TaskInstanceId>>();
        this.taskAttributes = new TaskAttributes();
    }

    public void initialize(Config cfg, int workerid) {
        this.initialize(cfg);
        this.workerId = workerid;
    }

    public Map<String, TaskSchedulePlan> schedule(WorkerPlan workerPlan, ComputeGraph ... computeGraphs) {
        if (computeGraphs.length > 1) {
            this.addReceptorsCollectors(computeGraphs);
            this.validateParallelism();
            this.dependentGraphs = true;
            for (ComputeGraph computeGraph : computeGraphs) {
                TaskSchedulePlan taskSchedulePlan = this.schedule(computeGraph, workerPlan);
                this.taskSchedulePlanMap.put(computeGraph.getGraphName(), taskSchedulePlan);
            }
        } else {
            TaskSchedulePlan taskSchedulePlan = this.schedule(computeGraphs[0], workerPlan);
            this.taskSchedulePlanMap.put(computeGraphs[0].getGraphName(), taskSchedulePlan);
        }
        return this.taskSchedulePlanMap;
    }

    private void addReceptorsCollectors(ComputeGraph ... computeGraphs) {
        for (ComputeGraph computeGraph : computeGraphs) {
            LinkedHashSet vertices = new LinkedHashSet(computeGraph.getTaskVertexSet());
            for (Vertex vertex : vertices) {
                INode iNode = vertex.getTask();
                if (iNode instanceof Receptor) {
                    if (((Receptor)iNode).getReceivableNames() == null) continue;
                    ((Receptor)iNode).getReceivableNames().forEach(key -> receivableNameMap.put((String)key, vertex.getParallelism()));
                    continue;
                }
                if (!(iNode instanceof Collector) || ((Collector)iNode).getCollectibleNames() == null) continue;
                ((Collector)iNode).getCollectibleNames().forEach(key -> collectibleNameMap.put((String)key, vertex.getParallelism()));
            }
        }
    }

    public TaskSchedulePlan schedule(ComputeGraph computeGraph, WorkerPlan workerPlan) {
        LinkedHashSet<WorkerSchedulePlan> workerSchedulePlans = new LinkedHashSet<WorkerSchedulePlan>();
        LinkedHashSet<Vertex> taskVertexSet = new LinkedHashSet<Vertex>(computeGraph.getTaskVertexSet());
        Map<Integer, List<TaskInstanceId>> batchContainerInstanceMap = this.batchSchedulingAlgorithm(computeGraph, workerPlan.getNumberOfWorkers());
        TaskInstanceMapCalculation instanceMapCalculation = new TaskInstanceMapCalculation(this.instanceRAM, this.instanceCPU, this.instanceDisk);
        Map<Integer, Map<TaskInstanceId, Double>> instancesRamMap = instanceMapCalculation.getInstancesRamMapInContainer(batchContainerInstanceMap, taskVertexSet);
        Map<Integer, Map<TaskInstanceId, Double>> instancesDiskMap = instanceMapCalculation.getInstancesDiskMapInContainer(batchContainerInstanceMap, taskVertexSet);
        Map<Integer, Map<TaskInstanceId, Double>> instancesCPUMap = instanceMapCalculation.getInstancesCPUMapInContainer(batchContainerInstanceMap, taskVertexSet);
        for (int containerId : batchContainerInstanceMap.keySet()) {
            double containerRAMValue = TaskSchedulerContext.containerRamPadding(this.config);
            double containerDiskValue = TaskSchedulerContext.containerDiskPadding(this.config);
            double containerCpuValue = TaskSchedulerContext.containerCpuPadding(this.config);
            List<TaskInstanceId> taskTaskInstanceIds = batchContainerInstanceMap.get(containerId);
            HashMap<TaskInstanceId, TaskInstancePlan> taskInstancePlanMap = new HashMap<TaskInstanceId, TaskInstancePlan>();
            for (TaskInstanceId id : taskTaskInstanceIds) {
                double instanceRAMValue = instancesRamMap.get(containerId).get(id);
                double instanceDiskValue = instancesDiskMap.get(containerId).get(id);
                double instanceCPUValue = instancesCPUMap.get(containerId).get(id);
                Resource instanceResource = new Resource(Double.valueOf(instanceRAMValue), Double.valueOf(instanceDiskValue), Double.valueOf(instanceCPUValue));
                taskInstancePlanMap.put(id, new TaskInstancePlan(id.getTaskName(), id.getTaskId(), id.getTaskIndex(), instanceResource));
                containerRAMValue += instanceRAMValue;
                containerDiskValue += instanceDiskValue;
                containerCpuValue += instanceDiskValue;
            }
            Worker worker = workerPlan.getWorker(containerId);
            Resource containerResource = worker != null && worker.getCpu() > 0 && worker.getDisk() > 0 && worker.getRam() > 0 ? new Resource(Double.valueOf(worker.getRam()), Double.valueOf(worker.getDisk()), Double.valueOf(worker.getCpu())) : new Resource(Double.valueOf(containerRAMValue), Double.valueOf(containerDiskValue), Double.valueOf(containerCpuValue));
            WorkerSchedulePlan taskWorkerSchedulePlan = new WorkerSchedulePlan(containerId, new LinkedHashSet(taskInstancePlanMap.values()), containerResource);
            workerSchedulePlans.add(taskWorkerSchedulePlan);
            if (!this.dependentGraphs || this.index != 0) continue;
            this.workerIdList.add(containerId);
        }
        ++this.index;
        TaskSchedulePlan taskSchedulePlan = new TaskSchedulePlan(0, workerSchedulePlans);
        if (this.workerId == 0) {
            Map containersMap = taskSchedulePlan.getContainersMap();
            for (Map.Entry entry : containersMap.entrySet()) {
                Integer integer = (Integer)entry.getKey();
                WorkerSchedulePlan workerSchedulePlan = (WorkerSchedulePlan)entry.getValue();
                Set containerPlanTaskInstances = workerSchedulePlan.getTaskInstances();
                LOG.fine("Graph Name:" + computeGraph.getGraphName() + "\tcontainer id:" + integer);
                for (TaskInstancePlan ip : containerPlanTaskInstances) {
                    LOG.fine("Task Id:" + ip.getTaskId() + "\tIndex" + ip.getTaskIndex() + "\tName:" + ip.getTaskName());
                }
            }
        }
        return taskSchedulePlan;
    }

    private Map<Integer, List<TaskInstanceId>> batchSchedulingAlgorithm(ComputeGraph graph, int numberOfContainers) throws TaskSchedulerException {
        LinkedHashSet taskVertexSet = new LinkedHashSet(graph.getTaskVertexSet());
        TreeSet<Vertex> orderedTaskSet = new TreeSet<Vertex>(new VertexComparator());
        orderedTaskSet.addAll(taskVertexSet);
        IntStream.range(0, numberOfContainers).forEach(i1 -> this.batchTaskAllocation.put(i1, new ArrayList()));
        int globalTaskIndex = 0;
        if (this.dependentGraphs) {
            for (Vertex vertex : taskVertexSet) {
                INode iNode = vertex.getTask();
                if (iNode instanceof Receptor) {
                    this.validateReceptor(graph, vertex);
                }
                this.dependentTaskWorkerAllocation(graph, vertex, numberOfContainers, globalTaskIndex);
                ++globalTaskIndex;
            }
        } else {
            for (Vertex vertex : taskVertexSet) {
                INode iNode = vertex.getTask();
                if (iNode instanceof Collector) {
                    ((Collector)iNode).getCollectibleNames().forEach(key -> collectibleNameMap.put((String)key, vertex.getParallelism()));
                } else if (iNode instanceof Receptor) {
                    ((Receptor)iNode).getReceivableNames().forEach(key -> receivableNameMap.put((String)key, vertex.getParallelism()));
                    this.validateParallelism();
                }
                this.independentTaskWorkerAllocation(graph, vertex, numberOfContainers, globalTaskIndex);
                ++globalTaskIndex;
            }
        }
        return this.batchTaskAllocation;
    }

    private void validateParallelism() {
        for (Map.Entry<String, Integer> receivable : receivableNameMap.entrySet()) {
            int collectorParallel;
            if (!collectibleNameMap.containsKey(receivable.getKey()) || (collectorParallel = collectibleNameMap.get(receivable.getKey()).intValue()) == receivable.getValue()) continue;
            throw new Twister2RuntimeException("Please verify the dependent collector(s) and receptor(s) parallelism values which are not equal");
        }
    }

    private void dependentTaskWorkerAllocation(ComputeGraph graph, Vertex vertex, int numberOfContainers, int globalTaskIndex) {
        if (graph.getNodeConstraints().isEmpty()) {
            int totalTaskInstances = this.taskAttributes.getTotalNumberOfInstances(vertex);
            String task = vertex.getName();
            for (int i = 0; i < totalTaskInstances; ++i) {
                int containerIndex = this.workerIdList.size() == 0 ? i % numberOfContainers : i % this.workerIdList.size();
                this.batchTaskAllocation.get(containerIndex).add(new TaskInstanceId(task, globalTaskIndex, i));
            }
        } else {
            int totalTaskInstances = this.taskAttributes.getTotalNumberOfInstances(vertex, (Map<String, Map<String, String>>)graph.getNodeConstraints());
            int instancesPerWorker = this.taskAttributes.getInstancesPerWorker(graph.getGraphConstraints());
            int maxTaskInstancesPerContainer = 0;
            for (int i = 0; i < totalTaskInstances; ++i) {
                int containerIndex = this.workerIdList.size() == 0 ? i % numberOfContainers : i % this.workerIdList.size();
                if (maxTaskInstancesPerContainer < instancesPerWorker) {
                    this.batchTaskAllocation.get(containerIndex).add(new TaskInstanceId(vertex.getName(), globalTaskIndex, i));
                    ++maxTaskInstancesPerContainer;
                    continue;
                }
                throw new TaskSchedulerException("Task Scheduling couldn't be possible for the presentconfiguration, please check the number of workers maximum instances per worker");
            }
        }
    }

    private void independentTaskWorkerAllocation(ComputeGraph graph, Vertex vertex, int numberOfContainers, int globalTaskIndex) {
        int totalTaskInstances = !graph.getNodeConstraints().isEmpty() ? this.taskAttributes.getTotalNumberOfInstances(vertex, (Map<String, Map<String, String>>)graph.getNodeConstraints()) : this.taskAttributes.getTotalNumberOfInstances(vertex);
        if (!graph.getNodeConstraints().isEmpty()) {
            int instancesPerWorker = this.taskAttributes.getInstancesPerWorker(graph.getGraphConstraints());
            int maxTaskInstancesPerContainer = 0;
            for (int i = 0; i < totalTaskInstances; ++i) {
                int containerIndex = i % numberOfContainers;
                if (maxTaskInstancesPerContainer < instancesPerWorker) {
                    this.batchTaskAllocation.get(containerIndex).add(new TaskInstanceId(vertex.getName(), globalTaskIndex, i));
                    ++maxTaskInstancesPerContainer;
                    continue;
                }
                throw new TaskSchedulerException("Task Scheduling couldn't be possible for the presentconfiguration, please check the number of workers, maximum instances per worker");
            }
        } else {
            String task = vertex.getName();
            for (int i = 0; i < totalTaskInstances; ++i) {
                int containerIndex = i % numberOfContainers;
                this.batchTaskAllocation.get(containerIndex).add(new TaskInstanceId(task, globalTaskIndex, i));
            }
        }
    }

    private void validateReceptor(ComputeGraph graph, Vertex vertex) {
        Set edges = graph.outEdges(vertex);
        for (Edge e : edges) {
            Vertex child = graph.childOfTask(vertex, e.getName());
            if (!(child.getTask() instanceof Collector) || child.getParallelism() == vertex.getParallelism()) continue;
            throw new TaskSchedulerException("Specify the same parallelism for parent and child tasks which depends on the input from the parent in" + graph.getGraphName() + " graph");
        }
    }

    private static class VertexComparator
    implements Comparator<Vertex> {
        private VertexComparator() {
        }

        @Override
        public int compare(Vertex o1, Vertex o2) {
            return o1.getName().compareTo(o2.getName());
        }
    }
}

