diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 51f00c56dd6a..11cfc42e9578 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -639,8 +639,9 @@ def _options_from_jax_configs(plugin_name): "Should be in format 'key:value'") options[option_list[0]] = option_list[1] - if plugin_name == "cuda": - visible_devices = CUDA_VISIBLE_DEVICES.value + if plugin_name in ("cuda", "rocm"): + visible_devices = (CUDA_VISIBLE_DEVICES.value if plugin_name == "cuda" + else _ROCM_VISIBLE_DEVICES.value) if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None