Skip to content

Commit

Permalink
Fix segfault when old GPU plugins are installed.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726908959
  • Loading branch information
dfm authored and Google-ML-Automation committed Feb 14, 2025
1 parent 4df5961 commit 1755c43
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
8 changes: 6 additions & 2 deletions jaxlib/gpu_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@

if _cuda_linalg:
for _name, _value in _cuda_linalg.registrations().items():
# TODO(danfm): remove after JAX 0.5.1 release
api_version = 1 if "_ffi" in _name else 0
xla_client.register_custom_call_target(
_name, _value, platform="CUDA", api_version=1
_name, _value, platform="CUDA", api_version=api_version
)

if _hip_linalg:
for _name, _value in _hip_linalg.registrations().items():
# TODO(danfm): remove after JAX 0.5.1 release
api_version = 1 if "_ffi" in _name else 0
xla_client.register_custom_call_target(
_name, _value, platform="ROCM", api_version=1
_name, _value, platform="ROCM", api_version=api_version
)
6 changes: 4 additions & 2 deletions jaxlib/gpu_prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@

if _cuda_prng:
for _name, _value in _cuda_prng.registrations().items():
# TODO(danfm): remove after JAX 0.5.1 release
api_version = 1 if "_ffi" in _name else 0
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=1)
api_version=api_version)

for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
try:
Expand All @@ -51,7 +53,7 @@

if _hip_prng:
for _name, _value in _hip_prng.registrations().items():
# TODO(b/338022728): remove after 6 months, always api_version=1
# TODO(danfm): remove after JAX 0.5.1 release
api_version = 1 if "_ffi" in _name else 0
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=api_version)
Expand Down

0 comments on commit 1755c43

Please sign in to comment.