/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.listeners.profiler;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Writer;
import java.lang.management.ManagementFactory;
import java.util.Collections;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingDeque;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.profiler.data.Phase;
import org.nd4j.autodiff.listeners.profiler.data.TraceEvent;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ProfilingListener
extends BaseListener {
    private static final Logger log = LoggerFactory.getLogger(ProfilingListener.class);
    private final File outputFile;
    private final boolean all;
    private final int warmup;
    private final int nIter;
    private final long nMs;
    private final Operation[] operations;
    private final long pid;
    private final long tid;
    private Long firstOpStart = null;
    private int countTotalIter = 0;
    private boolean logActive = false;
    private long opStartNano;
    private Writer writer;
    private ObjectMapper json;
    private final Thread fileWritingThread;
    private final BlockingQueue<TraceEvent> writeQueue;
    private final AtomicBoolean writing = new AtomicBoolean(false);

    protected ProfilingListener(@NonNull File outputFile, boolean all, int warmup, int nIter, long nMs, Operation[] operations) {
        if (outputFile == null) {
            throw new NullPointerException("outputFile is marked non-null but is null");
        }
        Preconditions.checkArgument(!outputFile.exists(), "Output file already exists: %s", (Object)outputFile);
        this.outputFile = outputFile;
        this.all = all;
        this.warmup = warmup;
        this.nIter = nIter;
        this.nMs = nMs;
        this.operations = operations;
        this.pid = this.getProcessId();
        this.tid = Thread.currentThread().getId();
        try {
            this.writer = new BufferedWriter(new FileWriter(outputFile, false));
            this.writer.write("[");
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        this.json = ProfilingListener.jsonMapper();
        this.writeQueue = new LinkedBlockingDeque<TraceEvent>();
        this.fileWritingThread = new Thread(new Runnable(){

            @Override
            public void run() {
                try {
                    this.runHelper();
                }
                catch (Throwable t) {
                    log.error("Error when attempting to write results to file", t);
                }
            }

            public void runHelper() throws Exception {
                while (true) {
                    TraceEvent te = (TraceEvent)ProfilingListener.this.writeQueue.take();
                    ProfilingListener.this.writing.set(true);
                    try {
                        String j = ProfilingListener.this.json.writeValueAsString(te);
                        ProfilingListener.this.writer.append(j);
                        ProfilingListener.this.writer.append(",\n");
                        continue;
                    }
                    catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                    finally {
                        ProfilingListener.this.writing.set(false);
                        continue;
                    }
                    break;
                }
            }
        });
        this.fileWritingThread.setDaemon(true);
        this.fileWritingThread.start();
    }

    @Override
    public boolean isActive(Operation operation) {
        return this.operations == null || ArrayUtils.contains((Object[])this.operations, (Object)operation);
    }

    @Override
    public void operationStart(SameDiff sd, Operation op) {
        this.logActive = this.operations == null || ArrayUtils.contains((Object[])this.operations, (Object)op);
    }

    @Override
    public void operationEnd(SameDiff sd, Operation op) {
        if (this.logActive) {
            while ((!this.writeQueue.isEmpty() || this.writing.get()) && this.fileWritingThread.isAlive()) {
                try {
                    Thread.sleep(100L);
                }
                catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
            try {
                this.writer.flush();
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        this.logActive = false;
        if (op == Operation.INFERENCE) {
            ++this.countTotalIter;
        }
    }

    @Override
    public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
        if (this.logActive) {
            ++this.countTotalIter;
        }
    }

    @Override
    public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) {
        if (this.logActive) {
            this.opStartNano = System.nanoTime();
            if (!this.all && this.nMs > 0L && this.firstOpStart == null) {
                this.firstOpStart = this.opStartNano;
            }
        }
    }

    @Override
    public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
        if (this.logActive) {
            int terminationPt;
            long now = System.nanoTime();
            if (this.warmup > 0 && this.countTotalIter < this.warmup) {
                return;
            }
            int n = terminationPt = this.nIter > 0 ? this.nIter : Integer.MAX_VALUE;
            if (this.warmup > 0 && this.nIter > 0) {
                terminationPt += this.warmup;
            }
            if (this.countTotalIter > terminationPt) {
                this.logActive = false;
                return;
            }
            if (!this.all && this.nMs > 0L && (now - this.firstOpStart) / 1000L > this.nMs) {
                this.logActive = false;
                return;
            }
            TraceEvent event = TraceEvent.builder().name(op.getOp().opName()).categories(Collections.singletonList("Op")).ts(this.opStartNano / 1000L).dur((now - this.opStartNano) / 1000L).pid((int)this.pid).tid(this.tid).ph(Phase.X).args(Collections.singletonMap("name", op.getName())).build();
            this.writeQueue.add(event);
        }
    }

    private long getProcessId() {
        String jvmName = ManagementFactory.getRuntimeMXBean().getName();
        int index = jvmName.indexOf(64);
        if (index < 1) {
            return 0L;
        }
        try {
            return Long.parseLong(jvmName.substring(0, index));
        }
        catch (NumberFormatException numberFormatException) {
            return 0L;
        }
    }

    public static ObjectMapper jsonMapper() {
        ObjectMapper json = new ObjectMapper();
        json.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        json.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        json.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false);
        json.disable(SerializationFeature.INDENT_OUTPUT);
        return json;
    }

    public static Builder builder(File outputFile) {
        return new Builder(outputFile);
    }

    public File getOutputFile() {
        return this.outputFile;
    }

    public boolean isAll() {
        return this.all;
    }

    public int getWarmup() {
        return this.warmup;
    }

    public int getNIter() {
        return this.nIter;
    }

    public long getNMs() {
        return this.nMs;
    }

    public Operation[] getOperations() {
        return this.operations;
    }

    public long getPid() {
        return this.pid;
    }

    public long getTid() {
        return this.tid;
    }

    public Long getFirstOpStart() {
        return this.firstOpStart;
    }

    public int getCountTotalIter() {
        return this.countTotalIter;
    }

    public boolean isLogActive() {
        return this.logActive;
    }

    public long getOpStartNano() {
        return this.opStartNano;
    }

    public Writer getWriter() {
        return this.writer;
    }

    public ObjectMapper getJson() {
        return this.json;
    }

    public Thread getFileWritingThread() {
        return this.fileWritingThread;
    }

    public BlockingQueue<TraceEvent> getWriteQueue() {
        return this.writeQueue;
    }

    public AtomicBoolean getWriting() {
        return this.writing;
    }

    public static class Builder {
        private final File outputFile;
        private boolean all = true;
        private int warmup = 0;
        private int nIter = -1;
        private long nMs = -1L;
        private Operation[] operations;

        public Builder(@NonNull File outputFile) {
            if (outputFile == null) {
                throw new NullPointerException("outputFile is marked non-null but is null");
            }
            this.outputFile = outputFile;
        }

        public Builder recordAll() {
            this.all = true;
            this.nIter = -1;
            this.nMs = -1L;
            return this;
        }

        public Builder warmup(int iterations) {
            this.warmup = iterations;
            return this;
        }

        public Builder maxProfileIterations(int iterations) {
            this.nIter = iterations;
            this.all = false;
            return this;
        }

        public Builder maxProfilerMilliseconds(long ms) {
            this.nMs = ms;
            this.all = false;
            return this;
        }

        public Builder operations(Operation ... operations) {
            this.operations = operations;
            return this;
        }

        public ProfilingListener build() {
            return new ProfilingListener(this.outputFile, this.all, this.warmup, this.nIter, this.nMs, this.operations);
        }
    }
}

