Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.ops.segment_sum is 100x slower when segment_ids are sorted, but only for float16 #26227

Open
rafaelha opened this issue Jan 31, 2025 · 2 comments
Assignees
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@rafaelha
Copy link

Description

jax.ops.segment_sum is more than 100x slower when segment_ids are sorted. This only happens for fp16, fp8. It does not happen for fp32 (here, the sorted case is about 50 percent slower than randomly shuffled indices).

Setting indices_are_sorted to False or True does not change this behavior.

There are some related issues with fp16, but not specific to sorted input: #20186 and #23136.

import jax
import jax.numpy as jnp

dtype = jnp.float16

key = jax.random.PRNGKey(0)

n = 100_000_000
num_segments = 1_000_000
x = jax.random.uniform(key, (n, ), dtype=dtype)
idx = jax.random.randint(key, shape=(n,), minval=0, maxval=num_segments)
idx_sorted = jnp.sort(idx)

%timeit jax.ops.segment_sum(x, idx, num_segments).block_until_ready()
%timeit jax.ops.segment_sum(x, idx_sorted, num_segments, indices_are_sorted=True).block_until_ready()
# 7.58 ms ± 145 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 1.15 s ± 892 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.2
python: 3.9.21 (main, Dec  5 2024, 00:00:00)  [GCC 11.5.0 20240719 (Red Hat 11.5.0-2)]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='localhost.localdomain', release='5.14.0-503.22.1.el9_5.x86_64', version='#1 SMP PREEMPT_DYNAMIC Wed Jan 15 08:02:15 EST 2025', machine='x86_64')


$ nvidia-smi
Thu Jan 30 16:30:07 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.86.15              Driver Version: 570.86.15      CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A40                     Off |   00000000:25:00.0 Off |                    0 |
|  0%   39C    P0             76W /  300W |   34806MiB /  46068MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A40                     Off |   00000000:81:00.0 Off |                  Off |
|  0%   38C    P0             76W /  300W |     538MiB /  49140MiB |      3%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            2725      G   /usr/libexec/Xorg                       108MiB |
|    0   N/A  N/A            2815      G   /usr/bin/gnome-shell                      9MiB |
|    0   N/A  N/A           24564      C   /root/rafael/.venv/bin/python         34394MiB |
|    0   N/A  N/A           25966      C   python                                  262MiB |
|    1   N/A  N/A           24564      C   /root/rafael/.venv/bin/python           262MiB |
|    1   N/A  N/A           25966      C   python                                  262MiB |
+-----------------------------------------------------------------------------------------+
@rafaelha rafaelha added the bug Something isn't working label Jan 31, 2025
@pifon2a
Copy link
Contributor

pifon2a commented Feb 3, 2025

@rafaelha Is it on GPU? If yes, then there was an implementation of sorted/vectorized scatter that landed in XLA this January. Could you try it with a newer JAX?

@hawkinsp hawkinsp added the NVIDIA GPU Issues specific to NVIDIA GPUs label Feb 3, 2025
@rafaelha
Copy link
Author

rafaelha commented Feb 4, 2025

Yes, this is running on a A40 GPU. I upgraded jax to version 0.5.0 (details below) but still get the same slowdown for ordered indices.

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.2
python: 3.11.5 (main, Feb  3 2025, 18:20:00) [GCC 11.5.0 20240719 (Red Hat 11.5.0-2)]
device info: NVIDIA A40-2, 2 local devices"
process_count: 1
platform: uname_result(system='Linux', node='localhost.localdomain', release='5.14.0-503.22.1.el9_5.x86_64', version='#1 SMP PREEMPT_DYNAMIC Wed Jan 15 08:02:15 EST 2025', machine='x86_64')


$ nvidia-smi
Mon Feb  3 18:30:57 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.86.15              Driver Version: 570.86.15      CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A40                     Off |   00000000:25:00.0 Off |                    0 |
|  0%   38C    P0             77W /  300W |   34541MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A40                     Off |   00000000:81:00.0 Off |                  Off |
|  0%   36C    P0             74W /  300W |     271MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            2725      G   /usr/libexec/Xorg                       108MiB |
|    0   N/A  N/A            2815      G   /usr/bin/gnome-shell                      9MiB |
|    0   N/A  N/A           91412      C   /root/rafael3/.venv/bin/python        34396MiB |
|    1   N/A  N/A           91412      C   /root/rafael3/.venv/bin/python          262MiB |
+-----------------------------------------------------------------------------------------+

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests

3 participants