-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unexpected host-to-device transfer when slicing #26425
Comments
Thanks for the report – this is an interesting one! It looks like this comes from the fact that the impl rule for Lines 1364 to 1365 in e64650e
This means that the indices (which hare host-side Python integers) are transfered to the device within the implementation. In that sense, this error is being raised correctly: these host values are actually being transferred to device during your computation. This If you want to dispatch on static indices, without incurring a host-to-device transfer for those indices, you could do so by putting the slice operation with in a import jax
a = jax.random.uniform(jax.random.key(0), (10, 20))
print(a.device)
with jax.transfer_guard('disallow'):
b = jax.jit(lambda a: a[1:, :10])(a) |
Thanks, this makes sense! I'll try to come up with a more realistic MRE |
I have a somewhat related case, though complicated by the use of Consider the following code: import jax
import jax.numpy as jp
device_name = "cpu"
device = jax.devices(device_name)[0]
# Normal Arrays do not get moved to the device
x = jp.ones((1, 1), device=device)
print(x.device) # TFRT_CPU_0
print(jp.roll(x, -1, axis=-1).device) # TFRT_CPU_0
# An Array with default_device gets created on the correct device...
with jax.default_device(device):
y = jax.random.normal(jax.random.PRNGKey(0), (1, 1))
assert y.device == device, f"device mismatch: {y.device} != {device}" # No assert error
print(y.device) # TFRT_CPU_0
# ... but then moves unexpectedly. This does not happen if we cast y = jp.array(y, device=device)
print(jp.roll(y, -1, axis=-1).device) # cuda:0
with jax.transfer_guard("disallow"): # <-- jaxlib.xla_extension.XlaRuntimeError
jp.roll(y, -1, axis=-1) Clearly, |
Description
Hi JAX maintainers and thanks for the awesome work!
It looks like slicing an on-device array is causing host-to-device transfers. I saw similar issues around (e.g. #16002) but wasn't sure if they're exactly the same.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.38
jaxlib: 0.4.38
numpy: 1.26.4
python: 3.12.3 (main, Nov 6 2024, 18:32:19) [GCC 13.2.0]
device info: NVIDIA GeForce RTX 3060-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='alaurenzi-iit-desktop', release='6.8.0-51-generic', version='#52-Ubuntu SMP PREEMPT_DYNAMIC Thu Dec 5 13:09:44 UTC 2024', machine='x86_64')
$ nvidia-smi
Sat Feb 8 16:01:23 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.08 Driver Version: 550.127.08 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 Off | N/A |
| 0% 40C P2 8W / 170W | 895MiB / 12288MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 3404 G /usr/lib/xorg/Xorg 550MiB |
| 0 N/A N/A 3666 G /usr/bin/gnome-shell 10MiB |
| 0 N/A N/A 4208 G /usr/libexec/xdg-desktop-portal-gnome 83MiB |
| 0 N/A N/A 4368 G ...3/usr/bin/snapd-desktop-integration 6MiB |
| 0 N/A N/A 588592 G ...erProcess --variations-seed-version 107MiB |
| 0 N/A N/A 937073 C python 108MiB |
| 0 N/A N/A 968851 G ...zi/Qt/Tools/QtCreator/bin/qtcreator 2MiB |
| 0 N/A N/A 1188083 G ...zi/Qt/Tools/QtCreator/bin/qtcreator 2MiB |
| 0 N/A N/A 2215079 G ...er/Qt/Tools/QtCreator/bin/qtcreator 2MiB |
+-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered: