/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.checkpointing;

import java.io.Serializable;
import java.util.Iterator;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.execution.JobClient;
import org.apache.flink.core.execution.SavepointFormatType;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.jobgraph.SavepointConfigOptions;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.testutils.CommonTestUtils;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.environment.CheckpointConfig;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.testutils.junit.SharedObjects;
import org.apache.flink.testutils.junit.SharedReference;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.TestLogger;
import org.hamcrest.Matcher;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.jetbrains.annotations.NotNull;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class RestoreUpgradedJobITCase
extends TestLogger {
    private static final int PARALLELISM = 4;
    private static final int TOTAL_RECORDS = 100;
    @ClassRule
    public static TemporaryFolder temporaryFolder = new TemporaryFolder();
    @Parameterized.Parameter
    public TestCheckpointType checkpointType;
    @ClassRule
    public static final MiniClusterWithClientResource CLUSTER = new MiniClusterWithClientResource(new MiniClusterResourceConfiguration.Builder().setConfiguration(new Configuration()).setNumberTaskManagers(2).setNumberSlotsPerTaskManager(4).build());
    @Rule
    public final SharedObjects sharedObjects = SharedObjects.create();
    private SharedReference<OneShotLatch> allDataEmittedLatch;
    private SharedReference<AtomicLong> result;

    public void setupSharedObjects() {
        this.allDataEmittedLatch = this.sharedObjects.add((Object)new OneShotLatch());
        this.result = this.sharedObjects.add((Object)new AtomicLong());
    }

    @Parameterized.Parameters(name="Savepoint type[{0}]")
    public static Object[][] parameters() {
        return new Object[][]{{TestCheckpointType.ALIGNED_CHECKPOINT}, {TestCheckpointType.CANONICAL_SAVEPOINT}, {TestCheckpointType.NATIVE_SAVEPOINT}};
    }

    @Test
    public void testRestoreUpgradedJob() throws Exception {
        this.setupSharedObjects();
        String snapshotPath = this.runOriginalJob();
        MatcherAssert.assertThat((Object)((AtomicLong)this.result.get()).longValue(), (Matcher)Matchers.is((Object)this.calculateExpectedResultBeforeSavepoint()));
        ((AtomicLong)this.result.get()).set(0L);
        this.runUpgradedJob(snapshotPath);
        MatcherAssert.assertThat((Object)((AtomicLong)this.result.get()).longValue(), (Matcher)Matchers.is((Object)this.calculateExpectedResultBeforeSavepoint()));
    }

    private long calculateExpectedResultAfterSavepoint() {
        long totalStates = 0L;
        for (int i = 1; i <= MapName.values().length; ++i) {
            totalStates += (long)i * (long)i;
        }
        long expectedAfterSavepointResult = 0L;
        for (int i = 0; i < 100; ++i) {
            expectedAfterSavepointResult += (long)i + totalStates;
        }
        return 4L * expectedAfterSavepointResult;
    }

    private long calculateExpectedResultBeforeSavepoint() {
        long expectedBeforeSavepointResult = 0L;
        for (int i = 0; i < 100; ++i) {
            expectedBeforeSavepointResult += (long)i;
        }
        return 4L * expectedBeforeSavepointResult;
    }

    @NotNull
    private String runOriginalJob() throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.getCheckpointConfig().setExternalizedCheckpointCleanup(CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION);
        env.getCheckpointConfig().enableUnalignedCheckpoints(false);
        env.getCheckpointConfig().setCheckpointStorage("file://" + temporaryFolder.getRoot().getAbsolutePath());
        env.setParallelism(4);
        env.enableCheckpointing(Integer.MAX_VALUE);
        env.addSource((SourceFunction)new IntSource(this.allDataEmittedLatch)).map((MapFunction)new IntMap(MapName.MAP_5.id())).uid(MapName.MAP_5.name()).forward().map((MapFunction)new IntMap(MapName.MAP_1.id())).uid(MapName.MAP_1.name()).slotSharingGroup("anotherSharingGroup").keyBy((KeySelector & Serializable)key -> key).map((MapFunction)new IntMap(MapName.MAP_6.id())).uid(MapName.MAP_6.name()).rebalance().map((MapFunction)new IntMap(MapName.MAP_4.id())).uid(MapName.MAP_4.name()).broadcast().map((MapFunction)new IntMap(MapName.MAP_2.id())).uid(MapName.MAP_2.name()).rescale().map((MapFunction)new IntMap(MapName.MAP_3.id())).uid(MapName.MAP_3.name()).addSink((SinkFunction)new IntSink(this.result)).setParallelism(1);
        JobClient jobClient = env.executeAsync("Total sum");
        CommonTestUtils.waitForAllTaskRunning((MiniCluster)CLUSTER.getMiniCluster(), (JobID)jobClient.getJobID(), (boolean)false);
        ((OneShotLatch)this.allDataEmittedLatch.get()).await();
        ((OneShotLatch)this.allDataEmittedLatch.get()).reset();
        return this.stopWithSnapshot(jobClient);
    }

    private void runUpgradedJob(String snapshotPath) throws Exception {
        Configuration conf = new Configuration();
        conf.set(SavepointConfigOptions.SAVEPOINT_PATH, (Object)snapshotPath);
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment((Configuration)conf);
        env.setParallelism(4);
        env.addSource((SourceFunction)new StringSource(this.allDataEmittedLatch)).map((MapFunction)new StringMap(MapName.MAP_1.id())).uid(MapName.MAP_1.name()).forward().map((MapFunction)new StringMap(MapName.MAP_2.id())).uid(MapName.MAP_2.name()).slotSharingGroup("anotherSharingGroup").keyBy((KeySelector & Serializable)key -> key).map((MapFunction)new StringMap(MapName.MAP_3.id())).uid(MapName.MAP_3.name()).map((MapFunction)new StringMap(-1)).uid("new_chained_map").rebalance().map((MapFunction)new StringMap(-2)).uid("new_map2").map((MapFunction)new StringMap(MapName.MAP_4.id())).uid(MapName.MAP_4.name()).rescale().map((MapFunction)new StringMap(MapName.MAP_5.id())).uid(MapName.MAP_5.name()).broadcast().map((MapFunction)new StringMap(MapName.MAP_6.id())).uid(MapName.MAP_6.name()).addSink((SinkFunction)new StringSink(this.result)).setParallelism(1);
        JobClient jobClient = env.executeAsync("Total sum");
        CommonTestUtils.waitForAllTaskRunning((MiniCluster)CLUSTER.getMiniCluster(), (JobID)jobClient.getJobID(), (boolean)false);
        ((OneShotLatch)this.allDataEmittedLatch.get()).await();
        jobClient.stopWithSavepoint(true, temporaryFolder.getRoot().getAbsolutePath(), SavepointFormatType.CANONICAL).get();
    }

    private String stopWithSnapshot(JobClient jobClient) throws InterruptedException, ExecutionException {
        String snapshotPath;
        if (this.checkpointType == TestCheckpointType.ALIGNED_CHECKPOINT) {
            snapshotPath = (String)CLUSTER.getMiniCluster().triggerCheckpoint(jobClient.getJobID()).get();
            jobClient.cancel().get();
        } else if (this.checkpointType == TestCheckpointType.CANONICAL_SAVEPOINT) {
            snapshotPath = (String)jobClient.stopWithSavepoint(true, temporaryFolder.getRoot().getAbsolutePath(), SavepointFormatType.CANONICAL).get();
        } else if (this.checkpointType == TestCheckpointType.NATIVE_SAVEPOINT) {
            snapshotPath = (String)jobClient.stopWithSavepoint(true, temporaryFolder.getRoot().getAbsolutePath(), SavepointFormatType.NATIVE).get();
        } else {
            throw new IllegalArgumentException("Unknown checkpoint type: " + (Object)((Object)this.checkpointType));
        }
        return snapshotPath;
    }

    private static abstract class TestSource<T>
    implements SourceFunction<T> {
        private static final long serialVersionUID = 1L;
        private final SharedReference<OneShotLatch> dataEmitted;
        private volatile boolean isRunning = true;

        public TestSource(SharedReference<OneShotLatch> dataEmitted) {
            this.dataEmitted = dataEmitted;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void run(SourceFunction.SourceContext<T> ctx) throws Exception {
            int i = 100;
            while (i-- > 0) {
                Object object = ctx.getCheckpointLock();
                synchronized (object) {
                    this.collect(ctx, i);
                }
            }
            ((OneShotLatch)this.dataEmitted.get()).trigger();
            while (this.isRunning) {
                LockSupport.parkNanos(100000L);
            }
        }

        abstract void collect(SourceFunction.SourceContext<T> var1, int var2);

        public void cancel() {
            this.isRunning = false;
        }
    }

    private static class StringSource
    extends TestSource<String> {
        public StringSource(SharedReference<OneShotLatch> dataEmitted) {
            super(dataEmitted);
        }

        @Override
        void collect(SourceFunction.SourceContext<String> ctx, int index) {
            ctx.collect((Object)String.valueOf(index));
        }
    }

    private static class IntSource
    extends TestSource<Integer> {
        public IntSource(SharedReference<OneShotLatch> dataEmitted) {
            super(dataEmitted);
        }

        @Override
        void collect(SourceFunction.SourceContext<Integer> ctx, int index) {
            ctx.collect((Object)index);
        }
    }

    private static abstract class AbstractMap<T>
    extends RichMapFunction<T, T>
    implements CheckpointedFunction {
        protected ListState<Integer> valueState;
        protected final int id;

        private AbstractMap(int id) {
            this.id = id;
        }

        protected int calculate(int value) throws Exception {
            return value;
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            this.valueState.add((Object)this.id);
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.valueState = context.getOperatorStateStore().getListState(new ListStateDescriptor("state", Types.INT));
        }
    }

    private static class StringMap
    extends AbstractMap<String> {
        private StringMap(int id) {
            super(id);
        }

        public String map(String value) throws Exception {
            return String.valueOf(this.calculate(Integer.parseInt(value)));
        }

        @Override
        public void initializeState(FunctionInitializationContext context) throws Exception {
            super.initializeState(context);
            Iterator iterator = ((Iterable)this.valueState.get()).iterator();
            if (this.id > 0) {
                Preconditions.checkState((boolean)iterator.hasNext(), (Object)"Value state can not be empty.");
                Integer state = (Integer)iterator.next();
                Preconditions.checkState((this.id == state ? 1 : 0) != 0, (Object)String.format("Value state(%s) should be equal to id(%s).", state, this.id));
            }
            Preconditions.checkState((!iterator.hasNext() ? 1 : 0) != 0, (Object)"Value state should be empty.");
        }
    }

    private static class IntMap
    extends AbstractMap<Integer> {
        private IntMap(int id) {
            super(id);
        }

        public Integer map(Integer value) throws Exception {
            return this.calculate(value);
        }
    }

    private static class StringSink
    implements SinkFunction<String> {
        private final SharedReference<AtomicLong> result;

        public StringSink(SharedReference<AtomicLong> result) {
            this.result = result;
        }

        public void invoke(String value, SinkFunction.Context context) throws Exception {
            ((AtomicLong)this.result.get()).addAndGet(Integer.parseInt(value));
        }
    }

    private static class IntSink
    implements SinkFunction<Integer> {
        private final SharedReference<AtomicLong> result;

        public IntSink(SharedReference<AtomicLong> result) {
            this.result = result;
        }

        public void invoke(Integer value, SinkFunction.Context context) throws Exception {
            ((AtomicLong)this.result.get()).addAndGet(value.intValue());
        }
    }

    static enum MapName {
        MAP_1,
        MAP_2,
        MAP_3,
        MAP_4,
        MAP_5,
        MAP_6;


        int id() {
            return this.ordinal() + 1;
        }
    }

    static enum TestCheckpointType {
        ALIGNED_CHECKPOINT,
        CANONICAL_SAVEPOINT,
        NATIVE_SAVEPOINT;

    }
}

