Skip to content

Commit

Permalink
Check "jax_rocm_visible_devices" at client creation.
Browse files Browse the repository at this point in the history
This aligns rocm with cuda when using jax.distributed in combination
with one of the mechanisms for cluster-autodetection that set visible
devices in the "jax_rocm_visible_devices" flag.

Fixes jax-ml#26298
  • Loading branch information
Sebastian Kehl committed Feb 5, 2025
1 parent c07b6b5 commit 68fdba7
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,8 +639,9 @@ def _options_from_jax_configs(plugin_name):
"Should be in format 'key:value'")
options[option_list[0]] = option_list[1]

if plugin_name == "cuda":
visible_devices = CUDA_VISIBLE_DEVICES.value
if plugin_name in ("cuda", "rocm"):
visible_devices = (CUDA_VISIBLE_DEVICES.value if plugin_name == "cuda"
else _ROCM_VISIBLE_DEVICES.value)
if visible_devices != 'all':
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None
Expand Down

0 comments on commit 68fdba7

Please sign in to comment.