Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected host-to-device transfer when slicing #26425

Open
alaurenzi opened this issue Feb 8, 2025 · 3 comments
Open

Unexpected host-to-device transfer when slicing #26425

alaurenzi opened this issue Feb 8, 2025 · 3 comments
Assignees
Labels
question Questions for the JAX team

Comments

@alaurenzi
Copy link

Description

Hi JAX maintainers and thanks for the awesome work!

It looks like slicing an on-device array is causing host-to-device transfers. I saw similar issues around (e.g. #16002) but wasn't sure if they're exactly the same.

import jax 

a = jax.random.uniform(jax.random.key(0), (10, 20))
print(a.device)
with jax.transfer_guard('disallow'):
    b = a[1:, :10]     
cuda:0
Traceback (most recent call last):
  File "/home/alaurenzi/code/mjx_playground/playground/src/kyon_mjx/ilqr/scratch.py", line 7, in <module>
    b = a[1:, :10]
        ~^^^^^^^^^
  File "/home/alaurenzi/code/mjx_playground/venv/mjx/lib/python3.12/site-packages/jax/_src/array.py", line 371, in __getitem__
    return lax_numpy._rewriting_take(self, idx)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alaurenzi/code/mjx_playground/venv/mjx/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 11924, in _rewriting_take
    if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None:
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alaurenzi/code/mjx_playground/venv/mjx/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 11901, in _attempt_rewriting_take_via_slice
    arr = lax.slice(
          ^^^^^^^^^^
  File "/home/alaurenzi/code/mjx_playground/venv/mjx/lib/python3.12/site-packages/jax/_src/lax/slicing.py", line 107, in slice
    return slice_p.bind(operand, start_indices=tuple(start_indices),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alaurenzi/code/mjx_playground/venv/mjx/lib/python3.12/site-packages/jax/_src/core.py", line 463, in bind
    return self.bind_with_trace(prev_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alaurenzi/code/mjx_playground/venv/mjx/lib/python3.12/site-packages/jax/_src/core.py", line 468, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alaurenzi/code/mjx_playground/venv/mjx/lib/python3.12/site-packages/jax/_src/core.py", line 941, in process_primitive
    return primitive.impl(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alaurenzi/code/mjx_playground/venv/mjx/lib/python3.12/site-packages/jax/_src/lax/slicing.py", line 1363, in _slice_impl
    return dispatch.apply_primitive(dynamic_slice_p, x, *start_indices,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alaurenzi/code/mjx_playground/venv/mjx/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Disallowed host-to-device transfer: aval=ShapedArray(int32[]), dst_sharding=SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.38
jaxlib: 0.4.38
numpy: 1.26.4
python: 3.12.3 (main, Nov 6 2024, 18:32:19) [GCC 13.2.0]
device info: NVIDIA GeForce RTX 3060-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='alaurenzi-iit-desktop', release='6.8.0-51-generic', version='#52-Ubuntu SMP PREEMPT_DYNAMIC Thu Dec 5 13:09:44 UTC 2024', machine='x86_64')

$ nvidia-smi
Sat Feb 8 16:01:23 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.08 Driver Version: 550.127.08 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 Off | N/A |
| 0% 40C P2 8W / 170W | 895MiB / 12288MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 3404 G /usr/lib/xorg/Xorg 550MiB |
| 0 N/A N/A 3666 G /usr/bin/gnome-shell 10MiB |
| 0 N/A N/A 4208 G /usr/libexec/xdg-desktop-portal-gnome 83MiB |
| 0 N/A N/A 4368 G ...3/usr/bin/snapd-desktop-integration 6MiB |
| 0 N/A N/A 588592 G ...erProcess --variations-seed-version 107MiB |
| 0 N/A N/A 937073 C python 108MiB |
| 0 N/A N/A 968851 G ...zi/Qt/Tools/QtCreator/bin/qtcreator 2MiB |
| 0 N/A N/A 1188083 G ...zi/Qt/Tools/QtCreator/bin/qtcreator 2MiB |
| 0 N/A N/A 2215079 G ...er/Qt/Tools/QtCreator/bin/qtcreator 2MiB |
+-----------------------------------------------------------------------------------------+

@alaurenzi alaurenzi added the bug Something isn't working label Feb 8, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Feb 10, 2025

Thanks for the report – this is an interesting one! It looks like this comes from the fact that the impl rule for slice_p calls out to dynamic_slice_p here:

jax/jax/_src/lax/slicing.py

Lines 1364 to 1365 in e64650e

return dispatch.apply_primitive(dynamic_slice_p, x, *start_indices,
slice_sizes=slice_sizes)

This means that the indices (which hare host-side Python integers) are transfered to the device within the implementation. In that sense, this error is being raised correctly: these host values are actually being transferred to device during your computation.

This slice_p -> dynamic_slice_p eager impl was put in for the sake of efficiency: it means that repeated slices with different indices in eager mode don't incur compilation overhead for each unique index.

If you want to dispatch on static indices, without incurring a host-to-device transfer for those indices, you could do so by putting the slice operation with in a jit so you're no longer on this eager dispatch path:

import jax 

a = jax.random.uniform(jax.random.key(0), (10, 20))
print(a.device)

with jax.transfer_guard('disallow'):
  b = jax.jit(lambda a: a[1:, :10])(a)

@jakevdp jakevdp self-assigned this Feb 10, 2025
@jakevdp jakevdp added question Questions for the JAX team and removed bug Something isn't working labels Feb 10, 2025
@alaurenzi
Copy link
Author

Thanks, this makes sense!
Although I'm afraid I might have simplified my real use case into an MRE a little too aggressively :D I believe in my real use case I'm actually jitting everything, but I still see the unexpected transfer

I'll try to come up with a more realistic MRE

@amacati
Copy link

amacati commented Feb 12, 2025

I have a somewhat related case, though complicated by the use of jax.default_device. It may or may not be the same issue

Consider the following code:

import jax
import jax.numpy as jp

device_name = "cpu"
device = jax.devices(device_name)[0]

# Normal Arrays do not get moved to the device
x = jp.ones((1, 1), device=device)
print(x.device)  # TFRT_CPU_0
print(jp.roll(x, -1, axis=-1).device)  # TFRT_CPU_0

# An Array with default_device gets created on the correct device...
with jax.default_device(device):
    y = jax.random.normal(jax.random.PRNGKey(0), (1, 1))
assert y.device == device, f"device mismatch: {y.device} != {device}"  # No assert error
print(y.device)  # TFRT_CPU_0
# ... but then moves unexpectedly. This does not happen if we cast y = jp.array(y, device=device)
print(jp.roll(y, -1, axis=-1).device)  # cuda:0
with jax.transfer_guard("disallow"):  # <-- jaxlib.xla_extension.XlaRuntimeError
    jp.roll(y, -1, axis=-1)

Clearly, y is not supposed to be transferred. Interestingly, this only happens when placing the random call into a jax.default_device context manager. Explicitly transferring to the CPU again (even though the device is already on the CPU per y.device assertion) prevents the error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants