package dev.hilla.push;

import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import dev.hilla.ConditionalOnFeatureFlag;
import dev.hilla.EndpointInvocationException;
import dev.hilla.EndpointInvoker;
import dev.hilla.push.messages.fromclient.AbstractServerMessage;
import dev.hilla.push.messages.fromclient.SubscribeMessage;
import dev.hilla.push.messages.fromclient.UnsubscribeMessage;
import dev.hilla.push.messages.toclient.AbstractClientMessage;
import dev.hilla.push.messages.toclient.ClientMessageComplete;
import dev.hilla.push.messages.toclient.ClientMessageError;
import dev.hilla.push.messages.toclient.ClientMessageUpdate;
import java.security.Principal;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Service;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;

@ConditionalOnFeatureFlag(PushMessageHandler.PUSH_FEATURE_FLAG)
@Service
/* loaded from: input_file:dev/hilla/push/PushMessageHandler.class */
public class PushMessageHandler {
    static final String PUSH_FEATURE_FLAG = "hillaPush";
    private final EndpointInvoker endpointInvoker;
    private Map<String, Disposable> closeHandlers = new ConcurrentHashMap();

    public PushMessageHandler(EndpointInvoker endpointInvoker) {
        this.endpointInvoker = endpointInvoker;
    }

    public void handleMessage(AbstractServerMessage abstractServerMessage, Consumer<AbstractClientMessage> consumer) {
        if (abstractServerMessage instanceof SubscribeMessage) {
            handleSubscribe((SubscribeMessage) abstractServerMessage, consumer);
        } else {
            if (!(abstractServerMessage instanceof UnsubscribeMessage)) {
                throw new IllegalArgumentException("Unknown message type: " + abstractServerMessage.getClass().getName());
            }
            handleClose((UnsubscribeMessage) abstractServerMessage);
        }
    }

    private void handleSubscribe(SubscribeMessage subscribeMessage, Consumer<AbstractClientMessage> consumer) {
        if (this.endpointInvoker.getReturnType(subscribeMessage.getEndpointName(), subscribeMessage.getMethodName()) != Flux.class) {
            consumer.accept(new ClientMessageError(subscribeMessage.getId(), "Method " + subscribeMessage.getEndpointName() + "/" + subscribeMessage.getMethodName() + " is not a Flux method"));
            return;
        }
        ArrayNode params = subscribeMessage.getParams();
        ObjectNode objectNode = params.objectNode();
        for (int i = 0; i < params.size(); i++) {
            objectNode.set(i + "", params.get(i));
        }
        Principal authentication = SecurityContextHolder.getContext().getAuthentication();
        try {
            this.closeHandlers.put(subscribeMessage.getId(), ((Flux) this.endpointInvoker.invoke(subscribeMessage.getEndpointName(), subscribeMessage.getMethodName(), objectNode, authentication, str -> {
                return Boolean.valueOf(authentication.getAuthorities().stream().anyMatch(grantedAuthority -> {
                    return grantedAuthority.getAuthority().equals("ROLE_" + str);
                }));
            })).subscribe(obj -> {
                send(consumer, new ClientMessageUpdate(subscribeMessage.getId(), obj));
            }, th -> {
                this.closeHandlers.remove(subscribeMessage.getId());
                send(consumer, new ClientMessageError(subscribeMessage.getId(), "Exception in Flux"));
                getLogger().error("Exception in Flux", th);
            }, () -> {
                this.closeHandlers.remove(subscribeMessage.getId());
                send(consumer, new ClientMessageComplete(subscribeMessage.getId()));
            }));
        } catch (EndpointInvocationException.EndpointAccessDeniedException | EndpointInvocationException.EndpointBadRequestException | EndpointInvocationException.EndpointInternalException e) {
            consumer.accept(new ClientMessageError(subscribeMessage.getId(), e.getMessage()));
        } catch (EndpointInvocationException.EndpointNotFoundException e2) {
            consumer.accept(new ClientMessageError(subscribeMessage.getId(), "No such endpoint"));
        }
    }

    private void send(Consumer<AbstractClientMessage> consumer, AbstractClientMessage abstractClientMessage) {
        consumer.accept(abstractClientMessage);
    }

    private void handleClose(UnsubscribeMessage unsubscribeMessage) {
        Disposable remove = this.closeHandlers.remove(unsubscribeMessage.getId());
        if (remove == null) {
            getLogger().warn("Trying to close an unknown flux with id " + unsubscribeMessage.getId());
        } else {
            remove.dispose();
        }
    }

    private Logger getLogger() {
        return LoggerFactory.getLogger(getClass());
    }
}
