Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchax] jit compile the model constructor #8635

Open
tengyifei opened this issue Jan 28, 2025 · 0 comments
Open

[torchax] jit compile the model constructor #8635

tengyifei opened this issue Jan 28, 2025 · 0 comments

Comments

@tengyifei
Copy link
Collaborator

tengyifei commented Jan 28, 2025

🚀 Feature

It's worth providing some function to capture the model constructor (i.e., the torch ops that generates the random weights that make up the parameters of a model) as one StableHLO graph, and run that on accelerator devices.

Motivation

The primary motivation is to more closely match PyTorch eager UX during SPMD training.

Today, in order to initialize a large model on e.g. 256 TPUs, we randomly initialize every layer, and then send that layer to TPUs following a sharding spec:

def create_sharded_weights(model, mesh, sharding_map):
res = {}
env = torchax.default_env()
for name, weight_meta in model.state_dict().items():
sharding_spec = sharding_map.get(_process_sharding_name(name))
if sharding_spec is None:
print('Skipping weight:', name)
continue
sharding = NamedSharding(mesh, P(*sharding_spec))
with jax.default_device(jax.devices('cpu')[0]):
weight_torch = torch.randn(
weight_meta.shape,
dtype=weight_meta.dtype)
weight_jax = torchax.default_env().to_xla(weight_torch).jax()
#print(name, weight.shape, weight.dtype)
res[name] = env.j2t_iso(jax.make_array_from_callback(
weight_jax.shape, sharding, lambda a: weight_jax[a]
))
return res

This works but has some drawbacks:

  • We're initializing the weights with torch.randn but eager PyTorch initializes the weights with a variety of different distributions. When I tested training a Llama model with randn (Gaussian distributed) weights, the loss at step 0 is 10x larger than what eager PyTorch gives us to start with.
  • We could probably workaround this in the near term by having the user specify a dictionary of module: initializer_fn mappings. But that's more code over eager PyTorch, and is a cost that users are paying without corresponding gains. In comparison, SPMD sharding annotations is a cost that users pay to get automatic collectives/sharding propagation.

Pitch

In PyTorch/XLA, I could do this:

with torch_xla.device():
  model = Model()
  xs.mark_sharding(model.some_weight, ...)
torch_xla.sync()

The above will compile down to a graph that outputs a bunch of model weights, and the outputs (weights) have sharding annotations in them. When this graph is executed, each TPU will initialize the shard of the model weight that they are responsible for. Note that in this case the only extra cost (besides the torch_xla.sync()) is the mark_sharding for SPMD sharding annotations, reflecting a principle of "only pay for what you use".

We propose looking into the feasibility of this sort of feature in torchax. For example, could we run the model constructor under some sort of compile function or jit context manager, where all the model weights are tracers, and we lower them into StableHLO?

Alternatives

An alternative is to maintain a comprehensive module: initializer_fn mapping inside torchax itself. However, that still won't cover cases where the user added custom initialization logic in their model constructor.

Additional context

We found this while bringing up Llama 3.1 405B training, c.f. https://github.com/AI-Hypercomputer/torchprime/pull/25/files

cc @qihqi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant