/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.server.coordinator.balancer;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.druid.client.DruidServer;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.server.coordination.ServerType;
import org.apache.druid.server.coordinator.CreateDataSegments;
import org.apache.druid.server.coordinator.ServerHolder;
import org.apache.druid.server.coordinator.balancer.BalancerSegmentHolder;
import org.apache.druid.server.coordinator.balancer.ReservoirSegmentSampler;
import org.apache.druid.server.coordinator.loading.LoadQueuePeon;
import org.apache.druid.server.coordinator.loading.SegmentAction;
import org.apache.druid.server.coordinator.loading.TestLoadQueuePeon;
import org.apache.druid.timeline.DataSegment;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class ReservoirSegmentSamplerTest {
    private final List<DataSegment> segments = CreateDataSegments.ofDatasource("wiki").forIntervals(100, Granularities.DAY).startingAt("2022-01-01").withNumPartitions(10).eachOfSizeInMb(100L);

    @Before
    public void setUp() {
    }

    @Test
    public void testEverySegmentGetsPickedAtleastOnce() {
        int iterations = 50;
        List<ServerHolder> servers = Arrays.asList(this.createHistorical("server1", this.segments.get(0)), this.createHistorical("server2", this.segments.get(1)), this.createHistorical("server3", this.segments.get(2)), this.createHistorical("server4", this.segments.get(3)));
        HashMap<DataSegment, Integer> segmentCountMap = new HashMap<DataSegment, Integer>();
        for (int i = 0; i < iterations; ++i) {
            segmentCountMap.compute(((BalancerSegmentHolder)ReservoirSegmentSampler.pickMovableSegmentsFrom(servers, (int)1, ServerHolder::getServedSegments, Collections.emptySet()).get(0)).getSegment(), (segment, count) -> count == null ? 1 : count + 1);
        }
        Assert.assertEquals((long)4L, (long)segmentCountMap.size());
    }

    @Test
    public void getRandomBalancerSegmentHolderTestSegmentsToConsiderLimit() {
        int iterations = 50;
        DataSegment excludedSegment = this.segments.get(3);
        List<ServerHolder> servers = Arrays.asList(this.createHistorical("server1", this.segments.get(0)), this.createHistorical("server2", this.segments.get(1)), this.createHistorical("server3", this.segments.get(2)), this.createHistorical("server4", excludedSegment));
        HashMap<DataSegment, Integer> segmentCountMap = new HashMap<DataSegment, Integer>();
        double percentOfSegmentsToConsider = 75.0;
        for (int i = 0; i < iterations; ++i) {
            segmentCountMap.compute(ReservoirSegmentSampler.getRandomBalancerSegmentHolder(servers, Collections.emptySet(), (double)75.0).getSegment(), (segment, count) -> count == null ? 1 : count + 1);
        }
        Assert.assertFalse((boolean)segmentCountMap.containsKey(excludedSegment));
        Assert.assertEquals((long)3L, (long)segmentCountMap.size());
    }

    @Test
    public void testPickLoadingOrLoadedSegments() {
        List<DataSegment> loadedSegments = Arrays.asList(this.segments.get(0), this.segments.get(1));
        List<DataSegment> loadingSegments = Arrays.asList(this.segments.get(2), this.segments.get(3));
        ServerHolder server1 = this.createHistorical("server1", loadedSegments.get(0));
        server1.startOperation(SegmentAction.LOAD, loadingSegments.get(0));
        ServerHolder server2 = this.createHistorical("server2", loadedSegments.get(1));
        server2.startOperation(SegmentAction.LOAD, loadingSegments.get(1));
        Set pickedSegments = ReservoirSegmentSampler.pickMovableSegmentsFrom(Arrays.asList(server1, server2), (int)10, ServerHolder::getLoadingSegments, Collections.emptySet()).stream().map(BalancerSegmentHolder::getSegment).collect(Collectors.toSet());
        Assert.assertEquals((long)loadingSegments.size(), (long)pickedSegments.size());
        Assert.assertTrue((boolean)pickedSegments.containsAll(loadingSegments));
        List pickedHolders = ReservoirSegmentSampler.pickMovableSegmentsFrom(Arrays.asList(server1, server2), (int)10, ServerHolder::getServedSegments, Collections.emptySet());
        pickedSegments = pickedHolders.stream().map(BalancerSegmentHolder::getSegment).collect(Collectors.toSet());
        Assert.assertEquals((long)loadedSegments.size(), (long)pickedSegments.size());
        Assert.assertTrue((boolean)pickedSegments.containsAll(loadedSegments));
    }

    @Test
    public void testSegmentsOnBrokersAreIgnored() {
        ServerHolder historical = this.createHistorical("hist1", this.segments.get(0), this.segments.get(1));
        ServerHolder broker = new ServerHolder(new DruidServer("broker1", "broker1", null, 1000L, ServerType.BROKER, null, 1).addDataSegment(this.segments.get(2)).addDataSegment(this.segments.get(3)).toImmutableDruidServer(), (LoadQueuePeon)new TestLoadQueuePeon());
        List pickedSegments = ReservoirSegmentSampler.pickMovableSegmentsFrom(Arrays.asList(historical, broker), (int)10, ServerHolder::getServedSegments, Collections.emptySet());
        Assert.assertEquals((long)2L, (long)pickedSegments.size());
        for (BalancerSegmentHolder holder : pickedSegments) {
            Assert.assertEquals((Object)historical, (Object)holder.getServer());
        }
    }

    @Test
    public void testBroadcastSegmentsAreIgnored() {
        String broadcastDatasource = "ds_broadcast";
        List<DataSegment> broadcastSegments = CreateDataSegments.ofDatasource("ds_broadcast").forIntervals(4, Granularities.DAY).startingAt("2022-01-01").withNumPartitions(1).eachOfSizeInMb(100L);
        List<ServerHolder> servers = Arrays.asList(this.createHistorical("server1", broadcastSegments.toArray(new DataSegment[0])), this.createHistorical("server2", this.segments.get(0), this.segments.get(1)));
        List pickedSegments = ReservoirSegmentSampler.pickMovableSegmentsFrom(servers, (int)10, ServerHolder::getServedSegments, Collections.singleton("ds_broadcast"));
        Assert.assertEquals((long)2L, (long)pickedSegments.size());
        for (BalancerSegmentHolder holder : pickedSegments) {
            Assert.assertNotEquals((Object)"ds_broadcast", (Object)holder.getSegment().getDataSource());
        }
    }

    @Test
    public void testSegmentsFromAllServersAreEquallyLikelyToBePicked() {
        int[] samplePercentages;
        List subSegmentLists = Lists.partition(this.segments, (int)(this.segments.size() / 4));
        List<ServerHolder> servers = IntStream.range(0, 4).mapToObj(i -> this.createHistorical("server_" + i, ((List)subSegmentLists.get(i)).toArray(new DataSegment[0]))).collect(Collectors.toList());
        for (int samplePercentage : samplePercentages = new int[]{50, 20, 10, 5}) {
            int[] numSegmentsPickedFromServer = this.pickSegmentsAndGetPickedCountPerServer(servers, samplePercentage, 50);
            int totalSegmentsPicked = Arrays.stream(numSegmentsPickedFromServer).sum();
            double expectedPickedSegments = (double)totalSegmentsPicked * 0.25;
            double error = (double)totalSegmentsPicked * 0.02;
            for (int pickedSegments : numSegmentsPickedFromServer) {
                Assert.assertEquals((double)expectedPickedSegments, (double)pickedSegments, (double)error);
            }
        }
    }

    @Test
    public void testSegmentsFromMorePopulousServerAreMoreLikelyToBePicked() {
        int[] samplePercentages;
        List subSegmentLists = Lists.partition(this.segments, (int)(this.segments.size() / 5));
        ArrayList<ServerHolder> servers = new ArrayList<ServerHolder>();
        ArrayList<DataSegment> segmentsForServer0 = new ArrayList<DataSegment>((Collection)subSegmentLists.get(0));
        segmentsForServer0.addAll((Collection)subSegmentLists.get(1));
        servers.add(this.createHistorical("server_0", segmentsForServer0));
        IntStream.range(1, 4).mapToObj(i -> this.createHistorical("server_" + i, (List)subSegmentLists.get(i + 1))).forEach(servers::add);
        for (int samplePercentage : samplePercentages = new int[]{50, 20, 10, 5}) {
            int[] numSegmentsPickedFromServer = this.pickSegmentsAndGetPickedCountPerServer(servers, samplePercentage, 50);
            int totalSegmentsPicked = Arrays.stream(numSegmentsPickedFromServer).sum();
            double error = (double)totalSegmentsPicked * 0.02;
            Assert.assertEquals((double)((double)totalSegmentsPicked * 0.4), (double)numSegmentsPickedFromServer[0], (double)error);
            for (int serverId = 1; serverId < servers.size(); ++serverId) {
                Assert.assertEquals((double)((double)totalSegmentsPicked * 0.2), (double)numSegmentsPickedFromServer[serverId], (double)error);
            }
        }
    }

    @Test(timeout=60000L)
    public void testNumberOfSamplingsRequiredToPickAllSegments() {
        int[] samplePercentages = new int[]{100, 50, 10, 5, 1};
        int[] expectedIterations = new int[]{1, 20, 100, 200, 1000};
        int[] totalObservedIterations = new int[5];
        for (int i = 0; i < 50; ++i) {
            for (int j = 0; j < samplePercentages.length; ++j) {
                int n = j;
                totalObservedIterations[n] = totalObservedIterations[n] + this.countMinRunsToPickAllSegments(samplePercentages[j]);
            }
        }
        for (int j = 0; j < samplePercentages.length; ++j) {
            double avgObservedIterations = (double)totalObservedIterations[j] / 50.0;
            Assert.assertTrue((avgObservedIterations <= (double)expectedIterations[j] ? 1 : 0) != 0);
        }
    }

    private int countMinRunsToPickAllSegments(int samplePercentage) {
        int numIterations;
        int numSegments = this.segments.size();
        List<ServerHolder> servers = Arrays.asList(this.createHistorical("server1", this.segments.subList(0, numSegments / 2).toArray(new DataSegment[0])), this.createHistorical("server2", this.segments.subList(numSegments / 2, numSegments).toArray(new DataSegment[0])));
        HashSet pickedSegments = new HashSet();
        int sampleSize = (int)((double)(numSegments * samplePercentage) / 100.0);
        for (numIterations = 1; numIterations < 10000; ++numIterations) {
            ReservoirSegmentSampler.pickMovableSegmentsFrom(servers, (int)sampleSize, ServerHolder::getServedSegments, Collections.emptySet()).forEach(holder -> pickedSegments.add(holder.getSegment()));
            if (pickedSegments.size() >= numSegments) break;
        }
        return numIterations;
    }

    private int[] pickSegmentsAndGetPickedCountPerServer(List<ServerHolder> servers, int samplePercentage, int numIterations) {
        int numSegmentsToPick = (int)((double)(this.segments.size() * samplePercentage) / 100.0);
        int[] numSegmentsPickedFromServer = new int[servers.size()];
        for (int i = 0; i < numIterations; ++i) {
            List pickedSegments = ReservoirSegmentSampler.pickMovableSegmentsFrom(servers, (int)numSegmentsToPick, ServerHolder::getServedSegments, Collections.emptySet());
            for (BalancerSegmentHolder pickedSegment : pickedSegments) {
                int serverIndex;
                int n = serverIndex = servers.indexOf(pickedSegment.getServer());
                numSegmentsPickedFromServer[n] = numSegmentsPickedFromServer[n] + 1;
            }
        }
        return numSegmentsPickedFromServer;
    }

    private ServerHolder createHistorical(String serverName, List<DataSegment> loadedSegments) {
        return this.createHistorical(serverName, loadedSegments.toArray(new DataSegment[0]));
    }

    private ServerHolder createHistorical(String serverName, DataSegment ... loadedSegments) {
        DruidServer server = new DruidServer(serverName, serverName, null, 100000L, ServerType.HISTORICAL, "normal", 1);
        for (DataSegment segment : loadedSegments) {
            server.addDataSegment(segment);
        }
        return new ServerHolder(server.toImmutableDruidServer(), (LoadQueuePeon)new TestLoadQueuePeon());
    }
}

