package com.vaadin.uitest.ai.prompts;

import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import com.vaadin.uitest.ai.AiArguments;
import com.vaadin.uitest.codesnippetgeneration.CodeSnippetProvider;
import com.vaadin.uitest.codesnippetgeneration.PlaywrightNodeImports;
import com.vaadin.uitest.model.TestFramework;
import com.vaadin.uitest.model.codesnippetgeneration.CodeSnippet;
import com.vaadin.uitest.model.scenario.TestScenario;
import com.vaadin.uitest.model.scenario.TestScenarios;

public abstract class AiSnippets {

    public static List<CodeSnippet> getCodeSnippets(AiArguments aiArguments) {
        TestScenarios testScenarios = TestScenarios
                .parse(aiArguments.getGherkin());
        return testScenarios.getScenarios().stream().map(TestScenario::getSteps)
                .flatMap(List::stream)
                .map(step -> CodeSnippetProvider.getCodeSnippet(step,
                        aiArguments.getTestFramework()))
                .filter(Objects::nonNull).toList();
    }

    public static String getCodeSnippetsArg(List<CodeSnippet> codeSnippets) {
        return codeSnippets.stream().map(AiSnippets::getFormattedCodeSnippet)
                .collect(Collectors.joining("\n"));
    }

    private static String getFormattedCodeSnippet(CodeSnippet codeSnippet) {
        return """
                Step: %s
                Code snippet:
                %s
                """.formatted(codeSnippet.getDescription(),
                codeSnippet.getCode());
    }

    public static String getImportsArg(TestFramework testFramework,
            List<CodeSnippet> codeSnippets) {
        switch (testFramework) {
        case PLAYWRIGHT_JAVA -> {
            return getJavaImportsArg(codeSnippets);
        }
        case PLAYWRIGHT_NODE -> {
            return getNodeImportsArg(codeSnippets);
        }
        case TEST_BENCH -> {
            throw new RuntimeException();
        }
        }
        return null;
    }

    private static String getNodeImportsArg(List<CodeSnippet> codeSnippets) {
        Set<String> imports = codeSnippets.stream().map(CodeSnippet::getImports)
                .flatMap(Set::stream)
                .collect(Collectors.toCollection(HashSet::new));
        if (imports.contains(PlaywrightNodeImports.EXPECT)) {
            imports.remove(PlaywrightNodeImports.EXPECT);
            imports.remove(PlaywrightNodeImports.TEST);
            imports.add("import { test, expect } from '@playwright/test';");
        }
        return imports.stream().sorted().collect(Collectors.joining("\n"));
    }

    private static String getJavaImportsArg(List<CodeSnippet> codeSnippets) {
        Set<String> imports = codeSnippets.stream().map(CodeSnippet::getImports)
                .flatMap(Set::stream)
                .collect(Collectors.toCollection(HashSet::new));
        return imports.stream().sorted().collect(Collectors.joining("\n"));
    }

}
