You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello! I have just started working with jax and am currently doing a GPU inference optimization project related to alphafold3. I tried to optimize the program using persistent cache and added the following configurations (no other modifications were made):
In the version where all operators are implemented using the XLA compiler, after adding the persistent cache configuration, the second run of the program shows a significant performance boost (around 2.5 times faster compared to the first run). However, when I replaced some of the operators with Triton implementations and used the same persistent cache configuration, the second run did not show any performance improvement over the first run. (Moreover, in this scenario, a cache file is generated anew with each run, rather than reusing the previously created file.)
I wonder if this indicates that the internal implementation of persistent cache in the jax framework is incompatible with Triton. If not, what could be the reasons for it to be ineffective? I don't know if this is more related to the internal implementation of alphafold3.
Appendix:
Alphafold3 version: 3.0.0
command:
For alphafold3, by default, the flash attention operator is implemented using Triton. When using the XLA compiler, the --flash_attention_implementation=xla argument must be added to the run command.
The command to run the XLA version python run_alphafold.py --json_path=.../XXX.json --model_dir=... --norun_data_pipeline --output_dir=.,.. --flash_attention_implementation=xla
The command to run the Triton version is python run_alphafold.py --json_path=.../XXX.json --model_dir=... --norun_data_pipeline --output_dir=.,..
(omitting file paths)
System info (python version, jaxlib version, accelerator, etc.)
Description
Hello! I have just started working with jax and am currently doing a GPU inference optimization project related to alphafold3. I tried to optimize the program using persistent cache and added the following configurations (no other modifications were made):
In the version where all operators are implemented using the XLA compiler, after adding the persistent cache configuration, the second run of the program shows a significant performance boost (around 2.5 times faster compared to the first run). However, when I replaced some of the operators with Triton implementations and used the same persistent cache configuration, the second run did not show any performance improvement over the first run. (Moreover, in this scenario, a cache file is generated anew with each run, rather than reusing the previously created file.)
I wonder if this indicates that the internal implementation of persistent cache in the jax framework is incompatible with Triton. If not, what could be the reasons for it to be ineffective? I don't know if this is more related to the internal implementation of alphafold3.
Appendix:
For alphafold3, by default, the flash attention operator is implemented using Triton. When using the XLA compiler, the
--flash_attention_implementation=xla
argument must be added to the run command.The command to run the XLA version
python run_alphafold.py --json_path=.../XXX.json --model_dir=... --norun_data_pipeline --output_dir=.,.. --flash_attention_implementation=xla
The command to run the Triton version is
python run_alphafold.py --json_path=.../XXX.json --model_dir=... --norun_data_pipeline --output_dir=.,..
(omitting file paths)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34
jaxlib: 0.4.34
numpy: 2.0.2
python: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0]
jax.devices (4 total, 4 local): [CudaDevice(id=0) CudaDevice(id=1) CudaDevice(id=2) CudaDevice(id=3)]
process_count: 1
platform: uname_result(system='Linux', node='gpu20.pi.sjtu.edu.cn', release='4.18.0-513.9.1.el8_9.x86_64', version='#1 SMP Wed Nov 29 18:55:19 UTC 2023', machine='x86_64')
$ nvidia-smi
Tue Feb 4 20:51:24 2025
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM4-40GB On | 00000000:31:00.0 Off | 0 |
| N/A 39C P0 58W / 400W | 430MiB / 40960MiB | 3% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-SXM4-40GB On | 00000000:4B:00.0 Off | 0 |
| N/A 39C P0 64W / 400W | 430MiB / 40960MiB | 1% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA A100-SXM4-40GB On | 00000000:CA:00.0 Off | 0 |
| N/A 40C P0 63W / 400W | 430MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA A100-SXM4-40GB On | 00000000:E3:00.0 Off | 0 |
| N/A 39C P0 56W / 400W | 430MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 874029 C python3 416MiB |
| 1 N/A N/A 874029 C python3 416MiB |
| 2 N/A N/A 874029 C python3 416MiB |
| 3 N/A N/A 874029 C python3 416MiB |
+---------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered: