diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 8403b8876382..e00b5b3e87af 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -11649,7 +11649,8 @@ def replace(tup, val): j = 0 for i in range(rank): if i == axis_int: - indices = _normalize_index(indices, axis_size) + if mode != 'promise_in_bounds': + indices = _normalize_index(indices, axis_size) gather_indices.append(lax.reshape(indices, gather_index_shape)) slice_sizes.append(1) start_index_map.append(i) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 92185cc2c23c..908b4760465e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2129,7 +2129,11 @@ def _gather_lowering_rule( slice_sizes == (1, 1) and not unique_indices and not indices_are_sorted - and mode == lax.GatherScatterMode.FILL_OR_DROP + and mode + in ( + lax.GatherScatterMode.FILL_OR_DROP, + lax.GatherScatterMode.PROMISE_IN_BOUNDS, + ) ): if dimension_numbers == lax.GatherDimensionNumbers( offset_dims=(), diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index df2d5cb67310..73170aa7a25b 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -389,10 +389,10 @@ def kernel(x_ref, o_ref): ref = jax.jit(lambda x: round_fn(x).astype(target))(x) np.testing.assert_array_equal(out, ref) - @parameterized.product(axis=[0, 1]) - def test_dynamic_gather_along_axis(self, axis): - if not jtu.if_cloud_tpu_at_least(2025, 2, 3): - self.skipTest("Requires libtpu built after 2025-02-03") + @parameterized.product(axis=[0, 1], mode=["promise_in_bounds", None]) + def test_dynamic_gather_along_axis(self, axis, mode): + if not jtu.if_cloud_tpu_at_least(2025, 2, 5): + self.skipTest("Requires libtpu built after 2025-02-05") if (axis == 0 and not jtu.is_device_tpu_at_least(version=5)) or ( axis == 1 and not jtu.is_device_tpu_at_least(version=4) ): @@ -401,7 +401,7 @@ def test_dynamic_gather_along_axis(self, axis): shape = (8, 128) def kernel(x, indices, out): - out[...] = jnp.take_along_axis(x[...], indices[...], axis) + out[...] = jnp.take_along_axis(x[...], indices[...], axis, mode=mode) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) idx = jax.random.randint(