/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.state;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.runtime.state.TaskStateManager;
import org.apache.flink.runtime.state.TestTaskStateManager;
import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
import org.apache.flink.streaming.runtime.tasks.OneInputStreamTaskTestHarness;
import org.apache.flink.streaming.runtime.tasks.StreamMockEnvironment;
import org.apache.flink.streaming.util.TestHarnessUtil;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

public class StatefulOperatorChainedTaskTest {
    private static final Set<OperatorID> RESTORED_OPERATORS = ConcurrentHashMap.newKeySet();
    private TemporaryFolder temporaryFolder;

    @Before
    public void setup() throws IOException {
        RESTORED_OPERATORS.clear();
        this.temporaryFolder = new TemporaryFolder();
        this.temporaryFolder.create();
    }

    @Test
    public void testMultipleStatefulOperatorChainedSnapshotAndRestore() throws Exception {
        OperatorID headOperatorID = new OperatorID(42L, 42L);
        OperatorID tailOperatorID = new OperatorID(44L, 44L);
        JobManagerTaskRestore restore = this.createRunAndCheckpointOperatorChain(headOperatorID, new CounterOperator("head"), tailOperatorID, new CounterOperator("tail"), Optional.empty());
        TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
        Assert.assertEquals((long)2L, (long)stateHandles.getSubtaskStateMappings().size());
        this.createRunAndCheckpointOperatorChain(headOperatorID, new CounterOperator("head"), tailOperatorID, new CounterOperator("tail"), Optional.of(restore));
        Assert.assertEquals(new HashSet<OperatorID>(Arrays.asList(headOperatorID, tailOperatorID)), RESTORED_OPERATORS);
    }

    private JobManagerTaskRestore createRunAndCheckpointOperatorChain(OperatorID headId, OneInputStreamOperator<String, String> headOperator, OperatorID tailId, OneInputStreamOperator<String, String> tailOperator, Optional<JobManagerTaskRestore> restore) throws Exception {
        File localRootDir = this.temporaryFolder.newFolder();
        OneInputStreamTaskTestHarness testHarness = new OneInputStreamTaskTestHarness(OneInputStreamTask::new, 1, 1, (TypeInformation)BasicTypeInfo.STRING_TYPE_INFO, (TypeInformation)BasicTypeInfo.STRING_TYPE_INFO, localRootDir);
        testHarness.setupOperatorChain(headId, headOperator).chain(tailId, tailOperator, (TypeSerializer)StringSerializer.INSTANCE, true).finish();
        if (restore.isPresent()) {
            JobManagerTaskRestore taskRestore = restore.get();
            testHarness.setTaskStateSnapshot(taskRestore.getRestoreCheckpointId(), taskRestore.getTaskStateSnapshot());
        }
        StreamMockEnvironment environment = new StreamMockEnvironment(testHarness.jobConfig, testHarness.taskConfig, testHarness.getExecutionConfig(), testHarness.memorySize, new MockInputSplitProvider(), testHarness.bufferSize, (TaskStateManager)testHarness.getTaskStateManager());
        Configuration configuration = new Configuration();
        configuration.setString(CheckpointingOptions.STATE_BACKEND.key(), "rocksdb");
        File file = this.temporaryFolder.newFolder();
        configuration.setString(CheckpointingOptions.CHECKPOINTS_DIRECTORY.key(), file.toURI().toString());
        configuration.setString(CheckpointingOptions.INCREMENTAL_CHECKPOINTS.key(), "true");
        environment.setTaskManagerInfo((TaskManagerRuntimeInfo)new TestingTaskManagerRuntimeInfo(configuration, System.getProperty("java.io.tmpdir").split(",|" + File.pathSeparator)));
        testHarness.invoke(environment);
        testHarness.waitForTaskRunning();
        OneInputStreamTask streamTask = testHarness.getTask();
        this.processRecords((OneInputStreamTaskTestHarness<String, String>)testHarness);
        this.triggerCheckpoint((OneInputStreamTaskTestHarness<String, String>)testHarness, (OneInputStreamTask<String, String>)streamTask);
        TestTaskStateManager taskStateManager = testHarness.getTaskStateManager();
        JobManagerTaskRestore jobManagerTaskRestore = new JobManagerTaskRestore(taskStateManager.getReportedCheckpointId(), taskStateManager.getLastJobManagerTaskStateSnapshot());
        testHarness.endInput();
        testHarness.waitForTaskCompletion();
        return jobManagerTaskRestore;
    }

    private void triggerCheckpoint(OneInputStreamTaskTestHarness<String, String> testHarness, OneInputStreamTask<String, String> streamTask) throws Exception {
        long checkpointId = 1L;
        CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 1L);
        testHarness.getTaskStateManager().setWaitForReportLatch(new OneShotLatch());
        while (!streamTask.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpointWithDefaultLocation(), false)) {
        }
        testHarness.getTaskStateManager().getWaitForReportLatch().await();
        long reportedCheckpointId = testHarness.getTaskStateManager().getReportedCheckpointId();
        Assert.assertEquals((long)checkpointId, (long)reportedCheckpointId);
    }

    private void processRecords(OneInputStreamTaskTestHarness<String, String> testHarness) throws Exception {
        ConcurrentLinkedQueue<StreamRecord> expectedOutput = new ConcurrentLinkedQueue<StreamRecord>();
        testHarness.processElement((Object)new StreamRecord((Object)"10"), 0, 0);
        testHarness.processElement((Object)new StreamRecord((Object)"20"), 0, 0);
        testHarness.processElement((Object)new StreamRecord((Object)"30"), 0, 0);
        testHarness.waitForInputProcessing();
        expectedOutput.add(new StreamRecord((Object)"10"));
        expectedOutput.add(new StreamRecord((Object)"20"));
        expectedOutput.add(new StreamRecord((Object)"30"));
        TestHarnessUtil.assertOutputEquals((String)"Output was not correct.", expectedOutput, (Queue)testHarness.getOutput());
    }

    private static class CounterOperator
    extends RestoreWatchOperator<String, String> {
        private static final long serialVersionUID = 2048954179291813243L;
        private static long snapshotOutData = 0L;
        private ValueState<Long> counterState;
        private long counter = 0L;
        private String prefix;

        CounterOperator(String prefix) {
            this.prefix = prefix;
        }

        public void processElement(StreamRecord<String> element) throws Exception {
            ++this.counter;
            this.output.collect(element);
        }

        @Override
        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.counterState = context.getKeyedStateStore().getState(new ValueStateDescriptor(this.prefix + "counter-state", (TypeSerializer)LongSerializer.INSTANCE));
            this.setCurrentKey("10");
            if (context.isRestored()) {
                this.counter = (Long)this.counterState.value();
                Assert.assertEquals((long)snapshotOutData, (long)this.counter);
                this.counterState.clear();
            }
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            this.counterState.update((Object)this.counter);
            snapshotOutData = this.counter;
        }
    }

    private static abstract class RestoreWatchOperator<IN, OUT>
    extends AbstractStreamOperator<OUT>
    implements OneInputStreamOperator<IN, OUT> {
        private RestoreWatchOperator() {
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            if (context.isRestored()) {
                RESTORED_OPERATORS.add(this.getOperatorID());
            }
        }
    }
}

