-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jnp.mean(x, dtype=bfloat16) is not respected #26365
Comments
Thanks for the report! |
I dug in a bit to understand what NumPy does here, and it looks like NumPy upcasts float16 regardless of whether the dtype is specified; for example: In [1]: import numpy as np
In [2]: np.random.seed(0)
In [3]: x = (100 * np.random.randn(10000)).astype('float16')
In [4]: x.sum()
Out[4]: np.float16(-18430.0)
In [5]: x.sum(dtype='float16') # same result when specifying dtype=float16
Out[5]: np.float16(-18430.0)
In [6]: x.astype('float32').sum().astype('float16') # same result when upcasting to float32 for the sum
Out[6]: np.float16(-18430.0)
In [6]: np.cumsum(x)[-1] # cumsum operates in float16 only, and has different rounding
Out[6]: np.float16(-18160.0) So if That said, we do need a way to do what the original request asked for, namely allow the user to perform the reduction without a cast if they wish. I can think of a few options:
All else being equal I think (2) may be the nicest option, because the fundamental request here is to be able to have more control over which primitive operations are being used to implement a reduction, and that would give you maximum control. What do you think? |
Actually, I'm kind of leaning toward doing both (1) and (2) here |
FWIW, Python array API v2023.12 says: "If the data type (either specified or resolved) differs from the data type of x, the input array should be cast to the specified data type before computing the sum (rationale: the dtype keyword argument is intended to help prevent overflows)." that corresponds to (1). However, notice that the |
Description
Regarding accumulation dtype for np.mean, the numpy docs say:
This suggests that for a bfloat16/float16 input, the default value for
dtype
is float32, but the user can request a different precision.In #17792, we fixed the default upcasting, but in my testing on TPU,
jnp.mean(x, dtype=jnp.bfloat16)
still casts to fp32 so thedtype
parameter does not seem to allow the user to override it.I reprod this under pjit with simply:
and I observe the following:
![Image](https://private-user-images.githubusercontent.com/26527/410592155-9deabae5-273b-482a-afa2-a536178a9060.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwNTUyMzMsIm5iZiI6MTczOTA1NDkzMywicGF0aCI6Ii8yNjUyNy80MTA1OTIxNTUtOWRlYWJhZTUtMjczYi00ODJhLWFmYTItYTUzNjE3OGE5MDYwLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDglMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA4VDIyNDg1M1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTM1NDhmYTY5NTBhYzQ1ZjRhNWU3ZGEyMjRjYmE2YjIzNTY2M2FjZTAzZjA5NDYxMjljOWNjYTkwODEyMzRiYTEmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.IxozzbnYsDqJns_0eJG4soTKostaDTXdID-uj4DhyQA)
System info (python version, jaxlib version, accelerator, etc.)
(internal to google, running around cl/723930299 on borg)
The text was updated successfully, but these errors were encountered: