Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705359803
  • Loading branch information
niketkumar authored and Orbax Authors committed Jan 8, 2025
1 parent db6bd4a commit 1974e62
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def _build_array_tspec_write(
dtype=dtype,
target_dtype=(arg.dtype if arg is not None else None),
chunk_byte_size=(arg.chunk_byte_size if arg is not None else None),
shard_axes=(arg.shard_axes if arg is not None else None),
use_zarr3=info.use_zarr3,
use_ocdbt=use_ocdbt,
process_id=process_index,
Expand Down
4 changes: 4 additions & 0 deletions checkpoint/orbax/checkpoint/_src/serialization/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,15 @@ class SaveArgs:
specified chunk_byte_size. Both the write_chunk_shape and read_chunk_shape
are automatically set to the chosen shape. This uses a greedy algorithm that
prioritizes splitting the largest dimensions first.
shard_axes: An optional list of axes that should be prioritized when
sharding array for storage. If empty, storage sharding implementation will
prioritize axes which are already sharded.
"""

aggregate: bool = False
dtype: Optional[jnp.dtype] = None
chunk_byte_size: Optional[int] = None
shard_axes: tuple[int, ...] = tuple()

def __post_init__(self):
if self.aggregate:
Expand Down

0 comments on commit 1974e62

Please sign in to comment.