You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It should be possible to somehow cache the traced graphs in torch_xla.experimental.scan so we don't trace on every call.
Motivation
Today torch_xla.experimental.scan and scan_layers traces the user function with both AOTAutograd (to get the backward) and with LazyTensor (to lower them to HLO). AOTAutograd is very slow and we can easily become tracing bound. For example, python3 examples/train_decoder_only_base.py takes 1min30s but python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan takes 4min.
Pitch
We could wait for torch.scan to support autograd (c.f. #7901 (comment)) which will take a long time. In the meantime, we can implement some simple caching based on the id of the input function/module.
The text was updated successfully, but these errors were encountered:
🚀 Feature
It should be possible to somehow cache the traced graphs in
torch_xla.experimental.scan
so we don't trace on every call.Motivation
Today
torch_xla.experimental.scan
andscan_layers
traces the user function with both AOTAutograd (to get the backward) and with LazyTensor (to lower them to HLO). AOTAutograd is very slow and we can easily become tracing bound. For example,python3 examples/train_decoder_only_base.py
takes 1min30s butpython3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan
takes 4min.Pitch
We could wait for
torch.scan
to support autograd (c.f. #7901 (comment)) which will take a long time. In the meantime, we can implement some simple caching based on theid
of the input function/module.The text was updated successfully, but these errors were encountered: