/*
 * Decompiled with CFR 0.152.
 */
package org.graylog.shaded.opensearch2.org.apache.lucene.sandbox.document;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import org.graylog.shaded.opensearch2.org.apache.lucene.document.FloatPoint;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.LeafReaderContext;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.PointValues;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.FieldDoc;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.IndexSearcher;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.ScoreDoc;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.TopFieldDocs;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.TotalHits;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.Bits;

public class FloatPointNearestNeighbor {
    private static NearestHit[] nearest(List<PointValues> readers, List<Bits> liveDocs, List<Integer> docBases, int topN, float[] origin) throws IOException {
        PriorityQueue<NearestHit> hitQueue = new PriorityQueue<NearestHit>(topN, (a, b) -> {
            int cmp = Double.compare(a.distanceSquared, b.distanceSquared);
            return cmp != 0 ? -cmp : b.docID - a.docID;
        });
        PriorityQueue<Cell> cellQueue = new PriorityQueue<Cell>();
        NearestVisitor visitor = new NearestVisitor(hitQueue, topN, origin);
        for (int i = 0; i < readers.size(); ++i) {
            PointValues reader = readers.get(i);
            byte[] minPackedValue = reader.getMinPackedValue();
            byte[] maxPackedValue = reader.getMaxPackedValue();
            PointValues.PointTree indexTree = reader.getPointTree();
            cellQueue.offer(new Cell(indexTree, i, reader.getMinPackedValue(), reader.getMaxPackedValue(), FloatPointNearestNeighbor.pointToRectangleDistanceSquared(minPackedValue, maxPackedValue, origin)));
        }
        while (cellQueue.size() > 0) {
            double distanceRight;
            Cell cell = (Cell)cellQueue.poll();
            if (cell.distanceSquared > visitor.bottomNearestDistanceSquared) break;
            if (!cell.index.moveToChild()) {
                visitor.curDocBase = docBases.get(cell.readerIndex);
                visitor.curLiveDocs = liveDocs.get(cell.readerIndex);
                cell.index.visitDocValues(visitor);
                continue;
            }
            PointValues.PointTree newIndex = cell.index.clone();
            double distanceLeft = FloatPointNearestNeighbor.pointToRectangleDistanceSquared(newIndex.getMinPackedValue(), newIndex.getMaxPackedValue(), origin);
            if (distanceLeft <= visitor.bottomNearestDistanceSquared) {
                cellQueue.offer(new Cell(newIndex, cell.readerIndex, newIndex.getMinPackedValue(), newIndex.getMaxPackedValue(), distanceLeft));
            }
            if (!cell.index.moveToSibling() || !((distanceRight = FloatPointNearestNeighbor.pointToRectangleDistanceSquared(cell.index.getMinPackedValue(), cell.index.getMaxPackedValue(), origin)) <= visitor.bottomNearestDistanceSquared)) continue;
            cellQueue.offer(new Cell(cell.index, cell.readerIndex, cell.index.getMinPackedValue(), cell.index.getMaxPackedValue(), distanceRight));
        }
        NearestHit[] hits = new NearestHit[hitQueue.size()];
        int downTo = hitQueue.size() - 1;
        while (hitQueue.size() != 0) {
            hits[downTo] = hitQueue.poll();
            --downTo;
        }
        return hits;
    }

    private static double pointToRectangleDistanceSquared(byte[] minPackedValue, byte[] maxPackedValue, float[] value) {
        double sumOfSquaredDiffs = 0.0;
        int i = 0;
        int offset = 0;
        while (i < value.length) {
            double min = FloatPoint.decodeDimension(minPackedValue, offset);
            if ((double)value[i] < min) {
                double diff = min - (double)value[i];
                sumOfSquaredDiffs += diff * diff;
            } else {
                double max = FloatPoint.decodeDimension(maxPackedValue, offset);
                if ((double)value[i] > max) {
                    double diff = max - (double)value[i];
                    sumOfSquaredDiffs += diff * diff;
                }
            }
            ++i;
            offset += 4;
        }
        return sumOfSquaredDiffs;
    }

    public static TopFieldDocs nearest(IndexSearcher searcher, String field, int topN, float ... origin) throws IOException {
        if (topN < 1) {
            throw new IllegalArgumentException("topN must be at least 1; got " + topN);
        }
        if (field == null) {
            throw new IllegalArgumentException("field must not be null");
        }
        if (searcher == null) {
            throw new IllegalArgumentException("searcher must not be null");
        }
        ArrayList<PointValues> readers = new ArrayList<PointValues>();
        ArrayList<Integer> docBases = new ArrayList<Integer>();
        ArrayList<Bits> liveDocs = new ArrayList<Bits>();
        int totalHits = 0;
        for (LeafReaderContext leaf : searcher.getIndexReader().leaves()) {
            PointValues points = leaf.reader().getPointValues(field);
            if (points == null) continue;
            totalHits += points.getDocCount();
            readers.add(points);
            docBases.add(leaf.docBase);
            liveDocs.add(leaf.reader().getLiveDocs());
        }
        NearestHit[] hits = FloatPointNearestNeighbor.nearest(readers, liveDocs, docBases, topN, origin);
        ScoreDoc[] scoreDocs = new ScoreDoc[hits.length];
        for (int i = 0; i < hits.length; ++i) {
            NearestHit hit = hits[i];
            scoreDocs[i] = new FieldDoc(hit.docID, 0.0f, new Object[]{Float.valueOf((float)Math.sqrt(hit.distanceSquared))});
        }
        return new TopFieldDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs, null);
    }

    static class NearestHit {
        public int docID;
        public double distanceSquared;

        NearestHit() {
        }

        public String toString() {
            return "NearestHit(docID=" + this.docID + " distanceSquared=" + this.distanceSquared + ")";
        }
    }

    private static class NearestVisitor
    implements PointValues.IntersectVisitor {
        int curDocBase;
        Bits curLiveDocs;
        final int topN;
        final PriorityQueue<NearestHit> hitQueue;
        final float[] origin;
        private final int dims;
        double bottomNearestDistanceSquared = Double.POSITIVE_INFINITY;
        int bottomNearestDistanceDoc = Integer.MAX_VALUE;

        public NearestVisitor(PriorityQueue<NearestHit> hitQueue, int topN, float[] origin) {
            this.hitQueue = hitQueue;
            this.topN = topN;
            this.origin = origin;
            this.dims = origin.length;
        }

        @Override
        public void visit(int docID) {
            throw new AssertionError();
        }

        @Override
        public void visit(int docID, byte[] packedValue) {
            if (this.curLiveDocs != null && !this.curLiveDocs.get(docID)) {
                return;
            }
            double distanceSquared = 0.0;
            int d = 0;
            int offset = 0;
            while (d < this.dims) {
                double diff = (double)FloatPoint.decodeDimension(packedValue, offset) - (double)this.origin[d];
                if ((distanceSquared += diff * diff) > this.bottomNearestDistanceSquared) {
                    return;
                }
                ++d;
                offset += 4;
            }
            int fullDocID = this.curDocBase + docID;
            if (this.hitQueue.size() == this.topN) {
                if (distanceSquared == this.bottomNearestDistanceSquared && fullDocID > this.bottomNearestDistanceDoc) {
                    return;
                }
                NearestHit bottom = this.hitQueue.poll();
                bottom.docID = fullDocID;
                bottom.distanceSquared = distanceSquared;
                this.hitQueue.offer(bottom);
                this.updateBottomNearestDistance();
            } else {
                NearestHit hit = new NearestHit();
                hit.docID = fullDocID;
                hit.distanceSquared = distanceSquared;
                this.hitQueue.offer(hit);
                if (this.hitQueue.size() == this.topN) {
                    this.updateBottomNearestDistance();
                }
            }
        }

        private void updateBottomNearestDistance() {
            NearestHit newBottom = this.hitQueue.peek();
            this.bottomNearestDistanceSquared = newBottom.distanceSquared;
            this.bottomNearestDistanceDoc = newBottom.docID;
        }

        @Override
        public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
            if (this.hitQueue.size() == this.topN && FloatPointNearestNeighbor.pointToRectangleDistanceSquared(minPackedValue, maxPackedValue, this.origin) > this.bottomNearestDistanceSquared) {
                return PointValues.Relation.CELL_OUTSIDE_QUERY;
            }
            return PointValues.Relation.CELL_CROSSES_QUERY;
        }
    }

    static class Cell
    implements Comparable<Cell> {
        final int readerIndex;
        final byte[] minPacked;
        final byte[] maxPacked;
        final PointValues.PointTree index;
        final double distanceSquared;

        Cell(PointValues.PointTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceSquared) {
            this.index = index;
            this.readerIndex = readerIndex;
            this.minPacked = (byte[])minPacked.clone();
            this.maxPacked = (byte[])maxPacked.clone();
            this.distanceSquared = distanceSquared;
        }

        @Override
        public int compareTo(Cell other) {
            return Double.compare(this.distanceSquared, other.distanceSquared);
        }

        public String toString() {
            return "Cell(readerIndex=" + this.readerIndex + " " + this.index.toString() + " distanceSquared=" + this.distanceSquared + ")";
        }
    }
}

