/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.recovery;

import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import javax.annotation.concurrent.GuardedBy;
import org.apache.flink.api.common.ExecutionMode;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.configuration.UnmodifiableConfiguration;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.messages.webmonitor.JobIdsWithStatusOverview;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.rest.RestClient;
import org.apache.flink.runtime.rest.messages.EmptyMessageParameters;
import org.apache.flink.runtime.rest.messages.EmptyRequestBody;
import org.apache.flink.runtime.rest.messages.JobIdsWithStatusesOverviewHeaders;
import org.apache.flink.runtime.rest.messages.JobMessageParameters;
import org.apache.flink.runtime.rest.messages.JobVertexDetailsHeaders;
import org.apache.flink.runtime.rest.messages.JobVertexDetailsInfo;
import org.apache.flink.runtime.rest.messages.JobVertexMessageParameters;
import org.apache.flink.runtime.rest.messages.MessageHeaders;
import org.apache.flink.runtime.rest.messages.MessageParameters;
import org.apache.flink.runtime.rest.messages.RequestBody;
import org.apache.flink.runtime.rest.messages.ResponseBody;
import org.apache.flink.runtime.rest.messages.job.JobDetailsHeaders;
import org.apache.flink.runtime.rest.messages.job.JobDetailsInfo;
import org.apache.flink.runtime.rest.messages.job.SubtaskExecutionAttemptDetailsInfo;
import org.apache.flink.runtime.testutils.MiniClusterResource;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.test.util.TestEnvironment;
import org.apache.flink.util.Collector;
import org.apache.flink.util.ConfigurationException;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FlinkException;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.TemporaryClassLoaderContext;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.concurrent.ExecutorThreadFactory;
import org.apache.flink.util.concurrent.FutureUtils;
import org.hamcrest.CoreMatchers;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BatchFineGrainedRecoveryITCase
extends TestLogger {
    private static final Logger LOG = LoggerFactory.getLogger(BatchFineGrainedRecoveryITCase.class);
    private static final int EMITTED_RECORD_NUMBER = 1000;
    private static final int MAP_NUMBER = 3;
    private static final String MAP_PARTITION_TEST_PARTITION_MAPPER = "MapPartition (Test partition mapper ";
    private static final Pattern MAPPER_NUMBER_IN_TASK_NAME_PATTERN = Pattern.compile("MapPartition \\(Test partition mapper (\\d+)\\)");
    private static final int ALL_MAPPERS_BACKTRACK_FAILURES = IntStream.range(0, 4).sum();
    private static final int MAX_JOB_RESTART_ATTEMPTS = ALL_MAPPERS_BACKTRACK_FAILURES + 6;
    private static final int[] EXPECTED_MAP_ATTEMPT_NUMBERS = IntStream.range(0, 3).map(i -> 1 + (3 - i - 1) + 1 + 1).toArray();
    private static final String TASK_NAME_PREFIX = "Test partition mapper ";
    private static final List<Long> EXPECTED_JOB_OUTPUT = LongStream.range(3L, 1003L).boxed().collect(Collectors.toList());
    @ClassRule
    public static final MiniClusterResource MINI_CLUSTER_RESOURCE = new MiniClusterResource(new MiniClusterResourceConfiguration.Builder().setConfiguration(BatchFineGrainedRecoveryITCase.createConfiguration()).setNumberTaskManagers(1).setNumberSlotsPerTaskManager(1).build());
    private static MiniCluster miniCluster;
    private static MiniClusterClient client;
    private static AtomicInteger lastTaskManagerIndexInMiniCluster;
    private static final Random rnd;
    private static GlobalMapFailureTracker failureTracker;

    @Before
    public void setup() throws Exception {
        miniCluster = MINI_CLUSTER_RESOURCE.getMiniCluster();
        client = new MiniClusterClient(miniCluster);
        lastTaskManagerIndexInMiniCluster = new AtomicInteger(0);
        failureTracker = new GlobalMapFailureTracker(3);
    }

    @After
    public void teardown() throws Exception {
        if (client != null) {
            client.close();
        }
    }

    @Test
    public void testProgram() throws Exception {
        ExecutionEnvironment env = BatchFineGrainedRecoveryITCase.createExecutionEnvironment();
        DataSource input = env.generateSequence(0L, 999L);
        for (int trackingIndex = 0; trackingIndex < 3; ++trackingIndex) {
            input = input.mapPartition((MapPartitionFunction)new TestPartitionMapper(trackingIndex, BatchFineGrainedRecoveryITCase.createFailureStrategy(trackingIndex))).name(TASK_NAME_PREFIX + trackingIndex);
        }
        Assert.assertThat((Object)input.collect(), (org.hamcrest.Matcher)CoreMatchers.is(EXPECTED_JOB_OUTPUT));
        BatchFineGrainedRecoveryITCase.failureTracker.verify(BatchFineGrainedRecoveryITCase.getMapperAttempts());
    }

    private static Configuration createConfiguration() {
        Configuration configuration = new Configuration();
        configuration.set(JobManagerOptions.EXECUTION_FAILOVER_STRATEGY, (Object)"region");
        return configuration;
    }

    private static FailureStrategy createFailureStrategy(int trackingIndex) {
        int failWithExceptionAfterNumberOfProcessedRecords = rnd.nextInt(1000) + 1;
        int failTaskExecutorAfterNumberOfProcessedRecords = rnd.nextInt(1000) + 1;
        OneTimeFailureStrategy failureStrategy = new OneTimeFailureStrategy(new JoinedFailureStrategy(new FailureStrategy[]{new GloballyTrackingFailureStrategy(new ExceptionFailureStrategy(failWithExceptionAfterNumberOfProcessedRecords)), new GloballyTrackingFailureStrategy(new TaskExecutorFailureStrategy(failTaskExecutorAfterNumberOfProcessedRecords))}));
        LOG.info("FailureStrategy for the mapper {}: {}", (Object)trackingIndex, (Object)failureStrategy);
        return failureStrategy;
    }

    private static ExecutionEnvironment createExecutionEnvironment() {
        TestEnvironment env = new TestEnvironment(miniCluster, 1, true);
        env.setRestartStrategy(RestartStrategies.fixedDelayRestart((int)MAX_JOB_RESTART_ATTEMPTS, (Time)Time.milliseconds((long)10L)));
        env.getConfig().setExecutionMode(ExecutionMode.BATCH_FORCED);
        return env;
    }

    private static void restartTaskManager() throws Exception {
        int tmi = lastTaskManagerIndexInMiniCluster.getAndIncrement();
        try {
            miniCluster.terminateTaskManager(tmi).get();
        }
        finally {
            miniCluster.startTaskManager();
        }
    }

    private static int[] getMapperAttempts() {
        int[] attempts = new int[3];
        BatchFineGrainedRecoveryITCase.client.getInternalTaskInfos().stream().filter(t -> ((InternalTaskInfo)t).name.startsWith(MAP_PARTITION_TEST_PARTITION_MAPPER)).forEach(t -> {
            attempts[BatchFineGrainedRecoveryITCase.parseMapperNumberFromTaskName((String)((InternalTaskInfo)t).name)] = ((InternalTaskInfo)t).attempt;
        });
        return attempts;
    }

    private static int parseMapperNumberFromTaskName(String name) {
        Matcher m = MAPPER_NUMBER_IN_TASK_NAME_PATTERN.matcher(name);
        if (m.matches()) {
            return Integer.parseInt(m.group(1));
        }
        throw new FlinkRuntimeException("Failed to find mapper number in its task name: " + name);
    }

    static {
        rnd = new Random();
    }

    private static class InternalTaskInfo {
        private final String name;
        private final int attempt;

        private InternalTaskInfo(String name, SubtaskExecutionAttemptDetailsInfo vertexTaskDetail) {
            this.name = name;
            this.attempt = vertexTaskDetail.getAttempt();
        }

        public String toString() {
            return this.name + " (Attempt #" + this.attempt + ')';
        }
    }

    private static class MiniClusterClient
    implements AutoCloseable {
        private final RestClient restClient;
        private final ExecutorService executorService;
        private final URI restAddress;

        private MiniClusterClient(MiniCluster miniCluster) throws ConfigurationException {
            this.restAddress = (URI)miniCluster.getRestAddress().join();
            this.executorService = Executors.newSingleThreadScheduledExecutor((ThreadFactory)new ExecutorThreadFactory("Flink-RestClient-IO"));
            this.restClient = this.createRestClient();
        }

        private RestClient createRestClient() throws ConfigurationException {
            return new RestClient((Configuration)new UnmodifiableConfiguration(new Configuration()), (Executor)this.executorService);
        }

        private List<InternalTaskInfo> getInternalTaskInfos() {
            return this.getJobs().stream().flatMap(jobId -> this.getJobDetails((JobID)jobId).join().getJobVertexInfos().stream().map(info -> Tuple2.of((Object)jobId, (Object)info))).flatMap(vertexInfoWithJobId -> this.getJobVertexDetailsInfo((JobID)vertexInfoWithJobId.f0, ((JobDetailsInfo.JobVertexDetailsInfo)vertexInfoWithJobId.f1).getJobVertexID()).getSubtasks().stream().map(subtask -> new InternalTaskInfo(((JobDetailsInfo.JobVertexDetailsInfo)vertexInfoWithJobId.f1).getName(), (SubtaskExecutionAttemptDetailsInfo)subtask))).collect(Collectors.toList());
        }

        private Collection<JobID> getJobs() {
            JobIdsWithStatusOverview jobIds = (JobIdsWithStatusOverview)this.sendRequest(JobIdsWithStatusesOverviewHeaders.getInstance(), EmptyMessageParameters.getInstance()).join();
            return jobIds.getJobsWithStatus().stream().map(JobIdsWithStatusOverview.JobIdWithStatus::getJobId).collect(Collectors.toList());
        }

        private CompletableFuture<JobDetailsInfo> getJobDetails(JobID jobId) {
            JobMessageParameters params = new JobMessageParameters();
            params.jobPathParameter.resolve((Object)jobId);
            return this.sendRequest(JobDetailsHeaders.getInstance(), params);
        }

        private JobVertexDetailsInfo getJobVertexDetailsInfo(JobID jobId, JobVertexID jobVertexID) {
            JobVertexDetailsHeaders detailsHeaders = JobVertexDetailsHeaders.getInstance();
            JobVertexMessageParameters params = new JobVertexMessageParameters();
            params.jobPathParameter.resolve((Object)jobId);
            params.jobVertexIdPathParameter.resolve((Object)jobVertexID);
            return (JobVertexDetailsInfo)this.sendRequest(detailsHeaders, params).join();
        }

        private <M extends MessageHeaders<EmptyRequestBody, P, U>, U extends MessageParameters, P extends ResponseBody> CompletableFuture<P> sendRequest(M messageHeaders, U messageParameters) {
            try {
                return this.restClient.sendRequest(this.restAddress.getHost(), this.restAddress.getPort(), messageHeaders, messageParameters, (RequestBody)EmptyRequestBody.getInstance());
            }
            catch (IOException e) {
                return FutureUtils.completedExceptionally((Throwable)e);
            }
        }

        @Override
        public void close() throws Exception {
            this.restClient.close();
            this.executorService.shutdownNow();
        }
    }

    private static class TestPartitionMapper
    extends RichMapPartitionFunction<Long, Long> {
        private static final long serialVersionUID = 1L;
        private final int trackingIndex;
        private final FailureStrategy failureStrategy;

        private TestPartitionMapper(int trackingIndex, FailureStrategy failureStrategy) {
            this.trackingIndex = trackingIndex;
            this.failureStrategy = failureStrategy;
        }

        public void mapPartition(Iterable<Long> values, Collector<Long> out) throws Exception {
            for (Long value : values) {
                this.failureStrategy.failOrNot(this.trackingIndex);
                out.collect((Object)(value + 1L));
            }
        }
    }

    private static class GlobalMapFailureTracker {
        private final List<Set<FailureStrategy>> mapFailures;
        private final Object classLock = new Object();
        @GuardedBy(value="classLock")
        private Throwable unexpectedFailure;

        private GlobalMapFailureTracker(int numberOfMappers) {
            this.mapFailures = new ArrayList<Set<FailureStrategy>>(numberOfMappers);
            IntStream.range(0, numberOfMappers).forEach(i -> this.addNewMapper());
        }

        private int addNewMapper() {
            this.mapFailures.add(new HashSet(2));
            return this.mapFailures.size() - 1;
        }

        private boolean failOrNot(int index, FailureStrategy failureStrategy) throws Exception {
            boolean alreadyFailed = this.mapFailures.get(index).contains(failureStrategy);
            boolean failedNow = false;
            try {
                failedNow = !alreadyFailed && failureStrategy.failOrNot(index);
            }
            catch (Exception e) {
                failedNow = true;
                throw e;
            }
            finally {
                if (failedNow) {
                    this.mapFailures.get(index).add(failureStrategy);
                }
            }
            return failedNow;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private void unrelatedFailure(Throwable failure) {
            Object object = this.classLock;
            synchronized (object) {
                this.unexpectedFailure = ExceptionUtils.firstOrSuppressed((Throwable)failure, (Throwable)this.unexpectedFailure);
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private void verify(int[] mapAttemptNumbers) {
            Object object = this.classLock;
            synchronized (object) {
                if (this.unexpectedFailure != null) {
                    throw new AssertionError("Test failed due to unexpected exception.", this.unexpectedFailure);
                }
            }
            Assert.assertThat((Object)mapAttemptNumbers, (org.hamcrest.Matcher)CoreMatchers.is((Object)EXPECTED_MAP_ATTEMPT_NUMBERS));
        }
    }

    private static abstract class AbstractOnceAfterCallNumberFailureStrategy
    implements FailureStrategy {
        private static final long serialVersionUID = 1L;
        private final UUID id;
        private final int failAfterCallNumber;
        private transient int callCounter;

        private AbstractOnceAfterCallNumberFailureStrategy(int failAfterCallNumber) {
            this.failAfterCallNumber = failAfterCallNumber;
            this.id = UUID.randomUUID();
        }

        @Override
        public boolean failOrNot(int trackingIndex) throws Exception {
            boolean generateFailure;
            ++this.callCounter;
            boolean bl = generateFailure = this.callCounter == this.failAfterCallNumber;
            if (generateFailure) {
                this.fail(trackingIndex);
            }
            return generateFailure;
        }

        abstract void fail(int var1) throws Exception;

        public String toString() {
            return this.getClass().getSimpleName() + " (fail after " + this.failAfterCallNumber + " calls)";
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            return Objects.equals(this.id, ((AbstractOnceAfterCallNumberFailureStrategy)o).id);
        }

        public int hashCode() {
            return this.id.hashCode();
        }
    }

    private static class TaskExecutorFailureStrategy
    extends AbstractOnceAfterCallNumberFailureStrategy {
        private static final long serialVersionUID = 1L;

        private TaskExecutorFailureStrategy(int failAfterCallNumber) {
            super(failAfterCallNumber);
        }

        @Override
        void fail(int trackingIndex) throws Exception {
            try (TemporaryClassLoaderContext unused = TemporaryClassLoaderContext.of((ClassLoader)ClassLoader.getSystemClassLoader());){
                try {
                    BatchFineGrainedRecoveryITCase.restartTaskManager();
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
                catch (Throwable t) {
                    failureTracker.unrelatedFailure(t);
                    throw t;
                }
            }
        }
    }

    private static class ExceptionFailureStrategy
    extends AbstractOnceAfterCallNumberFailureStrategy {
        private static final long serialVersionUID = 1L;

        private ExceptionFailureStrategy(int failAfterCallNumber) {
            super(failAfterCallNumber);
        }

        @Override
        void fail(int trackingIndex) throws FlinkException {
            throw new FlinkException("BAGA-BOOM!!! The user function generated test failure.");
        }
    }

    private static class GloballyTrackingFailureStrategy
    implements FailureStrategy {
        private static final long serialVersionUID = 1L;
        private final FailureStrategy wrappedFailureStrategy;

        private GloballyTrackingFailureStrategy(FailureStrategy wrappedFailureStrategy) {
            this.wrappedFailureStrategy = wrappedFailureStrategy;
        }

        @Override
        public boolean failOrNot(int trackingIndex) throws Exception {
            return failureTracker.failOrNot(trackingIndex, this.wrappedFailureStrategy);
        }

        public String toString() {
            return "Tracked{" + this.wrappedFailureStrategy + '}';
        }
    }

    private static class JoinedFailureStrategy
    implements FailureStrategy {
        private static final long serialVersionUID = 1L;
        private final FailureStrategy[] failureStrategies;

        private JoinedFailureStrategy(FailureStrategy ... failureStrategies) {
            this.failureStrategies = failureStrategies;
        }

        @Override
        public boolean failOrNot(int trackingIndex) throws Exception {
            for (FailureStrategy failureStrategy : this.failureStrategies) {
                if (!failureStrategy.failOrNot(trackingIndex)) continue;
                return true;
            }
            return false;
        }

        public String toString() {
            return String.join((CharSequence)" or ", () -> Arrays.stream(this.failureStrategies).map(Object::toString).iterator());
        }
    }

    private static class OneTimeFailureStrategy
    implements FailureStrategy {
        private static final long serialVersionUID = 1L;
        private final FailureStrategy wrappedFailureStrategy;
        private transient boolean failed;

        private OneTimeFailureStrategy(FailureStrategy wrappedFailureStrategy) {
            this.wrappedFailureStrategy = wrappedFailureStrategy;
        }

        @Override
        public boolean failOrNot(int trackingIndex) throws Exception {
            if (!this.failed) {
                try {
                    boolean failedNow;
                    this.failed = failedNow = this.wrappedFailureStrategy.failOrNot(trackingIndex);
                    return failedNow;
                }
                catch (Exception e) {
                    this.failed = true;
                    throw e;
                }
            }
            return false;
        }

        public String toString() {
            return "FailingOnce{" + this.wrappedFailureStrategy + '}';
        }
    }

    @FunctionalInterface
    private static interface FailureStrategy
    extends Serializable {
        public boolean failOrNot(int var1) throws Exception;
    }
}

