/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.server;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.SetMultimap;
import com.google.common.util.concurrent.ListenableFuture;
import io.github.resilience4j.bulkhead.Bulkhead;
import io.github.resilience4j.bulkhead.BulkheadConfig;
import io.github.resilience4j.bulkhead.BulkheadRegistry;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import org.apache.druid.client.SegmentServerSelector;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.guava.LazySequence;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryWatcher;
import org.apache.druid.server.QueryCapacityExceededException;
import org.apache.druid.server.QueryLaningStrategy;
import org.apache.druid.server.QueryPrioritizationStrategy;
import org.apache.druid.server.initialization.ServerConfig;

public class QueryScheduler
implements QueryWatcher {
    public static final int UNAVAILABLE = -1;
    public static final String TOTAL = "total";
    private final int totalCapacity;
    private final QueryPrioritizationStrategy prioritizationStrategy;
    private final QueryLaningStrategy laningStrategy;
    private final BulkheadRegistry laneRegistry;
    private final SetMultimap<String, ListenableFuture<?>> queryFutures;
    private final SetMultimap<String, String> queryDatasources;

    public QueryScheduler(int totalNumThreads, QueryPrioritizationStrategy prioritizationStrategy, QueryLaningStrategy laningStrategy, ServerConfig serverConfig) {
        boolean limitTotal;
        this.prioritizationStrategy = prioritizationStrategy;
        this.laningStrategy = laningStrategy;
        this.queryFutures = Multimaps.synchronizedSetMultimap((SetMultimap)HashMultimap.create());
        this.queryDatasources = Multimaps.synchronizedSetMultimap((SetMultimap)HashMultimap.create());
        if (totalNumThreads > 0 && totalNumThreads < serverConfig.getNumThreads()) {
            limitTotal = true;
            this.totalCapacity = totalNumThreads;
        } else {
            limitTotal = false;
            this.totalCapacity = serverConfig.getNumThreads();
        }
        this.laneRegistry = BulkheadRegistry.of(this.getLaneConfigs(limitTotal));
    }

    public void registerQueryFuture(Query<?> query, ListenableFuture<?> future) {
        String id = query.getId();
        Set datasources = query.getDataSource().getTableNames();
        this.queryFutures.put((Object)id, future);
        this.queryDatasources.putAll((Object)id, (Iterable)datasources);
        future.addListener(() -> {
            this.queryFutures.remove((Object)id, (Object)future);
            for (String datasource : datasources) {
                this.queryDatasources.remove((Object)id, (Object)datasource);
            }
        }, (Executor)Execs.directExecutor());
    }

    public <T> Query<T> prioritizeAndLaneQuery(QueryPlus<T> queryPlus, Set<SegmentServerSelector> segments) {
        Query query = queryPlus.getQuery();
        Optional<Integer> priority = this.prioritizationStrategy.computePriority(queryPlus, segments);
        query = priority.map(arg_0 -> ((Query)query).withPriority(arg_0)).orElse(query);
        Optional<String> lane = this.laningStrategy.computeLane(queryPlus.withQuery(query), segments);
        return lane.map(arg_0 -> ((Query)query).withLane(arg_0)).orElse(query);
    }

    public <T> Sequence<T> run(Query<?> query, Sequence<T> resultSequence) {
        List<Bulkhead> bulkheads = this.acquireLanes(query);
        return resultSequence.withBaggage(() -> this.finishLanes(bulkheads));
    }

    public <T> QueryRunner<T> wrapQueryRunner(QueryRunner<T> baseRunner) {
        return (queryPlus, responseContext) -> this.run((Query<?>)queryPlus.getQuery(), (Sequence)new LazySequence(() -> baseRunner.run(queryPlus, responseContext)));
    }

    public boolean cancelQuery(String id) {
        this.queryDatasources.removeAll((Object)id);
        Set futures = this.queryFutures.removeAll((Object)id);
        boolean success = true;
        for (ListenableFuture future : futures) {
            success = success && future.cancel(true);
        }
        return success;
    }

    public Set<String> getQueryDatasources(String queryId) {
        return this.queryDatasources.get((Object)queryId);
    }

    @VisibleForTesting
    int getTotalAvailableCapacity() {
        return this.laneRegistry.getConfiguration(TOTAL).map(config -> this.laneRegistry.bulkhead(TOTAL, config).getMetrics().getAvailableConcurrentCalls()).orElse(-1);
    }

    @VisibleForTesting
    int getLaneAvailableCapacity(String lane) {
        return this.laneRegistry.getConfiguration(lane).map(config -> this.laneRegistry.bulkhead(lane, config).getMetrics().getAvailableConcurrentCalls()).orElse(-1);
    }

    @VisibleForTesting
    List<Bulkhead> acquireLanes(Query<?> query) {
        String lane = QueryContexts.getLane(query);
        Optional laneConfig = lane == null ? Optional.empty() : this.laneRegistry.getConfiguration(lane);
        Optional totalConfig = this.laneRegistry.getConfiguration(TOTAL);
        ArrayList<Bulkhead> hallPasses = new ArrayList<Bulkhead>(2);
        try {
            laneConfig.ifPresent(config -> {
                Bulkhead laneLimiter = this.laneRegistry.bulkhead(lane, config);
                if (!laneLimiter.tryAcquirePermission()) {
                    throw new QueryCapacityExceededException(lane, config.getMaxConcurrentCalls());
                }
                hallPasses.add(laneLimiter);
            });
            totalConfig.ifPresent(config -> {
                Bulkhead totalLimiter = this.laneRegistry.bulkhead(TOTAL, config);
                if (!totalLimiter.tryAcquirePermission()) {
                    throw new QueryCapacityExceededException(config.getMaxConcurrentCalls());
                }
                hallPasses.add(totalLimiter);
            });
            return hallPasses;
        }
        catch (Exception ex) {
            this.releaseLanes(hallPasses);
            throw ex;
        }
    }

    @VisibleForTesting
    void releaseLanes(List<Bulkhead> bulkheads) {
        bulkheads.forEach(Bulkhead::releasePermission);
    }

    @VisibleForTesting
    void finishLanes(List<Bulkhead> bulkheads) {
        bulkheads.forEach(Bulkhead::onComplete);
    }

    private Map<String, BulkheadConfig> getLaneConfigs(boolean hasTotalLimit) {
        HashMap<String, BulkheadConfig> configs = new HashMap<String, BulkheadConfig>();
        if (hasTotalLimit) {
            configs.put(TOTAL, BulkheadConfig.custom().maxConcurrentCalls(this.totalCapacity).maxWaitDuration(Duration.ZERO).build());
        }
        for (Object2IntMap.Entry entry : this.laningStrategy.getLaneLimits(this.totalCapacity).object2IntEntrySet()) {
            configs.put((String)entry.getKey(), BulkheadConfig.custom().maxConcurrentCalls(entry.getIntValue()).maxWaitDuration(Duration.ZERO).build());
        }
        return configs;
    }
}

