/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.store.embedding.vespa;

import ai.vespa.client.dsl.A;
import ai.vespa.client.dsl.Annotation;
import ai.vespa.client.dsl.NearestNeighbor;
import ai.vespa.client.dsl.Q;
import ai.vespa.client.dsl.QueryChain;
import ai.vespa.feed.client.DocumentId;
import ai.vespa.feed.client.FeedClient;
import ai.vespa.feed.client.FeedClientBuilder;
import ai.vespa.feed.client.FeedException;
import ai.vespa.feed.client.JsonFeeder;
import ai.vespa.feed.client.Result;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Json;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.vespa.QueryResponse;
import dev.langchain4j.store.embedding.vespa.Record;
import dev.langchain4j.store.embedding.vespa.VespaQueryApi;
import dev.langchain4j.store.embedding.vespa.VespaQueryClient;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import retrofit2.Response;

public class VespaEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5L);
    private static final String DEFAULT_NAMESPACE = "namespace";
    private static final String DEFAULT_DOCUMENT_TYPE = "langchain4j";
    private static final boolean DEFAULT_AVOID_DUPS = true;
    private static final String FIELD_NAME_TEXT_SEGMENT = "text_segment";
    private static final String FIELD_NAME_VECTOR = "vector";
    private static final String FIELD_NAME_DOCUMENT_ID = "documentid";
    private static final String DEFAULT_RANK_PROFILE = "cosine_similarity";
    private static final int DEFAULT_TARGET_HITS = 10;
    private final String url;
    private final Path keyPath;
    private final Path certPath;
    private final Duration timeout;
    private final String namespace;
    private final String documentType;
    private final String rankProfile;
    private final int targetHits;
    private final boolean avoidDups;
    private VespaQueryApi queryApi;

    public VespaEmbeddingStore(String url, String keyPath, String certPath, Duration timeout, String namespace, String documentType, String rankProfile, Integer targetHits, Boolean avoidDups) {
        this.url = url;
        this.keyPath = Paths.get(keyPath, new String[0]);
        this.certPath = Paths.get(certPath, new String[0]);
        this.timeout = timeout != null ? timeout : DEFAULT_TIMEOUT;
        this.namespace = namespace != null ? namespace : DEFAULT_NAMESPACE;
        this.documentType = documentType != null ? documentType : DEFAULT_DOCUMENT_TYPE;
        this.rankProfile = rankProfile != null ? rankProfile : DEFAULT_RANK_PROFILE;
        this.targetHits = targetHits != null ? targetHits : 10;
        this.avoidDups = avoidDups != null ? avoidDups : true;
    }

    public String add(Embedding embedding) {
        return this.add(null, embedding, null);
    }

    public void add(String id, Embedding embedding) {
        this.add(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        return this.add(null, embedding, textSegment);
    }

    public List<String> addAll(List<Embedding> embeddings) {
        return this.addAll(embeddings, null);
    }

    public void addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (embedded != null && embeddings.size() != embedded.size()) {
            throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
        }
        try (JsonFeeder jsonFeeder = this.buildJsonFeeder();){
            ArrayList<Record> records = new ArrayList<Record>();
            for (int i = 0; i < embeddings.size(); ++i) {
                records.add(this.buildRecord(ids.get(i), embeddings.get(i), embedded != null ? embedded.get(i) : null));
            }
            jsonFeeder.feedMany(Json.toInputStream(records, List.class), new JsonFeeder.ResultCallback(){

                public void onNextResult(Result result, FeedException error) {
                    if (error != null) {
                        throw new RuntimeException(error.getMessage());
                    }
                }

                public void onError(FeedException error) {
                    throw new RuntimeException(error.getMessage());
                }
            });
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
        try {
            String searchQuery = Q.select((String)FIELD_NAME_DOCUMENT_ID, (String[])new String[]{FIELD_NAME_TEXT_SEGMENT, FIELD_NAME_VECTOR}).from(this.documentType).where((QueryChain)this.buildNearestNeighbor()).fix().hits(maxResults).ranking(this.rankProfile).param("input.query(q)", Json.toJson((Object)referenceEmbedding.vectorAsList())).param("input.query(threshold)", String.valueOf(minScore)).build();
            Response response = this.getQueryApi().search(searchQuery).execute();
            if (response.isSuccessful()) {
                QueryResponse parsedResponse = (QueryResponse)response.body();
                return parsedResponse.getRoot().getChildren().stream().map(VespaEmbeddingStore::toEmbeddingMatch).collect(Collectors.toList());
            }
            throw new RuntimeException("Request failed");
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private String add(String id, Embedding embedding, TextSegment textSegment) {
        AtomicReference resId = new AtomicReference();
        try (JsonFeeder jsonFeeder = this.buildJsonFeeder();){
            jsonFeeder.feedSingle(Json.toJson((Object)this.buildRecord(id, embedding, textSegment))).whenComplete((result, throwable) -> {
                if (throwable != null) {
                    throw new RuntimeException((Throwable)throwable);
                }
                if (Result.Type.success.equals((Object)result.type())) {
                    resId.set(result.documentId().toString());
                }
            });
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return (String)resId.get();
    }

    private JsonFeeder buildJsonFeeder() {
        return JsonFeeder.builder((FeedClient)FeedClientBuilder.create((URI)URI.create(this.url)).setCertificate(this.certPath, this.keyPath).build()).withTimeout(this.timeout).build();
    }

    private VespaQueryApi getQueryApi() {
        if (this.queryApi == null) {
            this.queryApi = VespaQueryClient.createInstance(this.url, this.certPath, this.keyPath);
        }
        return this.queryApi;
    }

    private static EmbeddingMatch<TextSegment> toEmbeddingMatch(Record in) {
        return new EmbeddingMatch(Double.valueOf(in.getRelevance()), in.getFields().getDocumentId(), Embedding.from(in.getFields().getVector().getValues()), (Object)TextSegment.from((String)in.getFields().getTextSegment()));
    }

    private Record buildRecord(String id, Embedding embedding, TextSegment textSegment) {
        String recordId = id != null ? id : (this.avoidDups && textSegment != null ? Utils.generateUUIDFrom((String)textSegment.text()) : Utils.randomUUID());
        DocumentId documentId = DocumentId.of((String)this.namespace, (String)this.documentType, (String)recordId);
        String text = textSegment != null ? textSegment.text() : null;
        return new Record(documentId.toString(), text, embedding.vectorAsList());
    }

    private Record buildRecord(Embedding embedding, TextSegment textSegment) {
        return this.buildRecord(null, embedding, textSegment);
    }

    private NearestNeighbor buildNearestNeighbor() throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
        NearestNeighbor nb = Q.nearestNeighbor((String)FIELD_NAME_VECTOR, (String)"q");
        Method method = NearestNeighbor.class.getDeclaredMethod("annotate", Annotation.class);
        method.setAccessible(true);
        method.invoke((Object)nb, A.a((String)"targetHits", (Object)this.targetHits));
        return nb;
    }

    public static VespaEmbeddingStoreBuilder builder() {
        return new VespaEmbeddingStoreBuilder();
    }

    public static class VespaEmbeddingStoreBuilder {
        private String url;
        private String keyPath;
        private String certPath;
        private Duration timeout;
        private String namespace;
        private String documentType;
        private String rankProfile;
        private Integer targetHits;
        private Boolean avoidDups;

        VespaEmbeddingStoreBuilder() {
        }

        public VespaEmbeddingStoreBuilder url(String url) {
            this.url = url;
            return this;
        }

        public VespaEmbeddingStoreBuilder keyPath(String keyPath) {
            this.keyPath = keyPath;
            return this;
        }

        public VespaEmbeddingStoreBuilder certPath(String certPath) {
            this.certPath = certPath;
            return this;
        }

        public VespaEmbeddingStoreBuilder timeout(Duration timeout) {
            this.timeout = timeout;
            return this;
        }

        public VespaEmbeddingStoreBuilder namespace(String namespace) {
            this.namespace = namespace;
            return this;
        }

        public VespaEmbeddingStoreBuilder documentType(String documentType) {
            this.documentType = documentType;
            return this;
        }

        public VespaEmbeddingStoreBuilder rankProfile(String rankProfile) {
            this.rankProfile = rankProfile;
            return this;
        }

        public VespaEmbeddingStoreBuilder targetHits(Integer targetHits) {
            this.targetHits = targetHits;
            return this;
        }

        public VespaEmbeddingStoreBuilder avoidDups(Boolean avoidDups) {
            this.avoidDups = avoidDups;
            return this;
        }

        public VespaEmbeddingStore build() {
            return new VespaEmbeddingStore(this.url, this.keyPath, this.certPath, this.timeout, this.namespace, this.documentType, this.rankProfile, this.targetHits, this.avoidDups);
        }

        public String toString() {
            return "VespaEmbeddingStore.VespaEmbeddingStoreBuilder(url=" + this.url + ", keyPath=" + this.keyPath + ", certPath=" + this.certPath + ", timeout=" + String.valueOf(this.timeout) + ", namespace=" + this.namespace + ", documentType=" + this.documentType + ", rankProfile=" + this.rankProfile + ", targetHits=" + this.targetHits + ", avoidDups=" + this.avoidDups + ")";
        }
    }
}

