Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[scan] Avoid re-tracing the combine function on every call #8632

Open
tengyifei opened this issue Jan 27, 2025 · 0 comments
Open

[scan] Avoid re-tracing the combine function on every call #8632

tengyifei opened this issue Jan 27, 2025 · 0 comments

Comments

@tengyifei
Copy link
Collaborator

tengyifei commented Jan 27, 2025

🚀 Feature

It should be possible to somehow cache the traced graphs in torch_xla.experimental.scan so we don't trace on every call.

Motivation

Today torch_xla.experimental.scan and scan_layers traces the user function with both AOTAutograd (to get the backward) and with LazyTensor (to lower them to HLO). AOTAutograd is very slow and we can easily become tracing bound. For example, python3 examples/train_decoder_only_base.py takes 1min30s but python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan takes 4min.

Pitch

We could wait for torch.scan to support autograd (c.f. #7901 (comment)) which will take a long time. In the meantime, we can implement some simple caching based on the id of the input function/module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant