Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705359803
  • Loading branch information
niketkumar authored and Orbax Authors committed Dec 26, 2024
1 parent c227788 commit a8abbfe
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 2 deletions.
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def as_user_metadata(
is_ocdbt_checkpoint=use_ocdbt,
use_zarr3=self.use_zarr3,
ts_context=ts_context,
write_shape=value_meta.write_shape,
)
flat_restore_types[keypath] = value_meta.value_type

Expand Down
2 changes: 2 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from etils import epath
import jax
from jax import numpy as jnp
from orbax.checkpoint._src.arrays import types as arrays_types
from orbax.checkpoint._src.metadata import sharding as sharding_metadata


Expand Down Expand Up @@ -55,6 +56,7 @@ class StorageMetadata:
"""Metadata describing how arrays are stored in a checkpoint."""

chunk_shape: Optional[tuple[int, ...]]
write_shape: arrays_types.Shape | None = None


@dataclasses.dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import dataclasses
from typing import Any, Dict

from orbax.checkpoint._src.arrays import types as arrays_types
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.serialization import types


_VALUE_TYPE = 'value_type'
_SKIP_DESERIALIZE = 'skip_deserialize'
_WRITE_SHAPE = 'write_shape'


@dataclasses.dataclass
Expand All @@ -41,12 +43,16 @@ class ValueMetadataEntry:

value_type: str
skip_deserialize: bool = False
write_shape: arrays_types.Shape | None = None

def to_json(self) -> Dict[str, Any]:
return {
json_dict = {
_VALUE_TYPE: self.value_type,
_SKIP_DESERIALIZE: self.skip_deserialize,
}
if self.write_shape is not None:
json_dict[_WRITE_SHAPE] = self.write_shape
return json_dict

@classmethod
def from_json(
Expand All @@ -60,6 +66,11 @@ def from_json(
pytree_metadata_options,
),
skip_deserialize=json_dict[_SKIP_DESERIALIZE],
write_shape=(
tuple(json_dict[_WRITE_SHAPE])
if _WRITE_SHAPE in json_dict
else None
),
)

@classmethod
Expand All @@ -69,6 +80,7 @@ def build(
save_arg: types.SaveArgs,
) -> ValueMetadataEntry:
"""Builds a ValueMetadataEntry."""
# TODO(niket): Add support for `write_shape`.
del save_arg
if info.value_typestr is None:
raise AssertionError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def _build_array_tspec_write(
dtype=dtype,
target_dtype=(arg.dtype if arg is not None else None),
chunk_byte_size=(arg.chunk_byte_size if arg is not None else None),
shard_axes=(arg.shard_axes if arg is not None else None),
use_zarr3=info.use_zarr3,
use_ocdbt=use_ocdbt,
process_id=process_index,
Expand Down Expand Up @@ -503,7 +504,8 @@ def _array_metadata_from_tensorstore(
dtype=jnp.dtype(t.dtype.name),
sharding=sharding,
storage=value_metadata.StorageMetadata(
chunk_shape=t.chunk_layout.read_chunk_template.shape
chunk_shape=t.chunk_layout.read_chunk_template.shape,
write_shape=info.write_shape,
),
)

Expand Down
8 changes: 8 additions & 0 deletions checkpoint/orbax/checkpoint/_src/serialization/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import future
from orbax.checkpoint._src.arrays import types as arrays_types
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.metadata import value as value_metadata
Expand Down Expand Up @@ -116,6 +117,8 @@ class ParamInfo:
raise_array_data_missing_error:
Only used for restoring. See documentation in `tensorstore_utils.py`. Comes
from tree metadata and should be the same across all parameters.
write_shape:
Shape of the array shard. Used in the subchunking context.
"""

name: Optional[str] = None
Expand All @@ -130,6 +133,7 @@ class ParamInfo:
value_typestr: Optional[str] = None
enable_pinned_host_transfer: bool = True
raise_array_data_missing_error: bool = True
write_shape: arrays_types.Shape | None = None


@dataclasses.dataclass
Expand All @@ -153,11 +157,15 @@ class SaveArgs:
specified chunk_byte_size. Both the write_chunk_shape and read_chunk_shape
are automatically set to the chosen shape. This uses a greedy algorithm that
prioritizes splitting the largest dimensions first.
shard_axes: An optional list of axes that should be prioritized when
sharding array for storage. If empty, storage sharding implementation will
prioritize axes which are already sharded.
"""

aggregate: bool = False
dtype: Optional[jnp.dtype] = None
chunk_byte_size: Optional[int] = None
shard_axes: tuple[int, ...] = tuple()

def __post_init__(self):
if self.aggregate:
Expand Down

0 comments on commit a8abbfe

Please sign in to comment.