Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchax] RNG handling in a jitted graph is unsound #8636

Open
tengyifei opened this issue Jan 28, 2025 · 0 comments
Open

[torchax] RNG handling in a jitted graph is unsound #8636

tengyifei opened this issue Jan 28, 2025 · 0 comments

Comments

@tengyifei
Copy link
Collaborator

🐛 Bug

If I jit compile some model code that uses the RNG (e.g. dropout layers), then all future invocation of that jitted function will use the same RNG value. The RNG output is burned into the compiled StableHLO.

To Reproduce

See this notebook:

https://github.com/tengyifei/playground/blob/master/torchax/rng-test.ipynb

The jitted function gets the same RNG on every call.

Expected behavior

I'd expect each iteration in the loop to output a different random number.

Environment

  • Reproducible on XLA backend: CPU/TPU
  • torchax version: 8e6ca6000e83ccbc4365a9d9358e510504b71dea

Additional context

If we don't fix this, then the jitted behavior of any model with dropout layers or random masking or any random operation is wrong. This may e.g. causing training to not converge.

cc @qihqi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants