Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][dag] Add ascii based CG visualization #48315

Merged
merged 41 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
938efd2
fix
Bye-legumes Oct 29, 2024
26340df
fix
Bye-legumes Oct 29, 2024
450a1b6
fix
Bye-legumes Oct 29, 2024
3f5a7c5
fix
Bye-legumes Oct 29, 2024
eb413a7
Merge branch 'master' into add_ascii_visualization
Bye-legumes Oct 30, 2024
f5e94d2
Merge branch 'master' into add_ascii_visualization
Bye-legumes Nov 4, 2024
4823bf7
Merge branch 'master' into add_ascii_visualization
Bye-legumes Nov 6, 2024
a5b98b0
Merge branch 'master' into add_ascii_visualization
Bye-legumes Nov 7, 2024
5f3aacc
fix
Bye-legumes Nov 7, 2024
c3af022
fix
Bye-legumes Nov 7, 2024
7a8fdab
fix
Bye-legumes Nov 7, 2024
50e0cba
fix
Bye-legumes Nov 7, 2024
64a8de4
fix
Bye-legumes Nov 7, 2024
d6a62ea
fix
Bye-legumes Nov 7, 2024
723bba7
fix
Bye-legumes Nov 7, 2024
bf33506
fix
Bye-legumes Nov 7, 2024
572cde0
Merge branch 'master' into add_ascii_visualization
Bye-legumes Nov 7, 2024
65d0fe7
fix
Bye-legumes Nov 11, 2024
11465b8
fix
Bye-legumes Nov 12, 2024
71c682e
fix
Bye-legumes Nov 12, 2024
26923c2
fix
Bye-legumes Nov 12, 2024
7852205
fix
Bye-legumes Nov 12, 2024
352d613
fix
Bye-legumes Nov 12, 2024
07ba652
Merge branch 'master' into add_ascii_visualization
Bye-legumes Nov 12, 2024
eab7883
fix
Bye-legumes Nov 18, 2024
de0098e
Dfix
Bye-legumes Nov 18, 2024
4cd7a81
Update python/ray/dag/compiled_dag_node.py
Bye-legumes Nov 19, 2024
84154fb
fix
Bye-legumes Nov 19, 2024
a8ac11b
fix
Bye-legumes Nov 19, 2024
44e25dd
Merge branch 'master' into add_ascii_visualization
Bye-legumes Nov 19, 2024
6205e09
Merge branch 'master' into add_ascii_visualization
Bye-legumes Nov 20, 2024
c16f15a
fix
Bye-legumes Nov 20, 2024
e89a0d2
Merge branch 'master' into add_ascii_visualization
Bye-legumes Nov 25, 2024
e87e238
Update python/ray/dag/compiled_dag_node.py
Bye-legumes Dec 4, 2024
2dac343
fix
Bye-legumes Dec 4, 2024
3d64351
fix
Bye-legumes Dec 4, 2024
dbb322d
fix
Bye-legumes Dec 4, 2024
1c16049
Merge branch 'master' into add_ascii_visualization
Bye-legumes Dec 6, 2024
9692c6f
fix
Bye-legumes Dec 6, 2024
f801dbb
fix
Bye-legumes Dec 6, 2024
96e6bd4
fix
Bye-legumes Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
356 changes: 277 additions & 79 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2234,28 +2234,45 @@ async def execute_async(
self._execution_index += 1
return fut

def visualize(
self, filename="compiled_graph", format="png", view=False, return_dot=False
):
def visualize(self, format="png", filename="compiled_graph", view=False):
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved
"""
Visualize the compiled graph using Graphviz.

This method generates a graphical representation of the compiled graph,
showing tasks and their dependencies.This method should be called
**after** the graph has been compiled using `experimental_compile()`.
This method provides two modes for visualization:
1. **Graphviz PNG/PDF**: Generates a graphical file representing tasks as
nodes with edges representing dependencies.
2. **ASCII Format**: Prints a detailed text-based visualization of the CG,
including task nodes, types, and edges with type hints.

This method should be called
**after** compiling the graph with `experimental_compile()`.


Args:
format: The output format for the visualization. Options:
- `"png"` (default): Generates a PNG file using Graphviz.
- `"pdf"`: Generates a PDF file using Graphviz.
- `"ascii"`: Prints the CG structure in ASCII format to the console.
filename: The name of the output file (without extension).
format: The format of the output file (e.g., 'png', 'pdf').
view: Whether to open the file with the default viewer.
return_dot: If True, returns the DOT source as a string instead of figure.

Returns:
- **None** if `format` is `"png"` or `"pdf"`.
- **str** if `format` is `"ascii"`,
returning the ASCII representation of the CG.
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved

Raises:
ValueError: If the graph is empty or not properly compiled.
ImportError: If the `graphviz` package is not installed.

Example:
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved
```python
# Visualize the compiled CG in PNG format
compiled_dag.visualize(format='png')

# Print the CG structure in ASCII format
print(compiled_dag.visualize(format='ascii'))
"""
import graphviz
from ray.dag import (
InputAttributeNode,
InputNode,
Expand All @@ -2278,89 +2295,270 @@ def visualize(
f"Task at index {idx} does not have a valid 'dag_node'. "
"Ensure that 'experimental_compile()' completed successfully."
)
if format == "ascii":
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved
from collections import defaultdict, deque
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved

# Create adjacency list representation of the DAG
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved
adj_list = defaultdict(list)
indegree = defaultdict(int)

is_multi_output = defaultdict(bool)
child2parent = defaultdict(int)
ascii_visualization = ""
node_info = {}
edge_info = []

for idx, task in self.idx_to_task.items():
dag_node = task.dag_node
label = f"Task {idx}\n"

# Determine the type and label of the node
if isinstance(dag_node, InputNode):
label += "InputNode"
elif isinstance(dag_node, InputAttributeNode):
label += f"InputAttributeNode[{dag_node.key}]"
elif isinstance(dag_node, MultiOutputNode):
label += "MultiOutputNode"
elif isinstance(dag_node, ClassMethodNode):
if dag_node.is_class_method_call:
method_name = dag_node.get_method_name()
actor_handle = dag_node._get_actor_handle()
actor_id = (
actor_handle._actor_id.hex()[:6]
if actor_handle
else "unknown"
)
label += f"Actor: {actor_id}...\nMethod: {method_name}"
elif dag_node.is_class_method_output:
label += f"ClassMethodOutputNode[{dag_node.output_idx}]"
else:
label += "ClassMethodNode"
else:
label += type(dag_node).__name__

# Dot file for debuging
dot = graphviz.Digraph(name="compiled_graph", format=format)
node_info[idx] = label

# Add nodes with task information
for idx, task in self.idx_to_task.items():
dag_node = task.dag_node
for arg_index, arg in enumerate(dag_node.get_args()):
if isinstance(arg, DAGNode):
upstream_task_idx = self.dag_node_to_idx[arg]

# Get the type hint for this argument
if arg_index < len(task.arg_type_hints):
type_hint = type(task.arg_type_hints[arg_index]).__name__
else:
type_hint = "UnknownType"

adj_list[upstream_task_idx].append(idx)
indegree[idx] += 1
edge_info.append((upstream_task_idx, idx, type_hint))

width_adjust = 0
for upstream_task_idx, child_idx_list in adj_list.items():
# Mark as multi-output if the node has more than one output path
if len(child_idx_list) > 1:
for child in child_idx_list:
is_multi_output[child] = True
child2parent[child] = upstream_task_idx
width_adjust = max(width_adjust, len(child_idx_list))

# Topological sort to determine layers
layers = defaultdict(list)
zero_indegree = deque(
[idx for idx in self.idx_to_task if indegree[idx] == 0]
)
layer_index = 0

while zero_indegree:
next_layer = deque()
while zero_indegree:
task_idx = zero_indegree.popleft()
layers[layer_index].append(task_idx)
for downstream in adj_list[task_idx]:
indegree[downstream] -= 1
if indegree[downstream] == 0:
next_layer.append(downstream)
zero_indegree = next_layer
layer_index += 1

# Print detailed node information
ascii_visualization += "Nodes Information:\n"
for idx, info in node_info.items():
ascii_visualization += f'{idx} [label="{info}"] \n'

# Print edges
ascii_visualization += "\nEdges Information:\n"
for upstream_task, downstream_task, type_hint in edge_info:
ascii_visualization += (
f"{upstream_task} -> {downstream_task} [label={type_hint}]\n"
)

# Initialize the label and attributes
label = f"Task {idx}\n"
shape = "oval" # Default shape
style = "filled"
fillcolor = ""

# Handle different types of dag_node
if isinstance(dag_node, InputNode):
label += "InputNode"
shape = "rectangle"
fillcolor = "lightblue"
elif isinstance(dag_node, InputAttributeNode):
label += f"InputAttributeNode[{dag_node.key}]"
shape = "rectangle"
fillcolor = "lightblue"
elif isinstance(dag_node, MultiOutputNode):
label += "MultiOutputNode"
shape = "rectangle"
fillcolor = "yellow"
elif isinstance(dag_node, ClassMethodNode):
if dag_node.is_class_method_call:
# Class Method Call Node
method_name = dag_node.get_method_name()
actor_handle = dag_node._get_actor_handle()
if actor_handle:
actor_id = actor_handle._actor_id.hex()
label += f"Actor: {actor_id[:6]}...\nMethod: {method_name}"
# Find the maximum width (number of nodes in any layer)
max_width = max(len(layer) for layer in layers.values()) + width_adjust
height = len(layers)

# Build grid for ASCII visualization
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved
grid = [[" " for _ in range(max_width * 20)] for _ in range(height * 2 - 1)]

# Place nodes in the grid with more details
task_to_pos = {}
for layer_num, layer_tasks in layers.items():
layer_y = layer_num * 2 # Every second row is for nodes
for col_num, task_idx in enumerate(layer_tasks):
task = self.idx_to_task[task_idx]
task_info = f"{task_idx}:"

# Determine if it's an actor method or a regular task
if isinstance(task.dag_node, ClassMethodNode):
if task.dag_node.is_class_method_call:
method_name = task.dag_node.get_method_name()
task_info += f"Actor:{method_name}"
elif task.dag_node.is_class_method_output:
task_info += f"Output[{task.dag_node.output_idx}]"
else:
task_info += "UnknownMethod"
else:
label += f"Method: {method_name}"
shape = "oval"
fillcolor = "lightgreen"
elif dag_node.is_class_method_output:
# Class Method Output Node
label += f"ClassMethodOutputNode[{dag_node.output_idx}]"
task_info += type(task.dag_node).__name__

adjust_col_num = 0
if task_idx in is_multi_output:
adjust_col_num = layers[layer_num - 1].index(
child2parent[task_idx]
)
col_x = (
col_num + adjust_col_num
) * 20 # Every 7th column for spacing
# Place the task information into the grid
for i, char in enumerate(task_info):
if col_x + i < len(
grid[0]
): # Ensure we don't overflow the grid
grid[layer_y][col_x + i] = char

task_to_pos[task_idx] = (layer_y, col_x)

# Connect the nodes with lines
for upstream_task, downstream_tasks in adj_list.items():
upstream_y, upstream_x = task_to_pos[upstream_task]
for downstream_task in downstream_tasks:
downstream_y, downstream_x = task_to_pos[downstream_task]

# Draw vertical line
for y in range(upstream_y + 1, downstream_y):
if grid[y][upstream_x] == " ":
grid[y][upstream_x] = "|"

# Draw horizontal line if needed
if upstream_x != downstream_x:
for x in range(
min(upstream_x, downstream_x) + 1,
max(upstream_x, downstream_x),
):
if grid[downstream_y - 1][x] != "|":
grid[downstream_y - 1][x] = "-"

# Draw connection to the next task
grid[downstream_y - 1][downstream_x] = "|"

# Ensure proper multi-output task connection
for idx, task in self.idx_to_task.items():
if isinstance(task.dag_node, MultiOutputNode):
output_tasks = task.dag_node.get_args()
for i, output_task in enumerate(output_tasks):
if isinstance(output_task, DAGNode):
output_task_idx = self.dag_node_to_idx[output_task]
if output_task_idx in task_to_pos:
output_y, output_x = task_to_pos[output_task_idx]
grid[output_y - 1][output_x] = "|"

# Convert grid to string for printing
ascii_visualization += "\nExperimental Graph Built:\n"
ascii_visualization += "\n".join("".join(row) for row in grid)

return ascii_visualization

else:
import graphviz

# Dot file for debuging
dot = graphviz.Digraph(name="compiled_graph", format=format)

# Add nodes with task information
for idx, task in self.idx_to_task.items():
dag_node = task.dag_node

# Initialize the label and attributes
label = f"Task {idx}\n"
shape = "oval" # Default shape
style = "filled"
fillcolor = ""

# Handle different types of dag_node
if isinstance(dag_node, InputNode):
label += "InputNode"
shape = "rectangle"
fillcolor = "lightblue"
elif isinstance(dag_node, InputAttributeNode):
label += f"InputAttributeNode[{dag_node.key}]"
shape = "rectangle"
fillcolor = "orange"
fillcolor = "lightblue"
elif isinstance(dag_node, MultiOutputNode):
label += "MultiOutputNode"
shape = "rectangle"
fillcolor = "yellow"
elif isinstance(dag_node, ClassMethodNode):
if dag_node.is_class_method_call:
# Class Method Call Node
method_name = dag_node.get_method_name()
actor_handle = dag_node._get_actor_handle()
if actor_handle:
actor_id = actor_handle._actor_id.hex()
label += f"Actor: {actor_id[:6]}...\nMethod: {method_name}"
else:
label += f"Method: {method_name}"
shape = "oval"
fillcolor = "lightgreen"
elif dag_node.is_class_method_output:
# Class Method Output Node
label += f"ClassMethodOutputNode[{dag_node.output_idx}]"
shape = "rectangle"
fillcolor = "orange"
else:
# Unexpected ClassMethodNode
label += "ClassMethodNode"
shape = "diamond"
fillcolor = "red"
else:
# Unexpected ClassMethodNode
label += "ClassMethodNode"
# Unexpected node type
label += type(dag_node).__name__
shape = "diamond"
fillcolor = "red"
else:
# Unexpected node type
label += type(dag_node).__name__
shape = "diamond"
fillcolor = "red"

# Add the node to the graph with attributes
dot.node(str(idx), label, shape=shape, style=style, fillcolor=fillcolor)
# Add the node to the graph with attributes
dot.node(str(idx), label, shape=shape, style=style, fillcolor=fillcolor)

# Add edges with type hints based on argument mappings
for idx, task in self.idx_to_task.items():
current_task_idx = idx
# Add edges with type hints based on argument mappings
for idx, task in self.idx_to_task.items():
current_task_idx = idx

for arg_index, arg in enumerate(task.dag_node.get_args()):
if isinstance(arg, DAGNode):
# Get the upstream task index
upstream_task_idx = self.dag_node_to_idx[arg]

# Get the type hint for this argument
if arg_index < len(task.arg_type_hints):
type_hint = type(task.arg_type_hints[arg_index]).__name__
else:
type_hint = "UnknownType"
for arg_index, arg in enumerate(task.dag_node.get_args()):
if isinstance(arg, DAGNode):
# Get the upstream task index
upstream_task_idx = self.dag_node_to_idx[arg]

# Draw an edge from the upstream task to the
# current task with the type hint
dot.edge(
str(upstream_task_idx), str(current_task_idx), label=type_hint
)
# Get the type hint for this argument
if arg_index < len(task.arg_type_hints):
type_hint = type(task.arg_type_hints[arg_index]).__name__
else:
type_hint = "UnknownType"

# Draw an edge from the upstream task to the
# current task with the type hint
dot.edge(
str(upstream_task_idx),
str(current_task_idx),
label=type_hint,
)

if return_dot:
return dot.source
else:
# Render the graph to a file
dot.render(filename, view=view)

def teardown(self, kill_actors: bool = False):
Expand Down
Loading