-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX numpy linalg gives different results from standard numpy linalg #26347
Comments
Thanks for the question! The operative difference between NumPy and JAX here is that NumPy defaults to float64 computation, and JAX defaults to float32. From your description, I assume this is the behavior you're observing: import numpy as np
import jax.numpy as jnp
x = np.array([[ 0.00000000e+00+0.00000000e+00j, 2.07513232e+13-9.24327713e+13j,
-8.20832193e+22+1.91886004e+23j, 4.81831162e+22-1.99485396e+23j,
-2.60903801e+15-4.05570916e+15j, -5.34020094e+14-8.78254643e+14j,
-2.18062549e+21+8.56477478e+20j,],
[-1.00000000e+00+0.00000000e+00j, 1.59581653e+22-1.01532595e+22j,
-3.86464909e+31+1.55825340e+31j, 3.49211618e+31-2.14336697e+31j,
2.02278421e+23-9.41350825e+23j, 4.81831207e+22-1.99485432e+23j,
-4.29166173e+29-1.86048470e+29j,],
[ 0.00000000e+00+0.00000000e+00j, 1.99471060e+20+8.26707576e+19j,
-3.98871172e+29-2.59209827e+29j, 4.29165795e+29+1.86048621e+29j,
1.01257470e+22-4.27620641e+21j, 2.18062507e+21-8.56475789e+20j,
-1.07235373e+27-5.23109404e+27j,],
[ 0.00000000e+00+0.00000000e+00j, 1.77212434e+22-7.48042381e+21j,
-4.13918163e+31+9.08451297e+30j, 3.86464788e+31-1.55825255e+31j,
3.62179446e+23-9.09737069e+23j, 8.20832193e+22-1.91885968e+23j,
-3.98871398e+29-2.59209676e+29j,],
[ 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
1.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j,],
[ 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 1.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j,],
[ 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j,
1.00000000e+00+0.00000000e+00j]])
print(np.linalg.det(x))
# (4.346537210043718e+44+5.453416380097976e+44j)
print(jnp.linalg.det(x))
# (inf+infj) If you do the NumPy computation in float32, you'll see that the result is infinite like JAX: print(np.linalg.det(x.astype('float32')))
# inf Likewise, if you restart the runtime and enable X64 mode, you'll see that the JAX float64 result matches the float64 NumPy result: import jax
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp
import numpy as np
x = np.array([...
print(jnp.linalg.det(x))
# (4.346537210043718e+44+5.453416380097974e+44j) Regarding the relative speed of JAX vs. NumPy, what you describe sounds roughly like what you'd expect for the dispatch of a single linear algebra operation; see JAX FAQ: is JAX faster than NumPy? for a discussion of this. Please let us know if any other questions come up! |
I'm going to close this since it looks like you don't have further questions. Thanks! |
Description
It seems that jax numpy linalg are not able to compute correctly determinant of matrix with large elements as follows:
And it takes much longer time than standard one.
I've tried jax 0.5.0 and 0.4.38
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.38
jaxlib: 0.4.38
numpy: 1.26.4
python: 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0]
device info: cpu-1, 1 local devices"
process_count: 1
The text was updated successfully, but these errors were encountered: