/*
 * Decompiled with CFR 0.152.
 */
package org.legendofdragoon.scripting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.OptionalInt;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Predicate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.Marker;
import org.apache.logging.log4j.MarkerManager;
import org.legendofdragoon.scripting.OpType;
import org.legendofdragoon.scripting.ParameterType;
import org.legendofdragoon.scripting.State;
import org.legendofdragoon.scripting.StringInfo;
import org.legendofdragoon.scripting.meta.Meta;
import org.legendofdragoon.scripting.tokens.Data;
import org.legendofdragoon.scripting.tokens.Entry;
import org.legendofdragoon.scripting.tokens.Entrypoint;
import org.legendofdragoon.scripting.tokens.LodString;
import org.legendofdragoon.scripting.tokens.Op;
import org.legendofdragoon.scripting.tokens.Param;
import org.legendofdragoon.scripting.tokens.PointerTable;
import org.legendofdragoon.scripting.tokens.Script;

public class Disassembler {
    private static final Logger LOGGER = LogManager.getFormatterLogger();
    private static final Marker DISASSEMBLY = MarkerManager.getMarker((String)"DISASSEMBLY");
    private final Meta meta;
    private State state;

    public Disassembler(Meta meta) {
        this.meta = meta;
    }

    public Script disassemble(byte[] bytes, int[] extraBranches) {
        this.state = new State(bytes);
        Script script = new Script(this.state.length() / 4);
        this.getEntrypoints(script);
        for (int entrypoint : script.entrypoints) {
            this.probeBranch(script, entrypoint);
        }
        block1: for (int entryIndex = 0; entryIndex < script.entries.length; ++entryIndex) {
            Entry entry = script.entries[entryIndex];
            if (!(entry instanceof PointerTable)) continue;
            PointerTable rel = (PointerTable)entry;
            ++entryIndex;
            for (int labelIndex = 1; labelIndex < rel.labels.length; ++labelIndex) {
                if (script.entries[entryIndex] != null && !(script.entries[entryIndex] instanceof Data)) {
                    LOGGER.warn("Jump table overrun at %x", (Object)entry.address);
                    for (int toRemove = labelIndex; toRemove < rel.labels.length; ++toRemove) {
                        if (script.labelUsageCount.get(rel.labels[toRemove]) > 1) continue;
                        for (List<String> labels : script.labels.values()) {
                            labels.remove(rel.labels[toRemove]);
                        }
                    }
                    rel.labels = Arrays.copyOfRange(rel.labels, 0, labelIndex);
                    continue block1;
                }
                ++entryIndex;
            }
        }
        for (int extraBranch : extraBranches) {
            this.probeBranch(script, extraBranch);
        }
        script.buildStrings.forEach(Runnable::run);
        this.fillStrings(script);
        this.fillData(script);
        LOGGER.info(DISASSEMBLY, "Probing complete");
        return script;
    }

    /*
     * Enabled aggressive block sorting
     */
    private void probeBranch(Script script, int offset) {
        if (script.branches.contains(offset)) {
            return;
        }
        LOGGER.info(DISASSEMBLY, "Probing branch %x", (Object)offset);
        script.branches.add(offset);
        int oldHeaderOffset = this.state.headerOffset();
        int oldCurrentOffset = this.state.currentOffset();
        this.state.jump(offset);
        block20: while (this.state.hasMore()) {
            this.state.step();
            Op op = this.parseHeader(this.state.currentOffset());
            if (op == null) break;
            this.state.advance();
            int entryOffset = this.state.headerOffset() / 4;
            script.entries[entryOffset++] = op;
            for (int i = 0; i < op.params.length; ++i) {
                ParameterType paramType = ParameterType.byOpcode(this.state.paramType());
                int[] rawValues = new int[paramType.width];
                for (int n = 0; n < paramType.width; ++n) {
                    rawValues[n] = this.state.wordAt(this.state.currentOffset() + n * 4);
                }
                int paramOffset = this.state.currentOffset();
                OptionalInt resolved = this.parseParamValue(this.state, paramType);
                Param param = new Param(paramOffset, paramType, rawValues, resolved, paramType.isInline() && resolved.isPresent() ? script.addLabel(resolved.getAsInt(), "LABEL_" + script.getLabelCount()) : null);
                for (int n = 0; n < paramType.width; ++n) {
                    script.entries[entryOffset++] = param;
                }
                op.params[i] = param;
                if (!paramType.isInlineTable() || op.type == OpType.GOSUB_TABLE || op.type == OpType.JMP_TABLE) continue;
                if (op.type == OpType.CALL && !"none".equalsIgnoreCase(this.meta.methods[op.headerParam].params[i].branch)) {
                    HashSet tableDestinations = switch (this.meta.methods[op.headerParam].params[i].branch.toLowerCase()) {
                        case "jump" -> script.jumpTableDests;
                        case "subroutine" -> script.subs;
                        case "reentry" -> script.reentries;
                        default -> {
                            LOGGER.warn("Unknown branch type %s", (Object)this.meta.methods[op.headerParam].params[i].branch);
                            yield new HashSet();
                        }
                    };
                    param.resolvedValue.ifPresent(tableAddress -> this.probeTableOfBranches(script, tableDestinations, tableAddress));
                    continue;
                }
                int finalI = i;
                param.resolvedValue.ifPresent(tableAddress -> this.handlePointerTable(script, op, finalI, tableAddress, script.buildStrings));
            }
            switch (op.type) {
                case CALL: {
                    Meta.ScriptMethod method = this.meta.methods[op.headerParam];
                    if (this.meta.methods[op.headerParam].params.length != op.params.length) {
                        // empty if block
                    }
                    for (int i = 0; i < this.meta.methods[op.headerParam].params.length; ++i) {
                        Meta.ScriptParam param = method.params[i];
                        if ("none".equalsIgnoreCase(param.branch)) continue;
                        op.params[i].resolvedValue.ifPresentOrElse(offset1 -> {
                            if ("gosub".equalsIgnoreCase(param.branch)) {
                                script.subs.add(offset1);
                            } else if ("reentry".equalsIgnoreCase(param.branch)) {
                                script.reentries.add(offset1);
                            }
                            this.probeBranch(script, offset1);
                        }, () -> LOGGER.warn("Skipping CALL at %x due to unknowable parameter", (Object)this.state.headerOffset()));
                    }
                    break;
                }
                case JMP: {
                    op.params[0].resolvedValue.ifPresentOrElse(offset1 -> this.probeBranch(script, offset1), () -> LOGGER.warn("Skipping JUMP at %x due to unknowable parameter", (Object)this.state.headerOffset()));
                    if (!op.params[0].resolvedValue.isPresent()) break;
                    break block20;
                }
                case JMP_CMP: 
                case JMP_CMP_0: {
                    op.params[op.params.length - 1].resolvedValue.ifPresentOrElse(addr -> {
                        this.probeBranch(script, this.state.currentOffset());
                        this.probeBranch(script, addr);
                    }, () -> LOGGER.warn("Skipping %s at %x due to unknowable parameter", (Object)op.type, (Object)this.state.headerOffset()));
                    break block20;
                }
                case JMP_TABLE: {
                    op.params[1].resolvedValue.ifPresentOrElse(tableOffset -> {
                        if (op.params[1].type.isInlineTable()) {
                            this.probeTableOfTables(script, script.jumpTableDests, tableOffset);
                        } else {
                            this.probeTableOfBranches(script, script.jumpTableDests, tableOffset);
                        }
                    }, () -> LOGGER.warn("Skipping JMP_TABLE at %x due to unknowable parameter", (Object)this.state.headerOffset()));
                    break block20;
                }
                case GOSUB: {
                    op.params[0].resolvedValue.ifPresentOrElse(offset1 -> {
                        script.subs.add(offset1);
                        this.probeBranch(script, offset1);
                    }, () -> LOGGER.warn("Skipping GOSUB at %x due to unknowable parameter", (Object)this.state.headerOffset()));
                    break;
                }
                case GOSUB_TABLE: {
                    op.params[1].resolvedValue.ifPresentOrElse(tableOffset -> {
                        if (op.params[1].type.isInlineTable()) {
                            this.probeTableOfTables(script, script.subs, tableOffset);
                        } else {
                            this.probeTableOfBranches(script, script.subs, tableOffset);
                        }
                    }, () -> LOGGER.warn("Skipping GOSUB_TABLE at %x due to unknowable parameter", (Object)this.state.headerOffset()));
                    break;
                }
                case REWIND: 
                case RETURN: 
                case DEALLOCATE: 
                case DEALLOCATE82: 
                case CONSUME: {
                    break block20;
                }
                case FORK: {
                    op.params[1].resolvedValue.ifPresentOrElse(offset1 -> {
                        script.reentries.add(offset1);
                        this.probeBranch(script, offset1);
                    }, () -> LOGGER.warn("Skipping FORK at %x due to unknowable parameter", (Object)this.state.headerOffset()));
                }
            }
        }
        this.state.headerOffset(oldHeaderOffset);
        this.state.currentOffset(oldCurrentOffset);
    }

    private void probeTableOfTables(Script script, Set<Integer> tableDestinations, int tableAddress) {
        this.probeTable(script, script.subTables, tableDestinations, tableAddress, subtableAddress -> !this.isProbablyOp(script, (int)subtableAddress), subtableAddress -> this.probeTableOfBranches(script, tableDestinations, (int)subtableAddress));
    }

    private void probeTableOfBranches(Script script, Set<Integer> tableDestinations, int subtableAddress) {
        this.probeTable(script, script.subTables, tableDestinations, subtableAddress, this::isValidOp, branchAddress -> this.probeBranch(script, (int)branchAddress));
    }

    private void probeTable(Script script, Set<Integer> tables, Set<Integer> tableDestinations, int tableAddress, Predicate<Integer> destinationAddressHeuristic, Consumer<Integer> visitor) {
        int destAddress;
        if (tables.contains(tableAddress)) {
            return;
        }
        tables.add(tableAddress);
        int earliestDestination = this.state.length();
        int latestDestination = 0;
        ArrayList<Integer> destinations = new ArrayList<Integer>();
        ArrayList<String> labels = new ArrayList<String>();
        for (int entryAddress = tableAddress; entryAddress <= this.state.length() - 4 && script.entries[entryAddress / 4] == null && (this.state.wordAt(entryAddress) > 0 ? entryAddress < earliestDestination : entryAddress > latestDestination) && (!this.isProbablyOp(script, entryAddress) || this.isValidOp(tableAddress + this.state.wordAt(entryAddress) * 4)) && (destAddress = tableAddress + this.state.wordAt(entryAddress) * 4) >= 4 && destAddress < this.state.length() - 4 && destinationAddressHeuristic.test(destAddress); entryAddress += 4) {
            if (earliestDestination > destAddress) {
                earliestDestination = destAddress;
            }
            if (latestDestination < destAddress) {
                latestDestination = destAddress;
            }
            tableDestinations.add(destAddress);
            destinations.add(destAddress);
            labels.add(script.addLabel(destAddress, "JMP_%x_%d".formatted(tableAddress, labels.size())));
        }
        if (labels.isEmpty()) {
            throw new RuntimeException("Empty table at 0x%x".formatted(tableAddress));
        }
        script.entries[tableAddress / 4] = new PointerTable(tableAddress, (String[])labels.toArray(String[]::new));
        destinations.stream().distinct().sorted(Comparator.reverseOrder()).forEach(visitor);
    }

    private void handlePointerTable(Script script, Op op, int paramIndex, int tableAddress, List<Runnable> buildStrings) {
        if (tableAddress / 4 >= script.entries.length) {
            LOGGER.warn("Op %s param %d points to invalid pointer table %x", (Object)op, (Object)paramIndex, (Object)tableAddress);
            return;
        }
        if (script.entries[tableAddress / 4] != null) {
            return;
        }
        ArrayList<Integer> destinations = new ArrayList<Integer>();
        int entryCount = 0;
        int earliestDestination = this.state.length();
        int latestDestination = 0;
        for (int entryAddress = tableAddress; entryAddress <= this.state.length() - 4 && script.entries[entryAddress / 4] == null && (this.state.wordAt(entryAddress) > 0 ? entryAddress < earliestDestination : entryAddress > latestDestination); entryAddress += 4) {
            int destination = tableAddress + this.state.wordAt(entryAddress) * 4;
            if (op.type == OpType.CALL && "string".equalsIgnoreCase(this.meta.methods[op.headerParam].params[paramIndex].type)) {
                if (script.entries[entryAddress / 4] instanceof Op) break;
                if (this.isProbablyOp(script, entryAddress)) {
                    boolean foundTerminator = false;
                    for (int i = destination / 4; i < destination / 4 + 300 && i < script.entries.length && script.entries[i] == null; ++i) {
                        int word = this.state.wordAt(i * 4);
                        if ((word & 0xFFFF) != 41215 && (word >> 16 & 0xFFFF) != 41215) continue;
                        foundTerminator = true;
                        break;
                    }
                    if (!foundTerminator) {
                        break;
                    }
                }
            } else if (this.isProbablyOp(script, entryAddress)) break;
            if (destination >= this.state.length() - 4) break;
            if (earliestDestination > destination) {
                earliestDestination = destination;
            }
            if (latestDestination < destination) {
                latestDestination = destination;
            }
            if (op.type == OpType.GOSUB_TABLE || op.type == OpType.JMP_TABLE) {
                destination = tableAddress + this.state.wordAt(destination) * 4;
            }
            destinations.add(destination);
            ++entryCount;
        }
        String[] labels = new String[entryCount];
        for (int entryIndex = 0; entryIndex < entryCount; ++entryIndex) {
            labels[entryIndex] = script.addLabel((Integer)destinations.get(entryIndex), "PTR_%x_%d".formatted(tableAddress, entryIndex));
        }
        PointerTable table = new PointerTable(tableAddress, labels);
        script.entries[tableAddress / 4] = table;
        if (op.type == OpType.CALL && "string".equalsIgnoreCase(this.meta.methods[op.headerParam].params[paramIndex].type)) {
            buildStrings.add(() -> {
                while (destinations.size() > table.labels.length) {
                    destinations.removeLast();
                }
                destinations.sort(Integer::compareTo);
                for (int i = 0; i < destinations.size(); ++i) {
                    if (i < destinations.size() - 1) {
                        script.strings.add(new StringInfo((Integer)destinations.get(i), (Integer)destinations.get(i + 1) - (Integer)destinations.get(i)));
                        continue;
                    }
                    script.strings.add(new StringInfo((Integer)destinations.get(i), -1));
                }
            });
        }
    }

    private void fillStrings(Script script) {
        for (StringInfo string : script.strings) {
            this.fillString(script, string.start, string.maxLength);
        }
    }

    private void fillString(Script script, int address, int maxLength) {
        int chr;
        ArrayList<Integer> chars = new ArrayList<Integer>();
        for (int i = 0; i < (maxLength != -1 ? maxLength : script.entries.length * 4 - address) && (chr = this.state.wordAt(address + i / 2 * 4) >>> i % 2 * 16 & 0xFFFF) != 41215; ++i) {
            chars.add(chr);
        }
        LodString string = new LodString(address, chars.stream().mapToInt(Integer::intValue).toArray());
        for (int i = 0; i < string.chars.length / 2; ++i) {
            script.entries[address / 4 + i] = string;
        }
    }

    private void fillData(Script script) {
        for (int i = 0; i < script.entries.length; ++i) {
            if (script.entries[i] != null) continue;
            script.entries[i] = new Data(i * 4, this.state.wordAt(i * 4));
        }
    }

    private void getEntrypoints(Script script) {
        int entrypoint;
        for (int i = 0; i < 32 && this.isValidOp(entrypoint = this.state.currentWord()); ++i) {
            String label = "ENTRYPOINT_" + i;
            script.entries[i] = new Entrypoint(i * 4, label);
            script.entrypoints.add(entrypoint);
            script.addUniqueLabel(entrypoint, label);
            this.state.advance();
        }
    }

    private Op parseHeader(int offset) {
        int opcode = this.state.wordAt(offset);
        OpType type = OpType.byOpcode(opcode & 0xFF);
        if (type == null) {
            return null;
        }
        int paramCount = opcode >> 8 & 0xFF;
        if (type != OpType.CALL && type.paramNames.length != paramCount) {
            return null;
        }
        int opParam = opcode >> 16;
        if (type.headerParamName == null && opParam != 0) {
            return null;
        }
        return new Op(offset, type, opParam, paramCount);
    }

    private boolean isValidOp(int offset) {
        if ((offset & 3) != 0) {
            return false;
        }
        if (offset < 4 || offset >= this.state.length()) {
            return false;
        }
        return this.parseHeader(offset) != null;
    }

    private boolean isProbablyOp(Script script, int address) {
        if ((address & 3) != 0) {
            return false;
        }
        if (address < 4 || address >= this.state.length()) {
            return false;
        }
        if (script.entries[address / 4] instanceof Op) {
            return true;
        }
        int testCount = 3;
        int certainty = 0;
        for (int opIndex = 0; opIndex < 3; ++opIndex) {
            Op op = this.parseHeader(address);
            if (op == null) {
                certainty -= 3 - opIndex;
                break;
            }
            certainty += opIndex + 1;
            address += 4;
            for (int paramIndex = 0; paramIndex < op.type.paramNames.length; ++paramIndex) {
                ParameterType parameterType = ParameterType.byOpcode(this.state.wordAt(address));
                if (parameterType != ParameterType.IMMEDIATE) {
                    ++certainty;
                }
                address += parameterType.width * 4;
            }
        }
        return certainty >= 2;
    }

    private OptionalInt parseParamValue(State state, ParameterType param) {
        OptionalInt value = switch (param) {
            case ParameterType.IMMEDIATE -> OptionalInt.of(state.currentWord());
            case ParameterType.NEXT_IMMEDIATE -> OptionalInt.of(state.wordAt(state.currentOffset() + 4));
            case ParameterType.INLINE_1, ParameterType.INLINE_2, ParameterType.INLINE_TABLE_1, ParameterType.INLINE_TABLE_3 -> OptionalInt.of(state.headerOffset() + (short)state.currentWord() * 4);
            case ParameterType.INLINE_TABLE_2, ParameterType.INLINE_TABLE_4 -> OptionalInt.of(state.headerOffset() + 4);
            case ParameterType.INLINE_3 -> OptionalInt.of(state.headerOffset() + ((short)state.currentWord() + state.param2()) * 4);
            default -> OptionalInt.empty();
        };
        this.state.advance(param.width);
        return value;
    }
}

