Skip to content

Commit

Permalink
Merge pull request #26498 from jakevdp:jnp-indexing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726917490
  • Loading branch information
Google-ML-Automation committed Feb 14, 2025
2 parents cdcf35f + f750d0b commit 794ae0f
Show file tree
Hide file tree
Showing 10 changed files with 1,293 additions and 1,250 deletions.
4 changes: 2 additions & 2 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def __format__(self, format_spec):

def __getitem__(self, idx):
from jax._src.lax import lax
from jax._src.numpy import lax_numpy
from jax._src.numpy import indexing
self._check_if_deleted()

if isinstance(self.sharding, PmapSharding):
Expand Down Expand Up @@ -418,7 +418,7 @@ def __getitem__(self, idx):
return ArrayImpl(
out.aval, sharding, [out], committed=False, _skip_checks=True)

return lax_numpy._rewriting_take(self, idx)
return indexing.rewriting_take(self, idx)

def __iter__(self):
if self.ndim == 0:
Expand Down
9 changes: 5 additions & 4 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc
from jax._src.numpy import array_api_metadata
from jax._src.numpy import indexing
from jax._src.numpy import lax_numpy
from jax._src.numpy import tensor_contractions
from jax._src import mesh as mesh_lib
Expand Down Expand Up @@ -382,8 +383,8 @@ def _take(self: Array, indices: ArrayLike, axis: int | None = None, out: None =
Refer to :func:`jax.numpy.take` for full documentation.
"""
return lax_numpy.take(self, indices, axis=axis, out=out, mode=mode, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value)
return indexing.take(self, indices, axis=axis, out=out, mode=mode, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value)

def _to_device(self: Array, device: xc.Device | Sharding, *,
stream: int | Any | None = None):
Expand Down Expand Up @@ -649,7 +650,7 @@ def _chunk_iter(x, size):
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)

def _getitem(self, item):
return lax_numpy._rewriting_take(self, item)
return indexing.rewriting_take(self, item)

# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:
Expand Down Expand Up @@ -777,7 +778,7 @@ def get(self, *, indices_are_sorted=False, unique_indices=False,
See :mod:`jax.ops` for details.
"""
take = partial(lax_numpy._rewriting_take,
take = partial(indexing.rewriting_take,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode,
fill_value=fill_value)
Expand Down
Loading

0 comments on commit 794ae0f

Please sign in to comment.