Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Qiao <[email protected]>
  • Loading branch information
ruisearch42 committed Jan 31, 2025
1 parent 2586df2 commit a753ca2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
34 changes: 16 additions & 18 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,9 @@ def _preprocess(self) -> None:
# Collect NCCL collective operations.
if isinstance(dag_node, CollectiveOutputNode):
self._track_communicator_usage(
dag_node, {actor_handle}, collective_op=True
dag_node,
dag_node._collective_op.actor_handles,
collective_op=True,
)
assert not self._overlap_gpu_communication, (
"Currently, the overlap_gpu_communication option is not "
Expand Down Expand Up @@ -1241,21 +1243,12 @@ def _init_communicators(self) -> None:
raise ValueError(
"Communicator creation is not allowed for collective operations."
)
actors = collective_op.actor_handles
if frozenset(actors) in self._actors_to_created_communicator_id:
communicator_id = self._actors_to_created_communicator_id[
frozenset(actors)
]
else:
communicator_id = _init_communicator(
actors,
None,
self._overlap_gpu_communication,
)
self._actors_to_created_communicator_id[
frozenset(actors)
] = communicator_id
collective_op.set_communicator_id(communicator_id)
actors = frozenset(collective_op.actor_handles)
communicator_id = collective_op.init_communicator(
self._actors_to_created_communicator_id.get(actors, None)
)
if actors not in self._actors_to_created_communicator_id:
self._actors_to_created_communicator_id[actors] = communicator_id

if self._pending_p2p_communicator_actors:
if (
Expand Down Expand Up @@ -1286,15 +1279,20 @@ def _track_communicator_usage(
"""
if None in actors:
raise ValueError("Driver cannot participate in the NCCL group.")
custom_communicator = dag_node.type_hint.get_custom_communicator()
if collective_op:
custom_communicator = (
dag_node._collective_op.type_hint.get_custom_communicator()
)
else:
custom_communicator = dag_node.type_hint.get_custom_communicator()
communicator = (
self._get_default_communicator(dag_node)
if custom_communicator is None
else custom_communicator
)
if communicator is None:
if collective_op:
self._pending_collective_ops.add(dag_node)
self._pending_collective_ops.add(dag_node._collective_op)
else:
self._pending_p2p_communicator_dag_nodes.add(dag_node)
self._pending_p2p_communicator_actors.update(actors)
Expand Down
4 changes: 2 additions & 2 deletions python/ray/experimental/channel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def get_custom_communicator(self) -> Optional[Communicator]:
"""
Return the custom NCCL group if one is specified.
"""
if self._contains_type is not None:
return self._contains_type.get_custom_nccl_group()
if hasattr(self, "contains_type") and self.contains_type is not None:
return self.contains_type.get_custom_nccl_group()
return None

def set_communicator_id(self, group_id: str) -> None:
Expand Down

0 comments on commit a753ca2

Please sign in to comment.