/*
 * Decompiled with CFR 0.152.
 */
package apoc.ml.watson;

import apoc.ApocConfig;
import apoc.Extended;
import apoc.ml.watson.WatsonHandler;
import apoc.result.MapResult;
import apoc.util.JsonUtil;
import java.io.IOException;
import java.lang.invoke.CallSite;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
public class Watson {
    static final String PROJECT_ID_KEY = "project_id";
    static final String SPACE_ID_KEY = "space_id";
    static final String MODEL_ID_KEY = "model_id";
    static final String WML_INSTANCE_CRN_KEY = "wml_instance_crn";
    static final String DEFAULT_COMPLETION_MODEL_ID = "ibm/granite-13b-chat-v2";
    static final String DEFAULT_EMBEDDING_MODEL_ID = "ibm/slate-30m-english-rtrvr";
    static final String DEFAULT_VERSION_DATE = "2023-05-29";
    static final String DEFAULT_REGION = "eu-de";
    @Context
    public ApocConfig apocConfig;
    @Context
    public URLAccessChecker urlAccessChecker;

    @Procedure(value="apoc.ml.watson.embedding")
    @Description(value="apoc.ml.watson.embedding([texts], $configuration) - returns the embeddings for a given text")
    public Stream<EmbeddingResult> embedding(@Name(value="texts") List<String> texts, @Name(value="accessToken") String accessToken, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        if (texts == null) {
            throw new RuntimeException("Null, blank or empty input provided. Please specify a valid input");
        }
        AtomicInteger idx = new AtomicInteger();
        return this.executeRequest(texts, accessToken, configuration, WatsonHandler.Type.EMBEDDING.get()).flatMap(v -> ((List)v.get("results")).stream()).map(i -> {
            int index = idx.getAndIncrement();
            List embedding = (List)i.get("embedding");
            return new EmbeddingResult(index, (String)texts.get(index), embedding);
        });
    }

    @Procedure(value="apoc.ml.watson.chat")
    @Description(value="apoc.ml.watson.chat(messages, accessToken, $configuration) - prompts the completion API")
    public Stream<MapResult> chatCompletion(@Name(value="messages") List<Map<String, Object>> messages, @Name(value="accessToken") String accessToken, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        if (messages == null) {
            throw new RuntimeException("Null, blank or empty input provided. Please specify a valid input");
        }
        String prompt = messages.stream().map(message -> {
            Object role = message.get("role");
            Object content = message.get("content");
            if (role == null || content == null) {
                throw new RuntimeException("The `messages` items must have the keys: `role` and `content`");
            }
            return role + ": " + content;
        }).collect(Collectors.joining("\n\n"));
        return this.completion(prompt, accessToken, configuration);
    }

    @Procedure(value="apoc.ml.watson.completion")
    @Description(value="apoc.ml.watson.completion(prompt, accessToken, $configuration) - prompts the completion API")
    public Stream<MapResult> completion(@Name(value="prompt") String prompt, @Name(value="accessToken") String accessToken, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        if (prompt == null) {
            throw new RuntimeException("Null, blank or empty input provided. Please specify a valid input");
        }
        return this.executeRequest(prompt, accessToken, configuration, WatsonHandler.Type.COMPLETION.get()).map(MapResult::new);
    }

    private Stream<Map> executeRequest(Object input, String accessToken, Map<String, Object> configuration, WatsonHandler type) {
        try {
            if (!(configuration.containsKey(PROJECT_ID_KEY) || configuration.containsKey(SPACE_ID_KEY) || configuration.containsKey(WML_INSTANCE_CRN_KEY))) {
                String apocConfProjectId = this.apocConfig.getString("apoc.ml.watson.project.id", null);
                if (apocConfProjectId == null) {
                    String errMessage = "The body request has none of %s, %s, and %s and the APOC config `%s` is not present.%nPlease, define one of these".formatted(PROJECT_ID_KEY, SPACE_ID_KEY, WML_INSTANCE_CRN_KEY, "apoc.ml.watson.project.id");
                    throw new RuntimeException(errMessage);
                }
                configuration.put(PROJECT_ID_KEY, apocConfProjectId);
            }
            String endpoint = type.getEndpoint(configuration);
            Map<String, CallSite> headers = Map.of("Content-Type", "application/json", "accept", "application/json", "Authorization", "Bearer " + accessToken);
            Map<String, Object> payloadMap = type.getPayload(configuration, input);
            String payload = JsonUtil.OBJECT_MAPPER.writeValueAsString(payloadMap);
            return JsonUtil.loadJson((Object)endpoint, headers, (String)payload, (String)"$", (boolean)true, List.of(), (URLAccessChecker)this.urlAccessChecker).map(v -> (Map)v);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public record EmbeddingResult(long index, String text, List<Double> embedding) {
    }
}

