You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
jax.ops.segment_sum is more than 100x slower when segment_ids are sorted. This only happens for fp16, fp8. It does not happen for fp32 (here, the sorted case is about 50 percent slower than randomly shuffled indices).
Setting indices_are_sorted to False or True does not change this behavior.
There are some related issues with fp16, but not specific to sorted input: #20186 and #23136.
importjaximportjax.numpyasjnpdtype=jnp.float16key=jax.random.PRNGKey(0)
n=100_000_000num_segments=1_000_000x=jax.random.uniform(key, (n, ), dtype=dtype)
idx=jax.random.randint(key, shape=(n,), minval=0, maxval=num_segments)
idx_sorted=jnp.sort(idx)
%timeitjax.ops.segment_sum(x, idx, num_segments).block_until_ready()
%timeitjax.ops.segment_sum(x, idx_sorted, num_segments, indices_are_sorted=True).block_until_ready()
# 7.58 ms ± 145 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)# 1.15 s ± 892 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.30
jaxlib: 0.4.30
numpy: 2.0.2
python: 3.9.21 (main, Dec 5 2024, 00:00:00) [GCC 11.5.0 20240719 (Red Hat 11.5.0-2)]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='localhost.localdomain', release='5.14.0-503.22.1.el9_5.x86_64', version='#1 SMP PREEMPT_DYNAMIC Wed Jan 15 08:02:15 EST 2025', machine='x86_64')
$ nvidia-smi
Thu Jan 30 16:30:07 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.86.15 Driver Version: 570.86.15 CUDA Version: 12.8 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A40 Off | 00000000:25:00.0 Off | 0 |
| 0% 39C P0 76W / 300W | 34806MiB / 46068MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA A40 Off | 00000000:81:00.0 Off | Off |
| 0% 38C P0 76W / 300W | 538MiB / 49140MiB | 3% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 2725 G /usr/libexec/Xorg 108MiB |
| 0 N/A N/A 2815 G /usr/bin/gnome-shell 9MiB |
| 0 N/A N/A 24564 C /root/rafael/.venv/bin/python 34394MiB |
| 0 N/A N/A 25966 C python 262MiB |
| 1 N/A N/A 24564 C /root/rafael/.venv/bin/python 262MiB |
| 1 N/A N/A 25966 C python 262MiB |
+-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
@rafaelha Is it on GPU? If yes, then there was an implementation of sorted/vectorized scatter that landed in XLA this January. Could you try it with a newer JAX?
Description
jax.ops.segment_sum
is more than 100x slower whensegment_ids
are sorted. This only happens for fp16, fp8. It does not happen for fp32 (here, the sorted case is about 50 percent slower than randomly shuffled indices).Setting
indices_are_sorted
toFalse
orTrue
does not change this behavior.There are some related issues with fp16, but not specific to sorted input: #20186 and #23136.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: