From 6153a9c76261ebdddee0011ae0a59a236ee26262 Mon Sep 17 00:00:00 2001 From: zhouaihui Date: Tue, 2 Jan 2024 15:11:02 +0800 Subject: [PATCH] feat: expose sending error to drive main thread (#190) * feat: expose sending error to drive main thread and fix some cross-silo error bugs. * Revert unintended modifications. * Fix lint. * Fix a typing. * Fix typing continue. --- fed/_private/global_context.py | 30 +++++++---- fed/_private/message_queue.py | 29 +++++----- fed/api.py | 65 ++++++++++++++++------- fed/cleanup.py | 22 ++++++-- fed/proxy/barriers.py | 3 +- fed/tests/test_cross_silo_error.py | 14 +---- fed/tests/test_exit_on_failure_sending.py | 32 +++++------ 7 files changed, 118 insertions(+), 77 deletions(-) diff --git a/fed/_private/global_context.py b/fed/_private/global_context.py index 8e5c550..0a7c522 100644 --- a/fed/_private/global_context.py +++ b/fed/_private/global_context.py @@ -20,11 +20,16 @@ class GlobalContext: def __init__( - self, job_name: str, current_party: str, failure_handler: Callable[[], None] + self, + job_name: str, + current_party: str, + sending_failure_handler: Callable[[Exception], None], + exit_on_sending_failure=False, ) -> None: self._job_name = job_name self._seq_count = 0 - self._failure_handler = failure_handler + self._sending_failure_handler = sending_failure_handler + self._exit_on_sending_failure = exit_on_sending_failure self._atomic_shutdown_flag_lock = threading.Lock() self._atomic_shutdown_flag = True self._cleanup_manager = CleanupManager( @@ -41,13 +46,16 @@ def get_cleanup_manager(self) -> CleanupManager: def get_job_name(self) -> str: return self._job_name - def get_failure_handler(self) -> Callable[[], None]: - return self._failure_handler + def get_sending_failure_handler(self) -> Callable[[], None]: + return self._sending_failure_handler + + def get_exit_on_sending_failure(self) -> bool: + return self._exit_on_sending_failure def acquire_shutdown_flag(self) -> bool: """ Acquiring a lock and set the flag to make sure - `fed.shutdown(intended=False)` can be called only once. + `fed.shutdown()` can be called only once. The unintended shutdown, i.e. `fed.shutdown(intended=False)`, needs to be executed only once. However, `fed.shutdown` may get called duing @@ -68,11 +76,15 @@ def acquire_shutdown_flag(self) -> bool: def init_global_context( - current_party: str, job_name: str, failure_handler: Callable[[], None] = None + current_party: str, + job_name: str, + sending_failure_handler: Callable[[Exception], None] = None, ) -> None: global _global_context if _global_context is None: - _global_context = GlobalContext(job_name, current_party, failure_handler) + _global_context = GlobalContext( + job_name, current_party, sending_failure_handler + ) def get_global_context(): @@ -80,8 +92,8 @@ def get_global_context(): return _global_context -def clear_global_context(): +def clear_global_context(graceful=True): global _global_context if _global_context is not None: - _global_context.get_cleanup_manager().stop() + _global_context.get_cleanup_manager().stop(graceful=graceful) _global_context = None diff --git a/fed/_private/message_queue.py b/fed/_private/message_queue.py index 9fdff7d..e4dabe9 100644 --- a/fed/_private/message_queue.py +++ b/fed/_private/message_queue.py @@ -63,20 +63,26 @@ def _loop(): def append(self, message): self._queue.append(message) - def notify_to_exit(self): + def appendleft(self, message): + self._queue.appendleft(message) + + def _notify_to_exit(self, immediately=False): logger.info(f"Notify message polling thread[{self._thread_name}] to exit.") - self.append(STOP_SYMBOL) + if immediately: + self.appendleft(STOP_SYMBOL) + else: + self.append(STOP_SYMBOL) - def stop(self): + def stop(self, immediately=False): """ Stop the message queue. Args: - graceful (bool): A flag indicating whether to stop the queue - gracefully or not. Default is True. - If True: insert the STOP_SYMBOL at the end of the queue - and wait for it to be processed, which will break the for-loop; - If False: forcelly kill the for-loop sub-thread. + immediately (bool): A flag indicating whether to stop the queue + immediately or not. Default is True. + If True: insert the STOP_SYMBOL at the begin of the queue. + If False: insert the STOP_SYMBOL at the end of the queue, which means + stop the for loop until all messages in queue are all sent. """ if threading.current_thread() == self._thread: logger.error( @@ -90,11 +96,10 @@ def stop(self): # encounter AssertionError because sub-thread's lock is not released. # Therefore, currently, not support forcelly kill thread if self.is_started(): - logger.debug(f"Gracefully killing thread[{self._thread_name}].") - self.notify_to_exit() + logger.debug(f"Killing thread[{self._thread_name}].") + self._notify_to_exit(immediately=immediately) self._thread.join() - - logger.info(f"The message polling thread[{self._thread_name}] was exited.") + logger.info(f"The message polling thread[{self._thread_name}] was exited.") def is_started(self): return self._thread is not None and self._thread.is_alive() diff --git a/fed/api.py b/fed/api.py index 66d4aba..caf8ef1 100644 --- a/fed/api.py +++ b/fed/api.py @@ -21,6 +21,7 @@ import cloudpickle import ray from ray.exceptions import RayError +import sys import fed._private.compatible_utils as compatible_utils import fed.config as fed_config @@ -74,7 +75,7 @@ def init( receiver_proxy_cls: ReceiverProxy = None, receiver_sender_proxy_cls: SenderReceiverProxy = None, job_name: str = constants.RAYFED_DEFAULT_JOB_NAME, - failure_handler: Callable[[], None] = None, + sending_failure_handler: Callable[[Exception], None] = None, ): """ Initialize a RayFed client. @@ -146,6 +147,9 @@ def init( default fixed name will be assigned, therefore messages of all anonymous jobs will be mixed together, which should only be used in the single job scenario or test mode. + sending_failure_handler: optional; a callback which will be triggeed if + cross-silo message sending failed and exit_on_sending_failure in config is + True. Examples: >>> import fed >>> import ray @@ -164,7 +168,9 @@ def init( fed_utils.validate_addresses(addresses) init_global_context( - current_party=party, job_name=job_name, failure_handler=failure_handler + current_party=party, + job_name=job_name, + sending_failure_handler=sending_failure_handler, ) tls_config = {} if tls_config is None else tls_config if tls_config: @@ -281,16 +287,42 @@ def _shutdown(intended=True): Shutdown a RayFed client. Args: - intended: (Optional) Whether this is a intended exit. If not - a "failure handler" will be triggered. + intended: (Optional) Whether this is a intended shutdown. If not + a "failure handler" will be triggered and sys.exit will be called then. """ - if get_global_context() is not None: - # Job has inited, can be shutdown - failure_handler = get_global_context().get_failure_handler() + + if get_global_context() is None: + # Do nothing since job has not been inited or is cleaned already. + return + + if intended: + logger.info('Shutdowning rayfed intendedly...') + else: + logger.warn('Shutdowning rayfed unintendedly...') + global_context = get_global_context() + last_sending_error = global_context.get_cleanup_manager().get_last_sending_error() + if last_sending_error is not None: + logging.error(f'Cross-silo sending error occured. {last_sending_error}') + + if not intended: + # Execute failure_handler fisrtly. + failure_handler = global_context.get_sending_failure_handler() + if failure_handler is not None: + logger.info('Executing failure handler...') + failure_handler(last_sending_error) + + # Clean context. compatible_utils._clear_internal_kv() - clear_global_context() - if not intended and failure_handler is not None: - failure_handler() + clear_global_context(graceful=intended) + logger.info('Shutdowned rayfed.') + + # Exit with error. + logger.critical('Exit now due to the previous error.') + sys.exit(1) + else: + # Clean context. + compatible_utils._clear_internal_kv() + clear_global_context(graceful=intended) logger.info('Shutdowned rayfed.') @@ -474,14 +506,11 @@ def get( if is_individual_id: values = values[0] return values - except RayError as e: - if isinstance(e, FedRemoteError): - logger.warning( - "Encounter RemoteError happend in other parties" - f", prepare to exit, error message: {e.cause}" - ) - if get_global_context().acquire_shutdown_flag(): - _shutdown(intended=False) + except FedRemoteError as e: + logger.warning( + "Encounter RemoteError happend in other parties" + f", error message: {e.cause}" + ) raise e diff --git a/fed/cleanup.py b/fed/cleanup.py index 91a0739..03aab89 100644 --- a/fed/cleanup.py +++ b/fed/cleanup.py @@ -57,6 +57,7 @@ def __init__(self, current_party, acquire_shutdown_flag) -> None: self._current_party = current_party self._acquire_shutdown_flag = acquire_shutdown_flag + self._last_sending_error = None def start(self, exit_on_sending_failure=False, expose_error_trace=False): self._exit_on_sending_failure = exit_on_sending_failure @@ -70,18 +71,25 @@ def start(self, exit_on_sending_failure=False, expose_error_trace=False): def _main_thread_monitor(): main_thread = threading.main_thread() main_thread.join() - self.stop() + logging.debug('Stoping sending queue ...') + self.stop(graceful=True) self._monitor_thread = threading.Thread(target=_main_thread_monitor) self._monitor_thread.start() logger.info('Start check sending monitor thread.') - def stop(self): + def stop(self, graceful=True): # NOTE(NKcqx): MUST firstly stop the data queue, because it # may still throw errors during the termination which need to # be sent to the error queue. - self._sending_data_q.stop() - self._sending_error_q.stop() + if graceful: + self._sending_data_q.stop(immediately=False) + self._sending_error_q.stop(immediately=False) + else: + # Stop data queue immediately, but stop error queue not immediately always + # to sure that error can be sent to peers. + self._sending_data_q.stop(immediately=True) + self._sending_error_q.stop(immediately=False) def push_to_sending( self, @@ -114,6 +122,9 @@ def push_to_sending( else: self._sending_data_q.append(msg_pack) + def get_last_sending_error(self): + return self._last_sending_error + def _signal_exit(self): """ Exit the current process immediately. The signal will be captured @@ -129,7 +140,7 @@ def _signal_exit(self): # once and avoid dead lock, the lock must be checked before sending # signals. if self._acquire_shutdown_flag(): - logger.debug("Signal SIGINT to exit.") + logger.warn("Signal SIGINT to exit.") os.kill(os.getpid(), signal.SIGINT) def _process_data_sending_task_return(self, message): @@ -161,6 +172,7 @@ def _process_data_sending_task_return(self, message): f'upstream_seq_id: {upstream_seq_id}, ' f'downstream_seq_id: {downstream_seq_id}.' ) + self._last_sending_error = e if isinstance(e, RayError): logger.info(f"Sending error {e.cause} to {dest_party}.") from fed.proxy.barriers import send diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index a7b1e0f..0f55afb 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -20,6 +20,7 @@ import ray import fed.config as fed_config +from fed.exceptions import FedRemoteError from fed._private import constants from fed._private.global_context import get_global_context from fed.proxy.base_proxy import ReceiverProxy, SenderProxy, SenderReceiverProxy @@ -223,7 +224,7 @@ async def get_data(self, src_party, upstream_seq_id, curr_seq_id): data = await self._proxy_instance.get_data( src_party, upstream_seq_id, curr_seq_id ) - if isinstance(data, Exception): + if isinstance(data, FedRemoteError): logger.debug( f"Receiving exception: {type(data)}, {data} from {src_party}, " f"upstream_seq_id: {upstream_seq_id}, " diff --git a/fed/tests/test_cross_silo_error.py b/fed/tests/test_cross_silo_error.py index f2f623e..c79a108 100644 --- a/fed/tests/test_cross_silo_error.py +++ b/fed/tests/test_cross_silo_error.py @@ -44,7 +44,6 @@ def error_func(self): def run(party): - my_failure_handler = Mock() compatible_utils.init_ray(address='local') addresses = { 'alice': '127.0.0.1:11012', @@ -57,12 +56,10 @@ def run(party): logging_level='debug', config={ 'cross_silo_comm': { - 'exit_on_sending_failure': True, 'timeout_ms': 20 * 1000, 'expose_error_trace': True, }, }, - failure_handler=my_failure_handler, ) # Both party should catch the error @@ -76,7 +73,6 @@ def run(party): else: assert isinstance(e.value.cause, MyError) assert "normal task Error" in str(e.value.cause) - my_failure_handler.assert_called_once() fed.shutdown() ray.shutdown() @@ -93,7 +89,6 @@ def test_cross_silo_normal_task_error(): def run2(party): - my_failure_handler = Mock() compatible_utils.init_ray(address='local') addresses = { 'alice': '127.0.0.1:11012', @@ -105,12 +100,10 @@ def run2(party): logging_level='debug', config={ 'cross_silo_comm': { - 'exit_on_sending_failure': True, 'timeout_ms': 20 * 1000, 'expose_error_trace': True, }, }, - failure_handler=my_failure_handler, ) # Both party should catch the error @@ -123,11 +116,9 @@ def run2(party): assert isinstance(e.value.cause, FedRemoteError) assert 'RemoteError occurred at alice' in str(e.value.cause) assert "actor task Error" in str(e.value.cause) - my_failure_handler.assert_called_once() else: assert isinstance(e.value.cause, MyError) assert "actor task Error" in str(e.value.cause) - my_failure_handler.assert_called_once() fed.shutdown() ray.shutdown() @@ -145,7 +136,6 @@ def test_cross_silo_actor_task_error(): def run3(party): - my_failure_handler = Mock() compatible_utils.init_ray(address='local') addresses = { 'alice': '127.0.0.1:11012', @@ -158,11 +148,10 @@ def run3(party): logging_level='debug', config={ 'cross_silo_comm': { - 'exit_on_sending_failure': True, 'timeout_ms': 20 * 1000, + 'expose_error_trace': False, }, }, - failure_handler=my_failure_handler, ) # Both party should catch the error @@ -176,7 +165,6 @@ def run3(party): else: assert isinstance(e.value.cause, MyError) assert "normal task Error" in str(e.value.cause) - my_failure_handler.assert_called_once() fed.shutdown() ray.shutdown() diff --git a/fed/tests/test_exit_on_failure_sending.py b/fed/tests/test_exit_on_failure_sending.py index 3da50ae..8414b48 100644 --- a/fed/tests/test_exit_on_failure_sending.py +++ b/fed/tests/test_exit_on_failure_sending.py @@ -13,27 +13,17 @@ # limitations under the License. import multiprocessing -import os -import signal import sys import pytest -import ray import fed import fed._private.compatible_utils as compatible_utils -def signal_handler(sig, frame): - if sig == signal.SIGTERM.value: - fed.shutdown() - ray.shutdown() - os._exit(0) - - @fed.remote def f(): - return 100 + raise Exception('By design.') @fed.remote @@ -45,9 +35,7 @@ def get_value(self): return self._value -def run(party): - signal.signal(signal.SIGTERM, signal_handler) - +def run(party: str, q: multiprocessing.Queue): compatible_utils.init_ray(address='local') addresses = { 'alice': '127.0.0.1:11012', @@ -61,6 +49,9 @@ def run(party): "retryableStatusCodes": ["UNAVAILABLE"], } + def failure_handler(error): + q.put('failure handler') + fed.init( addresses=addresses, party=party, @@ -72,22 +63,25 @@ def run(party): 'timeout_ms': 20 * 1000, }, }, - failure_handler=lambda: os.kill(os.getpid(), signal.SIGTERM), + sending_failure_handler=failure_handler, ) - o = f.party("alice").remote() My.party("bob").remote(o) + import time - # Wait for SIGTERM as failure on sending. + # Wait a long time. + # If the test takes effect, the main loop here will be broken. time.sleep(86400) def test_exit_when_failure_on_sending(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) + q = multiprocessing.Queue() + p_alice = multiprocessing.Process(target=run, args=('alice', q)) p_alice.start() p_alice.join() - assert p_alice.exitcode == 0 + assert p_alice.exitcode == 1 + assert q.get() == 'failure handler' if __name__ == "__main__":