Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.mean(x, dtype=bfloat16) is not respected #26365

Open
rryan opened this issue Feb 6, 2025 · 4 comments · May be fixed by #26403
Open

jnp.mean(x, dtype=bfloat16) is not respected #26365

rryan opened this issue Feb 6, 2025 · 4 comments · May be fixed by #26403
Assignees
Labels
bug Something isn't working

Comments

@rryan
Copy link

rryan commented Feb 6, 2025

Description

Regarding accumulation dtype for np.mean, the numpy docs say:

Note that for floating-point input, the mean is computed using the same precision the input has. Depending on the input data, this can cause the results to be inaccurate, especially for float32 (see example below). Specifying a higher-precision accumulator using the dtype keyword can alleviate this issue.

By default, float16 results are computed using float32 intermediates for extra precision.

This suggests that for a bfloat16/float16 input, the default value for dtype is float32, but the user can request a different precision.

In #17792, we fixed the default upcasting, but in my testing on TPU, jnp.mean(x, dtype=jnp.bfloat16) still casts to fp32 so the dtype parameter does not seem to allow the user to override it.

I reprod this under pjit with simply:

x = jnp.zeros((2, 3, 5), dtype=jnp.bfloat16)
y = jnp.mean(x, axis=-1, keepdims=True, dtype=x.dtype)

and I observe the following:
Image

System info (python version, jaxlib version, accelerator, etc.)

(internal to google, running around cl/723930299 on borg)

@rryan rryan added the bug Something isn't working label Feb 6, 2025
@jakevdp jakevdp self-assigned this Feb 6, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Feb 6, 2025

Thanks for the report!

@jakevdp
Copy link
Collaborator

jakevdp commented Feb 7, 2025

I dug in a bit to understand what NumPy does here, and it looks like NumPy upcasts float16 regardless of whether the dtype is specified; for example:

In [1]: import numpy as np

In [2]: np.random.seed(0)

In [3]: x = (100 * np.random.randn(10000)).astype('float16')

In [4]: x.sum()
Out[4]: np.float16(-18430.0)

In [5]: x.sum(dtype='float16')  # same result when specifying dtype=float16
Out[5]: np.float16(-18430.0)

In [6]: x.astype('float32').sum().astype('float16')  # same result when upcasting to float32 for the sum
Out[6]: np.float16(-18430.0)

In [6]: np.cumsum(x)[-1]  # cumsum operates in float16 only, and has different rounding
Out[6]: np.float16(-18160.0)

So if jax.numpy functions are to stay true to the semantics of the NumPy functions they're implementing, we need to upcast float16 values regardless of whether the user specifies the dtype.

That said, we do need a way to do what the original request asked for, namely allow the user to perform the reduction without a cast if they wish. I can think of a few options:

  1. Decide to diverge from NumPy, so that x.sum(dtype='float16') lets the user specify that accumulation should happen in float16.
  2. Expose lower-level summation APIs via jax.lax so that primitives like reduce_sum_p can be applied directly
  3. Expose the upcast_f16_for_computation flag to the user-level jax.numpy APIs, so that users can configure this behavior.

All else being equal I think (2) may be the nicest option, because the fundamental request here is to be able to have more control over which primitive operations are being used to implement a reduction, and that would give you maximum control.

What do you think?

@jakevdp
Copy link
Collaborator

jakevdp commented Feb 7, 2025

Actually, I'm kind of leaning toward doing both (1) and (2) here

@pearu
Copy link
Collaborator

pearu commented Feb 7, 2025

FWIW, Python array API v2023.12 says: "If the data type (either specified or resolved) differs from the data type of x, the input array should be cast to the specified data type before computing the sum (rationale: the dtype keyword argument is intended to help prevent overflows)." that corresponds to (1).

However, notice that the mean function as specified in Python array API does not have dtype argument.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants