/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.batch.integration.partition;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.batch.core.Entity;
import org.springframework.batch.core.JobExecution;
import org.springframework.batch.core.StepExecution;
import org.springframework.batch.core.explore.JobExplorer;
import org.springframework.batch.core.explore.support.JobExplorerFactoryBean;
import org.springframework.batch.core.partition.support.AbstractPartitionHandler;
import org.springframework.batch.integration.partition.StepExecutionRequest;
import org.springframework.batch.poller.DirectPoller;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.integration.MessageTimeoutException;
import org.springframework.integration.annotation.Aggregator;
import org.springframework.integration.annotation.MessageEndpoint;
import org.springframework.integration.annotation.Payloads;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.core.MessagingTemplate;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.PollableChannel;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

@MessageEndpoint
public class MessageChannelPartitionHandler
extends AbstractPartitionHandler
implements InitializingBean {
    private static final Log logger = LogFactory.getLog(MessageChannelPartitionHandler.class);
    private MessagingTemplate messagingGateway;
    private String stepName;
    private long pollInterval = 10000L;
    private JobExplorer jobExplorer;
    private boolean pollRepositoryForResults = false;
    private long timeout = -1L;
    private DataSource dataSource;
    private PollableChannel replyChannel;

    public void afterPropertiesSet() throws Exception {
        Assert.state((this.stepName != null ? 1 : 0) != 0, (String)"A step name must be provided for the remote workers.");
        Assert.state((this.messagingGateway != null ? 1 : 0) != 0, (String)"The MessagingOperations must be set");
        boolean bl = this.pollRepositoryForResults = this.dataSource != null || this.jobExplorer != null;
        if (this.pollRepositoryForResults) {
            logger.debug((Object)"MessageChannelPartitionHandler is configured to poll the job repository for worker results");
        }
        if (this.dataSource != null && this.jobExplorer == null) {
            JobExplorerFactoryBean jobExplorerFactoryBean = new JobExplorerFactoryBean();
            jobExplorerFactoryBean.setDataSource(this.dataSource);
            jobExplorerFactoryBean.afterPropertiesSet();
            this.jobExplorer = jobExplorerFactoryBean.getObject();
        }
        if (!this.pollRepositoryForResults && this.replyChannel == null) {
            this.replyChannel = new QueueChannel();
        }
    }

    public void setTimeout(long timeout) {
        this.timeout = timeout;
    }

    public void setJobExplorer(JobExplorer jobExplorer) {
        this.jobExplorer = jobExplorer;
    }

    public void setPollInterval(long pollInterval) {
        this.pollInterval = pollInterval;
    }

    public void setDataSource(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    public void setMessagingOperations(MessagingTemplate messagingGateway) {
        this.messagingGateway = messagingGateway;
    }

    public void setStepName(String stepName) {
        this.stepName = stepName;
    }

    @Aggregator(sendPartialResultsOnExpiry="true")
    public List<?> aggregate(@Payloads List<?> messages) {
        return messages;
    }

    public void setReplyChannel(PollableChannel replyChannel) {
        this.replyChannel = replyChannel;
    }

    protected Set<StepExecution> doHandle(StepExecution managerStepExecution, Set<StepExecution> partitionStepExecutions) throws Exception {
        if (CollectionUtils.isEmpty(partitionStepExecutions)) {
            return partitionStepExecutions;
        }
        int count = 0;
        for (StepExecution stepExecution : partitionStepExecutions) {
            Message<StepExecutionRequest> request = this.createMessage(count++, partitionStepExecutions.size(), new StepExecutionRequest(this.stepName, stepExecution.getJobExecutionId(), stepExecution.getId()), this.replyChannel);
            if (logger.isDebugEnabled()) {
                logger.debug((Object)("Sending request: " + String.valueOf(request)));
            }
            this.messagingGateway.send(request);
        }
        if (!this.pollRepositoryForResults) {
            return this.receiveReplies(this.replyChannel);
        }
        return this.pollReplies(managerStepExecution, partitionStepExecutions);
    }

    private Set<StepExecution> pollReplies(StepExecution managerStepExecution, Set<StepExecution> split) throws Exception {
        Set partitionStepExecutionIds = split.stream().map(Entity::getId).collect(Collectors.toSet());
        Callable<Set> callback = () -> {
            JobExecution jobExecution = this.jobExplorer.getJobExecution(managerStepExecution.getJobExecutionId());
            Set finishedStepExecutions = jobExecution.getStepExecutions().stream().filter(stepExecution -> partitionStepExecutionIds.contains(stepExecution.getId())).filter(stepExecution -> !stepExecution.getStatus().isRunning()).collect(Collectors.toSet());
            if (logger.isDebugEnabled()) {
                logger.debug((Object)String.format("Currently waiting on %s partitions to finish", split.size()));
            }
            if (finishedStepExecutions.size() == split.size()) {
                return finishedStepExecutions;
            }
            return null;
        };
        DirectPoller poller = new DirectPoller(this.pollInterval);
        Future resultsFuture = poller.poll(callback);
        if (this.timeout >= 0L) {
            return (Set)resultsFuture.get(this.timeout, TimeUnit.MILLISECONDS);
        }
        return (Set)resultsFuture.get();
    }

    private Set<StepExecution> receiveReplies(PollableChannel currentReplyChannel) {
        Collection payload;
        Message message = this.messagingGateway.receive((Object)currentReplyChannel);
        if (message == null) {
            throw new MessageTimeoutException("Timeout occurred before all partitions returned");
        }
        if (logger.isDebugEnabled()) {
            logger.debug((Object)("Received replies: " + String.valueOf(message)));
        }
        return (payload = (Collection)message.getPayload()) instanceof Set ? (Set)payload : new HashSet((Collection)message.getPayload());
    }

    private Message<StepExecutionRequest> createMessage(int sequenceNumber, int sequenceSize, StepExecutionRequest stepExecutionRequest, PollableChannel replyChannel) {
        return ((MessageBuilder)((MessageBuilder)((MessageBuilder)((MessageBuilder)MessageBuilder.withPayload((Object)stepExecutionRequest).setSequenceNumber(Integer.valueOf(sequenceNumber))).setSequenceSize(Integer.valueOf(sequenceSize))).setCorrelationId((Object)(stepExecutionRequest.getJobExecutionId() + ":" + stepExecutionRequest.getStepName()))).setReplyChannel((MessageChannel)replyChannel)).build();
    }
}

