Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Check "jax_rocm_visible_devices" at client creation.
This aligns rocm with cuda when using jax.distributed in combination with one of the mechanisms for cluster-autodetection that set visible devices in the "jax_rocm_visible_devices" flag. Fixes jax-ml#26298
- Loading branch information