/*
 * Decompiled with CFR 0.152.
 */
package sootup.interceptors.typeresolving;

import com.google.common.collect.Lists;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import sootup.core.graph.StmtGraph;
import sootup.core.jimple.basic.Immediate;
import sootup.core.jimple.basic.LValue;
import sootup.core.jimple.basic.Local;
import sootup.core.jimple.basic.Value;
import sootup.core.jimple.common.expr.AbstractBinopExpr;
import sootup.core.jimple.common.expr.JCastExpr;
import sootup.core.jimple.common.expr.JNegExpr;
import sootup.core.jimple.common.ref.JArrayRef;
import sootup.core.jimple.common.ref.JInstanceFieldRef;
import sootup.core.jimple.common.stmt.AbstractDefinitionStmt;
import sootup.core.jimple.common.stmt.Stmt;
import sootup.core.model.Body;
import sootup.core.types.ArrayType;
import sootup.core.types.ClassType;
import sootup.core.types.NullType;
import sootup.core.types.PrimitiveType;
import sootup.core.types.Type;
import sootup.core.views.View;
import sootup.interceptors.typeresolving.AugEvalFunction;
import sootup.interceptors.typeresolving.BytecodeHierarchy;
import sootup.interceptors.typeresolving.CastCounter;
import sootup.interceptors.typeresolving.TypePromotionVisitor;
import sootup.interceptors.typeresolving.Typing;
import sootup.interceptors.typeresolving.types.AugmentIntegerTypes;
import sootup.interceptors.typeresolving.types.BottomType;
import sootup.interceptors.typeresolving.types.TopType;
import sootup.java.core.JavaIdentifierFactory;
import sootup.java.core.views.JavaView;

public class TypeResolver {
    private final ArrayList<AbstractDefinitionStmt> assignments = new ArrayList();
    private final Map<Local, BitSet> depends = new HashMap<Local, BitSet>();
    private final JavaView view;
    private final Type objectType;
    private static final Logger logger = LoggerFactory.getLogger(TypeResolver.class);

    public TypeResolver(@Nonnull JavaView view) {
        this.view = view;
        this.objectType = view.getIdentifierFactory().getClassType("java.lang.Object");
    }

    public boolean resolve(@Nonnull Body.BodyBuilder builder) {
        this.init(builder);
        BytecodeHierarchy hierarchy = new BytecodeHierarchy((View)this.view);
        AugEvalFunction evalFunction = new AugEvalFunction((View)this.view);
        ArrayList locals = Lists.newArrayList((Iterable)builder.getLocals());
        Typing iniTyping = new Typing(locals);
        Collection typings = this.applyAssignmentConstraint((StmtGraph<?>)builder.getStmtGraph(), iniTyping, evalFunction, hierarchy);
        if (typings.isEmpty()) {
            return false;
        }
        TypePromotionVisitor promotionVisitor = new TypePromotionVisitor(builder, evalFunction, hierarchy);
        typings = typings.stream().map(promotionVisitor::getPromotedTyping).collect(Collectors.toSet());
        for (Typing typing : typings) {
            for (Local local2 : locals) {
                typing.set(local2, this.convertUnderspecifiedType(typing.getType(local2)));
            }
        }
        CastCounter minCastsCounter = this.getMinCastsCounter(builder, typings, evalFunction, hierarchy);
        minCastsCounter.insertCastStmts();
        Typing minCastsTyping = minCastsCounter.getTyping();
        for (Local local2 : locals) {
            Type convertedType;
            Type type = minCastsTyping.getType(local2);
            if (type == null || (convertedType = this.convertType(type)) == null) continue;
            minCastsTyping.set(local2, convertedType);
        }
        locals.stream().forEach(local -> {
            Type oldType = local.getType();
            Type type = minCastsTyping.getMap().getOrDefault(local, oldType);
            if (type != oldType) {
                Local newLocal = local.withType(type);
                builder.replaceLocal(local, newLocal);
            }
        });
        return true;
    }

    private void init(Body.BodyBuilder builder) {
        for (Stmt stmt : builder.getStmtGraph()) {
            AbstractDefinitionStmt defStmt;
            LValue lhs;
            if (!(stmt instanceof AbstractDefinitionStmt) || !((lhs = (defStmt = (AbstractDefinitionStmt)stmt).getLeftOp()) instanceof Local) && !(lhs instanceof JArrayRef)) continue;
            int defStmtId = this.assignments.size();
            this.assignments.add(defStmt);
            this.addDependsForRHS(defStmt.getRightOp(), defStmtId);
        }
    }

    private void addDependsForRHS(Value rhs, int id) {
        if (rhs instanceof Local) {
            this.addDependency((Local)rhs, id);
        } else if (rhs instanceof AbstractBinopExpr) {
            Immediate op1 = ((AbstractBinopExpr)rhs).getOp1();
            Immediate op2 = ((AbstractBinopExpr)rhs).getOp2();
            if (op1 instanceof Local) {
                this.addDependency((Local)op1, id);
            }
            if (op2 instanceof Local) {
                this.addDependency((Local)op2, id);
            }
        } else if (rhs instanceof JNegExpr) {
            Immediate op = ((JNegExpr)rhs).getOp();
            if (op instanceof Local) {
                this.addDependency((Local)op, id);
            }
        } else if (rhs instanceof JCastExpr) {
            Immediate op = ((JCastExpr)rhs).getOp();
            if (op instanceof Local) {
                this.addDependency((Local)op, id);
            }
        } else if (rhs instanceof JArrayRef) {
            Local base = ((JArrayRef)rhs).getBase();
            this.addDependency(base, id);
        }
    }

    private void addDependency(@Nonnull Local local, int id) {
        BitSet bitSet = this.depends.computeIfAbsent(local, k -> new BitSet());
        bitSet.set(id);
    }

    private Collection<Typing> applyAssignmentConstraint(@Nonnull StmtGraph<?> graph, @Nonnull Typing typing, @Nonnull AugEvalFunction evalFunction, @Nonnull BytecodeHierarchy hierarchy) {
        int numOfAssignments = this.assignments.size();
        if (numOfAssignments == 0) {
            return Collections.emptyList();
        }
        ArrayDeque<Typing> workQueue = new ArrayDeque<Typing>();
        ArrayList<Typing> ret = new ArrayList<Typing>();
        BitSet stmtsList = new BitSet(numOfAssignments);
        stmtsList.set(0, numOfAssignments);
        typing.setStmtsIDList(stmtsList);
        workQueue.add(typing);
        while (!workQueue.isEmpty()) {
            Collection leastCommonAncestors;
            Local local;
            Typing actualTyping = (Typing)workQueue.getFirst();
            BitSet actualSL = actualTyping.getStmtsIDList();
            int stmtId = actualSL.nextSetBit(0);
            if (stmtId == -1) {
                ret.add(actualTyping);
                workQueue.removeFirst();
                continue;
            }
            actualSL.clear(stmtId);
            AbstractDefinitionStmt defStmt = this.assignments.get(stmtId);
            LValue lhs = defStmt.getLeftOp();
            if (lhs instanceof Local) {
                local = (Local)lhs;
            } else if (lhs instanceof JArrayRef) {
                local = ((JArrayRef)lhs).getBase();
            } else {
                if (lhs instanceof JInstanceFieldRef) continue;
                throw new IllegalStateException("can not handle " + lhs.getClass());
            }
            Type rhsType = evalFunction.evaluate(actualTyping, defStmt.getRightOp(), (Stmt)defStmt, graph);
            if (rhsType == null) {
                workQueue.removeFirst();
                continue;
            }
            Type oldType = actualTyping.getType(local);
            if (oldType == null) {
                logger.info("Body.locals do not match the Locals occurring in the Stmts.");
                continue;
            }
            if (lhs instanceof JArrayRef) {
                if (oldType instanceof ArrayType) {
                    Type elementType = ((ArrayType)oldType).getElementType();
                    if (elementType instanceof PrimitiveType) continue;
                    Collection<Type> leastCommonAncestorsElement = hierarchy.getLeastCommonAncestors(elementType, rhsType);
                    leastCommonAncestors = leastCommonAncestorsElement.stream().map(type -> Type.createArrayType((Type)type, (int)1)).collect(Collectors.toSet());
                } else {
                    leastCommonAncestors = hierarchy.getLeastCommonAncestors(oldType, (Type)Type.createArrayType((Type)rhsType, (int)1));
                }
            } else {
                leastCommonAncestors = hierarchy.getLeastCommonAncestors(oldType, rhsType);
            }
            assert (!leastCommonAncestors.isEmpty());
            boolean isFirstType = true;
            for (Type type2 : leastCommonAncestors) {
                if (type2.equals(oldType)) continue;
                BitSet dependStmtList = this.depends.get(local);
                if (isFirstType) {
                    isFirstType = false;
                } else {
                    actualTyping = new Typing(actualTyping, (BitSet)actualSL.clone());
                    workQueue.add(actualTyping);
                    actualSL = actualTyping.getStmtsIDList();
                }
                actualTyping.set(local, type2);
                if (dependStmtList == null) continue;
                actualSL.or(dependStmtList);
            }
        }
        this.minimize(ret, hierarchy);
        return ret;
    }

    private void minimize(@Nonnull List<Typing> typings, @Nonnull BytecodeHierarchy hierarchy) {
        HashSet<ClassType> objectLikeTypes = new HashSet<ClassType>();
        JavaIdentifierFactory identifierFactory = this.view.getIdentifierFactory();
        objectLikeTypes.add(identifierFactory.getClassType("java.lang.Object"));
        objectLikeTypes.add(identifierFactory.getClassType("java.io.Serializable"));
        objectLikeTypes.add(identifierFactory.getClassType("java.lang.Cloneable"));
        HashSet<Local> objectLikeLocals = new HashSet<Local>();
        Map<Local, Set<Type>> local2Types = this.getLocal2Types(typings);
        for (Map.Entry<Local, Set<Type>> local : local2Types.entrySet()) {
            if (!local.getValue().equals(objectLikeTypes)) continue;
            objectLikeLocals.add(local.getKey());
        }
        ArrayList<Typing> typings_clo = new ArrayList<Typing>(typings);
        block1: for (Typing tpi : typings_clo) {
            for (Typing tpj : typings_clo) {
                if (tpi.compare(tpj, hierarchy, objectLikeLocals) != 1) continue;
                typings.remove(tpi);
                continue block1;
            }
        }
    }

    private Map<Local, Set<Type>> getLocal2Types(@Nonnull List<Typing> typings) {
        HashMap<Local, Set<Type>> map = new HashMap<Local, Set<Type>>();
        for (Typing typing : typings) {
            for (Local local : typing.getLocals()) {
                Set types = map.computeIfAbsent(local, k -> new HashSet());
                types.add(typing.getType(local));
            }
        }
        return map;
    }

    private CastCounter getMinCastsCounter(@Nonnull Body.BodyBuilder builder, @Nonnull Collection<Typing> typings, @Nonnull AugEvalFunction evalFunction, @Nonnull BytecodeHierarchy hierarchy) {
        return typings.stream().map(typing -> new CastCounter(builder, evalFunction, hierarchy, (Typing)typing)).min(Comparator.comparingInt(CastCounter::getCastCount)).get();
    }

    private Type convertUnderspecifiedType(@Nonnull Type type) {
        if (type instanceof ArrayType) {
            Type elementType = this.convertUnderspecifiedType(((ArrayType)type).getElementType());
            return Type.createArrayType((Type)elementType, (int)1);
        }
        if (type instanceof NullType || type instanceof BottomType || type instanceof TopType) {
            return this.objectType;
        }
        if (type instanceof AugmentIntegerTypes.Integer1Type) {
            return PrimitiveType.getBoolean();
        }
        if (type instanceof AugmentIntegerTypes.Integer127Type) {
            return PrimitiveType.getByte();
        }
        if (type instanceof AugmentIntegerTypes.Integer32767Type) {
            return PrimitiveType.getShort();
        }
        return type;
    }

    private Type convertType(@Nonnull Type type) {
        if (type instanceof AugmentIntegerTypes.Integer1Type) {
            return PrimitiveType.getBoolean();
        }
        if (type instanceof AugmentIntegerTypes.Integer127Type) {
            return PrimitiveType.getByte();
        }
        if (type instanceof AugmentIntegerTypes.Integer32767Type) {
            return PrimitiveType.getShort();
        }
        if (type instanceof ArrayType) {
            Type eleType = this.convertType(((ArrayType)type).getElementType());
            if (eleType != null) {
                return Type.createArrayType((Type)eleType, (int)1);
            }
            return null;
        }
        return null;
    }
}

