/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.checkpointing;

import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nonnull;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.accumulators.LongCounter;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.runtime.state.CheckpointListener;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.environment.LocalStreamEnvironment;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.util.TestLogger;
import org.hamcrest.Matchers;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ErrorCollector;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.Timeout;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class UnalignedCheckpointITCase
extends TestLogger {
    public static final String NUM_INPUTS = "inputs";
    public static final String NUM_OUTPUTS = "outputs";
    private static final String NUM_OUT_OF_ORDER = "outOfOrder";
    private static final String NUM_DUPLICATES = "duplicates";
    private static final String NUM_LOST = "lost";
    private static final Logger LOG = LoggerFactory.getLogger(UnalignedCheckpointITCase.class);
    @Rule
    public ErrorCollector collector = new ErrorCollector();
    @Rule
    public final TemporaryFolder temp = new TemporaryFolder();
    @Rule
    public final Timeout timeout = Timeout.builder().withTimeout(300L, TimeUnit.SECONDS).build();

    @Test
    public void shouldPerformUnalignedCheckpointOnNonParallelLocalChannel() throws Exception {
        this.execute(1, 1, true);
    }

    @Test
    public void shouldPerformUnalignedCheckpointOnParallelLocalChannel() throws Exception {
        this.execute(5, 5, true);
    }

    @Test
    public void shouldPerformUnalignedCheckpointOnNonParallelRemoteChannel() throws Exception {
        this.execute(1, 1, false);
    }

    @Test
    public void shouldPerformUnalignedCheckpointOnParallelRemoteChannel() throws Exception {
        this.execute(5, 1, false);
    }

    @Test
    public void shouldPerformUnalignedCheckpointOnLocalAndRemoteChannel() throws Exception {
        this.execute(5, 3, true);
    }

    @Test
    public void shouldPerformUnalignedCheckpointMassivelyParallel() throws Exception {
        this.execute(20, 20, true);
    }

    private void execute(int parallelism, int slotsPerTaskManager, boolean slotSharing) throws Exception {
        LocalStreamEnvironment env = this.createEnv(parallelism, slotsPerTaskManager, slotSharing);
        long minCheckpoints = 10L;
        this.createDAG((StreamExecutionEnvironment)env, minCheckpoints, slotSharing);
        JobExecutionResult result = env.execute();
        this.collector.checkThat(result.getAccumulatorResult(NUM_OUT_OF_ORDER), Matchers.equalTo((Object)0L));
        this.collector.checkThat(result.getAccumulatorResult(NUM_DUPLICATES), Matchers.equalTo((Object)0L));
        this.collector.checkThat(result.getAccumulatorResult(NUM_LOST), Matchers.equalTo((Object)0L));
        Long inputs = (Long)result.getAccumulatorResult(NUM_INPUTS);
        this.collector.checkThat((Object)inputs, Matchers.greaterThan((Comparable)Long.valueOf(0L)));
        this.collector.checkThat(result.getAccumulatorResult(NUM_OUTPUTS), Matchers.equalTo((Object)inputs));
    }

    @Nonnull
    private LocalStreamEnvironment createEnv(int parallelism, int slotsPerTaskManager, boolean slotSharing) throws IOException {
        Configuration conf = new Configuration();
        conf.setInteger(TaskManagerOptions.NUM_TASK_SLOTS, slotsPerTaskManager);
        conf.setFloat(TaskManagerOptions.NETWORK_MEMORY_FRACTION, 0.9f);
        conf.setInteger("local.number-taskmanager", slotSharing ? (parallelism + slotsPerTaskManager - 1) / slotsPerTaskManager : parallelism * 3);
        conf.setString(CheckpointingOptions.STATE_BACKEND, "filesystem");
        conf.setString(CheckpointingOptions.CHECKPOINTS_DIRECTORY, this.temp.newFolder().toURI().toString());
        LocalStreamEnvironment env = StreamExecutionEnvironment.createLocalEnvironment((int)parallelism, (Configuration)conf);
        env.enableCheckpointing(100L);
        env.setRestartStrategy(RestartStrategies.fixedDelayRestart((int)5, (Time)Time.milliseconds((long)100L)));
        env.getCheckpointConfig().enableUnalignedCheckpoints(true);
        return env;
    }

    private void createDAG(StreamExecutionEnvironment env, long minCheckpoints, boolean slotSharing) {
        env.addSource((SourceFunction)new LongSource(minCheckpoints)).slotSharingGroup(slotSharing ? "default" : "source").partitionCustom((Partitioner)new ShiftingPartitioner(), (KeySelector & Serializable)l -> l).map((MapFunction)new FailingMapper((FilterFunction & Serializable)state -> ((FailingMapperState)state).completedCheckpoints == minCheckpoints / 4L && ((FailingMapperState)state).runNumber == 0L || ((FailingMapperState)state).completedCheckpoints == minCheckpoints * 3L / 4L && ((FailingMapperState)state).runNumber == 2L, (FilterFunction & Serializable)state -> ((FailingMapperState)state).completedCheckpoints == minCheckpoints / 2L && ((FailingMapperState)state).runNumber == 1L, (FilterFunction & Serializable)state -> ((FailingMapperState)state).runNumber == 3L, (FilterFunction & Serializable)state -> ((FailingMapperState)state).runNumber == 4L)).slotSharingGroup(slotSharing ? "default" : "map").partitionCustom((Partitioner)new DistributingPartitioner(), (KeySelector & Serializable)l -> l).addSink((SinkFunction)new VerifyingSink(minCheckpoints)).slotSharingGroup(slotSharing ? "default" : "sink");
    }

    static void info(RuntimeContext runtimeContext, String description, Object[] args) {
        LOG.info(description + " @ {} subtask ({} attempt)", ArrayUtils.addAll((Object[])args, (Object[])new Object[]{runtimeContext.getIndexOfThisSubtask(), runtimeContext.getAttemptNumber()}));
    }

    private static class FailingMapper
    extends RichMapFunction<Long, Long>
    implements CheckpointedFunction,
    CheckpointListener {
        private static final ListStateDescriptor<FailingMapperState> FAILING_MAPPER_STATE_DESCRIPTOR = new ListStateDescriptor("state", FailingMapperState.class);
        private ListState<FailingMapperState> listState;
        private FailingMapperState state;
        private final FilterFunction<FailingMapperState> failDuringMap;
        private final FilterFunction<FailingMapperState> failDuringSnapshot;
        private final FilterFunction<FailingMapperState> failDuringRecovery;
        private final FilterFunction<FailingMapperState> failDuringClose;
        private long lastValue;

        private FailingMapper(FilterFunction<FailingMapperState> failDuringMap, FilterFunction<FailingMapperState> failDuringSnapshot, FilterFunction<FailingMapperState> failDuringRecovery, FilterFunction<FailingMapperState> failDuringClose) {
            this.failDuringMap = failDuringMap;
            this.failDuringSnapshot = failDuringSnapshot;
            this.failDuringRecovery = failDuringRecovery;
            this.failDuringClose = failDuringClose;
        }

        public Long map(Long value) throws Exception {
            this.lastValue = value;
            this.checkFail(this.failDuringMap, "map");
            return value;
        }

        public void checkFail(FilterFunction<FailingMapperState> failFunction, String description) throws Exception {
            if (this.getRuntimeContext().getIndexOfThisSubtask() == 0 && failFunction.filter((Object)this.state)) {
                this.failMapper(description);
            }
        }

        private void failMapper(String description) throws Exception {
            throw new Exception("Failing " + description + " @ " + this.state.completedCheckpoints + " (" + this.state.runNumber + " attempt); last value " + this.lastValue);
        }

        public void notifyCheckpointComplete(long checkpointId) {
            this.state.completedCheckpoints++;
        }

        public void notifyCheckpointAborted(long checkpointId) {
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            this.checkFail(this.failDuringSnapshot, "snapshotState");
            this.listState.clear();
            this.listState.add((Object)this.state);
        }

        public void close() throws Exception {
            this.checkFail(this.failDuringClose, "close");
            super.close();
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.listState = context.getOperatorStateStore().getListState(FAILING_MAPPER_STATE_DESCRIPTOR);
            this.state = (FailingMapperState)Iterables.getOnlyElement((Iterable)((Iterable)this.listState.get()), (Object)new FailingMapperState(0L, 0L));
            this.state.runNumber = this.getRuntimeContext().getAttemptNumber();
            this.checkFail(this.failDuringRecovery, "initializeState");
        }
    }

    private static class FailingMapperState {
        private long completedCheckpoints;
        private long runNumber;

        private FailingMapperState(long completedCheckpoints, long runNumber) {
            this.completedCheckpoints = completedCheckpoints;
            this.runNumber = runNumber;
        }
    }

    private static class DistributingPartitioner
    implements Partitioner<Long> {
        private DistributingPartitioner() {
        }

        public int partition(Long key, int numPartitions) {
            return (int)(key / (long)numPartitions % (long)numPartitions);
        }
    }

    private static class ShiftingPartitioner
    implements Partitioner<Long> {
        private ShiftingPartitioner() {
        }

        public int partition(Long key, int numPartitions) {
            return (int)((key + 1L) % (long)numPartitions);
        }
    }

    private static class VerifyingSink
    extends RichSinkFunction<Long>
    implements CheckpointedFunction {
        private final LongCounter numOutputCounter = new LongCounter();
        private final LongCounter outOfOrderCounter = new LongCounter();
        private final LongCounter lostCounter = new LongCounter();
        private final LongCounter duplicatesCounter = new LongCounter();
        private static final ListStateDescriptor<State> STATE_DESCRIPTOR = new ListStateDescriptor("state", State.class);
        private ListState<State> stateList;
        private State state;
        private final long minCheckpoints;

        private VerifyingSink(long minCheckpoints) {
            this.minCheckpoints = minCheckpoints;
        }

        public void open(Configuration parameters) throws Exception {
            super.open(parameters);
            this.getRuntimeContext().addAccumulator(UnalignedCheckpointITCase.NUM_OUTPUTS, (Accumulator)this.numOutputCounter);
            this.getRuntimeContext().addAccumulator(UnalignedCheckpointITCase.NUM_OUT_OF_ORDER, (Accumulator)this.outOfOrderCounter);
            this.getRuntimeContext().addAccumulator(UnalignedCheckpointITCase.NUM_DUPLICATES, (Accumulator)this.duplicatesCounter);
            this.getRuntimeContext().addAccumulator(UnalignedCheckpointITCase.NUM_LOST, (Accumulator)this.lostCounter);
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.stateList = context.getOperatorStateStore().getListState(STATE_DESCRIPTOR);
            this.state = (State)Iterables.getOnlyElement((Iterable)((Iterable)this.stateList.get()), (Object)new State(this.getRuntimeContext().getNumberOfParallelSubtasks()));
            this.info("Initialized last snapshotted records {}", Arrays.asList(new long[][]{this.state.lastRecordInPartitions}));
        }

        private void info(String description, Object ... args) {
            UnalignedCheckpointITCase.info(this.getRuntimeContext(), description, args);
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            this.stateList.clear();
            this.stateList.add((Object)this.state);
            this.info("Last snapshotted records {}", Arrays.asList(new long[][]{this.state.lastRecordInPartitions}));
        }

        public void close() throws Exception {
            this.numOutputCounter.add(this.state.numOutput);
            this.outOfOrderCounter.add(this.state.numOutOfOrderness);
            this.duplicatesCounter.add(this.state.numDuplicates);
            this.lostCounter.add(this.state.numLostValues);
            this.info("Last received records {}", Arrays.asList(new long[][]{this.state.lastRecordInPartitions}));
            super.close();
        }

        public void invoke(Long value, SinkFunction.Context context) throws Exception {
            int parallelism = this.state.lastRecordInPartitions.length;
            int partition = (int)(value % (long)parallelism);
            long lastRecord = this.state.lastRecordInPartitions[partition];
            if (value < lastRecord) {
                this.state.numOutOfOrderness++;
                this.info("Out of order records current={} and last={}", value, lastRecord);
            } else if (value == lastRecord) {
                this.state.numDuplicates++;
                this.info("Duplicate record {}", value);
            } else if (lastRecord != -1L) {
                long expectedValue = lastRecord + (long)(parallelism * parallelism);
                if (value != expectedValue) {
                    this.state.numLostValues++;
                    this.info("Lost records {}-{}", expectedValue, value);
                }
            }
            ((State)this.state).lastRecordInPartitions[partition] = value;
            this.state.numOutput++;
        }

        private static class State {
            private long numOutOfOrderness;
            private long numLostValues;
            private long numDuplicates;
            private long numOutput = 0L;
            private long[] lastRecordInPartitions;

            private State(int numberOfParallelSubtasks) {
                this.lastRecordInPartitions = new long[numberOfParallelSubtasks];
                for (int index = 0; index < this.lastRecordInPartitions.length; ++index) {
                    this.lastRecordInPartitions[index] = -1L;
                }
            }
        }
    }

    private static class LongSource
    extends RichParallelSourceFunction<Long>
    implements CheckpointListener,
    CheckpointedFunction {
        private final long minCheckpoints;
        private volatile boolean running = true;
        private static final ListStateDescriptor<State> STATE_DESCRIPTOR = new ListStateDescriptor("state", State.class);
        private final LongCounter numInputsCounter = new LongCounter();
        private ListState<State> stateList;
        private State state;

        public LongSource(long minCheckpoints) {
            this.minCheckpoints = minCheckpoints;
        }

        public void open(Configuration parameters) throws Exception {
            super.open(parameters);
            this.getRuntimeContext().addAccumulator(UnalignedCheckpointITCase.NUM_INPUTS, (Accumulator)this.numInputsCounter);
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.stateList = context.getOperatorStateStore().getListState(STATE_DESCRIPTOR);
            this.state = (State)Iterables.getOnlyElement((Iterable)((Iterable)this.stateList.get()), (Object)new State(0L, this.getRuntimeContext().getIndexOfThisSubtask()));
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            this.stateList.clear();
            this.stateList.add((Object)this.state);
            this.info("Snapshotted next input {}", this.state.nextNumber);
        }

        private void info(String description, Object ... args) {
            UnalignedCheckpointITCase.info(this.getRuntimeContext(), description, args);
        }

        public void notifyCheckpointComplete(long checkpointId) {
            this.state.numCompletedCheckpoints++;
        }

        public void notifyCheckpointAborted(long checkpointId) {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void run(SourceFunction.SourceContext<Long> ctx) throws Exception {
            int increment = this.getRuntimeContext().getNumberOfParallelSubtasks();
            this.info("First emitted input {}", this.state.nextNumber);
            while (this.running) {
                Object object = ctx.getCheckpointLock();
                synchronized (object) {
                    ctx.collect((Object)this.state.nextNumber);
                    State state = this.state;
                    state.nextNumber = state.nextNumber + (long)increment;
                    if (this.state.numCompletedCheckpoints >= this.minCheckpoints) {
                        this.cancel();
                    }
                }
            }
            this.numInputsCounter.add(this.state.nextNumber / (long)increment);
            this.info("Last emitted input {} = {} total emits", this.state.nextNumber - (long)increment, this.numInputsCounter.getLocalValue());
        }

        public void cancel() {
            this.running = false;
        }

        private static class State {
            private long numCompletedCheckpoints;
            private long nextNumber;

            private State(long numCompletedCheckpoints, long nextNumber) {
                this.numCompletedCheckpoints = numCompletedCheckpoints;
                this.nextNumber = nextNumber;
            }
        }
    }
}

