Skip to content

Commit

Permalink
fix: do not check ping messages. (#189)
Browse files Browse the repository at this point in the history
* fix: do not check ping messages.

* Fix typing typos.
  • Loading branch information
zhouaihui authored Dec 26, 2023
1 parent bb7092a commit a615474
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import copy
import logging
import time
from typing import Dict
from typing import Any, Dict

import ray

Expand Down Expand Up @@ -458,11 +458,20 @@ def _start_sender_receiver_proxy(
logger.info("Succeeded to create receiver proxy actor.")


def send(dest_party, data, upstream_seq_id, downstream_seq_id, is_error=False):
def send(
dest_party: str,
data: Any,
upstream_seq_id: int,
downstream_seq_id: int,
is_error: bool = False,
check_sending: bool = True,
):
"""
Args:
is_error: Whether the `data` is an error object or not. Default is False.
If True, the data will be sent to the error message queue.
check_sending: Whether to check the data sending. If true, the data will be
checked in the sending check loop.
"""
sender_proxy = ray.get_actor(sender_proxy_actor_name())
res = sender_proxy.send.remote(
Expand All @@ -471,13 +480,14 @@ def send(dest_party, data, upstream_seq_id, downstream_seq_id, is_error=False):
upstream_seq_id=upstream_seq_id,
downstream_seq_id=downstream_seq_id,
)
get_global_context().get_cleanup_manager().push_to_sending(
res, dest_party, upstream_seq_id, downstream_seq_id, is_error
)
if check_sending:
get_global_context().get_cleanup_manager().push_to_sending(
res, dest_party, upstream_seq_id, downstream_seq_id, is_error
)
return res


def recv(party: str, src_party: str, upstream_seq_id, curr_seq_id):
def recv(party: str, src_party: str, upstream_seq_id: int, curr_seq_id: int):
assert party, 'Party can not be None.'
receiver_proxy = ray.get_actor(receiver_proxy_actor_name())
return receiver_proxy.get_data.remote(src_party, upstream_seq_id, curr_seq_id)
Expand All @@ -496,7 +506,9 @@ def ping_others(addresses: Dict[str, Dict], self_party: str, max_retries=3600):
_party_ping_obj = {} # {$party_name: $ObjectRef}
# Batch ping all the other parties
for other in others:
_party_ping_obj[other] = send(other, b'data', 'ping', 'ping')
_party_ping_obj[other] = send(
other, b'data', 'ping', 'ping', check_sending=False
)
_, _unready = ray.wait(list(_party_ping_obj.values()), timeout=1)

# Keep the unready party for the next ping.
Expand Down

0 comments on commit a615474

Please sign in to comment.