package dev.langchain4j.model.googleai;

import com.google.gson.Gson;
import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import retrofit2.Call;

@Experimental
/* loaded from: input_file:dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.class */
public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
    private static final Logger log = LoggerFactory.getLogger(GoogleAiGeminiChatModel.class);
    private static final Gson GSON = new Gson();
    private final GeminiService geminiService;
    private final String apiKey;
    private final String modelName;
    private final Integer maxRetries;
    private final Double temperature;
    private final Integer topK;
    private final Double topP;
    private final Integer maxOutputTokens;
    private final List<String> stopSequences;
    private final Integer candidateCount;
    private final ResponseFormat responseFormat;
    private final GeminiFunctionCallingConfig toolConfig;
    private final boolean allowCodeExecution;
    private final boolean includeCodeExecutionOutput;
    private final Boolean logRequestsAndResponses;
    private final List<GeminiSafetySetting> safetySettings;
    private final List<ChatModelListener> listeners;
    private final GoogleAiGeminiTokenizer geminiTokenizer;

    /* loaded from: input_file:dev/langchain4j/model/googleai/GoogleAiGeminiChatModel$GoogleAiGeminiChatModelBuilder.class */
    public static class GoogleAiGeminiChatModelBuilder {
        private String apiKey;
        private String modelName;
        private Integer maxRetries;
        private Double temperature;
        private Integer topK;
        private Double topP;
        private Integer maxOutputTokens;
        private Integer candidateCount;
        private Duration timeout;
        private ResponseFormat responseFormat;
        private List<String> stopSequences;
        private GeminiFunctionCallingConfig toolConfig;
        private Boolean allowCodeExecution;
        private Boolean includeCodeExecutionOutput;
        private Boolean logRequestsAndResponses;
        private List<GeminiSafetySetting> safetySettings;
        private List<ChatModelListener> listeners;

        public GoogleAiGeminiChatModelBuilder toolConfig(GeminiMode geminiMode, String... strArr) {
            this.toolConfig = new GeminiFunctionCallingConfig(geminiMode, Arrays.asList(strArr));
            return this;
        }

        public GoogleAiGeminiChatModelBuilder safetySettings(Map<GeminiHarmCategory, GeminiHarmBlockThreshold> map) {
            this.safetySettings = (List) map.entrySet().stream().map(entry -> {
                return new GeminiSafetySetting((GeminiHarmCategory) entry.getKey(), (GeminiHarmBlockThreshold) entry.getValue());
            }).collect(Collectors.toList());
            return this;
        }

        GoogleAiGeminiChatModelBuilder() {
        }

        public GoogleAiGeminiChatModelBuilder apiKey(String str) {
            this.apiKey = str;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder modelName(String str) {
            this.modelName = str;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder maxRetries(Integer num) {
            this.maxRetries = num;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder temperature(Double d) {
            this.temperature = d;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder topK(Integer num) {
            this.topK = num;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder topP(Double d) {
            this.topP = d;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder maxOutputTokens(Integer num) {
            this.maxOutputTokens = num;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder candidateCount(Integer num) {
            this.candidateCount = num;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder timeout(Duration duration) {
            this.timeout = duration;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder responseFormat(ResponseFormat responseFormat) {
            this.responseFormat = responseFormat;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder stopSequences(List<String> list) {
            this.stopSequences = list;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder allowCodeExecution(Boolean bool) {
            this.allowCodeExecution = bool;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder includeCodeExecutionOutput(Boolean bool) {
            this.includeCodeExecutionOutput = bool;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder logRequestsAndResponses(Boolean bool) {
            this.logRequestsAndResponses = bool;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder listeners(List<ChatModelListener> list) {
            this.listeners = list;
            return this;
        }

        public GoogleAiGeminiChatModel build() {
            return new GoogleAiGeminiChatModel(this.apiKey, this.modelName, this.maxRetries, this.temperature, this.topK, this.topP, this.maxOutputTokens, this.candidateCount, this.timeout, this.responseFormat, this.stopSequences, this.toolConfig, this.allowCodeExecution, this.includeCodeExecutionOutput, this.logRequestsAndResponses, this.safetySettings, this.listeners);
        }

        public String toString() {
            return "GoogleAiGeminiChatModel.GoogleAiGeminiChatModelBuilder(apiKey=" + this.apiKey + ", modelName=" + this.modelName + ", maxRetries=" + this.maxRetries + ", temperature=" + this.temperature + ", topK=" + this.topK + ", topP=" + this.topP + ", maxOutputTokens=" + this.maxOutputTokens + ", candidateCount=" + this.candidateCount + ", timeout=" + this.timeout + ", responseFormat=" + this.responseFormat + ", stopSequences=" + this.stopSequences + ", toolConfig=" + this.toolConfig + ", allowCodeExecution=" + this.allowCodeExecution + ", includeCodeExecutionOutput=" + this.includeCodeExecutionOutput + ", logRequestsAndResponses=" + this.logRequestsAndResponses + ", safetySettings=" + this.safetySettings + ", listeners=" + this.listeners + ")";
        }
    }

    public GoogleAiGeminiChatModel(String str, String str2, Integer num, Double d, Integer num2, Double d2, Integer num3, Integer num4, Duration duration, ResponseFormat responseFormat, List<String> list, GeminiFunctionCallingConfig geminiFunctionCallingConfig, Boolean bool, Boolean bool2, Boolean bool3, List<GeminiSafetySetting> list2, List<ChatModelListener> list3) {
        this.apiKey = ValidationUtils.ensureNotBlank(str, "apiKey");
        this.modelName = ValidationUtils.ensureNotBlank(str2, "modelName");
        this.maxRetries = (Integer) Utils.getOrDefault(num, 3);
        this.temperature = (Double) Utils.getOrDefault(d, Double.valueOf(1.0d));
        this.topK = (Integer) Utils.getOrDefault(num2, 64);
        this.topP = (Double) Utils.getOrDefault(d2, Double.valueOf(0.95d));
        this.maxOutputTokens = (Integer) Utils.getOrDefault(num3, 8192);
        this.candidateCount = (Integer) Utils.getOrDefault(num4, 1);
        this.stopSequences = (List) Utils.getOrDefault(list, Collections.emptyList());
        this.toolConfig = geminiFunctionCallingConfig;
        this.allowCodeExecution = bool != null ? bool.booleanValue() : false;
        this.includeCodeExecutionOutput = bool2 != null ? bool2.booleanValue() : false;
        this.logRequestsAndResponses = (Boolean) Utils.getOrDefault(bool3, false);
        this.safetySettings = Utils.copyIfNotNull(list2);
        this.responseFormat = responseFormat;
        this.listeners = list3 == null ? Collections.emptyList() : new ArrayList<>(list3);
        this.geminiService = GeminiService.getGeminiService(((Boolean) Utils.getOrDefault(bool3, false)).booleanValue() ? log : null, (Duration) Utils.getOrDefault(duration, Duration.ofSeconds(60L)));
        this.geminiTokenizer = GoogleAiGeminiTokenizer.builder().modelName(this.modelName).apiKey(this.apiKey).timeout((Duration) Utils.getOrDefault(duration, Duration.ofSeconds(60L))).maxRetries(this.maxRetries).logRequestsAndResponses(this.logRequestsAndResponses).build();
    }

    private static String computeMimeType(ResponseFormat responseFormat) {
        return (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) ? "text/plain" : (!ResponseFormatType.JSON.equals(responseFormat.type()) || responseFormat.jsonSchema() == null || responseFormat.jsonSchema().rootElement() == null || !(responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema)) ? "application/json" : "text/x.enum";
    }

    public Response<AiMessage> generate(List<ChatMessage> list) {
        ChatResponse chat = chat(ChatRequest.builder().messages(list).build());
        return Response.from(chat.aiMessage(), chat.tokenUsage(), chat.finishReason());
    }

    public Response<AiMessage> generate(List<ChatMessage> list, ToolSpecification toolSpecification) {
        return generate(list, Collections.singletonList(toolSpecification));
    }

    public Response<AiMessage> generate(List<ChatMessage> list, List<ToolSpecification> list2) {
        ChatResponse chat = chat(ChatRequest.builder().messages(list).toolSpecifications(list2).build());
        return Response.from(chat.aiMessage(), chat.tokenUsage(), chat.finishReason());
    }

    public ChatResponse chat(ChatRequest chatRequest) {
        GeminiContent geminiContent = new GeminiContent(GeminiRole.MODEL.toString());
        List<GeminiContent> fromMessageToGContent = PartsAndContentsMapper.fromMessageToGContent(chatRequest.messages(), geminiContent);
        List list = chatRequest.toolSpecifications();
        ResponseFormat responseFormat = chatRequest.responseFormat() != null ? chatRequest.responseFormat() : this.responseFormat;
        GeminiSchema geminiSchema = null;
        String computeMimeType = computeMimeType(responseFormat);
        if (responseFormat != null && responseFormat.jsonSchema() != null) {
            geminiSchema = SchemaMapper.fromJsonSchemaToGSchema(responseFormat.jsonSchema());
        }
        GeminiGenerateContentRequest build = GeminiGenerateContentRequest.builder().contents(fromMessageToGContent).systemInstruction(!geminiContent.getParts().isEmpty() ? geminiContent : null).generationConfig(GeminiGenerationConfig.builder().candidateCount(this.candidateCount).maxOutputTokens(this.maxOutputTokens).responseMimeType(computeMimeType).responseSchema(geminiSchema).stopSequences(this.stopSequences).temperature(this.temperature).topK(this.topK).topP(this.topP).build()).safetySettings(this.safetySettings).tools(FunctionMapper.fromToolSepcsToGTool(list, this.allowCodeExecution)).toolConfig(new GeminiToolConfig(this.toolConfig)).build();
        ChatModelRequest build2 = ChatModelRequest.builder().model(this.modelName).temperature(this.temperature).topP(this.topP).maxTokens(this.maxOutputTokens).messages(chatRequest.messages()).toolSpecifications(chatRequest.toolSpecifications()).build();
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(build2, concurrentHashMap);
        this.listeners.forEach(chatModelListener -> {
            try {
                chatModelListener.onRequest(chatModelRequestContext);
            } catch (Exception e) {
                log.warn("Exception while calling model listener (onRequest)", e);
            }
        });
        try {
            retrofit2.Response execute = ((Call) RetryUtils.withRetry(() -> {
                return this.geminiService.generateContent(this.modelName, this.apiKey, build);
            }, this.maxRetries.intValue())).execute();
            GeminiGenerateContentResponse geminiGenerateContentResponse = (GeminiGenerateContentResponse) execute.body();
            if (execute.code() >= 300) {
                try {
                    GeminiError error = ((GeminiErrorContainer) GSON.fromJson(execute.errorBody().string(), GeminiErrorContainer.class)).getError();
                    RuntimeException runtimeException = new RuntimeException(String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
                    ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext(runtimeException, build2, (ChatModelResponse) null, concurrentHashMap);
                    this.listeners.forEach(chatModelListener2 -> {
                        try {
                            chatModelListener2.onError(chatModelErrorContext);
                        } catch (Exception e) {
                            log.warn("Exception while calling model listener (onError)", e);
                        }
                    });
                    throw runtimeException;
                } finally {
                }
            }
            if (geminiGenerateContentResponse == null) {
                throw new RuntimeException("Gemini response was null");
            }
            GeminiCandidate geminiCandidate = geminiGenerateContentResponse.getCandidates().get(0);
            GeminiUsageMetadata usageMetadata = geminiGenerateContentResponse.getUsageMetadata();
            FinishReason fromGFinishReasonToFinishReason = FinishReasonMapper.fromGFinishReasonToFinishReason(geminiCandidate.getFinishReason());
            AiMessage from = geminiCandidate.getContent() == null ? AiMessage.from("No text was returned by the model. The model finished generating because of the following reason: " + fromGFinishReasonToFinishReason) : PartsAndContentsMapper.fromGPartsToAiMessage(geminiCandidate.getContent().getParts(), this.includeCodeExecutionOutput);
            TokenUsage tokenUsage = new TokenUsage(usageMetadata.getPromptTokenCount(), usageMetadata.getCandidatesTokenCount(), usageMetadata.getTotalTokenCount());
            ChatModelResponseContext chatModelResponseContext = new ChatModelResponseContext(ChatModelResponse.builder().model(this.modelName).tokenUsage(tokenUsage).finishReason(fromGFinishReasonToFinishReason).aiMessage(from).build(), build2, concurrentHashMap);
            this.listeners.forEach(chatModelListener3 -> {
                try {
                    chatModelListener3.onResponse(chatModelResponseContext);
                } catch (Exception e) {
                    log.warn("Exception while calling model listener (onResponse)", e);
                }
            });
            return ChatResponse.builder().aiMessage(from).finishReason(fromGFinishReasonToFinishReason).tokenUsage(tokenUsage).build();
        } catch (IOException e) {
            RuntimeException runtimeException2 = new RuntimeException("An error occurred when calling the Gemini API endpoint.", e);
            ChatModelErrorContext chatModelErrorContext2 = new ChatModelErrorContext(e, build2, (ChatModelResponse) null, concurrentHashMap);
            this.listeners.forEach(chatModelListener4 -> {
                try {
                    chatModelListener4.onError(chatModelErrorContext2);
                } catch (Exception e2) {
                    log.warn("Exception while calling model listener (onError)", e2);
                }
            });
            throw runtimeException2;
        }
    }

    public int estimateTokenCount(List<ChatMessage> list) {
        return this.geminiTokenizer.estimateTokenCountInMessages(list);
    }

    public Set<Capability> supportedCapabilities() {
        HashSet hashSet = new HashSet();
        if (this.responseFormat != null && ResponseFormatType.JSON.equals(this.responseFormat.type())) {
            hashSet.add(Capability.RESPONSE_FORMAT_JSON_SCHEMA);
        }
        return hashSet;
    }

    public static GoogleAiGeminiChatModelBuilder builder() {
        return new GoogleAiGeminiChatModelBuilder();
    }
}
