From cc008927f10867243e032bd6004c2fece724f384 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 5 Feb 2025 19:31:28 +0800 Subject: [PATCH] fix(framework) Avoid processing requests from non-FleetStub (#4900) Co-authored-by: Javier --- .../fleet/grpc_rere/server_interceptor.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index c7b5370f415b..5f9f6c822ca5 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -43,9 +43,11 @@ MAX_TIMESTAMP_DIFF = TIMESTAMP_TOLERANCE + SYSTEM_TIME_TOLERANCE -def _unary_unary_rpc_terminator(message: str) -> grpc.RpcMethodHandler: +def _unary_unary_rpc_terminator( + message: str, code: Any = grpc.StatusCode.UNAUTHENTICATED +) -> grpc.RpcMethodHandler: def terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage: - context.abort(grpc.StatusCode.UNAUTHENTICATED, message) + context.abort(code, message) raise RuntimeError("Should not reach this point") # Make mypy happy return grpc.unary_unary_rpc_method_handler(terminate) @@ -68,7 +70,7 @@ def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False): self.state_factory = state_factory self.auto_auth = auto_auth - def intercept_service( + def intercept_service( # pylint: disable=too-many-return-statements self, continuation: Callable[[Any], Any], handler_call_details: grpc.HandlerCallDetails, @@ -79,6 +81,13 @@ def intercept_service( metadata sent by the node. Continue RPC call if node is authenticated, else, terminate RPC call by setting context to abort. """ + # Filter out non-Fleet service calls + if not handler_call_details.method.startswith("/flwr.proto.Fleet/"): + return _unary_unary_rpc_terminator( + "This request should be sent to a different service.", + grpc.StatusCode.FAILED_PRECONDITION, + ) + state = self.state_factory.state() metadata_dict = dict(handler_call_details.invocation_metadata)