/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.checkpointing;

import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ExternalizedCheckpointRetention;
import org.apache.flink.configuration.StateBackendOptions;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.testutils.CommonTestUtils;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.KeyedStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.testutils.junit.SharedObjects;
import org.apache.flink.testutils.junit.SharedReference;
import org.apache.flink.util.Collector;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.TestLogger;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

public class RescaleCheckpointManuallyITCase
extends TestLogger {
    private static final int NUM_TASK_MANAGERS = 2;
    private static final int SLOTS_PER_TASK_MANAGER = 2;
    private static MiniClusterWithClientResource cluster;
    @Rule
    public final SharedObjects sharedObjects = SharedObjects.create();
    @ClassRule
    public static TemporaryFolder temporaryFolder;

    @Before
    public void setup() throws Exception {
        Configuration config = new Configuration();
        config.set(StateBackendOptions.STATE_BACKEND, (Object)"rocksdb");
        config.set(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, (Object)true);
        cluster = new MiniClusterWithClientResource(new MiniClusterResourceConfiguration.Builder().setConfiguration(config).setNumberTaskManagers(2).setNumberSlotsPerTaskManager(2).build());
        cluster.before();
    }

    @After
    public void shutDownExistingCluster() {
        if (cluster != null) {
            cluster.after();
            cluster = null;
        }
    }

    @Test
    public void testCheckpointRescalingInKeyedState() throws Exception {
        this.testCheckpointRescalingKeyedState(false);
    }

    @Test
    public void testCheckpointRescalingOutKeyedState() throws Exception {
        this.testCheckpointRescalingKeyedState(true);
    }

    public void testCheckpointRescalingKeyedState(boolean scaleOut) throws Exception {
        int numberKeys = 42;
        int numberElements = 1000;
        int numberElements2 = 500;
        int parallelism = scaleOut ? 3 : 4;
        int parallelism2 = scaleOut ? 4 : 3;
        int maxParallelism = 13;
        MiniCluster miniCluster = cluster.getMiniCluster();
        String checkpointPath = this.runJobAndGetCheckpoint(42, 1000, parallelism, 13, miniCluster);
        Assert.assertNotNull((Object)checkpointPath);
        this.restoreAndAssert(parallelism2, 13, 42, 500, 1500, miniCluster, checkpointPath);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private String runJobAndGetCheckpoint(int numberKeys, int numberElements, int parallelism, int maxParallelism, MiniCluster miniCluster) throws Exception {
        try {
            JobGraph jobGraph = this.createJobGraphWithKeyedState(parallelism, maxParallelism, numberKeys, numberElements, numberElements, true, 100, miniCluster);
            miniCluster.submitJob(jobGraph).get();
            miniCluster.requestJobResult(jobGraph.getJobID()).get();
            String string = (String)CommonTestUtils.getLatestCompletedCheckpointPath((JobID)jobGraph.getJobID(), (MiniCluster)miniCluster).orElseThrow(() -> new IllegalStateException("Cannot get completed checkpoint, job failed before completing checkpoint"));
            return string;
        }
        finally {
            CollectionSink.clearElementsSet();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void restoreAndAssert(int restoreParallelism, int maxParallelism, int numberKeys, int numberElements, int numberElementsExpect, MiniCluster miniCluster, String restorePath) throws Exception {
        try {
            JobGraph scaledJobGraph = this.createJobGraphWithKeyedState(restoreParallelism, maxParallelism, numberKeys, numberElements, numberElementsExpect, false, 100, miniCluster);
            scaledJobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath((String)restorePath));
            miniCluster.submitJob(scaledJobGraph).get();
            miniCluster.requestJobResult(scaledJobGraph.getJobID()).get();
            Set actualResult = CollectionSink.getElementsSet();
            HashSet<Tuple2> expectedResult = new HashSet<Tuple2>();
            for (int key = 0; key < numberKeys; ++key) {
                int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup((Object)key, (int)maxParallelism);
                expectedResult.add(Tuple2.of((Object)KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup((int)maxParallelism, (int)restoreParallelism, (int)keyGroupIndex), (Object)(key * numberElementsExpect)));
            }
            Assert.assertEquals(expectedResult, actualResult);
        }
        finally {
            CollectionSink.clearElementsSet();
        }
    }

    private JobGraph createJobGraphWithKeyedState(int parallelism, int maxParallelism, int numberKeys, int numberElements, int numberElementsExpect, boolean failAfterEmission, int checkpointingInterval, MiniCluster miniCluster) throws IOException {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(parallelism);
        if (0 < maxParallelism) {
            env.getConfig().setMaxParallelism(maxParallelism);
        }
        env.enableCheckpointing((long)checkpointingInterval);
        env.getCheckpointConfig().setCheckpointStorage(temporaryFolder.newFolder().toURI());
        env.getCheckpointConfig().setExternalizedCheckpointRetention(ExternalizedCheckpointRetention.RETAIN_ON_CANCELLATION);
        env.setRestartStrategy(RestartStrategies.noRestart());
        env.getConfig().setUseSnapshotCompression(true);
        final SharedReference jobID = this.sharedObjects.add((Object)new JobID());
        final SharedReference miniClusterRef = this.sharedObjects.add((Object)miniCluster);
        KeyedStream input = env.addSource((SourceFunction)new NotifyingDefiniteKeySource(numberKeys, numberElements, failAfterEmission){

            @Override
            public void waitCheckpointCompleted() throws Exception {
                Optional mostRecentCompletedCheckpointPath = CommonTestUtils.getLatestCompletedCheckpointPath((JobID)((JobID)jobID.get()), (MiniCluster)((MiniCluster)miniClusterRef.get()));
                while (!mostRecentCompletedCheckpointPath.isPresent()) {
                    Thread.sleep(50L);
                    mostRecentCompletedCheckpointPath = CommonTestUtils.getLatestCompletedCheckpointPath((JobID)((JobID)jobID.get()), (MiniCluster)((MiniCluster)miniClusterRef.get()));
                }
            }
        }).keyBy((KeySelector)new KeySelector<Integer, Integer>(){
            private static final long serialVersionUID = 1L;

            public Integer getKey(Integer value) {
                return value;
            }
        });
        SingleOutputStreamOperator result = input.flatMap((FlatMapFunction)new SubtaskIndexFlatMapper(numberElementsExpect));
        result.addSink(new CollectionSink());
        return env.getStreamGraph().getJobGraph(env.getClass().getClassLoader(), (JobID)jobID.get());
    }

    static {
        temporaryFolder = new TemporaryFolder();
    }

    private static class CollectionSink<IN>
    implements SinkFunction<IN> {
        private static final Set<Object> elements = Collections.newSetFromMap(new ConcurrentHashMap());
        private static final long serialVersionUID = 1L;

        private CollectionSink() {
        }

        public static <IN> Set<IN> getElementsSet() {
            return elements;
        }

        public static void clearElementsSet() {
            elements.clear();
        }

        public void invoke(IN value) throws Exception {
            elements.add(value);
        }
    }

    private static class SubtaskIndexFlatMapper
    extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>>
    implements CheckpointedFunction {
        private static final long serialVersionUID = 1L;
        private transient ValueState<Integer> counter;
        private transient ValueState<Integer> sum;
        private final int numberElements;

        public SubtaskIndexFlatMapper(int numberElements) {
            this.numberElements = numberElements;
        }

        public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out) throws Exception {
            Integer counterValue = (Integer)this.counter.value();
            int count = counterValue == null ? 1 : counterValue + 1;
            this.counter.update((Object)count);
            Integer sumValue = (Integer)this.sum.value();
            int s = sumValue == null ? value : sumValue + value;
            this.sum.update((Object)s);
            if (count == this.numberElements) {
                out.collect((Object)Tuple2.of((Object)this.getRuntimeContext().getTaskInfo().getIndexOfThisSubtask(), (Object)s));
            }
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.counter = context.getKeyedStateStore().getState(new ValueStateDescriptor("counter", Integer.class));
            this.sum = context.getKeyedStateStore().getState(new ValueStateDescriptor("sum", Integer.class));
        }
    }

    private static class NotifyingDefiniteKeySource
    extends RichParallelSourceFunction<Integer> {
        private static final long serialVersionUID = 1L;
        private final int numberKeys;
        protected final int numberElements;
        private final boolean failAfterEmission;
        protected int counter = 0;
        private boolean running = true;

        public NotifyingDefiniteKeySource(int numberKeys, int numberElements, boolean failAfterEmission) {
            Preconditions.checkState((numberElements > 0 ? 1 : 0) != 0);
            this.numberKeys = numberKeys;
            this.numberElements = numberElements;
            this.failAfterEmission = failAfterEmission;
        }

        public void waitCheckpointCompleted() throws Exception {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void run(SourceFunction.SourceContext<Integer> ctx) throws Exception {
            int subtaskIndex = this.getRuntimeContext().getTaskInfo().getIndexOfThisSubtask();
            while (this.running) {
                if (this.counter < this.numberElements) {
                    Object object = ctx.getCheckpointLock();
                    synchronized (object) {
                        for (int value = subtaskIndex; value < this.numberKeys; value += this.getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks()) {
                            ctx.collect((Object)value);
                        }
                        ++this.counter;
                        continue;
                    }
                }
                this.waitCheckpointCompleted();
                if (this.failAfterEmission) {
                    throw new FlinkRuntimeException("Make job fail artificially, to retain completed checkpoint.");
                }
                this.running = false;
            }
        }

        public void cancel() {
            this.running = false;
        }
    }
}

