/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.procedure.builtin;

import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Spliterator;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.neo4j.common.EntityType;
import org.neo4j.exceptions.KernelException;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.Label;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.NotFoundException;
import org.neo4j.graphdb.Transaction;
import org.neo4j.graphdb.schema.IndexCreator;
import org.neo4j.graphdb.schema.IndexSetting;
import org.neo4j.internal.kernel.api.IndexQueryConstraints;
import org.neo4j.internal.kernel.api.IndexReadSession;
import org.neo4j.internal.kernel.api.NodeValueIndexCursor;
import org.neo4j.internal.kernel.api.PropertyIndexQuery;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.internal.schema.IndexConfig;
import org.neo4j.internal.schema.IndexDescriptor;
import org.neo4j.internal.schema.IndexType;
import org.neo4j.kernel.api.KernelTransaction;
import org.neo4j.kernel.api.impl.schema.vector.VectorSimilarityFunction;
import org.neo4j.kernel.api.impl.schema.vector.VectorUtils;
import org.neo4j.kernel.api.txstate.TxStateHolder;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
import org.neo4j.util.FeatureToggles;
import org.neo4j.util.Preconditions;

public class VectorIndexProcedures {
    private static final long INDEX_ONLINE_QUERY_TIMEOUT_SECONDS = FeatureToggles.getInteger(VectorIndexProcedures.class, (String)"INDEX_ONLINE_QUERY_TIMEOUT_SECONDS", (int)30);
    @Context
    public GraphDatabaseAPI db;
    @Context
    public Transaction tx;
    @Context
    public KernelTransaction ktx;
    @Context
    public ProcedureCallContext callContext;

    @Description(value="Create a named node vector index for the given label and property for a specified vector dimensionality.\nValid similarity functions are 'EUCLIDEAN' and 'COSINE', and are case-insensitive.\nUse the `db.index.vector.queryNodes` procedure to query the named index.\n")
    @Procedure(name="db.index.vector.createNodeIndex", mode=Mode.SCHEMA)
    public void createIndex(@Name(value="indexName") String name, @Name(value="label") String label, @Name(value="propertyKey") String propertyKey, @Name(value="vectorDimension") Long vectorDimension, @Name(value="vectorSimilarityFunction") String vectorSimilarityFunction) {
        Objects.requireNonNull(name, "'indexName' must not be null");
        Objects.requireNonNull(label, "'label' must not be null");
        Objects.requireNonNull(propertyKey, "'propertyKey' must not be null");
        Objects.requireNonNull(vectorDimension, "'vectorDimension' must not be null");
        Preconditions.checkArgument((1L <= vectorDimension && vectorDimension <= 2048L ? 1 : 0) != 0, (String)"'vectorDimension' must be between %d and %d inclusively".formatted(1, 2048));
        VectorSimilarityFunction.fromName((String)Objects.requireNonNull(vectorSimilarityFunction, "'vectorSimilarityFunction' must not be null"));
        IndexCreator indexCreator = this.tx.schema().indexFor(Label.label((String)label)).on(propertyKey).withIndexType(IndexType.VECTOR.toPublicApi()).withIndexConfiguration(Map.of(IndexSetting.vector_Dimensions(), vectorDimension, IndexSetting.vector_Similarity_Function(), vectorSimilarityFunction)).withName(name);
        indexCreator.create();
    }

    @Description(value="Query the given vector index.\nReturns requested number of nearest neighbors to the provided query vector,\nand their similarity score to that query vector, based on the configured similarity function for the index.\nThe similarity score is a value between [0, 1]; where 0 indicates least similar, 1 most similar.\n")
    @Procedure(name="db.index.vector.queryNodes", mode=Mode.READ)
    public Stream<Neighbor> queryVectorIndex(@Name(value="indexName") String name, @Name(value="numberOfNearestNeighbours") Long numberOfNearestNeighbours, @Name(value="query") List<Double> query) throws KernelException {
        Objects.requireNonNull(name, "'indexName' must not be null");
        Objects.requireNonNull(numberOfNearestNeighbours, "'numberOfNearestNeighbours' must not be null");
        Preconditions.checkArgument((numberOfNearestNeighbours > 0L ? 1 : 0) != 0, (String)"'numberOfNearestNeighbours' must be positive");
        Objects.requireNonNull(query, "'query' must not be null");
        if (this.callContext.isSystemDatabase()) {
            return Stream.empty();
        }
        IndexDescriptor index = this.getValidIndex(name);
        float[] validatedQuery = this.validateAndConvertQuery(index, query);
        this.awaitOnline(index);
        EntityType entityType = index.schema().entityType();
        if (entityType != EntityType.NODE) {
            throw new IllegalArgumentException("The '%s' index (%s) is an index on %s, so it cannot be queried for nodes.".formatted(name, index, entityType));
        }
        NodeValueIndexCursor cursor = this.ktx.cursors().allocateNodeValueIndexCursor(this.ktx.cursorContext(), this.ktx.memoryTracker());
        IndexReadSession session = this.ktx.dataRead().indexReadSession(index);
        IndexQueryConstraints constraints = IndexQueryConstraints.unconstrained();
        int k = Math.toIntExact(numberOfNearestNeighbours);
        this.ktx.dataRead().nodeIndexSeek(this.ktx.queryContext(), session, cursor, constraints, new PropertyIndexQuery[]{PropertyIndexQuery.nearestNeighbors((int)k, (float[])validatedQuery)});
        return new NeighborSpliterator(this.tx, cursor, k).stream();
    }

    @Description(value="Set a vector property on a given node in a more space efficient representation than Cypher's SET.")
    @Procedure(name="db.create.setNodeVectorProperty", mode=Mode.WRITE)
    public void setNodeVectorProperty(@Name(value="node") Node node, @Name(value="key") String propKey, @Name(value="vector") List<Double> vector) {
        this.setVectorProperty((Entity)Objects.requireNonNull(node, "'node' must not be null"), Objects.requireNonNull(propKey, "'key' must not be null"), Objects.requireNonNull(vector, "'vector' must not be null"));
    }

    @Description(value="Set a vector property on a given node in a more space efficient representation than Cypher's SET.")
    @Procedure(name="db.create.setVectorProperty", mode=Mode.WRITE, deprecatedBy="db.create.setNodeVectorProperty")
    @Deprecated(since="5.13.0", forRemoval=true)
    public Stream<NodeRecord> deprecatedSetVectorProperty(@Name(value="node") Node node, @Name(value="key") String propKey, @Name(value="vector") List<Double> vector) {
        this.setNodeVectorProperty(node, propKey, vector);
        return Stream.of(new NodeRecord(node));
    }

    private void setVectorProperty(Entity entity, String propKey, List<Double> vector) {
        entity.setProperty(propKey, (Object)this.validVector(VectorSimilarityFunction.EUCLIDEAN, vector));
    }

    private float[] validVector(VectorSimilarityFunction similarityFunction, List<Double> candidate) {
        float[] vector = similarityFunction.maybeToValidVector(candidate);
        if (vector == null) {
            throw switch (similarityFunction) {
                default -> throw new IncompatibleClassChangeError();
                case VectorSimilarityFunction.EUCLIDEAN -> new IllegalArgumentException("Index query vector must contain finite values. Provided: " + candidate);
                case VectorSimilarityFunction.COSINE -> new IllegalArgumentException("Index query vector must contain finite values, and have positive and finite l2-norm. Provided: " + candidate);
            };
        }
        return vector;
    }

    private float[] validateAndConvertQuery(IndexDescriptor index, List<Double> query) {
        IndexConfig config = index.getIndexConfig();
        int dimensions = VectorUtils.vectorDimensionsFrom((IndexConfig)config);
        if (dimensions != query.size()) {
            throw new IllegalArgumentException("Index query vector has %d dimensions, but indexed vectors have %d.".formatted(query.size(), dimensions));
        }
        VectorSimilarityFunction similarityFunction = VectorUtils.vectorSimilarityFunctionFrom((IndexConfig)config);
        return this.validVector(similarityFunction, query);
    }

    private IndexDescriptor getValidIndex(String name) {
        IndexDescriptor index = this.ktx.schemaRead().indexGetForName(name);
        if (index == IndexDescriptor.NO_INDEX || index.getIndexType() != IndexType.VECTOR) {
            throw new IllegalArgumentException("There is no such vector schema index: " + name);
        }
        return index;
    }

    private void awaitOnline(IndexDescriptor index) {
        TxStateHolder ktx = (TxStateHolder)this.ktx;
        if (!ktx.hasTxStateWithChanges() || !ktx.txState().indexDiffSetsBySchema(index.schema()).isAdded((Object)index)) {
            this.tx.schema().awaitIndexOnline(index.getName(), INDEX_ONLINE_QUERY_TIMEOUT_SECONDS, TimeUnit.SECONDS);
        }
    }

    private record NeighborSpliterator(Transaction tx, NodeValueIndexCursor cursor, int k) implements Spliterator<Neighbor>
    {
        @Override
        public boolean tryAdvance(Consumer<? super Neighbor> action) {
            while (this.cursor.next()) {
                Neighbor neighbor = Neighbor.forExistingEntityOrNull(this.tx, this.cursor.nodeReference(), this.cursor.score());
                if (neighbor == null) continue;
                action.accept(neighbor);
                return true;
            }
            this.cursor.close();
            return false;
        }

        @Override
        public Spliterator<Neighbor> trySplit() {
            return null;
        }

        @Override
        public long estimateSize() {
            return this.k;
        }

        @Override
        public int characteristics() {
            return 1301;
        }

        @Override
        public Comparator<? super Neighbor> getComparator() {
            return null;
        }

        Stream<Neighbor> stream() {
            Stream<Neighbor> stream = StreamSupport.stream(this, false);
            return (Stream)stream.onClose(() -> ((NodeValueIndexCursor)this.cursor).close());
        }
    }

    public record NodeRecord(Node node) {
    }

    public record Neighbor(Node node, double score) implements Comparable<Neighbor>
    {
        @Override
        public int compareTo(Neighbor o) {
            int result = -Double.compare(this.score, o.score);
            if (result != 0) {
                return result;
            }
            return Long.compare(this.node.getId(), o.node.getId());
        }

        public static Neighbor forExistingEntityOrNull(Transaction tx, long nodeId, float score) {
            try {
                return new Neighbor(tx.getNodeById(nodeId), score);
            }
            catch (NotFoundException ignore) {
                return null;
            }
        }
    }
}

