/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.reactive.reactor;

import java.util.List;
import java.util.stream.Collectors;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Preconditions;
import org.openrewrite.Recipe;
import org.openrewrite.Tree;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.MethodCall;
import org.openrewrite.java.tree.TypeUtils;

public class ReactorDoAfterSuccessOrErrorToTap
extends Recipe {
    private static final String DEFAULT_SIGNAL_LISTENER = "reactor.core.observability.DefaultSignalListener";
    private static final String OPERATORS = "reactor.core.publisher.Operators";
    private static final String SIGNAL_TYPE = "reactor.core.publisher.SignalType";
    private static final String REACTOR_CONTEXT = "reactor.util.context.Context";
    private static final MethodMatcher DO_AFTER_SUCCESS_OR_ERROR = new MethodMatcher("reactor.core.publisher.Mono doAfterSuccessOrError(..)");
    private static final MethodMatcher DO_ON_FINALLY = new MethodMatcher("* doFinally(..)");

    public String getDisplayName() {
        return "Replace `doAfterSuccessOrError` calls with `tap` operator";
    }

    public String getDescription() {
        return "As of reactor-core 3.5 the `doAfterSuccessOrError` method is removed, this recipe replaces it with the `tap` operator.";
    }

    public TreeVisitor<?, ExecutionContext> getVisitor() {
        return Preconditions.check((TreeVisitor)new UsesMethod(DO_AFTER_SUCCESS_OR_ERROR), (TreeVisitor)new JavaIsoVisitor<ExecutionContext>(){

            public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
                J.MethodInvocation mi = super.visitMethodInvocation(method, (Object)ctx);
                if (DO_AFTER_SUCCESS_OR_ERROR.matches((MethodCall)mi)) {
                    boolean argumentLambda = mi.getArguments().get(0) instanceof J.Lambda;
                    Object[] templateArgs = new Object[]{mi.getSelect()};
                    if (!argumentLambda) {
                        templateArgs = new Expression[]{mi.getSelect(), (Expression)mi.getArguments().get(0)};
                    }
                    String signalListenerTemplate = this.newDefaultSignalListenerTemplate(mi);
                    J.MethodInvocation replacement = (J.MethodInvocation)JavaTemplate.builder((String)("#{any()}.tap(() -> " + signalListenerTemplate + ")")).contextSensitive().imports(new String[]{ReactorDoAfterSuccessOrErrorToTap.DEFAULT_SIGNAL_LISTENER, ReactorDoAfterSuccessOrErrorToTap.OPERATORS, ReactorDoAfterSuccessOrErrorToTap.SIGNAL_TYPE, ReactorDoAfterSuccessOrErrorToTap.REACTOR_CONTEXT}).javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, new String[]{"reactor-core-3.5.+", "reactive-streams-1.+"})).build().apply(this.updateCursor((Tree)mi), mi.getCoordinates().replace(), templateArgs);
                    this.maybeAddImport(ReactorDoAfterSuccessOrErrorToTap.DEFAULT_SIGNAL_LISTENER);
                    this.maybeAddImport(ReactorDoAfterSuccessOrErrorToTap.OPERATORS);
                    this.maybeAddImport(ReactorDoAfterSuccessOrErrorToTap.SIGNAL_TYPE);
                    this.maybeAddImport(ReactorDoAfterSuccessOrErrorToTap.REACTOR_CONTEXT);
                    if (argumentLambda) {
                        List originalStatements = ((J.Block)((J.Lambda)mi.getArguments().get(0)).getBody()).getStatements();
                        mi = replacement.withArguments(ListUtils.map((List)replacement.getArguments(), arg -> {
                            if (arg instanceof J.Lambda && ((J.Lambda)arg).getBody() instanceof J.NewClass) {
                                J.NewClass defaultSignalClass = (J.NewClass)((J.Lambda)arg).getBody();
                                arg = ((J.Lambda)arg).withBody((J)defaultSignalClass.withBody(defaultSignalClass.getBody().withStatements(ListUtils.map((List)defaultSignalClass.getBody().getStatements(), stmt -> {
                                    if (stmt instanceof J.MethodDeclaration) {
                                        J.ClassDeclaration cd = (J.ClassDeclaration)this.getCursor().firstEnclosing(J.ClassDeclaration.class);
                                        J.MethodDeclaration md = (J.MethodDeclaration)stmt;
                                        if (DO_ON_FINALLY.matches(md, cd)) {
                                            List newStatements = ListUtils.concatAll((List)md.getBody().getStatements(), (List)originalStatements);
                                            stmt = md.withBody(md.getBody().withStatements(newStatements));
                                        }
                                    }
                                    return stmt;
                                }))));
                            }
                            return arg;
                        }));
                    } else {
                        mi = replacement;
                    }
                    return (J.MethodInvocation)this.maybeAutoFormat((J)method, (J)mi, ctx);
                }
                return mi;
            }

            private String newDefaultSignalListenerTemplate(J.MethodInvocation doAfterSuccessOrError) {
                String clazz = TypeUtils.asFullyQualified((JavaType)((JavaType)((JavaType.Parameterized)doAfterSuccessOrError.getMethodType().getReturnType()).getTypeParameters().get(0))).getClassName();
                String result = "result";
                String error = "error";
                String impl = "#{any()}.accept(result, error);";
                Expression firstArgument = (Expression)doAfterSuccessOrError.getArguments().get(0);
                if (firstArgument instanceof J.Lambda) {
                    List doAfterSuccessOrErrorLambdaParams = ((J.Lambda)firstArgument).getParameters().getParameters().stream().map(J.VariableDeclarations.class::cast).collect(Collectors.toList());
                    result = ((J.VariableDeclarations.NamedVariable)((J.VariableDeclarations)doAfterSuccessOrErrorLambdaParams.get(0)).getVariables().get(0)).getSimpleName();
                    error = ((J.VariableDeclarations.NamedVariable)((J.VariableDeclarations)doAfterSuccessOrErrorLambdaParams.get(1)).getVariables().get(0)).getSimpleName();
                    impl = "";
                }
                return "new DefaultSignalListener<>() {    " + clazz + " " + result + ";    Throwable " + error + ";    boolean done;    boolean processedOnce;    Context currentContext;    @Override    public synchronized void doFinally(SignalType signalType) {      if (processedOnce) {          return;      }      processedOnce = true;      if (signalType == SignalType.CANCEL) {          return;      }      " + impl + "    }    @Override    public synchronized void doOnNext(" + clazz + " " + result + ") {        if (done) {            Operators.onDiscard(" + result + ", currentContext);            return;        }        this." + result + " = " + result + ";    }    @Override    public synchronized void doOnComplete() {        this.done = true;    }    @Override    public synchronized void doOnError(Throwable " + error + ") {        if (done) {            Operators.onErrorDropped(" + error + ", currentContext);            return;        }        this." + error + " = " + error + ";        this.done = true;    }    @Override    public Context addToContext(Context originalContext) {        currentContext = originalContext;        return originalContext;    }    @Override    public synchronized void doOnCancel() {        if (done) {            return;        }        this.done = true;        if (" + result + " != null) {            Operators.onDiscard(" + result + ", currentContext);        }    }}";
            }
        });
    }
}

