/*
 * Decompiled with CFR 0.152.
 */
package com.vaadin.flow.component.ai.provider;

import com.vaadin.flow.component.ai.provider.AttachmentContentType;
import com.vaadin.flow.component.ai.provider.LLMProvider;
import com.vaadin.flow.component.ai.provider.LLMProviderHelpers;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.agent.tool.ToolSpecifications;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.AudioContent;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.PdfFileContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.message.VideoContent;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.service.tool.DefaultToolExecutor;
import dev.langchain4j.service.tool.ToolExecutor;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;

public class LangChain4JLLMProvider
implements LLMProvider {
    private static final int MAX_MESSAGES = 30;
    private final transient StreamingChatModel streamingChatModel;
    private final transient ChatModel nonStreamingChatModel;
    private final transient ChatMemory chatMemory;

    public LangChain4JLLMProvider(StreamingChatModel chatModel) {
        this(null, Objects.requireNonNull(chatModel, "StreamingChatModel must not be null"));
    }

    public LangChain4JLLMProvider(ChatModel chatModel) {
        this(Objects.requireNonNull(chatModel, "ChatModel must not be null"), null);
    }

    private LangChain4JLLMProvider(ChatModel chatModel, StreamingChatModel streamingChatModel) {
        this.streamingChatModel = streamingChatModel;
        this.nonStreamingChatModel = chatModel;
        this.chatMemory = MessageWindowChatMemory.withMaxMessages((int)30);
    }

    @Override
    public Flux<String> stream(LLMProvider.LLMRequest request) {
        Objects.requireNonNull(request, "Request must not be null");
        Objects.requireNonNull(request.userMessage(), "User message must not be null");
        return Flux.create(sink -> {
            try {
                UserMessage userMessage = this.buildUserMessage(request);
                this.chatMemory.add((ChatMessage)userMessage);
                ToolContext toolContext = new ToolContext(this.prepareToolExecutors(request), this.prepareToolSpecifications(request));
                ChatExecutionContext context = new ChatExecutionContext(request, (FluxSink<String>)sink, this.chatMemory, toolContext);
                this.executeChat(context);
            }
            catch (Exception e) {
                sink.error((Throwable)e);
            }
        }, (FluxSink.OverflowStrategy)FluxSink.OverflowStrategy.BUFFER);
    }

    private Map<String, ToolExecutor> prepareToolExecutors(LLMProvider.LLMRequest request) {
        Object[] tools = request.tools();
        if (tools == null) {
            return Collections.emptyMap();
        }
        HashMap<String, ToolExecutor> toolExecutors = new HashMap<String, ToolExecutor>();
        for (Object toolObject : tools) {
            Arrays.stream(toolObject.getClass().getDeclaredMethods()).filter(method -> method.isAnnotationPresent(Tool.class)).forEach(method -> {
                String toolExecutorKey = ToolSpecifications.toolSpecificationFrom((Method)method).name();
                ToolExecutor toolExecutor = this.getToolExecutor(toolObject, (Method)method);
                toolExecutors.put(toolExecutorKey, toolExecutor);
            });
        }
        return toolExecutors;
    }

    private ToolExecutor getToolExecutor(Object toolObject, Method method) {
        DefaultToolExecutor baseExecutor = new DefaultToolExecutor(toolObject, method);
        return (arg_0, arg_1) -> ((DefaultToolExecutor)baseExecutor).execute(arg_0, arg_1);
    }

    private List<ToolSpecification> prepareToolSpecifications(LLMProvider.LLMRequest request) {
        if (request.tools() == null) {
            return Collections.emptyList();
        }
        return Arrays.stream(request.tools()).map(ToolSpecifications::toolSpecificationsFrom).flatMap(Collection::stream).toList();
    }

    private void executeChat(ChatExecutionContext context) {
        List<ChatMessage> messages = this.buildMessages(context.getRequest(), context.getChatMemory());
        if (this.streamingChatModel != null) {
            this.executeStreamingChat(messages, context);
        } else {
            this.executeNonStreamingChat(messages, context);
        }
    }

    private void executeStreamingChat(List<ChatMessage> messages, final ChatExecutionContext context) {
        ChatRequest.Builder chatRequestBuilder = ChatRequest.builder().messages(messages);
        List<ToolSpecification> specifications = context.getToolContext().specifications();
        if (!specifications.isEmpty()) {
            chatRequestBuilder = chatRequestBuilder.toolSpecifications(specifications);
        }
        ChatRequest chatRequest = chatRequestBuilder.build();
        this.streamingChatModel.chat(chatRequest, new StreamingChatResponseHandler(){

            public void onPartialResponse(String partialResponse) {
                context.getSink().next((Object)partialResponse);
            }

            public void onCompleteResponse(ChatResponse response) {
                LangChain4JLLMProvider.this.handleResponse(context, response);
            }

            public void onError(Throwable error) {
                context.getSink().error(error);
            }
        });
    }

    private void executeToolRequests(AiMessage aiMessage, ChatExecutionContext context) {
        List toolExecutionRequests = aiMessage.toolExecutionRequests();
        for (ToolExecutionRequest toolExecRequest : toolExecutionRequests) {
            ToolExecutor toolExecutor = context.getToolContext().executors().get(toolExecRequest.name());
            ToolExecutionResultMessage result = LangChain4JLLMProvider.executeToolRequest(toolExecutor, toolExecRequest);
            context.getChatMemory().add((ChatMessage)result);
        }
    }

    private void executeNonStreamingChat(List<ChatMessage> messages, ChatExecutionContext context) {
        try {
            ChatRequest.Builder requestBuilder = ChatRequest.builder().messages(messages);
            List<ToolSpecification> specifications = context.getToolContext().specifications();
            if (!specifications.isEmpty()) {
                requestBuilder.toolSpecifications(specifications);
            }
            ChatResponse response = this.nonStreamingChatModel.chat(requestBuilder.build());
            this.handleResponse(context, response);
        }
        catch (Exception e) {
            context.getSink().error((Throwable)e);
        }
    }

    private void handleResponse(ChatExecutionContext context, ChatResponse response) {
        String text;
        AiMessage aiMessage = response.aiMessage();
        if (aiMessage == null) {
            context.getSink().complete();
            return;
        }
        context.getChatMemory().add((ChatMessage)aiMessage);
        if (!this.isStreaming() && (text = aiMessage.text()) != null && !text.isEmpty()) {
            context.getSink().next((Object)text);
        }
        if (aiMessage.hasToolExecutionRequests()) {
            this.executeToolRequests(aiMessage, context);
            this.executeChat(context);
        } else {
            context.getSink().complete();
        }
    }

    private static ToolExecutionResultMessage executeToolRequest(ToolExecutor toolExecutor, ToolExecutionRequest toolExecRequest) {
        Object result;
        if (toolExecutor == null) {
            result = "Tool not found: " + toolExecRequest.name();
        } else {
            try {
                result = toolExecutor.execute(toolExecRequest, null);
            }
            catch (Exception e) {
                result = "Error executing tool: " + e.getMessage();
            }
        }
        return ToolExecutionResultMessage.from((ToolExecutionRequest)toolExecRequest, (String)result);
    }

    private List<ChatMessage> buildMessages(LLMProvider.LLMRequest request, ChatMemory chatMemory) {
        String systemPrompt;
        ArrayList<ChatMessage> messages = new ArrayList<ChatMessage>();
        if (request.systemPrompt() != null && !(systemPrompt = request.systemPrompt().trim()).isEmpty()) {
            messages.add((ChatMessage)SystemMessage.from((String)systemPrompt));
        }
        messages.addAll(chatMemory.messages());
        return messages;
    }

    private UserMessage buildUserMessage(LLMProvider.LLMRequest request) {
        ArrayList<TextContent> contents = new ArrayList<TextContent>();
        contents.add(TextContent.from((String)request.userMessage()));
        List<LLMProvider.Attachment> attachments = request.attachments();
        if (attachments != null) {
            attachments.stream().map(LangChain4JLLMProvider::getAttachmentContent).flatMap(Optional::stream).forEach(contents::add);
        }
        return UserMessage.from(contents);
    }

    private boolean isStreaming() {
        return this.streamingChatModel != null;
    }

    private static Optional<Content> getAttachmentContent(LLMProvider.Attachment attachment) {
        LLMProviderHelpers.validateAttachment(attachment);
        AttachmentContentType contentType = AttachmentContentType.fromMimeType(attachment.contentType());
        return switch (contentType) {
            default -> throw new MatchException(null, null);
            case AttachmentContentType.IMAGE -> Optional.of(LangChain4JLLMProvider.getImageAttachmentContent(attachment));
            case AttachmentContentType.TEXT -> Optional.of(LangChain4JLLMProvider.getTextAttachmentContent(attachment));
            case AttachmentContentType.PDF -> Optional.of(LangChain4JLLMProvider.getPdfAttachmentContent(attachment));
            case AttachmentContentType.AUDIO -> Optional.of(LangChain4JLLMProvider.getAudioAttachmentContent(attachment));
            case AttachmentContentType.VIDEO -> Optional.of(LangChain4JLLMProvider.getVideoAttachmentContent(attachment));
            case AttachmentContentType.UNSUPPORTED -> Optional.empty();
        };
    }

    private static TextContent getTextAttachmentContent(LLMProvider.Attachment attachment) {
        String textContent = LLMProviderHelpers.decodeAsUtf8(attachment.data(), attachment.fileName(), false);
        return TextContent.from((String)LLMProviderHelpers.formatTextAttachment(attachment.fileName(), textContent));
    }

    private static PdfFileContent getPdfAttachmentContent(LLMProvider.Attachment attachment) {
        String base64 = LLMProviderHelpers.getBase64Data(attachment.data());
        return PdfFileContent.from((String)base64, (String)attachment.contentType());
    }

    private static ImageContent getImageAttachmentContent(LLMProvider.Attachment attachment) {
        String base64 = LLMProviderHelpers.getBase64Data(attachment.data());
        return ImageContent.from((String)base64, (String)attachment.contentType());
    }

    private static AudioContent getAudioAttachmentContent(LLMProvider.Attachment attachment) {
        String base64 = LLMProviderHelpers.getBase64Data(attachment.data());
        return AudioContent.from((String)base64, (String)attachment.contentType());
    }

    private static VideoContent getVideoAttachmentContent(LLMProvider.Attachment attachment) {
        String base64 = LLMProviderHelpers.getBase64Data(attachment.data());
        return VideoContent.from((String)base64, (String)attachment.contentType());
    }

    private static class ChatExecutionContext {
        private final LLMProvider.LLMRequest request;
        private final FluxSink<String> sink;
        private final ChatMemory chatMemory;
        private final ToolContext toolContext;

        ChatExecutionContext(LLMProvider.LLMRequest request, FluxSink<String> sink, ChatMemory chatMemory, ToolContext toolContext) {
            this.request = request;
            this.sink = sink;
            this.chatMemory = chatMemory;
            this.toolContext = toolContext;
        }

        LLMProvider.LLMRequest getRequest() {
            return this.request;
        }

        FluxSink<String> getSink() {
            return this.sink;
        }

        ChatMemory getChatMemory() {
            return this.chatMemory;
        }

        ToolContext getToolContext() {
            return this.toolContext;
        }
    }

    private record ToolContext(Map<String, ToolExecutor> executors, List<ToolSpecification> specifications) {
    }
}

