Skip to content

Commit

Permalink
Make partition_spec as tuple to fix issue #8522 (#8613)
Browse files Browse the repository at this point in the history
  • Loading branch information
vealocia authored Jan 25, 2025
1 parent aef5f6e commit 79e4e72
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _prepare_spmd_partition_spec(param,
partition_spec = [None] * len(shape)
# Skip scalar tensors and it replicated.
if len(partition_spec) == 0:
return partition_spec
return tuple(partition_spec)

# Shard the 0th dimension of the parameter according to the
# fsdp axis of the mesh, if shard_maximal is not specified.
Expand Down

0 comments on commit 79e4e72

Please sign in to comment.