Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy.set_print_options is not applied to arrays with bfloat16 dtype #18820

Closed
mrTsjolder opened this issue Dec 5, 2023 · 1 comment
Closed
Assignees
Labels
bug Something isn't working

Comments

@mrTsjolder
Copy link

mrTsjolder commented Dec 5, 2023

Description

import jax.numpy as jnp
jnp.set_printoptions(precision=4)
print(jnp.array([1 / 3]))
print(jnp.array([1 / 3]).astype(jnp.bfloat16))

The output of the snippet of code above is

[0.3333]
[0.333984]

However, I would have expected the code to produce

[0.3333]
[0.3340]

or

[0.3333]
[0.334]

which corresponds to the output of print(jnp.array([1 / 3]).astype(jnp.bfloat16).astype(jnp.float32))

What jax/jaxlib version are you using?

0.4.20 0.4.20

Which accelerator(s) are you using?

CPU

Additional system info?

1.23.5 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] uname_result(system='Linux', node='ad8ae4a090bf', release='5.15.120+', version='#1 SMP Wed Aug 30 11:19:59 UTC 2023', machine='x86_64')

NVIDIA GPU info

No response

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 5, 2023

Thanks for the report! I moved this issue over to the ml_dtypes repo, where the code for printing these values lives: jax-ml/ml_dtypes#125

@jakevdp jakevdp closed this as completed Dec 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants