Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy reductions: avoid upcast of f16 when dtype is specified by user #26403

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
bool_op=lax.bitwise_or, upcast_f16_for_computation=(dtype is None),
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.psum,
promote_integers=promote_integers)
Expand Down Expand Up @@ -319,7 +319,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric,
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
bool_op=lax.bitwise_and, upcast_f16_for_computation=(dtype is None),
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, promote_integers=promote_integers)

Expand Down Expand Up @@ -865,9 +865,10 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
[6. ]], dtype=float32)
"""
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
where=where)
where=where, upcast_f16_for_computation=(dtype is None))

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'upcast_f16_for_computation'),
inline=True)
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, *,
upcast_f16_for_computation: bool = True,
Expand Down
Loading