You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If I jit compile some model code that uses the RNG (e.g. dropout layers), then all future invocation of that jitted function will use the same RNG value. The RNG output is burned into the compiled StableHLO.
If we don't fix this, then the jitted behavior of any model with dropout layers or random masking or any random operation is wrong. This may e.g. causing training to not converge.
🐛 Bug
If I
jit
compile some model code that uses the RNG (e.g. dropout layers), then all future invocation of that jitted function will use the same RNG value. The RNG output is burned into the compiled StableHLO.To Reproduce
See this notebook:
https://github.com/tengyifei/playground/blob/master/torchax/rng-test.ipynb
The jitted function gets the same RNG on every call.
Expected behavior
I'd expect each iteration in the loop to output a different random number.
Environment
8e6ca6000e83ccbc4365a9d9358e510504b71dea
Additional context
If we don't fix this, then the jitted behavior of any model with dropout layers or random masking or any random operation is wrong. This may e.g. causing training to not converge.
cc @qihqi
The text was updated successfully, but these errors were encountered: