You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It's worth providing some function to capture the model constructor (i.e., the torch ops that generates the random weights that make up the parameters of a model) as one StableHLO graph, and run that on accelerator devices.
Motivation
The primary motivation is to more closely match PyTorch eager UX during SPMD training.
Today, in order to initialize a large model on e.g. 256 TPUs, we randomly initialize every layer, and then send that layer to TPUs following a sharding spec:
We're initializing the weights with torch.randn but eager PyTorch initializes the weights with a variety of different distributions. When I tested training a Llama model with randn (Gaussian distributed) weights, the loss at step 0 is 10x larger than what eager PyTorch gives us to start with.
We could probably workaround this in the near term by having the user specify a dictionary of module: initializer_fn mappings. But that's more code over eager PyTorch, and is a cost that users are paying without corresponding gains. In comparison, SPMD sharding annotations is a cost that users pay to get automatic collectives/sharding propagation.
The above will compile down to a graph that outputs a bunch of model weights, and the outputs (weights) have sharding annotations in them. When this graph is executed, each TPU will initialize the shard of the model weight that they are responsible for. Note that in this case the only extra cost (besides the torch_xla.sync()) is the mark_sharding for SPMD sharding annotations, reflecting a principle of "only pay for what you use".
We propose looking into the feasibility of this sort of feature in torchax. For example, could we run the model constructor under some sort of compile function or jit context manager, where all the model weights are tracers, and we lower them into StableHLO?
Alternatives
An alternative is to maintain a comprehensive module: initializer_fn mapping inside torchax itself. However, that still won't cover cases where the user added custom initialization logic in their model constructor.
🚀 Feature
It's worth providing some function to capture the model constructor (i.e., the torch ops that generates the random weights that make up the parameters of a model) as one StableHLO graph, and run that on accelerator devices.
Motivation
The primary motivation is to more closely match PyTorch eager UX during SPMD training.
Today, in order to initialize a large model on e.g. 256 TPUs, we randomly initialize every layer, and then send that layer to TPUs following a sharding spec:
xla/torchax/examples/train_llama_torchtitan/train_llama.py
Lines 193 to 211 in 8e6ca60
This works but has some drawbacks:
torch.randn
but eager PyTorch initializes the weights with a variety of different distributions. When I tested training a Llama model withrandn
(Gaussian distributed) weights, the loss at step 0 is 10x larger than what eager PyTorch gives us to start with.module: initializer_fn
mappings. But that's more code over eager PyTorch, and is a cost that users are paying without corresponding gains. In comparison, SPMD sharding annotations is a cost that users pay to get automatic collectives/sharding propagation.Pitch
In PyTorch/XLA, I could do this:
The above will compile down to a graph that outputs a bunch of model weights, and the outputs (weights) have sharding annotations in them. When this graph is executed, each TPU will initialize the shard of the model weight that they are responsible for. Note that in this case the only extra cost (besides the
torch_xla.sync()
) is themark_sharding
for SPMD sharding annotations, reflecting a principle of "only pay for what you use".We propose looking into the feasibility of this sort of feature in
torchax
. For example, could we run the model constructor under some sort ofcompile
function orjit
context manager, where all the model weights are tracers, and we lower them into StableHLO?Alternatives
An alternative is to maintain a comprehensive
module: initializer_fn
mapping insidetorchax
itself. However, that still won't cover cases where the user added custom initialization logic in their model constructor.Additional context
We found this while bringing up Llama 3.1 405B training, c.f. https://github.com/AI-Hypercomputer/torchprime/pull/25/files
cc @qihqi
The text was updated successfully, but these errors were encountered: