/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.api.operators;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.OptionalLong;
import java.util.Set;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.state.KeyedStateStore;
import org.apache.flink.api.common.state.OperatorStateStore;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.SubTaskInitializationMetricsBuilder;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
import org.apache.flink.runtime.state.CheckpointableKeyedStateBackend;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.OperatorStreamStateHandle;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.StateInitializationContextImpl;
import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.TaskExecutorStateChangelogStoragesManager;
import org.apache.flink.runtime.state.TaskLocalStateStore;
import org.apache.flink.runtime.state.TaskStateManager;
import org.apache.flink.runtime.state.TaskStateManagerImpl;
import org.apache.flink.runtime.state.TestTaskLocalStateStore;
import org.apache.flink.runtime.state.changelog.StateChangelogStorage;
import org.apache.flink.runtime.state.changelog.inmemory.InMemoryStateChangelogStorage;
import org.apache.flink.runtime.state.hashmap.HashMapStateBackend;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
import org.apache.flink.runtime.taskmanager.CheckpointResponder;
import org.apache.flink.runtime.util.LongArrayList;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
import org.apache.flink.streaming.api.operators.KeyContext;
import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializerImpl;
import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
import org.apache.flink.streaming.runtime.tasks.StreamTaskCancellationContext;
import org.apache.flink.util.clock.SystemClock;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

class StateInitializationContextImplTest {
    static final int NUM_HANDLES = 10;
    private StateInitializationContextImpl initializationContext;
    private CloseableRegistry closableRegistry;
    private int writtenKeyGroups;
    private Set<Integer> writtenOperatorStates;

    StateInitializationContextImplTest() {
    }

    @BeforeEach
    void setUp() throws Exception {
        KeyGroupRangeOffsets offsets;
        DataOutputViewStreamWrapper dov;
        this.writtenKeyGroups = 0;
        this.writtenOperatorStates = new HashSet<Integer>();
        this.closableRegistry = new CloseableRegistry();
        ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos(64);
        ArrayList<KeyGroupsStateHandle> keyedStateHandles = new ArrayList<KeyGroupsStateHandle>(10);
        int prev = 0;
        for (int i = 0; i < 10; ++i) {
            out.reset();
            int size = i % 4;
            int end = prev + size;
            dov = new DataOutputViewStreamWrapper((OutputStream)out);
            offsets = new KeyGroupRangeOffsets(i == 9 ? KeyGroupRange.EMPTY_KEY_GROUP_RANGE : new KeyGroupRange(prev, end));
            prev = end + 1;
            Iterator iterator = offsets.getKeyGroupRange().iterator();
            while (iterator.hasNext()) {
                int kg = (Integer)iterator.next();
                offsets.setKeyGroupOffset(kg, (long)out.getPosition());
                dov.writeInt(kg);
                ++this.writtenKeyGroups;
            }
            KeyGroupsStateHandle handle = new KeyGroupsStateHandle(offsets, (StreamStateHandle)new ByteStateHandleCloseChecking("kg-" + i, out.toByteArray()));
            keyedStateHandles.add(handle);
        }
        ArrayList<OperatorStreamStateHandle> operatorStateHandles = new ArrayList<OperatorStreamStateHandle>(10);
        for (int i = 0; i < 10; ++i) {
            int size = i % 4;
            out.reset();
            dov = new DataOutputViewStreamWrapper((OutputStream)out);
            offsets = new LongArrayList(size);
            for (int s = 0; s < size; ++s) {
                offsets.add((long)out.getPosition());
                int val = i * 10 + s;
                dov.writeInt(val);
                this.writtenOperatorStates.add(val);
            }
            HashMap<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<String, OperatorStateHandle.StateMetaInfo>();
            offsetsMap.put("_default_", new OperatorStateHandle.StateMetaInfo(offsets.toArray(), OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
            OperatorStreamStateHandle operatorStateHandle = new OperatorStreamStateHandle(offsetsMap, (StreamStateHandle)new ByteStateHandleCloseChecking("os-" + i, out.toByteArray()));
            operatorStateHandles.add(operatorStateHandle);
        }
        OperatorSubtaskState operatorSubtaskState = OperatorSubtaskState.builder().setRawOperatorState(new StateObjectCollection(operatorStateHandles)).setRawKeyedState(new StateObjectCollection(keyedStateHandles)).build();
        OperatorID operatorID = new OperatorID();
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
        taskStateSnapshot.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState);
        JobManagerTaskRestore jobManagerTaskRestore = new JobManagerTaskRestore(0L, taskStateSnapshot);
        TaskStateManagerImpl manager = new TaskStateManagerImpl(new JobID(), ExecutionGraphTestUtils.createExecutionAttemptId(), (TaskLocalStateStore)new TestTaskLocalStateStore(), null, (StateChangelogStorage)new InMemoryStateChangelogStorage(), new TaskExecutorStateChangelogStoragesManager(), jobManagerTaskRestore, (CheckpointResponder)Mockito.mock(CheckpointResponder.class));
        DummyEnvironment environment = new DummyEnvironment("test", 1, 0, prev);
        environment.setTaskStateManager((TaskStateManager)manager);
        HashMapStateBackend stateBackend = new HashMapStateBackend();
        StreamTaskStateInitializerImpl streamTaskStateManager = new StreamTaskStateInitializerImpl((Environment)environment, (StateBackend)stateBackend, new SubTaskInitializationMetricsBuilder(SystemClock.getInstance().absoluteTimeMillis()), TtlTimeProvider.DEFAULT, new InternalTimeServiceManager.Provider(){

            public <K> InternalTimeServiceManager<K> create(TaskIOMetricGroup taskIOMetricGroup, CheckpointableKeyedStateBackend<K> keyedStatedBackend, ClassLoader userClassloader, KeyContext keyContext, ProcessingTimeService processingTimeService, Iterable<KeyGroupStatePartitionStreamProvider> rawKeyedStates, StreamTaskCancellationContext cancellationContext) throws Exception {
                return null;
            }
        }, StreamTaskCancellationContext.alwaysRunning());
        AbstractStreamOperator mockOperator = (AbstractStreamOperator)Mockito.mock(AbstractStreamOperator.class);
        Mockito.when((Object)mockOperator.getOperatorID()).thenReturn((Object)operatorID);
        StreamOperatorStateContext stateContext = streamTaskStateManager.streamOperatorStateContext(operatorID, "TestOperatorClass", (ProcessingTimeService)Mockito.mock(ProcessingTimeService.class), (KeyContext)mockOperator, (TypeSerializer)IntSerializer.INSTANCE, this.closableRegistry, (MetricGroup)new UnregisteredMetricsGroup(), 1.0, false);
        OptionalLong restoredCheckpointId = stateContext.getRestoredCheckpointId();
        this.initializationContext = new StateInitializationContextImpl(restoredCheckpointId.isPresent() ? Long.valueOf(restoredCheckpointId.getAsLong()) : null, (OperatorStateStore)stateContext.operatorStateBackend(), (KeyedStateStore)Mockito.mock(KeyedStateStore.class), (Iterable)stateContext.rawKeyedStateInputs(), (Iterable)stateContext.rawOperatorStateInputs());
    }

    @Test
    void getOperatorStateStreams() throws Exception {
        int i = 0;
        int s = 0;
        for (StatePartitionStreamProvider streamProvider : this.initializationContext.getRawOperatorStateInputs()) {
            if (0 == i % 4) {
                ++i;
            }
            Assertions.assertThat((Object)streamProvider).isNotNull();
            try (InputStream is = streamProvider.getStream();){
                DataInputViewStreamWrapper div = new DataInputViewStreamWrapper(is);
                int val = div.readInt();
                Assertions.assertThat((int)val).isEqualTo(i * 10 + s);
            }
            if (++s != i % 4) continue;
            s = 0;
            ++i;
        }
    }

    @Test
    void getKeyedStateStreams() throws Exception {
        int readKeyGroupCount = 0;
        for (KeyGroupStatePartitionStreamProvider stateStreamProvider : this.initializationContext.getRawKeyedStateInputs()) {
            Assertions.assertThat((Object)stateStreamProvider).isNotNull();
            InputStream is = stateStreamProvider.getStream();
            try {
                DataInputViewStreamWrapper div = new DataInputViewStreamWrapper(is);
                int val = div.readInt();
                ++readKeyGroupCount;
                Assertions.assertThat((int)val).isEqualTo(stateStreamProvider.getKeyGroupId());
            }
            finally {
                if (is == null) continue;
                is.close();
            }
        }
        Assertions.assertThat((int)readKeyGroupCount).isEqualTo(this.writtenKeyGroups);
    }

    @Test
    void getOperatorStateStore() throws Exception {
        HashSet<Integer> readStatesCount = new HashSet<Integer>();
        for (StatePartitionStreamProvider statePartitionStreamProvider : this.initializationContext.getRawOperatorStateInputs()) {
            Assertions.assertThat((Object)statePartitionStreamProvider).isNotNull();
            InputStream is = statePartitionStreamProvider.getStream();
            try {
                DataInputViewStreamWrapper div = new DataInputViewStreamWrapper(is);
                Assertions.assertThat((boolean)readStatesCount.add(div.readInt())).isTrue();
            }
            finally {
                if (is == null) continue;
                is.close();
            }
        }
        Assertions.assertThat(readStatesCount).isEqualTo(this.writtenOperatorStates);
    }

    @Test
    void close() throws Exception {
        int count = 0;
        int stopCount = 5;
        boolean isClosed = false;
        try {
            for (KeyGroupStatePartitionStreamProvider stateStreamProvider : this.initializationContext.getRawKeyedStateInputs()) {
                Assertions.assertThat((Object)stateStreamProvider).isNotNull();
                if (count == stopCount) {
                    this.closableRegistry.close();
                    isClosed = true;
                }
                InputStream is = stateStreamProvider.getStream();
                try {
                    DataInputViewStreamWrapper div = new DataInputViewStreamWrapper(is);
                    try {
                        int val = div.readInt();
                        Assertions.assertThat((int)val).isEqualTo(stateStreamProvider.getKeyGroupId());
                        ((AbstractBooleanAssert)Assertions.assertThat((boolean)isClosed).as("Close was ignored: stream", new Object[0])).isFalse();
                        ++count;
                    }
                    catch (IOException ioex) {
                        if (isClosed) continue;
                        throw ioex;
                    }
                }
                finally {
                    if (is == null) continue;
                    is.close();
                }
            }
            Assertions.fail((String)"Close was ignored: registry");
        }
        catch (IOException iex) {
            Assertions.assertThat((boolean)isClosed).isTrue();
            Assertions.assertThat((int)count).isEqualTo(stopCount);
        }
    }

    static final class ByteStateHandleCloseChecking
    extends ByteStreamStateHandle {
        private static final long serialVersionUID = -6201941296931334140L;

        public ByteStateHandleCloseChecking(String handleName, byte[] data) {
            super(handleName, data);
        }

        public FSDataInputStream openInputStream() throws IOException {
            final FSDataInputStream original = super.openInputStream();
            return new FSDataInputStream(){
                private boolean closed = false;

                public void seek(long desired) throws IOException {
                    original.seek(desired);
                }

                public long getPos() throws IOException {
                    return original.getPos();
                }

                public int read() throws IOException {
                    if (this.closed) {
                        throw new IOException("Stream closed");
                    }
                    return original.read();
                }

                public void close() throws IOException {
                    original.close();
                    this.closed = true;
                }
            };
        }
    }
}

