Skip to content

PyTorch/XLA 2.6 release

Latest
Compare
Choose a tag to compare
@tengyifei tengyifei released this 30 Jan 21:54
· 79 commits to master since this release
0bb4f6f

Highlights

Kernel improvements for vLLM: Multi-Queries Paged Attention Pallas Kernel

  • Added the multi-queries paged attention pallas kernel (#8328). Unlocks opportunities in vLLM such as prefix caching.
  • Perf improvement: only write to HBM at the last iteration (#8393)

Experimental scan operator (#7901)

Previously when you loop over many nn.Modules of the same structure in PyTorch/XLA, the loop will be unrolled during graph tracing, leading to giant computation graphs. This unrolling results in long compilation times, up to an hour for large language modules with many decoder layers. In this release we offer an experimental API to reduce compilation times called "scan", which mirrors the jax.lax.scan transform in JAX. When you replace a Python for loop with scan, instead of compiling every iteration individually, only the first iteration will be compiled, and the compiled HLO is reused for all subsequent iterations. Building upon torch_xla.experimental.scan, torch_xla.experimental.scan_layers offers a convenient interface for looping over a sequence of nn.Modules without unrolling.

Documentation: https://pytorch.org/xla/release/r2.6/features/scan.html

C++11 ABI builds

Starting from Pytorch/XLA 2.6, we'll provide wheels and docker images built with two C++ ABI flavors: C++11 and pre-C++11. Pre-C++11 is the default to align with PyTorch upstream, but C++11 ABI wheels and docker images have better lazy tensor tracing performance.

To install C++11 ABI flavored 2.6 wheels (Python 3.10 example):

pip install torch==2.6.0+cpu.cxx11.abi \
  https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl \
  'torch_xla[tpu]' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html \
  -f https://download.pytorch.org/whl/torch

The above command works for Python 3.10. We additionally have Python 3.9 and 3.11 wheels:

To access C++11 ABI flavored docker image:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11

If your model is tracing bound (e.g. you see that the host CPU is busy tracing the model while TPUs are idle), switching to the C++11 ABI wheels/docker images can improve performance. Mixtral 8x7B benchmarking results on v5p-256, global batch size 1024:

  • Pre-C++11 ABI MFU: 33%
  • C++ ABI MFU: 39%

GPU builds are temporarily skipped in 2.6

We do not offer a PyTorch/XLA:GPU wheel in the PyTorch/XLA 2.6 release. We understand this is important and plan to reinstate GPU support by the 2.7 release. PyTorch/XLA remains an open-source project and we welcome contributions from the community to help maintain and improve the project. To contribute, please start with the contributors guide.

The newest stable version where PyTorch/XLA:GPU wheel is available is torch_xla 2.5.

Stable Features

Stable libtpu releases

Starting from PyTorch/XLA 2.6, TPU backend support will be provided by a stable libtpu Python package. That means we'll expect less TPU-specific bugs and improved test coverage overall. The libtpu-nightly Python package will be pinned to a special empty version to avoid conflicts. As long as you use our PyTorch/XLA docker images or follow the latest installation instructions in the README.md, there are no actions needed on your part and the right dependencies will be installed.

GSPMD

  • [LoweringContext] Support an optimized parameter mapping for SPMD (#8460)
  • [LoweringContext] SPMD propagation #8471: this ensures that the computation has the respective sharding specs deduced from the inputs (scoped to the creation of the parameters), and to propagate the input shardings to the output.

AMP

  • Add autocast support for einsum #8420
  • Add autocast support for XlaPatchedLinear #8421
  • Support S32/U32 indices for BWD embedding & Neuron implicit downcast #8462

Bug fixes

  • Getting "undefined symbol: _ZN5torch4lazy13MetricFnValueB5cxx11E" with torch-xla nightly wheel for 2.6 #8406

Experimental Features

Support for host offloading (#8350, #8477)

When doing reverse-mode automatic differentiation, many tensors are saved during the forward pass to be used to compute the gradient during the backward pass. Previously you could use torch_xla.utils.checkpoint to discard tensors that's easy to recompute later, called "checkpointing" or "rematerialization". Now PyTorch/XLA also supports a technique called "host offloading", i.e. moving the tensor to host and moving them back, adding another tool in the arsenal to save memory. Use torch_xla.experimental.stablehlo_custom_call.place_to_host to move a tensor to host and torch_xla.experimental.stablehlo_custom_call.place_to_device to move a tensor back to the device. For example, you can use this to move intermediate activations to host during a forward pass, and move those activations back to device during the corresponding backward pass.

Because the XLA graph compiler aggressively reorders operations, host offloading is best used in combination with scan.

Updates to Flash Attention kernels

Support SegmentID in FlashAttention when doing data parallel SPMD #8425

Deprecations

See Backward Compatibility proposal.

APIs that will be removed in 2.7 release:

  1. Deprecate APIs (deprecated → new):
  • xla_model.xrt_world_size()runtime.world_size() [#7679][#7743]
  • xla_model.get_ordinal()runtime.global_ordinal() [#7679]
  • xla_model.get_local_ordinal()runtime.global_ordinal() [#7679]
  1. Internalize APIs
    • xla_model.parse_xla_device() [#7675]
  2. Improvement
    • Automatic PJRT device detection when importing torch_xla [#7787]
  3. Add deprecated decorator [#7703]

APIs that will be removed in 2.8 release:

  1. The XLA_USE_BF16 environment variable is deprecated. Please convert your model to bf16 directly: [#8474]