/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ratis.server.impl;

import java.util.Comparator;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.ratis.BaseTest;
import org.apache.ratis.RaftTestUtil;
import org.apache.ratis.client.RaftClient;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.proto.RaftProtos;
import org.apache.ratis.protocol.Message;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.impl.MiniRaftCluster;
import org.apache.ratis.server.impl.StateMachineUpdater;
import org.apache.ratis.statemachine.StateMachine;
import org.apache.ratis.statemachine.TransactionContext;
import org.apache.ratis.statemachine.impl.SimpleStateMachine4Testing;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class StateMachineShutdownTests<CLUSTER extends MiniRaftCluster>
extends BaseTest
implements MiniRaftCluster.Factory.Get<CLUSTER> {
    public static Logger LOG = LoggerFactory.getLogger(StateMachineUpdater.class);
    private static MockedStatic<CompletableFuture> mocked;

    @BeforeEach
    public void setup() {
        mocked = Mockito.mockStatic(CompletableFuture.class, (Answer)Mockito.CALLS_REAL_METHODS);
    }

    @AfterEach
    public void tearDownClass() {
        if (mocked != null) {
            mocked.close();
        }
    }

    @Test
    public void testStateMachineShutdownWaitsForApplyTxn() throws Exception {
        RaftProperties prop = this.getProperties();
        prop.setClass(MiniRaftCluster.STATEMACHINE_CLASS_KEY, StateMachineWithConditionalWait.class, StateMachine.class);
        MiniRaftCluster cluster = this.newCluster(3);
        cluster.start();
        RaftTestUtil.waitForLeader(cluster);
        RaftServer.Division leader = cluster.getLeader();
        RaftPeerId leaderId = leader.getId();
        ((StateMachineWithConditionalWait)leader.getStateMachine()).unblockAllTxns();
        ((StateMachineWithConditionalWait)cluster.getFollowers().get(0).getStateMachine()).unblockAllTxns();
        cluster.getLeaderAndSendFirstMessage(true);
        try (RaftClient client = cluster.createClient(leaderId);){
            client.io().send((Message)new RaftTestUtil.SimpleMessage("message"));
            RaftClientReply reply = client.io().send((Message)new RaftTestUtil.SimpleMessage("message2"));
            long logIndex = reply.getLogIndex();
            RaftClientReply watchReply = client.io().watch(logIndex, RaftProtos.ReplicationLevel.ALL_COMMITTED);
            watchReply.getCommitInfos().forEach(val -> Assertions.assertTrue((val.getCommitIndex() >= logIndex ? 1 : 0) != 0));
            RaftServer.Division secondFollower = cluster.getFollowers().get(1);
            Assertions.assertTrue((secondFollower.getInfo().getLastAppliedIndex() < logIndex ? 1 : 0) != 0);
            Thread t = new Thread(() -> ((RaftServer.Division)secondFollower).close());
            t.start();
            long minIndex = (Long)((StateMachineWithConditionalWait)secondFollower.getStateMachine()).blockTxns.stream().min(Comparator.naturalOrder()).get();
            Assertions.assertEquals((long)2L, (long)StateMachineWithConditionalWait.numTxns.values().stream().filter(val -> val.get() == 3L).count());
            Assertions.assertTrue((secondFollower.getInfo().getLastAppliedIndex() < minIndex ? 1 : 0) != 0);
            for (long index : ((StateMachineWithConditionalWait)secondFollower.getStateMachine()).blockTxns) {
                if (minIndex == index) continue;
                ((StateMachineWithConditionalWait)secondFollower.getStateMachine()).unBlockApplyTxn(index);
            }
            Assertions.assertEquals((long)2L, (long)StateMachineWithConditionalWait.numTxns.values().stream().filter(val -> val.get() == 3L).count());
            Assertions.assertTrue((secondFollower.getInfo().getLastAppliedIndex() < minIndex ? 1 : 0) != 0);
            ((StateMachineWithConditionalWait)secondFollower.getStateMachine()).unBlockApplyTxn(minIndex);
            t.join(5000L);
            Assertions.assertEquals((long)logIndex, (long)secondFollower.getInfo().getLastAppliedIndex());
            Assertions.assertEquals((long)3L, (long)StateMachineWithConditionalWait.numTxns.values().stream().filter(val -> val.get() == 3L).count());
            cluster.shutdown();
        }
    }

    protected static class StateMachineWithConditionalWait
    extends SimpleStateMachine4Testing {
        boolean unblockAllTxns = false;
        final Set<Long> blockTxns = ConcurrentHashMap.newKeySet();
        private final ExecutorService executor = Executors.newFixedThreadPool(10);
        public static Map<Long, Set<CompletableFuture<Message>>> futures = new ConcurrentHashMap<Long, Set<CompletableFuture<Message>>>();
        public static Map<RaftPeerId, AtomicLong> numTxns = new ConcurrentHashMap<RaftPeerId, AtomicLong>();
        private final Map<Long, Long> appliedTxns = new ConcurrentHashMap<Long, Long>();

        protected StateMachineWithConditionalWait() {
        }

        private synchronized void updateTxns() {
            long appliedIndex = this.getLastAppliedTermIndex().getIndex() + 1L;
            Long appliedTerm = null;
            while (this.appliedTxns.containsKey(appliedIndex)) {
                appliedTerm = this.appliedTxns.remove(appliedIndex);
                ++appliedIndex;
            }
            if (appliedTerm != null) {
                this.updateLastAppliedTermIndex(appliedTerm, appliedIndex - 1L);
            }
        }

        public void notifyTermIndexUpdated(long term, long index) {
            this.appliedTxns.put(index, term);
            this.updateTxns();
        }

        @Override
        public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
            RaftProtos.LogEntryProto entry = trx.getLogEntry();
            CompletableFuture<Message> future = new CompletableFuture<Message>();
            futures.computeIfAbsent(Thread.currentThread().getId(), k -> new HashSet()).add(future);
            this.executor.submit(() -> {
                Set<Long> set = this.blockTxns;
                synchronized (set) {
                    if (!this.unblockAllTxns) {
                        this.blockTxns.add(entry.getIndex());
                    }
                    while (!this.unblockAllTxns && this.blockTxns.contains(entry.getIndex())) {
                        try {
                            this.blockTxns.wait(10000L);
                        }
                        catch (InterruptedException e) {
                            throw new RuntimeException(e);
                        }
                    }
                }
                numTxns.computeIfAbsent(this.getId(), k -> new AtomicLong()).incrementAndGet();
                this.appliedTxns.put(entry.getIndex(), entry.getTerm());
                this.updateTxns();
                future.complete(new RaftTestUtil.SimpleMessage("done"));
            });
            return future;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void unBlockApplyTxn(long txnId) {
            Set<Long> set = this.blockTxns;
            synchronized (set) {
                this.blockTxns.remove(txnId);
                this.blockTxns.notifyAll();
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void unblockAllTxns() {
            this.unblockAllTxns = true;
            Set<Long> set = this.blockTxns;
            synchronized (set) {
                for (Long txnId : this.blockTxns) {
                    this.blockTxns.remove(txnId);
                }
                this.blockTxns.notifyAll();
            }
        }
    }
}

