Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix str + repr of multi-slice Devices #26418

Open
skye opened this issue Feb 8, 2025 · 0 comments
Open

Fix str + repr of multi-slice Devices #26418

skye opened this issue Feb 8, 2025 · 0 comments
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@skye
Copy link
Member

skye commented Feb 8, 2025

I think multi-slice TPU devices should still be called TpuDevice instead of MegaScalePjRtDevice. You can already tell the device is part of a multi-slice config because it has a slice_id. It's very verbose right now:

In [2]: d = jax.devices()[0]

In [3]: str(d)
Out[3]: 'MegaScalePjRtDevice(wrapped=TPU_0(process=0,(0,0,0,0)), slice_id=0)'

In [4]: repr(d)
Out[4]: 'MegaScalePjRtDevice(wrapped=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), slice_id=0)'
@skye skye added enhancement New feature or request good first issue Good for newcomers labels Feb 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

1 participant