/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.plot;

import com.google.common.base.Function;
import com.google.common.primitives.Ints;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dimensionalityreduction.PCA;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Tsne {
    protected int maxIter = 1000;
    protected double realMin = Nd4j.EPS_THRESHOLD;
    protected double initialMomentum = 0.5;
    protected double finalMomentum = 0.8;
    protected double minGain = 0.01;
    protected double momentum = this.initialMomentum;
    protected int switchMomentumIteration = 100;
    protected boolean normalize = true;
    protected boolean usePca = false;
    protected int stopLyingIteration = 250;
    protected double tolerance = 1.0E-5;
    protected double learningRate = 500.0;
    protected AdaGrad adaGrad;
    protected boolean useAdaGrad = true;
    protected double perplexity = 30.0;
    protected INDArray Y;
    protected static final Logger logger = LoggerFactory.getLogger(Tsne.class);

    public Tsne(int maxIter, double realMin, double initialMomentum, double finalMomentum, double minGain, double momentum, int switchMomentumIteration, boolean normalize, boolean usePca, int stopLyingIteration, double tolerance, double learningRate, boolean useAdaGrad, double perplexity) {
        this.maxIter = maxIter;
        this.realMin = realMin;
        this.initialMomentum = initialMomentum;
        this.finalMomentum = finalMomentum;
        this.minGain = minGain;
        this.momentum = momentum;
        this.switchMomentumIteration = switchMomentumIteration;
        this.normalize = normalize;
        this.usePca = usePca;
        this.stopLyingIteration = stopLyingIteration;
        this.tolerance = tolerance;
        this.learningRate = learningRate;
        this.useAdaGrad = useAdaGrad;
        this.perplexity = perplexity;
        this.init();
    }

    protected void init() {
    }

    public INDArray calculate(INDArray X, int targetDimensions, double perplexity) {
        if (this.usePca) {
            X = PCA.pca((INDArray)X, (int)Math.min(50, X.columns()), (boolean)this.normalize);
        } else if (this.normalize) {
            X.subi(X.min(new int[]{Integer.MAX_VALUE}));
            X = X.divi(X.max(new int[]{Integer.MAX_VALUE}));
            X = X.subiRowVector(X.mean(new int[]{0}));
        }
        int n = X.rows();
        this.Y = Nd4j.randn((long)X.rows(), (long)targetDimensions, (Random)Nd4j.getRandom());
        INDArray dY = Nd4j.zeros((long)n, (long)targetDimensions);
        INDArray iY = Nd4j.zeros((long)n, (long)targetDimensions);
        INDArray gains = Nd4j.ones((int)n, (int)targetDimensions);
        boolean stopLying = false;
        logger.debug("Y:Shape is = " + Arrays.toString(this.Y.shape()));
        INDArray P = this.x2p(X, this.tolerance, perplexity);
        for (int i = 0; i < this.maxIter; ++i) {
            INDArray sumY = Transforms.pow((INDArray)this.Y, (Number)2).sum(new int[]{1}).transpose();
            INDArray qu = this.Y.mmul(this.Y.transpose()).muli((Number)-2).addiRowVector(sumY).transpose().addiRowVector(sumY).addi((Number)1).rdivi((Number)1);
            INDArray Q = qu.div((Number)qu.sumNumber().doubleValue());
            BooleanIndexing.applyWhere((INDArray)Q, (Condition)Conditions.lessThan((Number)1.0E-12), (Function)new Value((Number)1.0E-12));
            INDArray PQ = P.sub(Q).muli(qu);
            logger.debug("PQ shape is: " + Arrays.toString(PQ.shape()));
            logger.debug("PQ.sum(1) shape is: " + Arrays.toString(PQ.sum(new int[]{1}).shape()));
            dY = this.diag(PQ.sum(new int[]{1})).subi(PQ).mmul(this.Y).muli((Number)4);
            this.momentum = i < this.switchMomentumIteration ? this.initialMomentum : this.finalMomentum;
            gains = gains.add((Number)0.2).muli(dY.cond(Conditions.greaterThan((Number)0)).neqi(iY.cond(Conditions.greaterThan((Number)0)))).addi(gains.mul((Number)0.8).muli(dY.cond(Conditions.greaterThan((Number)0)).eqi(iY.cond(Conditions.greaterThan((Number)0)))));
            BooleanIndexing.applyWhere((INDArray)gains, (Condition)Conditions.lessThan((Number)this.minGain), (Function)new Value((Number)this.minGain));
            INDArray gradChange = gains.mul(dY);
            gradChange.muli((Number)this.learningRate);
            iY.muli((Number)this.momentum).subi(gradChange);
            double cost = P.mul(Transforms.log((INDArray)P.div(Q), (boolean)false)).sumNumber().doubleValue();
            logger.info("Iteration [" + i + "] error is: [" + cost + "]");
            this.Y.addi(iY);
            INDArray tiled = Nd4j.tile((INDArray)this.Y.mean(new int[]{0}), (int[])new int[]{this.Y.rows(), 1});
            this.Y.subi(tiled);
            if (stopLying || i <= this.maxIter / 2 && i < this.stopLyingIteration) continue;
            P.divi((Number)4);
            stopLying = true;
        }
        return this.Y;
    }

    public INDArray diag(INDArray ds) {
        boolean isLong = ds.rows() > ds.columns();
        INDArray sliceZero = ds.slice(0L);
        int dim = Math.max(ds.columns(), ds.rows());
        INDArray result = Nd4j.create((int)dim, (int)dim);
        for (int i = 0; i < dim; ++i) {
            INDArray sliceSrc = ds.slice((long)i);
            INDArray sliceDst = result.slice((long)i);
            for (int j = 0; j < dim; ++j) {
                if (i != j) continue;
                if (isLong) {
                    sliceDst.putScalar((long)j, sliceSrc.getDouble(0L));
                    continue;
                }
                sliceDst.putScalar((long)j, sliceZero.getDouble((long)i));
            }
        }
        return result;
    }

    public void plot(INDArray matrix, int nDims, List<String> labels, String path) throws IOException {
        this.calculate(matrix, nDims, this.perplexity);
        BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), true));
        for (int i = 0; i < this.Y.rows() && i < labels.size(); ++i) {
            String word = labels.get(i);
            if (word == null) continue;
            StringBuilder sb = new StringBuilder();
            INDArray wordVector = this.Y.getRow((long)i);
            int j = 0;
            while ((long)j < wordVector.length()) {
                sb.append(wordVector.getDouble((long)j));
                if ((long)j < wordVector.length() - 1L) {
                    sb.append(",");
                }
                ++j;
            }
            sb.append(",");
            sb.append(word);
            sb.append(" ");
            sb.append("\n");
            write.write(sb.toString());
        }
        write.flush();
        write.close();
    }

    public Pair<Double, INDArray> hBeta(INDArray d, double beta) {
        INDArray P = Transforms.exp((INDArray)d.neg().muli((Number)beta));
        double sumP = P.sumNumber().doubleValue();
        double logSumP = FastMath.log((double)sumP);
        Double H = logSumP + beta * d.mul(P).sumNumber().doubleValue() / sumP;
        P.divi((Number)sumP);
        return new Pair((Object)H, (Object)P);
    }

    private INDArray x2p(INDArray X, double tolerance, double perplexity) {
        int n = X.rows();
        INDArray p = Nd4j.zeros((long)n, (long)n);
        INDArray beta = Nd4j.ones((int)n, (int)1);
        double logU = Math.log(perplexity);
        INDArray sumX = Transforms.pow((INDArray)X, (Number)2).sum(new int[]{1});
        logger.debug("sumX shape: " + Arrays.toString(sumX.shape()));
        INDArray times = X.mmul(X.transpose()).muli((Number)-2);
        logger.debug("times shape: " + Arrays.toString(times.shape()));
        INDArray prodSum = times.transpose().addiColumnVector(sumX);
        logger.debug("prodSum shape: " + Arrays.toString(prodSum.shape()));
        INDArray D = X.mmul(X.transpose()).mul((Number)-2).transpose().addColumnVector(sumX).addRowVector(sumX.transpose());
        logger.info("Calculating probabilities of data similarities...");
        logger.debug("Tolerance: " + tolerance);
        for (int i = 0; i < n; ++i) {
            if (i % 500 == 0 && i > 0) {
                logger.info("Handled [" + i + "] records out of [" + n + "]");
            }
            double betaMin = Double.NEGATIVE_INFINITY;
            double betaMax = Double.POSITIVE_INFINITY;
            int[] vals = Ints.concat((int[][])new int[][]{ArrayUtil.range((int)0, (int)i), ArrayUtil.range((int)(i + 1), (int)n)});
            INDArrayIndex[] range = new INDArrayIndex[]{new SpecifiedIndex(vals)};
            INDArray row = D.slice((long)i).get(range);
            Pair<Double, INDArray> pair = this.hBeta(row, beta.getDouble((long)i));
            double hDiff = (Double)pair.getFirst() - logU;
            for (int tries = 0; Math.abs(hDiff) > tolerance && tries < 50; ++tries) {
                if (hDiff > 0.0) {
                    betaMin = beta.getDouble((long)i);
                    if (Double.isInfinite(betaMax)) {
                        beta.putScalar((long)i, beta.getDouble((long)i) * 2.0);
                    } else {
                        beta.putScalar((long)i, (beta.getDouble((long)i) + betaMax) / 2.0);
                    }
                } else {
                    betaMax = beta.getDouble((long)i);
                    if (Double.isInfinite(betaMin)) {
                        beta.putScalar((long)i, beta.getDouble((long)i) / 2.0);
                    } else {
                        beta.putScalar((long)i, (beta.getDouble((long)i) + betaMin) / 2.0);
                    }
                }
                pair = this.hBeta(row, beta.getDouble((long)i));
                hDiff = (Double)pair.getFirst() - logU;
            }
            p.slice((long)i).put(range, (INDArray)pair.getSecond());
        }
        logger.info("Mean value of sigma " + Transforms.sqrt((INDArray)beta.rdiv((Number)1)).mean(new int[]{Integer.MAX_VALUE}));
        BooleanIndexing.applyWhere((INDArray)p, (Condition)Conditions.isNan(), (Function)new Value((Number)1.0E-12));
        INDArray permute = p.transpose();
        INDArray pOut = p.add(permute);
        pOut.divi((Number)(pOut.sumNumber().doubleValue() + 1.0E-6));
        pOut.muli((Number)4);
        BooleanIndexing.applyWhere((INDArray)pOut, (Condition)Conditions.lessThan((Number)1.0E-12), (Function)new Value((Number)1.0E-12));
        return pOut;
    }

    public static class Builder {
        protected int maxIter = 1000;
        protected double realMin = 1.0E-12f;
        protected double initialMomentum = 0.5;
        protected double finalMomentum = 0.8f;
        protected double momentum = 0.5;
        protected int switchMomentumIteration = 100;
        protected boolean normalize = true;
        protected boolean usePca = false;
        protected int stopLyingIteration = 100;
        protected double tolerance = 1.0E-5f;
        protected double learningRate = 0.1f;
        protected boolean useAdaGrad = false;
        protected double perplexity = 30.0;
        protected double minGain = 0.1f;

        public Builder minGain(double minGain) {
            this.minGain = minGain;
            return this;
        }

        public Builder perplexity(double perplexity) {
            this.perplexity = perplexity;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public Builder tolerance(double tolerance) {
            this.tolerance = tolerance;
            return this;
        }

        public Builder stopLyingIteration(int stopLyingIteration) {
            this.stopLyingIteration = stopLyingIteration;
            return this;
        }

        public Builder usePca(boolean usePca) {
            this.usePca = usePca;
            return this;
        }

        public Builder normalize(boolean normalize) {
            this.normalize = normalize;
            return this;
        }

        public Builder setMaxIter(int maxIter) {
            this.maxIter = maxIter;
            return this;
        }

        public Builder setRealMin(double realMin) {
            this.realMin = realMin;
            return this;
        }

        public Builder setInitialMomentum(double initialMomentum) {
            this.initialMomentum = initialMomentum;
            return this;
        }

        public Builder setFinalMomentum(double finalMomentum) {
            this.finalMomentum = finalMomentum;
            return this;
        }

        public Builder setMomentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder setSwitchMomentumIteration(int switchMomentumIteration) {
            this.switchMomentumIteration = switchMomentumIteration;
            return this;
        }

        public Tsne build() {
            return new Tsne(this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.minGain, this.momentum, this.switchMomentumIteration, this.normalize, this.usePca, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity);
        }
    }
}

