/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.cdc.connectors.tidb;

import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.TreeMap;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.state.CheckpointListener;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.cdc.connectors.tidb.TiKVChangeEventDeserializationSchema;
import org.apache.flink.cdc.connectors.tidb.TiKVSnapshotEventDeserializationSchema;
import org.apache.flink.cdc.connectors.tidb.table.StartupMode;
import org.apache.flink.cdc.connectors.tidb.table.utils.TableKeyRangeUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.shaded.guava31.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tikv.cdc.CDCClient;
import org.tikv.common.TiConfiguration;
import org.tikv.common.TiSession;
import org.tikv.common.key.RowKey;
import org.tikv.common.meta.TiTableInfo;
import org.tikv.kvproto.Cdcpb;
import org.tikv.kvproto.Coprocessor;
import org.tikv.kvproto.Kvrpcpb;
import org.tikv.shade.com.google.protobuf.ByteString;
import org.tikv.txn.KVClient;

public class TiKVRichParallelSourceFunction<T>
extends RichParallelSourceFunction<T>
implements CheckpointListener,
CheckpointedFunction,
ResultTypeQueryable<T> {
    private static final long serialVersionUID = 1L;
    private static final Logger LOG = LoggerFactory.getLogger(TiKVRichParallelSourceFunction.class);
    private static final long SNAPSHOT_VERSION_EPOCH = -1L;
    private static final long STREAMING_VERSION_START_EPOCH = 0L;
    private final TiKVSnapshotEventDeserializationSchema<T> snapshotEventDeserializationSchema;
    private final TiKVChangeEventDeserializationSchema<T> changeEventDeserializationSchema;
    private final TiConfiguration tiConf;
    private final StartupMode startupMode;
    private final String database;
    private final String tableName;
    private transient TiSession session = null;
    private transient Coprocessor.KeyRange keyRange = null;
    private transient CDCClient cdcClient = null;
    private transient SourceFunction.SourceContext<T> sourceContext = null;
    private volatile transient long resolvedTs = -1L;
    private transient TreeMap<RowKeyWithTs, Cdcpb.Event.Row> prewrites = null;
    private transient TreeMap<RowKeyWithTs, Cdcpb.Event.Row> commits = null;
    private transient BlockingQueue<Cdcpb.Event.Row> committedEvents = null;
    private transient OutputCollector<T> outputCollector;
    private transient boolean running = true;
    private transient ExecutorService executorService;
    private transient ListState<Long> offsetState;
    private static final long CLOSE_TIMEOUT = 30L;

    public TiKVRichParallelSourceFunction(TiKVSnapshotEventDeserializationSchema<T> snapshotEventDeserializationSchema, TiKVChangeEventDeserializationSchema<T> changeEventDeserializationSchema, TiConfiguration tiConf, StartupMode startupMode, String database, String tableName) {
        this.snapshotEventDeserializationSchema = snapshotEventDeserializationSchema;
        this.changeEventDeserializationSchema = changeEventDeserializationSchema;
        this.tiConf = tiConf;
        this.startupMode = startupMode;
        this.database = database;
        this.tableName = tableName;
    }

    public void open(Configuration config) throws Exception {
        super.open(config);
        this.session = TiSession.create((TiConfiguration)this.tiConf);
        TiTableInfo tableInfo = this.session.getCatalog().getTable(this.database, this.tableName);
        if (tableInfo == null) {
            throw new RuntimeException(String.format("Table %s.%s does not exist.", this.database, this.tableName));
        }
        long tableId = tableInfo.getId();
        this.keyRange = TableKeyRangeUtils.getTableKeyRange(tableId, this.getRuntimeContext().getNumberOfParallelSubtasks(), this.getRuntimeContext().getIndexOfThisSubtask());
        this.cdcClient = new CDCClient(this.session, this.keyRange);
        this.prewrites = new TreeMap();
        this.commits = new TreeMap();
        this.committedEvents = new LinkedBlockingQueue<Cdcpb.Event.Row>();
        this.outputCollector = new OutputCollector();
        this.resolvedTs = this.startupMode == StartupMode.INITIAL ? -1L : 0L;
        ThreadFactory threadFactory = new ThreadFactoryBuilder().setNameFormat("tidb-source-function-" + this.getRuntimeContext().getIndexOfThisSubtask()).build();
        this.executorService = Executors.newSingleThreadExecutor(threadFactory);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void run(SourceFunction.SourceContext<T> ctx) throws Exception {
        this.sourceContext = ctx;
        ((OutputCollector)this.outputCollector).context = this.sourceContext;
        if (this.startupMode == StartupMode.INITIAL) {
            Object object = this.sourceContext.getCheckpointLock();
            synchronized (object) {
                this.readSnapshotEvents();
            }
        } else {
            LOG.info("Skip snapshot read");
            this.resolvedTs = this.session.getTimestamp().getVersion();
        }
        LOG.info("start read change events");
        this.cdcClient.start(this.resolvedTs);
        this.running = true;
        this.readChangeEvents();
    }

    private void handleRow(Cdcpb.Event.Row row) {
        if (!TableKeyRangeUtils.isRecordKey(row.getKey().toByteArray())) {
            return;
        }
        LOG.debug("binlog record, type: {}, data: {}", (Object)row.getType(), (Object)row);
        switch (row.getType()) {
            case COMMITTED: {
                this.prewrites.put(RowKeyWithTs.ofStart(row), row);
                this.commits.put(RowKeyWithTs.ofCommit(row), row);
                break;
            }
            case COMMIT: {
                this.commits.put(RowKeyWithTs.ofCommit(row), row);
                break;
            }
            case PREWRITE: {
                this.prewrites.put(RowKeyWithTs.ofStart(row), row);
                break;
            }
            case ROLLBACK: {
                this.prewrites.remove(RowKeyWithTs.ofStart(row));
                break;
            }
            default: {
                LOG.warn("Unsupported row type:" + row.getType());
            }
        }
    }

    protected void readSnapshotEvents() throws Exception {
        LOG.info("read snapshot events");
        try (KVClient scanClient = this.session.createKVClient();){
            long startTs = this.session.getTimestamp().getVersion();
            ByteString start = this.keyRange.getStart();
            while (true) {
                List segment;
                if ((segment = scanClient.scan(start, this.keyRange.getEnd(), startTs)).isEmpty()) {
                    this.resolvedTs = startTs;
                    break;
                }
                for (Kvrpcpb.KvPair pair : segment) {
                    if (!TableKeyRangeUtils.isRecordKey(pair.getKey().toByteArray())) continue;
                    this.snapshotEventDeserializationSchema.deserialize(pair, this.outputCollector);
                }
                start = RowKey.toRawKey((ByteString)((Kvrpcpb.KvPair)segment.get(segment.size() - 1)).getKey()).next().toByteString();
            }
        }
    }

    protected void readChangeEvents() throws Exception {
        LOG.info("read change event from resolvedTs:{}", (Object)this.resolvedTs);
        this.executorService.execute(() -> {
            while (this.running) {
                try {
                    Cdcpb.Event.Row committedRow = this.committedEvents.take();
                    this.changeEventDeserializationSchema.deserialize(committedRow, this.outputCollector);
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
        });
        while (this.resolvedTs >= 0L) {
            Cdcpb.Event.Row row;
            for (int i = 0; i < 1000 && (row = this.cdcClient.get()) != null; ++i) {
                this.handleRow(row);
            }
            this.resolvedTs = this.cdcClient.getMaxResolvedTs();
            if (this.commits.size() <= 0) continue;
            this.flushRows(this.resolvedTs);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void flushRows(long timestamp) throws Exception {
        Preconditions.checkState((this.sourceContext != null ? 1 : 0) != 0, (Object)"sourceContext shouldn't be null");
        SourceFunction.SourceContext<T> sourceContext = this.sourceContext;
        synchronized (sourceContext) {
            while (!this.commits.isEmpty() && this.commits.firstKey().timestamp <= timestamp) {
                Cdcpb.Event.Row commitRow = this.commits.pollFirstEntry().getValue();
                Cdcpb.Event.Row prewriteRow = this.prewrites.remove(RowKeyWithTs.ofStart(commitRow));
                this.committedEvents.offer(prewriteRow);
            }
        }
    }

    public void cancel() {
        try {
            this.running = false;
            if (this.cdcClient != null) {
                this.cdcClient.close();
            }
            if (this.executorService != null) {
                this.executorService.shutdown();
                if (!this.executorService.awaitTermination(30L, TimeUnit.SECONDS)) {
                    LOG.warn("Failed to close the tidb source function in {} seconds.", (Object)30L);
                }
            }
        }
        catch (Exception e) {
            LOG.error("Unable to close cdcClient", (Throwable)e);
        }
    }

    public void snapshotState(FunctionSnapshotContext context) throws Exception {
        LOG.info("snapshotState checkpoint: {} at resolvedTs: {}", (Object)context.getCheckpointId(), (Object)this.resolvedTs);
        this.flushRows(this.resolvedTs);
        this.offsetState.clear();
        this.offsetState.add((Object)this.resolvedTs);
    }

    public void initializeState(FunctionInitializationContext context) throws Exception {
        LOG.info("initialize checkpoint");
        this.offsetState = context.getOperatorStateStore().getListState(new ListStateDescriptor("resolvedTsState", (TypeSerializer)LongSerializer.INSTANCE));
        if (context.isRestored()) {
            Iterator iterator = ((Iterable)this.offsetState.get()).iterator();
            if (iterator.hasNext()) {
                Long offset = (Long)iterator.next();
                this.resolvedTs = offset;
                LOG.info("Restore State from resolvedTs: {}", (Object)this.resolvedTs);
                return;
            }
        } else {
            this.resolvedTs = 0L;
            LOG.info("Initialize State from resolvedTs: {}", (Object)this.resolvedTs);
        }
    }

    public void notifyCheckpointComplete(long checkpointId) throws Exception {
    }

    public TypeInformation<T> getProducedType() {
        return this.snapshotEventDeserializationSchema.getProducedType();
    }

    private static class OutputCollector<T>
    implements Collector<T> {
        private SourceFunction.SourceContext<T> context;

        private OutputCollector() {
        }

        public void collect(T record) {
            this.context.collect(record);
        }

        public void close() {
        }
    }

    private static class RowKeyWithTs
    implements Comparable<RowKeyWithTs> {
        private final long timestamp;
        private final RowKey rowKey;

        private RowKeyWithTs(long timestamp, RowKey rowKey) {
            this.timestamp = timestamp;
            this.rowKey = rowKey;
        }

        private RowKeyWithTs(long timestamp, byte[] key) {
            this(timestamp, RowKey.decode((byte[])key));
        }

        @Override
        public int compareTo(RowKeyWithTs that) {
            int res = Long.compare(this.timestamp, that.timestamp);
            if (res == 0) {
                res = Long.compare(this.rowKey.getTableId(), that.rowKey.getTableId());
            }
            if (res == 0) {
                res = Long.compare(this.rowKey.getHandle(), that.rowKey.getHandle());
            }
            return res;
        }

        public int hashCode() {
            return Objects.hash(this.timestamp, this.rowKey.getTableId(), this.rowKey.getHandle());
        }

        public boolean equals(Object thatObj) {
            if (thatObj instanceof RowKeyWithTs) {
                RowKeyWithTs that = (RowKeyWithTs)thatObj;
                return this.timestamp == that.timestamp && this.rowKey.equals((Object)that.rowKey);
            }
            return false;
        }

        static RowKeyWithTs ofStart(Cdcpb.Event.Row row) {
            return new RowKeyWithTs(row.getStartTs(), row.getKey().toByteArray());
        }

        static RowKeyWithTs ofCommit(Cdcpb.Event.Row row) {
            return new RowKeyWithTs(row.getCommitTs(), row.getKey().toByteArray());
        }
    }
}

