Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchax][RFC] Anchor on the device API everywhere #8638

Open
tengyifei opened this issue Jan 28, 2025 · 3 comments
Open

[torchax][RFC] Anchor on the device API everywhere #8638

tengyifei opened this issue Jan 28, 2025 · 3 comments
Assignees

Comments

@tengyifei
Copy link
Collaborator

🚀 Feature

Today we have two related concepts in torchax:

  • Environment
  • The "jax" device

In particular, the user need to know both to access JAX (e.g. TPU) features.

This RFC proposes that we should refactor the API such that the user only need to know the "jax" device in order to operate on the TPUs.

Motivation

The device is a well known concept in PyTorch: tutorials talk about tensor.cpu() and tensor.cuda(). It's commonly understood that if I create a tensor without the device argument, then that tensor lives on the default device (usually CPU). If I call .cuda(), then that tensor is moved to the GPU.

Since people primarily choose torchax to be able to use the TPU (or access the XLA GPU backend), it makes sense to present this functionality as a PyTorch device. By the principle of symmetry, it's natural to introduce jax counterparts for various cuda APIs, where applicable. This gives people a clear mental model of when they are or aren't using JAX.

We look at a few examples (all these assume we import torchax):

  • I can call torch.cuda.current_device() to get the index of the current CUDA device.

    • I can also call torch.jax.current_device() to get the index of the current JAX (XLA) device.
  • I can call torch.cuda.is_available() to check if CUDA support is available.

    • I can call torch.jax.is_available() to check if the JAX backend is available.
  • I can run torch.randn(1, 2, device='cuda') to generate a random number using the CUDA device.

    • ❌ If I ran torch.randn(1, 2, device='jax'), that fails with a confusing dispatcher error: 1
  • I can run torch.set_default_device('cuda') to make all subsequent tensor live on the CUDA device.

    • ❌ If I ran torch.set_default_device('jax') and then creates some tensor, that fails with another confusing error: 2

This RFC proposes that we should change torchax to close the behavior divergence such as the two above bullet points. In the limit, using eager torchax should feel identical to using some other backend of PyTorch.

Pitch

Always call enable_globally()

We're pretty close to closing the gaps above. If I run torchax.enable_globally() after importing torchax, then torch.randn(1, 2, device='jax') works, and the error after torch.set_default_device('jax') seems like a fixable bug. I propose we go one step further and just automatically call enable_globally() and we should also fix the default device behavior.

Always keep the torchax modes activated

Today the environment object is what activates the XLAFunctionMode and XLADispatchMode that intercept PyTorch operations. However, these modes are an implementation detail of how torchax supports the JAX device. It should be possible to always keep the XLAFunctionMode and XLADispatchMode activated in the mode stack, without changing the behavior of non-JAX tensors. This is akin to how PyTorch already keeps a few modes such as FuncTorchVmapMode and FuncTorchDynamicLayerFrontMode in the stack most of the time. For testing purposes, it could be useful to temporarily disable the XLAFunctionMode and XLADispatchMode, but that should be an internal API that users don't know about.

As a pressure test, we could probably try running some subset of PyTorch tests with XLA{Function,Dispatch}Mode in the mode stack, and make sure those don't fail. That's to ensure that even if the user import torchax, their CPU tensor behaviors don't change.

This suggests we need to decouple the XLAFunctionMode and XLADispatchMode from the environment. For example, perhaps those could be relocated to a torchax._internal.XLAModes context manager.

Configuration context managers

The environment object also holds certain configuration (e.g. optimize for performance or accuracy). As a user it's useful to change these settings sometimes. We can keep them in the environment and always provide a sensible default in the default environment. We could also support a stack of environments via context managers, where configurations at the top of the stack takes precedence. That's a useful way to locally change some config and have them revert to previous values when leaving the scope.

RNG seed

The environment object also holds a seed for the pseudo random number generator. That should probably change as part of solving #8636.

Alternatives

An alternative is to do nothing and stick to the status quo.

Another follow up is to see what torch changes do we need to remove frictions of using the JAX device. For example, today if I write tensor.jax() with the hope of moving the tensor to the JAX device, the Python type checker complains that jax() is not a known function on Tensor, unlike cuda.

Additional context

Anecdotally, some people had questions why they have to create an environment to use a torchax tensor and didn't understand the error message when the environment was missing.

@tengyifei
Copy link
Collaborator Author

@qihqi here's my proposal

@tengyifei
Copy link
Collaborator Author

With https://pytorch.org/tutorials/prototype/python_extension_autoload.html we might even be able to have torch.jax load automatically without any imports.

@tengyifei
Copy link
Collaborator Author

Eventually the user should be able to just

import torch
a = torch.randn(100, device='jax')
print(a)

and it just works, no other steps needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants