Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX numpy linalg gives different results from standard numpy linalg #26347

Closed
Shaoqigit opened this issue Feb 6, 2025 · 2 comments
Closed

JAX numpy linalg gives different results from standard numpy linalg #26347

Shaoqigit opened this issue Feb 6, 2025 · 2 comments
Assignees
Labels
question Questions for the JAX team

Comments

@Shaoqigit
Copy link

Shaoqigit commented Feb 6, 2025

Description

It seems that jax numpy linalg are not able to compute correctly determinant of matrix with large elements as follows:

[[ 0.00000000e+00+0.00000000e+00j  2.07513232e+13-9.24327713e+13j
  -8.20832193e+22+1.91886004e+23j  4.81831162e+22-1.99485396e+23j
  -2.60903801e+15-4.05570916e+15j -5.34020094e+14-8.78254643e+14j
  -2.18062549e+21+8.56477478e+20j]
 [-1.00000000e+00+0.00000000e+00j  1.59581653e+22-1.01532595e+22j
  -3.86464909e+31+1.55825340e+31j  3.49211618e+31-2.14336697e+31j
   2.02278421e+23-9.41350825e+23j  4.81831207e+22-1.99485432e+23j
  -4.29166173e+29-1.86048470e+29j]
 [ 0.00000000e+00+0.00000000e+00j  1.99471060e+20+8.26707576e+19j
  -3.98871172e+29-2.59209827e+29j  4.29165795e+29+1.86048621e+29j
   1.01257470e+22-4.27620641e+21j  2.18062507e+21-8.56475789e+20j
  -1.07235373e+27-5.23109404e+27j]
 [ 0.00000000e+00+0.00000000e+00j  1.77212434e+22-7.48042381e+21j
  -4.13918163e+31+9.08451297e+30j  3.86464788e+31-1.55825255e+31j
   3.62179446e+23-9.09737069e+23j  8.20832193e+22-1.91885968e+23j
  -3.98871398e+29-2.59209676e+29j]
 [ 0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
   1.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j]
 [ 0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j  1.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j]
 [ 0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
   1.00000000e+00+0.00000000e+00j]]

And it takes much longer time than standard one.
I've tried jax 0.5.0 and 0.4.38

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.38
jaxlib: 0.4.38
numpy: 1.26.4
python: 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0]
device info: cpu-1, 1 local devices"
process_count: 1

@Shaoqigit Shaoqigit added the bug Something isn't working label Feb 6, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Feb 6, 2025

Thanks for the question! The operative difference between NumPy and JAX here is that NumPy defaults to float64 computation, and JAX defaults to float32. From your description, I assume this is the behavior you're observing:

import numpy as np
import jax.numpy as jnp

x = np.array([[ 0.00000000e+00+0.00000000e+00j,  2.07513232e+13-9.24327713e+13j,
               -8.20832193e+22+1.91886004e+23j,  4.81831162e+22-1.99485396e+23j,
               -2.60903801e+15-4.05570916e+15j, -5.34020094e+14-8.78254643e+14j,
               -2.18062549e+21+8.56477478e+20j,],
              [-1.00000000e+00+0.00000000e+00j,  1.59581653e+22-1.01532595e+22j,
               -3.86464909e+31+1.55825340e+31j,  3.49211618e+31-2.14336697e+31j,
                2.02278421e+23-9.41350825e+23j,  4.81831207e+22-1.99485432e+23j,
               -4.29166173e+29-1.86048470e+29j,],
              [ 0.00000000e+00+0.00000000e+00j,  1.99471060e+20+8.26707576e+19j,
               -3.98871172e+29-2.59209827e+29j,  4.29165795e+29+1.86048621e+29j,
                1.01257470e+22-4.27620641e+21j,  2.18062507e+21-8.56475789e+20j,
               -1.07235373e+27-5.23109404e+27j,],
              [ 0.00000000e+00+0.00000000e+00j,  1.77212434e+22-7.48042381e+21j,
               -4.13918163e+31+9.08451297e+30j,  3.86464788e+31-1.55825255e+31j,
                3.62179446e+23-9.09737069e+23j,  8.20832193e+22-1.91885968e+23j,
               -3.98871398e+29-2.59209676e+29j,],
              [ 0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
                1.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
                0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
                0.00000000e+00+0.00000000e+00j,],
              [ 0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
                0.00000000e+00+0.00000000e+00j,  1.00000000e+00+0.00000000e+00j,
                0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
                0.00000000e+00+0.00000000e+00j,],
              [ 0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
                0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
                0.00000000e+00+0.00000000e+00j,  0.00000000e+00+0.00000000e+00j,
                1.00000000e+00+0.00000000e+00j]])

print(np.linalg.det(x))
# (4.346537210043718e+44+5.453416380097976e+44j)
print(jnp.linalg.det(x))
# (inf+infj)

If you do the NumPy computation in float32, you'll see that the result is infinite like JAX:

print(np.linalg.det(x.astype('float32')))
# inf

Likewise, if you restart the runtime and enable X64 mode, you'll see that the JAX float64 result matches the float64 NumPy result:

import jax
jax.config.update('jax_enable_x64', True)

import jax.numpy as jnp
import numpy as np

x = np.array([...

print(jnp.linalg.det(x))
# (4.346537210043718e+44+5.453416380097974e+44j)

Regarding the relative speed of JAX vs. NumPy, what you describe sounds roughly like what you'd expect for the dispatch of a single linear algebra operation; see JAX FAQ: is JAX faster than NumPy? for a discussion of this.

Please let us know if any other questions come up!

@jakevdp jakevdp self-assigned this Feb 6, 2025
@jakevdp jakevdp added question Questions for the JAX team and removed bug Something isn't working labels Feb 6, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Feb 12, 2025

I'm going to close this since it looks like you don't have further questions. Thanks!

@jakevdp jakevdp closed this as completed Feb 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

2 participants