Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewriting the kernel in Triton hinders the overall optimization benefits of persistent caching #26304

Open
MelodicDrumstep opened this issue Feb 4, 2025 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@MelodicDrumstep
Copy link

MelodicDrumstep commented Feb 4, 2025

Description

Hello! I have just started working with jax and am currently doing a GPU inference optimization project related to alphafold3. I tried to optimize the program using persistent cache and added the following configurations (no other modifications were made):

jax.config.update("jax_compilation_cache_dir", my_cache_dir)
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In the version where all operators are implemented using the XLA compiler, after adding the persistent cache configuration, the second run of the program shows a significant performance boost (around 2.5 times faster compared to the first run). However, when I replaced some of the operators with Triton implementations and used the same persistent cache configuration, the second run did not show any performance improvement over the first run. (Moreover, in this scenario, a cache file is generated anew with each run, rather than reusing the previously created file.)

I wonder if this indicates that the internal implementation of persistent cache in the jax framework is incompatible with Triton. If not, what could be the reasons for it to be ineffective? I don't know if this is more related to the internal implementation of alphafold3.

Appendix:

  • Alphafold3 version: 3.0.0
  • command:
    For alphafold3, by default, the flash attention operator is implemented using Triton. When using the XLA compiler, the --flash_attention_implementation=xla argument must be added to the run command.
    The command to run the XLA version python run_alphafold.py --json_path=.../XXX.json --model_dir=... --norun_data_pipeline --output_dir=.,.. --flash_attention_implementation=xla
    The command to run the Triton version is python run_alphafold.py --json_path=.../XXX.json --model_dir=... --norun_data_pipeline --output_dir=.,..
    (omitting file paths)

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.34
jaxlib: 0.4.34
numpy: 2.0.2
python: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0]
jax.devices (4 total, 4 local): [CudaDevice(id=0) CudaDevice(id=1) CudaDevice(id=2) CudaDevice(id=3)]
process_count: 1
platform: uname_result(system='Linux', node='gpu20.pi.sjtu.edu.cn', release='4.18.0-513.9.1.el8_9.x86_64', version='#1 SMP Wed Nov 29 18:55:19 UTC 2023', machine='x86_64')

$ nvidia-smi
Tue Feb 4 20:51:24 2025
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM4-40GB On | 00000000:31:00.0 Off | 0 |
| N/A 39C P0 58W / 400W | 430MiB / 40960MiB | 3% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-SXM4-40GB On | 00000000:4B:00.0 Off | 0 |
| N/A 39C P0 64W / 400W | 430MiB / 40960MiB | 1% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA A100-SXM4-40GB On | 00000000:CA:00.0 Off | 0 |
| N/A 40C P0 63W / 400W | 430MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA A100-SXM4-40GB On | 00000000:E3:00.0 Off | 0 |
| N/A 39C P0 56W / 400W | 430MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 874029 C python3 416MiB |
| 1 N/A N/A 874029 C python3 416MiB |
| 2 N/A N/A 874029 C python3 416MiB |
| 3 N/A N/A 874029 C python3 416MiB |
+---------------------------------------------------------------------------------------+

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants