Skip to content

Commit

Permalink
[torch.compile] fix sym_tensor_indices (vllm-project#12191)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Jan 20, 2025
1 parent df450aa commit 51ef828
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,9 +624,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
]

# index of tensors that have symbolic shapes (batch size)
# for weights and static buffers, they will have concrete shapes.
# symbolic shape only happens for input tensors.
from torch.fx.experimental.symbolic_shapes import is_symbolic
self.sym_tensor_indices = [
i for i, x in enumerate(fake_args)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
any(is_symbolic(d) for d in x.size())
]

# compiler managed cudagraph input buffers
Expand Down

0 comments on commit 51ef828

Please sign in to comment.