Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need to add because parent CL updates Orbax codebase. #1369

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def _param_info(name, value):
byte_limiter=byte_limiter,
ts_context=ts_context,
value_typestr=types.get_param_typestr(
value, self._type_handler_registry
value, self._type_handler_registry, self._pytree_metadata_options
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def _get_param_info(
name: str,
meta_or_value: Union[Any, tree_metadata.ValueMetadataEntry],
) -> Union[ParamInfo, Any]:
if empty_values.is_supported_empty_value(meta_or_value):
if empty_values.is_supported_empty_value(
meta_or_value, pytree_metadata_options
):
# Empty node, ParamInfo should not be returned.
return meta_or_value
elif not isinstance(meta_or_value, tree_metadata.ValueMetadataEntry):
Expand Down Expand Up @@ -917,7 +919,7 @@ def _get_internal_metadata(

def _is_empty_value(value):
return empty_values.is_supported_empty_value(
value
value, self._pytree_metadata_options
) or not utils.leaf_is_placeholder(value)

def _process_aggregate_leaf(value):
Expand Down
34 changes: 26 additions & 8 deletions checkpoint/orbax/checkpoint/_src/metadata/empty_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,48 @@

"""Handles empty values in the checkpoint PyTree."""

import collections
from typing import Any, Mapping
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.tree import utils as tree_utils

PyTreeMetadataOptions = pytree_metadata_options_lib.PyTreeMetadataOptions

RESTORE_TYPE_NONE = 'None'
RESTORE_TYPE_DICT = 'Dict'
RESTORE_TYPE_LIST = 'List'
RESTORE_TYPE_TUPLE = 'Tuple'
RESTORE_TYPE_NAMED_TUPLE = 'NamedTuple'
RESTORE_TYPE_UNKNOWN = 'Unknown'
# TODO: b/365169723 - Handle empty NamedTuple.


# TODO: b/365169723 - Handle empty NamedTuple.
def is_supported_empty_value(value: Any) -> bool:
def is_supported_empty_value(
value: Any,
pytree_metadata_options: PyTreeMetadataOptions = (
pytree_metadata_options_lib.PYTREE_METADATA_OPTIONS
),
) -> bool:
"""Determines if the *empty* `value` is supported without custom TypeHandler."""
# Check isinstance first to avoid `not` checks on jax.Arrays (raises error).
if tree_utils.isinstance_of_namedtuple(value):
if pytree_metadata_options.support_rich_types and not value:
return True
return False
return (
isinstance(value, (dict, list, tuple, type(None), Mapping)) and not value
)


# TODO: b/365169723 - Handle empty NamedTuple.
def get_empty_value_typestr(value: Any) -> str:
def get_empty_value_typestr(
value: Any, pytree_metadata_options: PyTreeMetadataOptions
) -> str:
"""Returns the typestr constant for the empty value."""
if not is_supported_empty_value(value):
if not is_supported_empty_value(value, pytree_metadata_options):
raise ValueError(f'{value} is not a supported empty type.')
if isinstance(value, list):
return RESTORE_TYPE_LIST
if tree_utils.isinstance_of_namedtuple(value): # Call before tuple check.
return RESTORE_TYPE_NAMED_TUPLE
if isinstance(value, tuple):
return RESTORE_TYPE_TUPLE
if isinstance(value, (dict, Mapping)):
Expand All @@ -52,20 +65,25 @@ def get_empty_value_typestr(value: Any) -> str:
raise ValueError(f'Unrecognized empty type: {value}.')


# TODO: b/365169723 - Handle empty NamedTuple.
def is_empty_typestr(typestr: str) -> bool:
return (
typestr == RESTORE_TYPE_LIST
or typestr == RESTORE_TYPE_NAMED_TUPLE
or typestr == RESTORE_TYPE_TUPLE
or typestr == RESTORE_TYPE_DICT
or typestr == RESTORE_TYPE_NONE
)


# TODO: b/365169723 - Handle empty NamedTuple.
class OrbaxEmptyNamedTuple(collections.namedtuple('OrbaxEmptyNamedTuple', ())):
pass


def get_empty_value_from_typestr(typestr: str) -> Any:
if typestr == RESTORE_TYPE_LIST:
return []
if typestr == RESTORE_TYPE_NAMED_TUPLE:
return OrbaxEmptyNamedTuple()
if typestr == RESTORE_TYPE_TUPLE:
return tuple()
if typestr == RESTORE_TYPE_DICT:
Expand Down
52 changes: 38 additions & 14 deletions checkpoint/orbax/checkpoint/_src/metadata/empty_values_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,51 @@
from absl.testing import absltest
from absl.testing import parameterized
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import pytree_metadata_options
from orbax.checkpoint._src.testing import test_tree_utils


class EmptyValuesTest(parameterized.TestCase):

@parameterized.parameters(
(1, False),
(dict(), True),
({}, True),
({"a": {}}, False),
([], True),
([[]], False),
(None, True),
((1, 2), False),
(test_tree_utils.EmptyNamedTuple(), False),
(test_tree_utils.MuNu(mu=None, nu=None), False),
(test_tree_utils.NamedTupleWithNestedAttributes(), False),
(test_tree_utils.NamedTupleWithNestedAttributes(nested_dict={}), False),
(1, False, False),
(dict(), True, True),
({}, True, True),
({"a": {}}, False, False),
([], True, True),
([[]], False, False),
(None, True, True),
((1, 2), False, False),
(test_tree_utils.EmptyNamedTuple(), False, True),
(test_tree_utils.MuNu(mu=None, nu=None), False, False),
(test_tree_utils.NamedTupleWithNestedAttributes(), False, False),
(
test_tree_utils.NamedTupleWithNestedAttributes(nested_dict={}),
False,
False,
),
)
def test_is_supported_empty_value(self, value, expected):
self.assertEqual(expected, empty_values.is_supported_empty_value(value))
def test_is_supported_empty_value(self, value, expected, expected_rich_type):
with self.subTest("legacy_metadata"):
self.assertEqual(
expected,
empty_values.is_supported_empty_value(
value,
pytree_metadata_options.PyTreeMetadataOptions(
support_rich_types=False
),
),
)
with self.subTest("rich_typed_metadata"):
self.assertEqual(
expected_rich_type,
empty_values.is_supported_empty_value(
value,
pytree_metadata_options.PyTreeMetadataOptions(
support_rich_types=True
),
),
)


if __name__ == "__main__":
Expand Down
43 changes: 10 additions & 33 deletions checkpoint/orbax/checkpoint/_src/metadata/tree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@
from orbax.checkpoint._src.tree import utils as tree_utils


def _to_param_infos(tree: Any):
def _to_param_infos(
tree: Any,
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions,
):
return jax.tree.map(
# Other properties are not relevant.
lambda x: types.ParamInfo(
value_typestr=types.get_param_typestr(
x, type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY
x,
type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY,
pytree_metadata_options,
)
),
tree,
Expand All @@ -40,38 +45,10 @@ def _to_param_infos(tree: Any):

class InternalTreeMetadataEntryTest(parameterized.TestCase):

@parameterized.named_parameters(
test_tree_utils.TEST_PYTREES_FOR_NAMED_PARAMETERS
)
def test_json_conversion(self, test_pytree: test_tree_utils.TestPyTree):
tree = test_pytree.provide_tree()
param_infos_tree = _to_param_infos(tree)
internal_tree_metadata = tree_metadata.InternalTreeMetadata.build(
param_infos_tree
)
internal_tree_metadata_json = internal_tree_metadata.to_json()

# Round trip check for json conversion.
self.assertCountEqual(
internal_tree_metadata.tree_metadata_entries,
(
tree_metadata.InternalTreeMetadata.from_json(
internal_tree_metadata_json
)
).tree_metadata_entries,
)

# Specifically check _TREE_METADATA_KEY.
self.assertDictEqual(
test_pytree.expected_tree_metadata_key_json,
internal_tree_metadata_json[tree_metadata._TREE_METADATA_KEY],
)

@parameterized.product(
test_pytree=test_tree_utils.TEST_PYTREES,
pytree_metadata_options=[
# TODO: b/365169723 - Re-enable if needed at all.
# tree_metadata.PyTreeMetadataOptions(support_rich_types=False),
tree_metadata.PyTreeMetadataOptions(support_rich_types=False),
tree_metadata.PyTreeMetadataOptions(support_rich_types=True),
],
)
Expand All @@ -82,7 +59,7 @@ def test_as_nested_tree(
):
tree = test_pytree.provide_tree()
original_internal_tree_metadata = tree_metadata.InternalTreeMetadata.build(
param_infos=_to_param_infos(tree),
param_infos=_to_param_infos(tree, pytree_metadata_options),
pytree_metadata_options=pytree_metadata_options,
)
json_object = original_internal_tree_metadata.to_json()
Expand All @@ -97,7 +74,7 @@ def test_as_nested_tree(
test_pytree.expected_nested_tree_metadata_with_rich_types
)
else:
raise NotImplementedError('Test to be added for non-rich types.')
expected_tree_metadata = test_pytree.expected_nested_tree_metadata
restored_tree_metadata = restored_internal_tree_metadata.as_nested_tree()
logging.info('expected_tree_metadata: \n%s', expected_tree_metadata)
logging.info('restored_tree_metadata: \n%s', restored_tree_metadata)
Expand Down
24 changes: 19 additions & 5 deletions checkpoint/orbax/checkpoint/_src/serialization/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,37 @@
import numpy as np
from orbax.checkpoint import future
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.metadata import value as value_metadata
from orbax.checkpoint._src.serialization import serialization
import tensorstore as ts

PyTreeMetadataOptions = pytree_metadata_options_lib.PyTreeMetadataOptions

def is_supported_type(value: Any) -> bool:

def is_supported_type(
value: Any,
pytree_metadata_options: PyTreeMetadataOptions = (
pytree_metadata_options_lib.PYTREE_METADATA_OPTIONS
),
) -> bool:
"""Determines if the `value` is supported without custom TypeHandler."""
return isinstance(
value,
(str, int, float, np.number, np.ndarray, bytes, jax.Array),
) or empty_values.is_supported_empty_value(value)
) or empty_values.is_supported_empty_value(value, pytree_metadata_options)


def get_param_typestr(value: Any, registry: TypeHandlerRegistry) -> str:
def get_param_typestr(
value: Any,
registry: TypeHandlerRegistry,
pytree_metadata_options: PyTreeMetadataOptions,
) -> str:
"""Retrieves the typestr for a given value."""
if empty_values.is_supported_empty_value(value):
typestr = empty_values.get_empty_value_typestr(value)
if empty_values.is_supported_empty_value(value, pytree_metadata_options):
typestr = empty_values.get_empty_value_typestr(
value, pytree_metadata_options
)
else:
try:
handler = registry.get(type(value))
Expand Down
Loading
Loading