Skip to content

Commit

Permalink
Don't pass dtype to lax_internal._zero
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 719119309
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 24, 2025
1 parent 8e1f956 commit aa46a56
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def apply(self, func, *, indices_are_sorted=False, unique_indices=False,
def _scatter_apply(x, indices, y, dims, **kwargs):
return lax.scatter_apply(x, indices, func, dims, update_shape=y.shape, **kwargs)
return scatter._scatter_update(self.array, self.index,
lax_internal._zero(self.array.dtype),
lax_internal._zero(self.array),
_scatter_apply,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
Expand Down

0 comments on commit aa46a56

Please sign in to comment.