jax.numpy.set_print_options
is not applied to arrays with bfloat16
dtype
#18820
Labels
bug
Something isn't working
Description
The output of the snippet of code above is
However, I would have expected the code to produce
or
which corresponds to the output of
print(jnp.array([1 / 3]).astype(jnp.bfloat16).astype(jnp.float32))
What jax/jaxlib version are you using?
0.4.20 0.4.20
Which accelerator(s) are you using?
CPU
Additional system info?
1.23.5 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] uname_result(system='Linux', node='ad8ae4a090bf', release='5.15.120+', version='#1 SMP Wed Aug 30 11:19:59 UTC 2023', machine='x86_64')
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: