package com.vaadin.uitest.ai.services.vectordb;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import com.vaadin.uitest.model.Framework;
import com.vaadin.uitest.model.chat.ChatCompletionMessage;
import com.vaadin.uitest.model.chat.Link;
import com.vaadin.uitest.model.vectordb.ChatIndexSource;
import com.vaadin.uitest.model.vectordb.Document;

import reactor.core.publisher.Flux;

@Service
public class DocsAssistantService {
    public static final int LINKS_LIMIT_JAVADOCS = 3;
    public static final int LINKS_LIMIT_DOCUMENTATION = 2;
    private static final Logger LOGGER = LoggerFactory
            .getLogger(DocsAssistantService.class);

    private static final int MAX_TOKENS = 16384;
    private static final int MAX_RESPONSE_TOKENS = 1250;
    private static final int MAX_CONTEXT_TOKENS_DEFAULT = 6144;
    private static final int MAX_DOCS_RESULTS = 20;
    private final OpenAIService openAIService;
    private final PineconeService pineconeService;
    private final Encoding tokenizer;

    private static final String DOCUMENTATION_LINK_PREFIX_FLOW = "https://vaadin.com/docs/latest";
    private static final String DOCUMENTATION_LINK_PREFIX_HILLA_REACT = "https://hilla.dev/docs";
    private static final String DOCUMENTATION_LINK_PREFIX_HILLA_LIT = "https://hilla.dev/docs";
    private static final String JAVADOCS_LINK_PREFIX = "https://vaadin.com/api/platform/24.3.8/";

    private final ChatIndexSource indexSource = new ChatIndexSource();

    public DocsAssistantService(OpenAIService openAIService,
            PineconeService pineconeService) {
        this.openAIService = openAIService;
        this.pineconeService = pineconeService;
        EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
        tokenizer = registry.getEncoding(EncodingType.CL100K_BASE);
    }

    /**
     * Finds similar documents in the documentation and calls the OpenAI chat
     * completion API.
     *
     * @return The completion as a stream of chunks
     */
    public Flux<ChatCompletionMessage> getCompletionStream(
            List<ChatCompletionMessage> history, Framework framework,
            String question, String sessionId, boolean forParser) {
        LOGGER.info("History size {}, {}", history.size(), framework);
        if (history.isEmpty()) {
            return Flux.error(new RuntimeException("History is empty"));
        }

        Map<Namespace, Integer> namespacesWithMaxDocumentCount = forParser
                ? getParserNamespacesWithMaxDocumentCount(framework)
                : getGeneratorNamespacesWithMaxDocumentCount(framework);
        assert namespacesWithMaxDocumentCount.values().stream()
                .mapToInt(Integer::intValue).sum() <= MAX_DOCS_RESULTS;
        if (question == null) {
            question = history.get(history.size() - 1).getContent();
        }
        return openAIService.createEmbedding(question)
                .flatMapMany(embedding -> Flux
                        .fromIterable(namespacesWithMaxDocumentCount.keySet())
                        .flatMap(namespace -> pineconeService
                                .findSimilarDocuments(embedding, namespace,
                                        namespacesWithMaxDocumentCount
                                                .get(namespace))))
                .collectList()
                .map(documentsQueryResult -> getPromptWithContext(history,
                        framework, documentsQueryResult, forParser))
                .flatMapMany(prompts -> openAIService.generateCompletionStream(
                        prompts, indexSource, sessionId));
    }

    private int getMaxContextTokens(Framework framework, Namespace namespace,
            boolean forParser) {
        Map<Namespace, Integer> namespacesWithMaxDocumentCount = forParser
                ? getParserNamespacesWithMaxDocumentCount(framework)
                : getGeneratorNamespacesWithMaxDocumentCount(framework);
        int maxDocumentCount = namespacesWithMaxDocumentCount.values().stream()
                .mapToInt(Integer::intValue).sum();
        int totalMaxContextTokens = forParser ? MAX_CONTEXT_TOKENS_DEFAULT / 4
                : MAX_CONTEXT_TOKENS_DEFAULT;
        return namespacesWithMaxDocumentCount.get(namespace)
                * totalMaxContextTokens / maxDocumentCount;
    }

    private Map<Namespace, Integer> getParserNamespacesWithMaxDocumentCount(
            Framework framework) {
        Map<Namespace, Integer> namespacesWithMaxDocumentCount = new HashMap<>();
        switch (framework) {
        case FLOW -> namespacesWithMaxDocumentCount.put(Namespace.FLOW, 12);
        case LIT -> namespacesWithMaxDocumentCount.put(Namespace.LIT, 12);
        case REACT -> namespacesWithMaxDocumentCount.put(Namespace.REACT, 12);
        }
        return namespacesWithMaxDocumentCount;
    }

    private Map<Namespace, Integer> getGeneratorNamespacesWithMaxDocumentCount(
            Framework framework) {
        Map<Namespace, Integer> namespacesWithMaxDocumentCount = new HashMap<>();
        // switch (framework) {
        // case FLOW -> {
        // namespacesWithMaxDocumentCount.put(Namespace.FLOW, 2);
        //// namespacesWithMaxDocumentCount.put(Namespace.FLOW_API, 2);
        // }
        // case LIT -> namespacesWithMaxDocumentCount.put(Namespace.LIT, 4);
        // case REACT -> namespacesWithMaxDocumentCount.put(Namespace.REACT, 4);
        // }
        // namespacesWithMaxDocumentCount.put(Namespace.RECOMMENDED_ACTIONS,
        // 12);
        // namespacesWithMaxDocumentCount.put(Namespace.PLAYWRIGHT_API, 4);
        return namespacesWithMaxDocumentCount;
    }

    private List<ChatCompletionMessage> getPromptWithContext(
            List<ChatCompletionMessage> history, Framework framework,
            List<DocumentsQueryResult> documentsQueryResults,
            boolean forParser) {
        if (documentsQueryResults.stream().map(DocumentsQueryResult::getDocs)
                .allMatch(Collection::isEmpty)) {
            return history;
        }
        var systemMessages = new ArrayList<>(history);
        List<ChatCompletionMessage> chatCompletionMessages = documentsQueryResults
                .stream()
                .filter(documentsQueryResult -> !documentsQueryResult.getDocs()
                        .isEmpty())
                .map(documentsQueryResult -> getChatCompletionMessage(framework,
                        documentsQueryResult, forParser))
                .toList();
        systemMessages.addAll(chatCompletionMessages);
        return capMessages(systemMessages);
    }

    private ChatCompletionMessage getChatCompletionMessage(Framework framework,
            DocumentsQueryResult documentsQueryResult, boolean forParser) {
        var messageTemplate = getMessageTemplate(framework,
                documentsQueryResult.getNamespace().getDocumentTemplate());
        var contextString = getContextString(documentsQueryResult.getDocs(),
                getMaxContextTokens(framework,
                        documentsQueryResult.getNamespace(), forParser));
        ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage(
                ChatCompletionMessage.Role.USER,
                String.format(messageTemplate, contextString));
        chatCompletionMessage
                .addLinks(createLinks(documentsQueryResult.getDocs(), framework,
                        documentsQueryResult.getNamespace()
                                .getDocumentTemplate(),
                        getLinksLimit(documentsQueryResult.getNamespace())));
        chatCompletionMessage
                .setScore(calculateScore(documentsQueryResult.getDocs()));
        return chatCompletionMessage;
    }

    private String getMessageDescription(Framework framework,
            DocumentTemplate documentTemplate) {
        switch (documentTemplate) {
        case JAVA_DOC -> {
            return "Here is the summary of related Vaadin Javadocs API:";
        }
        case REGULAR_DOC -> {
            return "Here are a few selected pieces of latest VAADIN "
                    + framework + " documentation:";
        }
        case RECOMMENDED_ACTIONS -> {
            return "When implementing a scenario step, please consider the following Playwright Java code snippets as references:";
        }
        case FULL_TEST_DATA -> {
            return String.format(
                    "Here are the %s samples and recommended Playwright Java actions for reference:",
                    framework.getLabel());
        }
        }
        return "";
    }

    private String getMessageTemplate(Framework framework,
            DocumentTemplate documentTemplate) {
        return """
                %s
                ===
                %s
                ===
                """.formatted(
                getMessageDescription(framework, documentTemplate), "%s");
    }

    private int getLinksLimit(Namespace namespace) {
        return DocumentTemplate.JAVA_DOC.equals(namespace.getDocumentTemplate())
                ? LINKS_LIMIT_JAVADOCS
                : LINKS_LIMIT_DOCUMENTATION;
    }

    private float calculateScore(List<Document> documents) {
        return documents.stream().map(Document::getScore).max(Float::compare)
                .orElse((float) 0);
    }

    private Collection<Link> createLinks(List<Document> documents,
            Framework framework, DocumentTemplate documentTemplate, int limit) {
        return documents.stream().map(Document::getLink)
                .filter(Objects::nonNull).distinct().limit(limit)
                .map(link -> new Link(documentTemplate.getLabel(),
                        createRealLink(link, framework, documentTemplate)))
                .toList();
    }

    private String createRealLink(String link, Framework framework,
            DocumentTemplate documentTemplate) {
        if (DocumentTemplate.JAVA_DOC.equals(documentTemplate)) {
            if (Boolean.getBoolean("ai.debug")) {
                LOGGER.info("Created Javadoc link {}",
                        JAVADOCS_LINK_PREFIX + link);
            }
            return JAVADOCS_LINK_PREFIX + link;
        }
        String documentationLinkPrefix = DOCUMENTATION_LINK_PREFIX_FLOW;
        if (Framework.REACT.equals(framework)) {
            documentationLinkPrefix = DOCUMENTATION_LINK_PREFIX_HILLA_REACT;
        } else if (Framework.LIT.equals(framework)) {
            documentationLinkPrefix = DOCUMENTATION_LINK_PREFIX_HILLA_LIT;
        }
        if (!link.isBlank()) {
            String articleLink = link.substring(
                    link.indexOf("dspublisher/out/public") + 23
                            + (framework.getValue().length()),
                    link.length() - 11);
            if (Boolean.getBoolean("ai.debug")) {
                LOGGER.info("Created link {}",
                        documentationLinkPrefix + articleLink);
            }
            return documentationLinkPrefix + articleLink;
        } else {
            if (Boolean.getBoolean("ai.debug")) {
                LOGGER.info("Created default link {}", documentationLinkPrefix);
            }
            return documentationLinkPrefix;
        }
    }

    /**
     * Returns a string of up to maxToken tokens from the contextDocs
     *
     * @param contextDocs
     *            The context documents
     * @param maxTokens
     *            max total token count
     */
    private String getContextString(List<Document> contextDocs, int maxTokens) {
        var tokenCount = 0;
        var stringBuilder = new StringBuilder();
        for (var doc : contextDocs) {
            tokenCount += tokenizer.encode(doc.getContent() + "\n---\n").size();
            if (tokenCount > maxTokens) {
                break;
            }
            stringBuilder.append(doc.getContent());
            stringBuilder.append("\n---\n");
            if (Boolean.getBoolean("ai.debug")) {
                LOGGER.info("Appended a doc with score {}, doc: {} ",
                        doc.getScore(), doc.getContent());
            }
        }
        LOGGER.info(
                "ContextDocs article count {}, context formed with {}/{} tokens",
                contextDocs.size(), tokenCount, maxTokens);
        return stringBuilder.toString();
    }

    /**
     * Removes messages until the total number of tokens + MAX_RESPONSE_TOKENS
     * stays under MAX_TOKENS
     *
     * @param systemMessages
     *            The system messages including context and prompt
     * @return The capped messages that can be sent to the OpenAI API.
     */
    private List<ChatCompletionMessage> capMessages(
            List<ChatCompletionMessage> systemMessages) {
        var availableTokens = MAX_TOKENS - MAX_RESPONSE_TOKENS;

        var tokens = getTokenCount(systemMessages);
        LOGGER.info("Current total context size {} tokens", tokens);
        while (tokens > availableTokens) {
            // remove from the end (documentation) if there is a need to cap
            // messages
            systemMessages
                    .remove(systemMessages.remove(systemMessages.size() - 1));
            tokens = getTokenCount(systemMessages);
            LOGGER.info(
                    "Messages capped to {} entries to reduce tokens, current context size {} tokens",
                    systemMessages.size(), tokens);
        }

        return new ArrayList<>(systemMessages);
    }

    /**
     * Returns the number of tokens in the messages. See
     * https://github.com/openai/openai-cookbook/blob/834181d5739740eb8380096dac7056c925578d9a/examples/How_to_count_tokens_with_tiktoken.ipynb
     *
     * @param messages
     *            The messages to count the tokens of
     * @return The number of tokens in the messages
     */
    private int getTokenCount(List<ChatCompletionMessage> messages) {
        var tokenCount = 3; // every reply is primed with
                            // <|start|>assistant<|message|>
        for (var message : messages) {
            tokenCount += getMessageTokenCount(message);
        }
        return tokenCount;
    }

    /**
     * Returns the number of tokens in the message.
     *
     * @param message
     *            The message to count the tokens of
     * @return The number of tokens in the message
     */
    private int getMessageTokenCount(ChatCompletionMessage message) {
        var tokens = 4; // every message follows
                        // <|start|>{role/name}\n{content}<|end|>\n

        tokens += tokenizer.encode(message.getRole().toString()).size();
        tokens += tokenizer.encode(message.getContent()).size();

        return tokens;
    }

    /**
     * Resets the running index for the given sessionId
     *
     * @param sessionId
     */
    public void resetIndexSource(String sessionId) {
        indexSource.reset(sessionId);
    }
}
