/*
 * Decompiled with CFR 0.152.
 */
package qilin.pta.toolkits.zipper.analysis;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import qilin.core.PTA;
import qilin.core.pag.ContextField;
import qilin.core.pag.LocalVarNode;
import qilin.core.pag.Node;
import qilin.core.pag.ValNode;
import qilin.core.pag.VarNode;
import qilin.pta.toolkits.common.OAG;
import qilin.pta.toolkits.common.ToolUtil;
import qilin.pta.toolkits.zipper.Global;
import qilin.pta.toolkits.zipper.analysis.PotentialContextElement;
import qilin.pta.toolkits.zipper.flowgraph.FlowAnalysis;
import qilin.pta.toolkits.zipper.flowgraph.ObjectFlowGraph;
import qilin.util.ANSIColor;
import qilin.util.Stopwatch;
import qilin.util.graph.ConcurrentDirectedGraphImpl;
import sootup.core.model.SootClass;
import sootup.core.model.SootClassMember;
import sootup.core.model.SootMethod;
import sootup.core.types.ClassType;
import sootup.core.types.Type;

public class Zipper {
    private final PTA pta;
    private final PotentialContextElement pce;
    private final ObjectFlowGraph ofg;
    private final AtomicInteger analyzedClasses = new AtomicInteger(0);
    private final AtomicInteger totalPFGNodes = new AtomicInteger(0);
    private final AtomicInteger totalPFGEdges = new AtomicInteger(0);
    private final ConcurrentDirectedGraphImpl<Node> overallPFG = new ConcurrentDirectedGraphImpl();
    private final Map<SootMethod, Integer> methodPts;
    private final Map<Type, Collection<SootMethod>> pcmMap = new ConcurrentHashMap<Type, Collection<SootMethod>>(1024);

    public Zipper(PTA pta) {
        this.pta = pta;
        OAG oag = new OAG(pta);
        oag.build();
        System.out.println("#OAG:" + oag.allNodes().size());
        this.pce = new PotentialContextElement(pta, oag);
        this.ofg = this.buildObjectFlowGraph();
        this.methodPts = this.getMethodPointsToSize();
    }

    public static void outputNumberOfClasses(PTA pta) {
        int nrClasses = (int)pta.getPag().getAllocNodes().stream().map(Node::getType).distinct().count();
        System.out.println("#classes: \u001b[1m\u001b[32m" + nrClasses + "\u001b[0m");
        System.out.println();
    }

    public int numberOfOverallPFGNodes() {
        return this.overallPFG.allNodes().size();
    }

    public int numberOfOverallPFGEdges() {
        int nrEdges = 0;
        for (Node node : this.overallPFG.allNodes()) {
            nrEdges += this.overallPFG.succsOf(node).size();
        }
        return nrEdges;
    }

    public ObjectFlowGraph buildObjectFlowGraph() {
        Stopwatch ofgTimer = Stopwatch.newAndStart("Object Flow Graph Timer");
        System.out.println("Building OFG (Object Flow Graph) ... ");
        ObjectFlowGraph ofg = new ObjectFlowGraph(this.pta);
        ofgTimer.stop();
        System.out.println(ofgTimer);
        Zipper.outputObjectFlowGraphSize(ofg);
        return ofg;
    }

    public static void outputObjectFlowGraphSize(ObjectFlowGraph ofg) {
        int nrNodes = ofg.allNodes().size();
        int nrEdges = 0;
        for (Node node : ofg.allNodes()) {
            nrEdges += ofg.outEdgesOf(node).size();
        }
        System.out.println("#nodes in OFG: \u001b[1m\u001b[32m" + nrNodes + "\u001b[0m");
        System.out.println("#edges in OFG: \u001b[1m\u001b[32m" + nrEdges + "\u001b[0m");
        System.out.println();
    }

    public Set<SootMethod> analyze() {
        this.reset();
        System.out.println("Building PFGs (Pollution Flow Graphs) and computing precision-critical methods ...");
        List<ClassType> types = this.pta.getPag().getAllocNodes().stream().map(Node::getType).distinct().sorted(Comparator.comparing(Object::toString)).filter(t -> t instanceof ClassType).map(t -> (ClassType)t).collect(Collectors.toList());
        if (Global.getThread() == -1) {
            this.computePCM(types);
        } else {
            this.computePCMConcurrent(types, Global.getThread());
        }
        System.out.println("#avg. nodes in PFG: \u001b[1m\u001b[32m" + Math.round(this.totalPFGNodes.floatValue() / (float)this.analyzedClasses.get()) + "\u001b[0m");
        System.out.println("#avg. edges in PFG: \u001b[1m\u001b[32m" + Math.round(this.totalPFGEdges.floatValue() / (float)this.analyzedClasses.get()) + "\u001b[0m");
        System.out.println("#Node:" + this.totalPFGNodes.intValue());
        System.out.println("#Edge:" + this.totalPFGEdges.intValue());
        System.out.println("#Node2:" + this.numberOfOverallPFGNodes());
        System.out.println("#Edge2:" + this.numberOfOverallPFGEdges());
        System.out.println();
        Set<SootMethod> pcm = this.collectAllPrecisionCriticalMethods(this.pcmMap, this.computePCMThreshold());
        System.out.println("#Precision-critical methods: \u001b[1m\u001b[32m" + pcm.size() + "\u001b[0m");
        return pcm;
    }

    private void computePCM(List<ClassType> types) {
        FlowAnalysis fa = new FlowAnalysis(this.pta, this.pce, this.ofg);
        types.forEach(type -> this.analyze((ClassType)type, fa));
    }

    private void computePCMConcurrent(List<ClassType> types, int nThread) {
        ExecutorService executorService = Executors.newFixedThreadPool(nThread);
        types.forEach(type -> executorService.execute(() -> {
            FlowAnalysis fa = new FlowAnalysis(this.pta, this.pce, this.ofg);
            this.analyze((ClassType)type, fa);
        }));
        executorService.shutdown();
        try {
            executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    private void analyze(ClassType type, FlowAnalysis fa) {
        if (Global.isDebug()) {
            System.out.println("----------------------------------------");
        }
        Set ms = this.pta.getPag().getAllocNodes().stream().filter(o -> o.getType().equals(type)).map(this.pce::methodsInvokedOn).flatMap(Collection::stream).collect(Collectors.toSet());
        Set<SootMethod> inms = ms.stream().filter(m -> !m.isPrivate()).filter(m -> ToolUtil.getParameters(this.pta.getPag(), m).stream().anyMatch(p -> !this.pta.reachingObjects((Node)p).toCIPointsToSet().isEmpty())).collect(Collectors.toSet());
        HashSet<SootMethod> outms = new HashSet<SootMethod>();
        ms.stream().filter(m -> !m.isPrivate()).filter(m -> ToolUtil.getRetVars(this.pta.getPag(), m).stream().anyMatch(r -> !this.pta.reachingObjects((Node)r).toCIPointsToSet().isEmpty())).forEach(outms::add);
        this.pce.PCEMethodsOf((Type)type).stream().filter(m -> !m.isPrivate() && !m.isStatic()).filter(m -> this.isInnerType(m.getDeclaringClassType(), type)).forEach(outms::add);
        this.pce.PCEMethodsOf((Type)type).stream().filter(m -> !m.isPrivate() && !m.isStatic()).filter(m -> m.getDeclaringClassType().equals((Object)type) && m.toString().contains("access$")).forEach(outms::add);
        if (Global.isDebug()) {
            System.out.println(ANSIColor.color("\u001b[33m", "In methods:"));
            inms.stream().sorted(Comparator.comparing(SootClassMember::toString)).forEach(m -> System.out.println("  " + m));
            System.out.println(ANSIColor.color("\u001b[33m", "Out methods:"));
            outms.stream().sorted(Comparator.comparing(SootClassMember::toString)).forEach(m -> System.out.println("  " + m));
        }
        fa.initialize((Type)type, inms, outms);
        inms.forEach(fa::analyze);
        Set<Node> flowNodes = fa.getFlowNodes();
        Set<SootMethod> precisionCriticalMethods = this.getPrecisionCriticalMethods((Type)type, flowNodes);
        if (Global.isDebug() && !precisionCriticalMethods.isEmpty()) {
            System.out.println(ANSIColor.color("\u001b[34m", "Flow found: ") + type);
        }
        this.mergeAnalysisResults((Type)type, fa.numberOfPFGNodes(), fa.numberOfPFGEdges(), precisionCriticalMethods);
        this.mergeSinglePFG(fa.getPFG());
        fa.clear();
    }

    public boolean isInnerType(ClassType pInner, ClassType pOuter) {
        String pInnerStr = pInner.toString();
        while (!pInnerStr.startsWith(pOuter.toString() + "$")) {
            SootClass sc = (SootClass)this.pta.getView().getClass(pOuter).get();
            if (sc.hasSuperclass()) {
                pOuter = (ClassType)sc.getSuperclass().get();
                continue;
            }
            return false;
        }
        return true;
    }

    private void mergeSinglePFG(ConcurrentDirectedGraphImpl<Node> pfg) {
        for (Node node : pfg.allNodes()) {
            this.overallPFG.addNode(node);
            for (Node succ : pfg.succsOf(node)) {
                this.overallPFG.addEdge(node, succ);
            }
        }
    }

    private void mergeAnalysisResults(Type type, int nrPFGNodes, int nrPFGEdges, Set<SootMethod> precisionCriticalMethods) {
        this.analyzedClasses.incrementAndGet();
        this.totalPFGNodes.addAndGet(nrPFGNodes);
        this.totalPFGEdges.addAndGet(nrPFGEdges);
        this.pcmMap.put(type, new ArrayList<SootMethod>(precisionCriticalMethods));
    }

    private Set<SootMethod> collectAllPrecisionCriticalMethods(Map<Type, Collection<SootMethod>> pcmMap, int pcmThreshold) {
        System.out.println("PCM Threshold:" + pcmThreshold);
        HashSet<SootMethod> pcm = new HashSet<SootMethod>();
        pcmMap.forEach((type, pcms) -> {
            if (Global.isExpress() && this.getAccumulativePointsToSetSize((Collection<SootMethod>)pcms) > (long)pcmThreshold) {
                System.out.println("type: " + type + ", accumulativePTSize: " + this.getAccumulativePointsToSetSize((Collection<SootMethod>)pcms));
                return;
            }
            pcm.addAll((Collection<SootMethod>)pcms);
        });
        return pcm;
    }

    private int computePCMThreshold() {
        int totalPTSSize = 0;
        for (ValNode var : this.pta.getPag().getValNodes()) {
            if (!(var instanceof VarNode)) continue;
            VarNode varNode = (VarNode)var;
            totalPTSSize += this.pta.reachingObjects(varNode).toCIPointsToSet().size();
        }
        return (int)(Global.getExpressThreshold() * (float)totalPTSSize);
    }

    private Set<SootMethod> getPrecisionCriticalMethods(Type type, Set<Node> nodes) {
        return nodes.stream().map(this::node2ContainingMethod).filter(Objects::nonNull).filter(this.pce.PCEMethodsOf(type)::contains).collect(Collectors.toSet());
    }

    private SootMethod node2ContainingMethod(Node node) {
        if (node instanceof LocalVarNode) {
            LocalVarNode lvn = (LocalVarNode)node;
            return lvn.getMethod();
        }
        ContextField ctxField = (ContextField)node;
        return ctxField.getBase().getMethod();
    }

    private void reset() {
        this.analyzedClasses.set(0);
        this.totalPFGNodes.set(0);
        this.totalPFGEdges.set(0);
        this.pcmMap.clear();
    }

    private Map<SootMethod, Integer> getMethodPointsToSize() {
        HashMap<SootMethod, Integer> results = new HashMap<SootMethod, Integer>();
        for (ValNode valnode : this.pta.getPag().getValNodes()) {
            if (!(valnode instanceof LocalVarNode)) continue;
            LocalVarNode lvn = (LocalVarNode)valnode;
            SootMethod inMethod = lvn.getMethod();
            int ptSize = ToolUtil.pointsToSetSizeOf(this.pta, lvn);
            if (results.containsKey(inMethod)) {
                int oldValue = (Integer)results.get(inMethod);
                results.replace(inMethod, oldValue, oldValue + ptSize);
                continue;
            }
            results.put(inMethod, ptSize);
        }
        return results;
    }

    private long getAccumulativePointsToSetSize(Collection<SootMethod> methods) {
        return methods.stream().mapToInt(this.methodPts::get).sum();
    }
}

