/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.checkpointing;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nonnull;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.client.ClientUtils;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.HighAvailabilityOptions;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.core.fs.Path;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.checkpoint.CheckpointIDCounter;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter;
import org.apache.flink.runtime.checkpoint.StandaloneCompletedCheckpointStore;
import org.apache.flink.runtime.checkpoint.TestingCheckpointRecoveryFactory;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.highavailability.HighAvailabilityServices;
import org.apache.flink.runtime.highavailability.HighAvailabilityServicesFactory;
import org.apache.flink.runtime.highavailability.nonha.embedded.EmbeddedHaServices;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
import org.apache.flink.runtime.state.AbstractSnapshotStrategy;
import org.apache.flink.runtime.state.BackendBuildingException;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
import org.apache.flink.runtime.state.DefaultOperatorStateBackendBuilder;
import org.apache.flink.runtime.state.DoneFuture;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.OperatorStateBackend;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.SnapshotResult;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.runtime.state.filesystem.FsStateBackend;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.CheckpointingMode;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.api.operators.StreamSink;
import org.apache.flink.streaming.runtime.tasks.ExceptionallyDoneFuture;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.util.TestLogger;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class NotifyCheckpointAbortedITCase
extends TestLogger {
    private static final long DECLINE_CHECKPOINT_ID = 2L;
    private static final long TEST_TIMEOUT = 60000L;
    private static final String DECLINE_SINK_NAME = "DeclineSink";
    private static MiniClusterWithClientResource cluster;
    private static Path checkpointPath;
    @Parameterized.Parameter
    public boolean unalignedCheckpointEnabled;
    @ClassRule
    public static final TemporaryFolder TEMPORARY_FOLDER;

    @Parameterized.Parameters(name="unalignedCheckpointEnabled ={0}")
    public static Collection<Boolean> parameter() {
        return Arrays.asList(true, false);
    }

    @Before
    public void setup() throws Exception {
        Configuration configuration = new Configuration();
        configuration.setBoolean(CheckpointingOptions.LOCAL_RECOVERY, true);
        configuration.setString(HighAvailabilityOptions.HA_MODE, TestingHAFactory.class.getName());
        checkpointPath = new Path(TEMPORARY_FOLDER.newFolder().toURI());
        cluster = new MiniClusterWithClientResource(new MiniClusterResourceConfiguration.Builder().setConfiguration(configuration).setNumberTaskManagers(1).setNumberSlotsPerTaskManager(1).build());
        cluster.before();
        NormalMap.reset();
        DeclineSink.reset();
        TestingCompletedCheckpointStore.reset();
    }

    @After
    public void shutdown() {
        if (cluster != null) {
            cluster.after();
            cluster = null;
        }
    }

    @Test(timeout=60000L)
    public void testNotifyCheckpointAborted() throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.enableCheckpointing(200L, CheckpointingMode.EXACTLY_ONCE);
        env.getCheckpointConfig().enableUnalignedCheckpoints(this.unalignedCheckpointEnabled);
        env.getCheckpointConfig().setTolerableCheckpointFailureNumber(1);
        env.disableOperatorChaining();
        env.setParallelism(1);
        DeclineSinkFailingStateBackend failingStateBackend = new DeclineSinkFailingStateBackend(checkpointPath);
        env.setStateBackend((StateBackend)failingStateBackend);
        env.addSource((SourceFunction)new NormalSource()).name("NormalSource").keyBy((KeySelector & Serializable)value -> (Integer)value.f0).transform("NormalMap", TypeInformation.of(Integer.class), (OneInputStreamOperator)new NormalMap()).transform(DECLINE_SINK_NAME, TypeInformation.of(Object.class), (OneInputStreamOperator)new DeclineSink());
        ClusterClient clusterClient = cluster.getClusterClient();
        JobGraph jobGraph = env.getStreamGraph().getJobGraph();
        JobID jobID = jobGraph.getJobID();
        ClientUtils.submitJob((ClusterClient)clusterClient, (JobGraph)jobGraph);
        TestingCompletedCheckpointStore.addCheckpointLatch.await();
        TestingCompletedCheckpointStore.abortCheckpointLatch.trigger();
        this.verifyAllOperatorsNotifyAborted();
        this.resetAllOperatorsNotifyAbortedLatches();
        this.verifyAllOperatorsNotifyAbortedTimes(1);
        DeclineSink.waitLatch.trigger();
        this.verifyAllOperatorsNotifyAborted();
        this.verifyAllOperatorsNotifyAbortedTimes(2);
        clusterClient.cancel(jobID).get();
    }

    private void verifyAllOperatorsNotifyAborted() throws InterruptedException {
        NormalMap.notifiedAbortedLatch.await();
        DeclineSink.notifiedAbortedLatch.await();
    }

    private void resetAllOperatorsNotifyAbortedLatches() {
        NormalMap.notifiedAbortedLatch.reset();
        DeclineSink.notifiedAbortedLatch.reset();
    }

    private void verifyAllOperatorsNotifyAbortedTimes(int expectedTimes) {
        Assert.assertEquals((long)expectedTimes, (long)NormalMap.notifiedAbortedTimes.get());
        Assert.assertEquals((long)expectedTimes, (long)DeclineSink.notifiedAbortedTimes.get());
    }

    static {
        TEMPORARY_FOLDER = new TemporaryFolder();
    }

    public static class TestingHAFactory
    implements HighAvailabilityServicesFactory {
        public HighAvailabilityServices createHAServices(Configuration configuration, Executor executor) {
            return new TestingHaServices((CheckpointRecoveryFactory)new TestingCheckpointRecoveryFactory((CompletedCheckpointStore)new TestingCompletedCheckpointStore(), (CheckpointIDCounter)new StandaloneCheckpointIDCounter()), executor);
        }
    }

    private static class TestingCompletedCheckpointStore
    extends StandaloneCompletedCheckpointStore {
        private static final OneShotLatch addCheckpointLatch = new OneShotLatch();
        private static final OneShotLatch abortCheckpointLatch = new OneShotLatch();

        TestingCompletedCheckpointStore() {
            super(1);
        }

        public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception {
            if (!abortCheckpointLatch.isTriggered()) {
                addCheckpointLatch.trigger();
                abortCheckpointLatch.await();
                throw new ExpectedTestException();
            }
            super.addCheckpoint(checkpoint);
        }

        static void reset() {
            addCheckpointLatch.reset();
            abortCheckpointLatch.reset();
        }
    }

    private static class TestingHaServices
    extends EmbeddedHaServices {
        private final CheckpointRecoveryFactory checkpointRecoveryFactory;

        TestingHaServices(CheckpointRecoveryFactory checkpointRecoveryFactory, Executor executor) {
            super(executor);
            this.checkpointRecoveryFactory = checkpointRecoveryFactory;
        }

        public CheckpointRecoveryFactory getCheckpointRecoveryFactory() {
            return this.checkpointRecoveryFactory;
        }
    }

    private static class DeclineSinkFailingStateBackend
    extends FsStateBackend {
        private static final long serialVersionUID = 1L;

        public DeclineSinkFailingStateBackend(Path checkpointDataUri) {
            super(checkpointDataUri);
        }

        public DeclineSinkFailingStateBackend configure(ReadableConfig config, ClassLoader classLoader) {
            return new DeclineSinkFailingStateBackend(checkpointPath);
        }

        public OperatorStateBackend createOperatorStateBackend(Environment env, String operatorIdentifier, @Nonnull Collection<OperatorStateHandle> stateHandles, CloseableRegistry cancelStreamRegistry) throws BackendBuildingException {
            if (operatorIdentifier.contains(NotifyCheckpointAbortedITCase.DECLINE_SINK_NAME)) {
                return new DeclineSinkFailingOperatorStateBackend(env.getExecutionConfig(), cancelStreamRegistry, new DeclineSinkFailingSnapshotStrategy());
            }
            return new DefaultOperatorStateBackendBuilder(env.getUserClassLoader(), env.getExecutionConfig(), false, stateHandles, cancelStreamRegistry).build();
        }
    }

    private static class DeclineSinkFailingOperatorStateBackend
    extends DefaultOperatorStateBackend {
        public DeclineSinkFailingOperatorStateBackend(ExecutionConfig executionConfig, CloseableRegistry closeStreamOnCancelRegistry, AbstractSnapshotStrategy<OperatorStateHandle> snapshotStrategy) {
            super(executionConfig, closeStreamOnCancelRegistry, new HashMap(), new HashMap(), new HashMap(), new HashMap(), snapshotStrategy);
        }
    }

    private static class DeclineSinkFailingSnapshotStrategy
    extends AbstractSnapshotStrategy<OperatorStateHandle> {
        protected DeclineSinkFailingSnapshotStrategy() {
            super("StuckAsyncSnapshotStrategy");
        }

        public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(long checkpointId, long timestamp, @Nonnull CheckpointStreamFactory streamFactory, @Nonnull CheckpointOptions checkpointOptions) {
            if (checkpointId == 2L) {
                return ExceptionallyDoneFuture.of((Throwable)new ExpectedTestException());
            }
            return DoneFuture.of((Object)SnapshotResult.empty());
        }
    }

    private static class DeclineSink
    extends StreamSink<Integer> {
        private static final long serialVersionUID = 1L;
        private static final OneShotLatch notifiedAbortedLatch = new OneShotLatch();
        private static final OneShotLatch waitLatch = new OneShotLatch();
        private static final AtomicInteger notifiedAbortedTimes = new AtomicInteger(0);

        public DeclineSink() {
            super((SinkFunction)new SinkFunction<Integer>(){
                private static final long serialVersionUID = 1L;
            });
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            if (context.getCheckpointId() == 2L) {
                waitLatch.await();
            }
            super.snapshotState(context);
        }

        public void notifyCheckpointAborted(long checkpointId) {
            notifiedAbortedTimes.incrementAndGet();
            notifiedAbortedLatch.trigger();
        }

        static void reset() {
            notifiedAbortedLatch.reset();
            waitLatch.reset();
            notifiedAbortedTimes.set(0);
        }
    }

    private static class NormalMapFunction
    implements MapFunction<Tuple2<Integer, Integer>, Integer>,
    CheckpointedFunction {
        private static final long serialVersionUID = 1L;
        private ValueState<Integer> valueState;

        private NormalMapFunction() {
        }

        public Integer map(Tuple2<Integer, Integer> value) throws Exception {
            this.valueState.update(value.f1);
            return (Integer)value.f1;
        }

        public void snapshotState(FunctionSnapshotContext context) {
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.valueState = context.getKeyedStateStore().getState(new ValueStateDescriptor("value", Integer.class));
        }
    }

    private static class NormalMap
    extends StreamMap<Tuple2<Integer, Integer>, Integer> {
        private static final long serialVersionUID = 1L;
        private static final OneShotLatch notifiedAbortedLatch = new OneShotLatch();
        private static final AtomicInteger notifiedAbortedTimes = new AtomicInteger(0);

        public NormalMap() {
            super((MapFunction)new NormalMapFunction());
        }

        public void notifyCheckpointAborted(long checkpointId) {
            notifiedAbortedTimes.incrementAndGet();
            notifiedAbortedLatch.trigger();
        }

        static void reset() {
            notifiedAbortedLatch.reset();
            notifiedAbortedTimes.set(0);
        }
    }

    private static class NormalSource
    implements SourceFunction<Tuple2<Integer, Integer>> {
        private static final long serialVersionUID = 1L;
        protected volatile boolean running = true;

        NormalSource() {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void run(SourceFunction.SourceContext<Tuple2<Integer, Integer>> ctx) throws Exception {
            while (this.running) {
                Object object = ctx.getCheckpointLock();
                synchronized (object) {
                    ctx.collect((Object)Tuple2.of((Object)ThreadLocalRandom.current().nextInt(), (Object)ThreadLocalRandom.current().nextInt()));
                }
                Thread.sleep(10L);
            }
        }

        public void cancel() {
            this.running = false;
        }
    }
}

