Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][compiled-graphs] Register custom serializers for InputAttributeNodes and pass their type hints to downstream nodes #49236

Merged
merged 14 commits into from
Dec 15, 2024
37 changes: 31 additions & 6 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,8 @@ def __init__(
# Preprocessing identifies the input node and output node.
self.input_task_idx: Optional[int] = None
self.output_task_idx: Optional[int] = None
# List of task indices that are input attribute nodes.
self.input_attr_task_idx_list: List[int] = []
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
# Denotes whether execute/execute_async returns a list of refs/futures.
self._returns_list: bool = False
# Number of expected positional args and kwargs that may be passed to
Expand Down Expand Up @@ -949,11 +951,13 @@ def _preprocess(self) -> None:
nccl_actors_p2p: Set["ray.actor.ActorHandle"] = set()
nccl_collective_ops: Set[_CollectiveOperation] = set()

# Find the input node to the DAG.
# Find the input node and input attribute nodes in the DAG.
for idx, task in self.idx_to_task.items():
if isinstance(task.dag_node, InputNode):
assert self.input_task_idx is None, "More than one InputNode found"
self.input_task_idx = idx
elif isinstance(task.dag_node, InputAttributeNode):
self.input_attr_task_idx_list.append(idx)

# Find the (multi-)output node to the DAG.
for idx, task in self.idx_to_task.items():
Expand Down Expand Up @@ -1088,6 +1092,9 @@ def _preprocess(self) -> None:
):
downstream_actor_handle = dag_node._get_actor_handle()

# Add the type hint of the upstream node to the task.
task.arg_type_hints.append(upstream_task.dag_node.type_hint)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolve "Reason 1". In the logic below, if the upstream_task.dag_node is an InputAttributeNode, the upstream_task will be set to self.idx_to_task[self.input_task_idx]. This is why I moved this line above.


if isinstance(upstream_task.dag_node, InputAttributeNode):
# Record all of the keys used to index the InputNode.
# During execution, we will check that the user provides
Expand Down Expand Up @@ -1123,7 +1130,6 @@ def _preprocess(self) -> None:
direct_input = True

upstream_task.downstream_task_idxs[task_idx] = downstream_actor_handle
task.arg_type_hints.append(upstream_task.dag_node.type_hint)

if upstream_task.dag_node.type_hint.requires_nccl():
# Add all readers to the NCCL actors of P2P.
Expand Down Expand Up @@ -1496,8 +1502,6 @@ def _get_or_compile(
)

input_task = self.idx_to_task[self.input_task_idx]
# Register custom serializers for inputs provided to dag.execute().
input_task.dag_node.type_hint.register_custom_serializer()
self.dag_input_channels = input_task.output_channels
assert self.dag_input_channels is not None

Expand Down Expand Up @@ -1599,8 +1603,9 @@ def _get_or_compile(
task = self.idx_to_task[output_idx]
assert len(task.output_channels) == 1
self.dag_output_channels.append(task.output_channels[0])
# Register custom serializers for DAG outputs.
output.type_hint.register_custom_serializer()

# Register custom serializers for input, input attribute, and output nodes.
self._register_input_output_custom_serializer()

assert self.dag_input_channels
assert self.dag_output_channels
Expand Down Expand Up @@ -2780,6 +2785,26 @@ def visualize(
dot.render(filename, view=view)
return dot.source

def _register_input_output_custom_serializer(self):
"""
Register custom serializers for input, input attribute, and output nodes.
"""
assert self.input_task_idx is not None
assert self.output_task_idx is not None

# Register custom serializers for input node.
input_task = self.idx_to_task[self.input_task_idx]
input_task.dag_node.type_hint.register_custom_serializer()

# Register custom serializers for input attribute nodes.
for input_attr_task_idx in self.input_attr_task_idx_list:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolve "Reason 2"

input_attr_task = self.idx_to_task[input_attr_task_idx]
input_attr_task.dag_node.type_hint.register_custom_serializer()

# Register custom serializers for output nodes.
for output in self.idx_to_task[self.output_task_idx].args:
output.type_hint.register_custom_serializer()

def teardown(self, kill_actors: bool = False):
"""Teardown and cancel all actor tasks for this DAG. After this
function returns, the actors should be available to execute new tasks
Expand Down
191 changes: 191 additions & 0 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ def forward(self, inp):
return torch.randn(10, 10)


@ray.remote
class Worker:
def __init__(self):
self.device = None

def no_op(self, tensor):
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(tensor, torch.Tensor)
self.device = tensor.device
return tensor

def get_device(self):
return self.device


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_p2p(ray_start_regular):
if USE_GPU:
Expand Down Expand Up @@ -1270,6 +1284,183 @@ def recv(self, tensor):
compiled_dag.teardown()


class TestTorchTensorTypeHintCustomSerializer:
@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
@pytest.mark.parametrize("tensor_device", ["cpu", "cuda"])
def test_input_node_without_type_hint(self, ray_start_regular, tensor_device):
"""
Since no TorchTensorType hint is provided in this compiled graph,
normal serialization and deserialization functions are used, which will
not move the tensor to GPU/CPU.
"""
if not USE_GPU:
pytest.skip("Test requires GPU")

worker = Worker.options(num_gpus=1).remote()

with InputNode() as inp:
dag = worker.no_op.bind(inp)

compiled_dag = dag.experimental_compile()
tensor = torch.tensor([1])
if tensor_device == "cuda":
tensor = tensor.cuda()
ref = compiled_dag.execute(tensor)
t = ray.get(ref)
assert torch.equal(t, tensor)

device = ray.get(worker.get_device.remote())
assert device.type == tensor_device

@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
@pytest.mark.parametrize("tensor_device", ["cpu", "cuda"])
def test_input_node_with_type_hint(self, ray_start_regular, tensor_device):
"""
Since `inp` has a TorchTensorType hint, both the driver and `worker` will
use the custom serializer.

Step 1: The driver calls `serialize_tensor` to serialize `input_tensor` and
move the tensor to CPU if it is on GPU.
Step 2: The `worker` calls `deserialize_tensor` to deserialize `input_tensor`
and moves it to GPU.
Step 3: The `worker` calls `serialize_tensor` to serialize the result of
`no_op` and moves it to CPU.
Step 4: The driver calls `deserialize_tensor` to deserialize the result of
`no_op`. Since the driver's `ChannelContext.torch_device` is CPU,
the tensor will not be moved to GPU.
"""
if not USE_GPU:
pytest.skip("Test requires GPU")

worker = Worker.options(num_gpus=1).remote()

with InputNode() as inp:
dag = worker.no_op.bind(inp.with_type_hint(TorchTensorType()))
compiled_dag = dag.experimental_compile()
cpu_tensor = torch.tensor([1])
input_tensor = cpu_tensor
if tensor_device == "cuda":
input_tensor = input_tensor.cuda()
ref = compiled_dag.execute(input_tensor)
# Verify Step 4
t = ray.get(ref)
assert torch.equal(t, cpu_tensor)

# Verify Step 2
device = ray.get(worker.get_device.remote())
assert device.type == "cuda"

@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_input_attr_nodes_with_all_tensor_type_hint(self, ray_start_regular):
"""
Since both `inp[0]` and `inp[1]` have tensor type hint, both workers will
use the custom serializer.

Step 1: The driver calls `serialize_tensor` to serialize `cpu_tensor_1`
and `cpu_tensor_2`.

Step 2:
* The `worker1` calls `deserialize_tensor` to deserialize `cpu_tensor_1`
and moves it to GPU.
* The `worker2` calls `deserialize_tensor` to deserialize `cpu_tensor_2`
and moves it to GPU.

Step 3:
* The `worker1` calls `serialize_tensor` to serialize the result of
`no_op` and moves it to CPU.
* The `worker2` calls `serialize_tensor` to serialize the result of
`no_op` and moves it to CPU.

Step 4: The driver calls `deserialize_tensor` to deserialize the result
of `no_op`. Since the driver's `ChannelContext.torch_device` is CPU,
the tensor will not be moved to GPU.
"""
worker1 = Worker.options(num_gpus=1).remote()
worker2 = Worker.options(num_gpus=1).remote()
with InputNode() as inp:
dag = inp[0].with_type_hint(TorchTensorType())
branch1 = worker1.no_op.bind(dag)
dag = inp[1].with_type_hint(TorchTensorType())
branch2 = worker2.no_op.bind(dag)
dag = MultiOutputNode([branch1, branch2])

compiled_dag = dag.experimental_compile()
cpu_tensor_1 = torch.tensor([1])
cpu_tensor_2 = torch.tensor([2])
ref = compiled_dag.execute(cpu_tensor_1, cpu_tensor_2)

# Verify Step 4
t1, t2 = ray.get(ref)
assert torch.equal(t1, cpu_tensor_1)
assert torch.equal(t2, cpu_tensor_2)

# Verify Step 2
device1 = ray.get(worker1.get_device.remote())
device2 = ray.get(worker2.get_device.remote())
assert device1.type == "cuda"
assert device2.type == "cuda"

@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_partial_input_attr_nodes_with_tensor_type_hint(self, ray_start_regular):
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
"""
Only `inp[0]` has a tensor type hint, so only `worker1` will use the custom
serializer. Note that although we don't register the custom serializer for
`worker2`, it still uses the custom deserializer. This is because when custom
serializers are registered with Ray, the registered deserializer is shipped
with the serialized value and used on the receiving end. See the comment in
`ChannelOutputType.register_custom_serializer` for more details.

Step 1: The driver calls `serialize_tensor` to serialize `cpu_tensor_1`
and `cpu_tensor_2`.

Step 2:
* The `worker1` calls `deserialize_tensor` to deserialize `cpu_tensor_1`
and moves it to GPU.
* The `worker2` calls `deserialize_tensor` to deserialize `cpu_tensor_2`
and moves it to GPU.

Step 3:
* The `worker1` calls `serialize_tensor` to serialize the result of `no_op`
and moves it to CPU.
* The `worker2` calls the normal serialization function to serialize the
result of `no_op` because it doesn't have a custom serializer, so the
tensor is still on GPU.

Step 4:
* The driver calls `deserialize_tensor` to deserialize the tensor from
`worker1`. Since the driver's `ChannelContext.torch_device` is CPU,
the tensor will not be moved to GPU.
* The driver calls normal deserialization function to deserialize the
tensor from `worker2`.
"""
worker1 = Worker.options(num_gpus=1).remote()
worker2 = Worker.options(num_gpus=1).remote()

with InputNode() as inp:
dag = inp[0].with_type_hint(TorchTensorType())
branch1 = worker1.no_op.bind(dag)
dag = inp[1]
branch2 = worker2.no_op.bind(dag)
dag = MultiOutputNode([branch1, branch2])

compiled_dag = dag.experimental_compile()
cpu_tensor_1 = torch.tensor([1])
cpu_tensor_2 = torch.tensor([2])
ref = compiled_dag.execute(cpu_tensor_1, cpu_tensor_2)
t1, t2 = ray.get(ref)
# Verify Step 3-1
assert torch.equal(t1, cpu_tensor_1)
# Verify Step 3-2
gpu_tensor_2 = cpu_tensor_2.cuda()
assert torch.equal(t2, gpu_tensor_2)

# Verify Step 2
device1 = ray.get(worker1.get_device.remote())
device2 = ray.get(worker2.get_device.remote())
assert device1.type == "cuda"
assert device2.type == "cuda"


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
Loading