From 819d6110ceab494d48cfcb8dfc72a94a9ad8187a Mon Sep 17 00:00:00 2001 From: Niket Kumar Bhumihar Date: Wed, 11 Dec 2024 21:40:51 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 705359803 --- .../orbax/checkpoint/_src/serialization/type_handlers.py | 1 + checkpoint/orbax/checkpoint/_src/serialization/types.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py index 0bae33d05..d892e9276 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py @@ -228,6 +228,7 @@ def _build_array_tspec_write( dtype=dtype, target_dtype=(arg.dtype if arg is not None else None), chunk_byte_size=(arg.chunk_byte_size if arg is not None else None), + shard_axes=(arg.shard_axes if arg is not None else None), use_zarr3=info.use_zarr3, use_ocdbt=use_ocdbt, process_id=process_index, diff --git a/checkpoint/orbax/checkpoint/_src/serialization/types.py b/checkpoint/orbax/checkpoint/_src/serialization/types.py index 77beff4e4..fbe7128bb 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/types.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/types.py @@ -157,11 +157,15 @@ class SaveArgs: specified chunk_byte_size. Both the write_chunk_shape and read_chunk_shape are automatically set to the chosen shape. This uses a greedy algorithm that prioritizes splitting the largest dimensions first. + shard_axes: An optional list of axes that should be prioritized when + sharding array for storage. If empty, storage sharding implementation will + prioritize axes which are already sharded. """ aggregate: bool = False dtype: Optional[jnp.dtype] = None chunk_byte_size: Optional[int] = None + shard_axes: tuple[int, ...] = tuple() def __post_init__(self): if self.aggregate: