Highlights
Kernel improvements for vLLM: Multi-Queries Paged Attention Pallas Kernel
- Added the multi-queries paged attention pallas kernel (#8328). Unlocks opportunities in vLLM such as prefix caching.
- Perf improvement: only write to HBM at the last iteration (#8393)
Experimental scan operator (#7901)
Previously when you loop over many nn.Module
s of the same structure in PyTorch/XLA, the loop will be unrolled during graph tracing, leading to giant computation graphs. This unrolling results in long compilation times, up to an hour for large language modules with many decoder layers. In this release we offer an experimental API to reduce compilation times called "scan", which mirrors the jax.lax.scan
transform in JAX. When you replace a Python for loop with scan, instead of compiling every iteration individually, only the first iteration will be compiled, and the compiled HLO is reused for all subsequent iterations. Building upon torch_xla.experimental.scan
, torch_xla.experimental.scan_layers
offers a convenient interface for looping over a sequence of nn.Module
s without unrolling.
Documentation: https://pytorch.org/xla/release/r2.6/features/scan.html
C++11 ABI builds
Starting from Pytorch/XLA 2.6, we'll provide wheels and docker images built with two C++ ABI flavors: C++11 and pre-C++11. Pre-C++11 is the default to align with PyTorch upstream, but C++11 ABI wheels and docker images have better lazy tensor tracing performance.
To install C++11 ABI flavored 2.6 wheels (Python 3.10 example):
pip install torch==2.6.0+cpu.cxx11.abi \
https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl \
'torch_xla[tpu]' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html \
-f https://download.pytorch.org/whl/torch
The above command works for Python 3.10. We additionally have Python 3.9 and 3.11 wheels:
- 3.9: https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp39-cp39-manylinux_2_28_x86_64.whl
- 3.10: https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl
- 3.11: https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp311-cp311-manylinux_2_28_x86_64.whl
To access C++11 ABI flavored docker image:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11
If your model is tracing bound (e.g. you see that the host CPU is busy tracing the model while TPUs are idle), switching to the C++11 ABI wheels/docker images can improve performance. Mixtral 8x7B benchmarking results on v5p-256, global batch size 1024:
- Pre-C++11 ABI MFU: 33%
- C++ ABI MFU: 39%
GPU builds are temporarily skipped in 2.6
We do not offer a PyTorch/XLA:GPU wheel in the PyTorch/XLA 2.6 release. We understand this is important and plan to reinstate GPU support by the 2.7 release. PyTorch/XLA remains an open-source project and we welcome contributions from the community to help maintain and improve the project. To contribute, please start with the contributors guide.
The newest stable version where PyTorch/XLA:GPU wheel is available is torch_xla 2.5.
Stable Features
Stable libtpu releases
Starting from PyTorch/XLA 2.6, TPU backend support will be provided by a stable libtpu
Python package. That means we'll expect less TPU-specific bugs and improved test coverage overall. The libtpu-nightly
Python package will be pinned to a special empty version to avoid conflicts. As long as you use our PyTorch/XLA docker images or follow the latest installation instructions in the README.md, there are no actions needed on your part and the right dependencies will be installed.
GSPMD
- [LoweringContext] Support an optimized parameter mapping for SPMD (#8460)
- [LoweringContext] SPMD propagation #8471: this ensures that the computation has the respective sharding specs deduced from the inputs (scoped to the creation of the parameters), and to propagate the input shardings to the output.
AMP
- Add autocast support for einsum #8420
- Add autocast support for XlaPatchedLinear #8421
- Support S32/U32 indices for BWD embedding & Neuron implicit downcast #8462
Bug fixes
- Getting "undefined symbol: _ZN5torch4lazy13MetricFnValueB5cxx11E" with torch-xla nightly wheel for 2.6 #8406
Experimental Features
Support for host offloading (#8350, #8477)
When doing reverse-mode automatic differentiation, many tensors are saved during the forward pass to be used to compute the gradient during the backward pass. Previously you could use torch_xla.utils.checkpoint
to discard tensors that's easy to recompute later, called "checkpointing" or "rematerialization". Now PyTorch/XLA also supports a technique called "host offloading", i.e. moving the tensor to host and moving them back, adding another tool in the arsenal to save memory. Use torch_xla.experimental.stablehlo_custom_call.place_to_host
to move a tensor to host and torch_xla.experimental.stablehlo_custom_call.place_to_device
to move a tensor back to the device. For example, you can use this to move intermediate activations to host during a forward pass, and move those activations back to device during the corresponding backward pass.
Because the XLA graph compiler aggressively reorders operations, host offloading is best used in combination with scan
.
Updates to Flash Attention kernels
Support SegmentID in FlashAttention when doing data parallel SPMD #8425
Deprecations
See Backward Compatibility proposal.
APIs that will be removed in 2.7 release:
- Deprecate APIs (deprecated → new):
xla_model.xrt_world_size()
→runtime.world_size()
[#7679][#7743]xla_model.get_ordinal()
→runtime.global_ordinal()
[#7679]xla_model.get_local_ordinal()
→runtime.global_ordinal()
[#7679]
- Internalize APIs
xla_model.parse_xla_device()
[#7675]
- Improvement
- Automatic PJRT device detection when importing
torch_xla
[#7787]
- Automatic PJRT device detection when importing
- Add deprecated decorator [#7703]
APIs that will be removed in 2.8 release:
- The
XLA_USE_BF16
environment variable is deprecated. Please convert your model to bf16 directly: [#8474]