/*-
 * Copyright (C) 2024 Vaadin Ltd
 *
 * This program is available under Vaadin Commercial License and Service Terms.
 *
 * See <https://vaadin.com/commercial-license-and-service-terms> for the full license.
-*/
package com.vaadin.controlcenter.starter.idm;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;

import java.io.IOException;
import java.time.Instant;
import java.util.Objects;

import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.web.filter.GenericFilterBean;

/**
 * Checks the access token expiration and clears the security context if the
 * token is expired. This causes the OAuth2 client to refresh the token on the
 * next request. If the refresh token is expired, the user will be redirected to
 * the login page.
 */
class RefreshTokenFilter extends GenericFilterBean {

    private final OAuth2AuthorizedClientService clientService;

    /**
     * Constructs a new RefreshTokenFilter.
     *
     * @param clientService
     *            the OAuth2 authorized client service
     */
    public RefreshTokenFilter(OAuth2AuthorizedClientService clientService) {
        this.clientService = clientService;
    }

    @Override
    public void doFilter(ServletRequest servletRequest,
            ServletResponse servletResponse, FilterChain filterChain)
            throws IOException, ServletException {
        Authentication authentication = SecurityContextHolder.getContext()
                .getAuthentication();
        if (authentication instanceof OAuth2AuthenticationToken token) {
            OAuth2AuthorizedClient client = clientService.loadAuthorizedClient(
                    token.getAuthorizedClientRegistrationId(), token.getName());
            OAuth2AccessToken accessToken = client.getAccessToken();
            if (Objects.requireNonNull(accessToken.getExpiresAt())
                    .isBefore(Instant.now())) {
                SecurityContextHolder.getContext().setAuthentication(null);
            }
        }
        filterChain.doFilter(servletRequest, servletResponse);
    }
}
