package com.vaadin.copilot;

import jakarta.annotation.security.DenyAll;
import jakarta.annotation.security.PermitAll;
import jakarta.annotation.security.RolesAllowed;

import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.vaadin.copilot.javarewriter.JavaRewriterUtil;
import com.vaadin.flow.router.Layout;
import com.vaadin.flow.router.Route;
import com.vaadin.flow.server.auth.AnonymousAllowed;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.expr.AnnotationExpr;
import com.github.javaparser.ast.expr.ArrayInitializerExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MarkerAnnotationExpr;
import com.github.javaparser.ast.expr.Name;
import com.github.javaparser.ast.expr.SingleMemberAnnotationExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;

/**
 * Utility class for figuring out access requirements for Java methods such as
 * routes and browser callables.
 */
public class AccessRequirementUtil {

    /**
     * Get the access requirement for the given annotated class or method based on
     * its annotations.
     *
     * @param annotatedClassOrMethod
     *            the annotated class or method
     * @param fallback
     *            the class or method to check if the given class has no access
     *            control annotations
     * @return the access requirement for using the class or method
     */
    public static AccessRequirement getAccessRequirement(AnnotatedElement annotatedClassOrMethod,
            AnnotatedElement fallback) {
        // Based on AccessAnnotationChecker.hasAccess
        if (annotatedClassOrMethod.isAnnotationPresent(DenyAll.class)) {
            return new AccessRequirement(AccessRequirement.Type.DENY_ALL);
        }
        if (annotatedClassOrMethod.isAnnotationPresent(AnonymousAllowed.class)) {
            return new AccessRequirement(AccessRequirement.Type.ANONYMOUS_ALLOWED);
        }
        RolesAllowed rolesAllowed = annotatedClassOrMethod.getAnnotation(RolesAllowed.class);
        if (rolesAllowed != null) {
            return new AccessRequirement(AccessRequirement.Type.ROLES_ALLOWED, rolesAllowed.value());
        } else if (annotatedClassOrMethod.isAnnotationPresent(PermitAll.class)) {
            return new AccessRequirement(AccessRequirement.Type.PERMIT_ALL);
        }

        if (fallback != null) {
            return getAccessRequirement(fallback, null);
        }
        return new AccessRequirement(AccessRequirement.Type.DENY_ALL);
    }

    /**
     * Removes all access control annotations from the given class.
     *
     * @param routeClass
     *            the class to remove access control annotations from
     */
    public static void removeAccessAnnotations(ClassOrInterfaceDeclaration routeClass) {
        Stream.of(AccessRequirement.Type.values())
                .forEach(type -> removeAnnotation(routeClass, type.getAnnotationType()));
    }

    /**
     * Sets the access control annotation for the given class.
     *
     * @param routeClass
     *            the class to set the access control annotation for
     * @param accessRequirement
     *            the access requirement to set
     */
    public static void setAccessAnnotation(ClassOrInterfaceDeclaration routeClass,
            AccessRequirement accessRequirement) {
        removeAccessAnnotations(routeClass);
        addAccessAnnotation(routeClass, accessRequirement);
    }

    public static Optional<AccessRequirement.Type> getAccessAnnotation(ClassOrInterfaceDeclaration routeClass) {
        return Stream.of(AccessRequirement.Type.values())
                .filter(type -> getAnnotations(routeClass, type.getAnnotationType()).findFirst().isPresent())
                .findFirst();
    }

    private static void removeAnnotation(ClassOrInterfaceDeclaration routeClass,
            Class<? extends Annotation> annotationType) {
        List<AnnotationExpr> annotations = getAnnotations(routeClass, annotationType).toList();
        annotations.forEach(AnnotationExpr::remove);
    }

    private static void addAccessAnnotation(ClassOrInterfaceDeclaration routeClass,
            AccessRequirement accessRequirement) {
        CompilationUnit cu = routeClass.findCompilationUnit()
                .orElseThrow(() -> new IllegalArgumentException("Route class is not inside a compilation unit"));

        Class<? extends Annotation> annotationType = accessRequirement.getType().getAnnotationType();
        JavaRewriterUtil.addImport(cu, annotationType.getName());
        if (annotationType == RolesAllowed.class) {
            List<Expression> rolesExpressions = Arrays.stream(accessRequirement.getRoles()).map(StringLiteralExpr::new)
                    .collect(Collectors.toUnmodifiableList());
            NodeList<Expression> rolesNodeList = NodeList.nodeList(rolesExpressions);
            routeClass.addAnnotation(new SingleMemberAnnotationExpr(new Name(annotationType.getSimpleName()),
                    new ArrayInitializerExpr(rolesNodeList)));
        } else {
            routeClass.addAnnotation(new MarkerAnnotationExpr(annotationType.getSimpleName()));
        }
    }

    static ClassOrInterfaceDeclaration findRouteClass(CompilationUnit compilationUnit, String path) {
        return compilationUnit.findAll(ClassOrInterfaceDeclaration.class).stream()
                .filter(c -> getAnnotations(c, Route.class)
                        .anyMatch(a -> (isRoute(a, path) || isRoute(a, "/" + path) || isRoute(a, path + "/"))))
                .findFirst().orElseThrow(() -> new IllegalArgumentException("No route found for path '" + path + "'"));
    }

    static ClassOrInterfaceDeclaration findLayoutClass(CompilationUnit compilationUnit) {
        return compilationUnit.findAll(ClassOrInterfaceDeclaration.class).stream()
                .filter(c -> getAnnotations(c, Layout.class).findFirst().isPresent()).findFirst()
                .orElseThrow(() -> new IllegalArgumentException("No layout found in " + compilationUnit));
    }

    private static Stream<AnnotationExpr> getAnnotations(ClassOrInterfaceDeclaration routeClass,
            Class<? extends Annotation> annotation) {
        return routeClass.getAnnotations().stream().filter(a -> a.getNameAsString().equals(annotation.getSimpleName())
                || a.getNameAsString().equals(annotation.getName()));
    }

    private static boolean isRoute(AnnotationExpr a, String path) {
        if (!a.getNameAsString().equals("Route")) {
            return false;
        }

        String value = JavaRewriterUtil.getAnnotationValue(a);
        return path.equals(value);

    }

}
