Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch XLA Model all_gather does not work with tensors of different sizes along dimension 0 #8660

Open
ajayvohra2005 opened this issue Jan 31, 2025 · 1 comment

Comments

@ajayvohra2005
Copy link

ajayvohra2005 commented Jan 31, 2025

🐛 Bug

Torch XLA Model all_gather works with tensors of same size along dim=0, but if tensor sizes are different along dim=0, it hangs.

To Reproduce

Save this code in test_all_gather.py

import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend as xb
import torch.distributed


def test_all_gather():

    same = [512, 512, 512, 512, 512, 512, 512, 512]

    different = [416, 536, 560, 544, 576, 512, 592, 360]
    torch.distributed.init_process_group(backend="xla", init_method="xla://")       

    rank = torch.distributed.get_rank()
    device = xm.xla_device()
    input = torch.randn((same[rank], 16), dtype=torch.float32, device=device)
    
    all_inputs = xm.all_gather(input, dim=0, groups=[[0,1,2,3,4,5,6,7]], pin_layout=False)
    print(f"!!!!!! rank: {rank}, all_inputs: {all_inputs}")
    
    input = torch.randn((different[rank], 16), dtype=torch.float32, device=device)
    
    all_inputs = xm.all_gather(input, dim=0, groups=[[0,1,2,3,4,5,6,7]], pin_layout=False)
    
    print(f"!!!!!! rank: {rank}, all_inputs: {all_inputs}")
    torch.distributed.destroy_process_group()
    
if __name__ == "__main__":
    test_all_gather()
torchrun --nproc_per_node=8 test_all_gather.py

Expected behavior

It should gather all the tensors from all the devices along dim=0

Environment

Docker image
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4

Additional context

According to this documentation for all_gather https://pytorch.org/docs/stable/distributed.html uneven tensor sizes are supported.

@zpcore
Copy link
Collaborator

zpcore commented Feb 3, 2025

Thanks for pointing this out. Currently we don't support dist ops with uneven sizes. Basically XLA needd to know the gathered output tensor shape before hand. One work around we can do is to pad those tensor to even sizes before all_gather.

By the way, we now support dist ops directly like:

dist.all_gather(output_tensor, input, None)
e.g.:

output_tensor = [
    torch.empty((512, 16), dtype=torch.float32).to(device)
    for _ in range(xr.world_size())
    ]
all_inputs = torch.randn((512,16), dtype=torch.float32, device=device)
dist.all_gather(output_tensor, all_inputs)

, and we suggest launch multiprocess using

def launch(

This should handle both torchrun call or xmp.spawn.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants