package com.vaadin.uitest.ai.services.vectordb;

import static com.vaadin.uitest.ai.services.AiServiceConstants.MODEL;
import static com.vaadin.uitest.ai.services.AiServiceConstants.OPENAI_API_COMPLETIONS;
import static com.vaadin.uitest.ai.services.AiServiceConstants.OPENAI_API_URL;
import static com.vaadin.uitest.ai.services.AiServiceConstants.TEMPERATURE;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.aot.hint.annotation.RegisterReflectionForBinding;
import org.springframework.http.MediaType;
import org.springframework.http.client.reactive.ReactorClientHttpConnector;
import org.springframework.web.reactive.function.client.WebClient;

import com.google.common.base.Joiner;
import com.vaadin.uitest.ai.utils.KeysUtils;
import com.vaadin.uitest.model.chat.ChatCompletionChunkResponse;
import com.vaadin.uitest.model.chat.ChatCompletionMessage;
import com.vaadin.uitest.model.chat.ChatCompletionMessageIn;
import com.vaadin.uitest.model.chat.Link;
import com.vaadin.uitest.model.vectordb.ChatIndexSource;
import com.vaadin.uitest.model.vectordb.EmbeddingResponse;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.netty.http.client.HttpClient;

public class OpenAIService {

    // can be adjusted to make server-to-client updates in bigger batches
    // (without this every character is sent separately)
    public static final int FLUX_BUFFER_MAX_SIZE = 10;
    private static final Logger LOGGER = LoggerFactory
            .getLogger(OpenAIService.class);
    private final static String OPENAI_API_KEY = KeysUtils.getOpenAiKey();

    private final WebClient webClient;

    public OpenAIService() {
        var client = HttpClient.create()
                .responseTimeout(Duration.ofSeconds(45));
        this.webClient = WebClient.builder()
                .clientConnector(new ReactorClientHttpConnector(client))
                .baseUrl(OPENAI_API_URL)
                .defaultHeader("Content-Type", MediaType.APPLICATION_JSON_VALUE)
                .defaultHeader("Authorization", "Bearer " + OPENAI_API_KEY)
                .build();

    }

    @RegisterReflectionForBinding(EmbeddingResponse.class)
    public Mono<List<Double>> createEmbedding(String text) {
        if (Boolean.getBoolean("ai.debug")) {
            LOGGER.info("Creating embedding for text: {}", text);
        }

        Map<String, Object> body = Map.of("model", MODEL, "input", text);

        return webClient.post().uri("/v1/embeddings").bodyValue(body).retrieve()
                .bodyToMono(EmbeddingResponse.class)
                .map(EmbeddingResponse::getEmbedding);
    }

    public Flux<ChatCompletionMessage> generateCompletionStream(
            List<ChatCompletionMessage> messages, ChatIndexSource indexSource,
            String sessionId) {
        if (Boolean.getBoolean("ai.debug")) {
            LOGGER.info("Generating completion for messages: {}", messages);
        }
        List<Link> links = messages.get(0).getLinks();
        int index = indexSource.get(sessionId);
        float score = messages.get(0).getScore();

        return webClient.post().uri(OPENAI_API_COMPLETIONS)
                .bodyValue(Map.of("model", MODEL, "messages", messages.stream()
                        .map(msg -> new ChatCompletionMessageIn(
                                ChatCompletionMessageIn.Role
                                        .valueOf(msg.getRole().name()),
                                msg.getContent()))
                        .collect(Collectors.toList()), "stream", true,
                        "temperature", TEMPERATURE))
                .retrieve().bodyToFlux(ChatCompletionChunkResponse.class)
                .onErrorResume(error -> {

                    // The stream terminates with a `[DONE]` message, which
                    // causes a serialization error
                    // Ignore this error and return an empty stream instead
                    if (error.getMessage().contains("JsonToken.START_ARRAY")) {
                        return Flux.empty();
                    }

                    // If the error is not caused by the `[DONE]` message,
                    // return the error
                    else {
                        return Flux.error(error);
                    }
                }).filter(response -> {
                    var content = response.getChoices().get(0).getDelta()
                            .getContent();
                    return content != null && !content.equals("\n\n");
                })
                .map(response -> response.getChoices().get(0).getDelta()
                        .getContent())
                .buffer(FLUX_BUFFER_MAX_SIZE)
                .map(list -> Joiner.on("").join(list)).map(val -> {
                    if (Boolean.getBoolean("ai.debug")) {
                        LOGGER.debug("Received chunk: {}", val);
                    }
                    return val;
                }).map(content -> new ChatCompletionMessage(null, content,
                        getLinksOnce(links), index, score));
    }

    private List<Link> getLinksOnce(List<Link> links) {
        if (!links.isEmpty()) {
            List<Link> linksToReturn = new ArrayList<>(links);
            links.clear();
            return linksToReturn;
        }
        return links;
    }

}
