Skip to content

Commit

Permalink
Simplify Root/StepMetadata serialization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713090169
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Jan 8, 2025
1 parent db6bd4a commit 28e72df
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 66 deletions.
67 changes: 26 additions & 41 deletions checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,22 @@ def get_metadata_file_path(
elif metadata_type == RootMetadata:
return checkpoint.root_metadata_file_path(path or self.directory)

def assertMetadataEqual(
self, a: StepMetadata | RootMetadata, b: StepMetadata | RootMetadata,
):
if isinstance(a, StepMetadata):
self.assertEqual(a.format, b.format)
self.assertEqual(a.item_handlers, b.item_handlers)
# ignore item_metadata
self.assertEqual(a.metrics, b.metrics)
self.assertEqual(a.performance_metrics, b.performance_metrics)
self.assertEqual(a.init_timestamp_nsecs, b.init_timestamp_nsecs)
self.assertEqual(a.commit_timestamp_nsecs, b.commit_timestamp_nsecs)
self.assertEqual(a.custom, b.custom)
elif isinstance(a, RootMetadata):
self.assertEqual(a.format, b.format)
self.assertEqual(a.custom, b.custom)

@parameterized.parameters(True, False)
def test_read_unknown_path(self, blocking_write: bool):
self.assertIsNone(
Expand Down Expand Up @@ -221,7 +237,7 @@ def test_read_default_values(
serialized_metadata = self.write_metadata_store(blocking_write).read(
file_path=self.get_metadata_file_path(metadata_class),
)
self.assertEqual(
self.assertMetadataEqual(
self.deserialize_metadata(metadata_class, serialized_metadata),
metadata,
)
Expand Down Expand Up @@ -249,7 +265,7 @@ def test_read_with_values(
serialized_metadata = self.write_metadata_store(blocking_write).read(
file_path=self.get_metadata_file_path(metadata_class),
)
self.assertEqual(
self.assertMetadataEqual(
self.deserialize_metadata(metadata_class, serialized_metadata),
metadata,
)
Expand Down Expand Up @@ -296,7 +312,7 @@ def test_update_without_prior_data(
serialized_metadata = self.write_metadata_store(blocking_write).read(
file_path=self.get_metadata_file_path(metadata_class),
)
self.assertEqual(
self.assertMetadataEqual(
self.deserialize_metadata(metadata_class, serialized_metadata),
metadata_class(
format=_SAMPLE_FORMAT,
Expand Down Expand Up @@ -331,7 +347,7 @@ def test_update_with_prior_data(
serialized_metadata = self.write_metadata_store(blocking_write).read(
file_path=self.get_metadata_file_path(metadata_class)
)
self.assertEqual(
self.assertMetadataEqual(
self.deserialize_metadata(metadata_class, serialized_metadata),
metadata_class(
format=_SAMPLE_FORMAT,
Expand Down Expand Up @@ -366,7 +382,7 @@ def test_update_with_unknown_kwargs(
serialized_metadata = self.write_metadata_store(blocking_write).read(
file_path=self.get_metadata_file_path(metadata_class)
)
self.assertEqual(
self.assertMetadataEqual(
self.deserialize_metadata(metadata_class, serialized_metadata),
metadata_class(
format=_SAMPLE_FORMAT,
Expand Down Expand Up @@ -417,7 +433,7 @@ def test_non_blocking_write_request_enables_writes(
serialized_metadata = self.read_metadata_store(blocking_write=True).read(
file_path=self.get_metadata_file_path(metadata_class)
)
self.assertEqual(
self.assertMetadataEqual(
self.deserialize_metadata(metadata_class, serialized_metadata),
self.get_metadata(metadata_class),
)
Expand All @@ -436,14 +452,14 @@ def test_non_blocking_write_request_enables_writes(
serialized_metadata = self.read_metadata_store(blocking_write=False).read(
file_path=self.get_metadata_file_path(metadata_class)
)
self.assertEqual(
self.assertMetadataEqual(
self.deserialize_metadata(metadata_class, serialized_metadata),
self.get_metadata(metadata_class, custom={'a': 2}),
)
serialized_metadata = self.write_metadata_store(blocking_write=False).read(
file_path=self.get_metadata_file_path(metadata_class)
)
self.assertEqual(
self.assertMetadataEqual(
self.deserialize_metadata(metadata_class, serialized_metadata),
self.get_metadata(metadata_class, custom={'a': 2}),
)
Expand All @@ -457,14 +473,14 @@ def test_non_blocking_write_request_enables_writes(
serialized_metadata = self.read_metadata_store(blocking_write=False).read(
file_path=self.get_metadata_file_path(metadata_class)
)
self.assertEqual(
self.assertMetadataEqual(
self.deserialize_metadata(metadata_class, serialized_metadata),
self.get_metadata(metadata_class, custom={'a': 3}),
)
serialized_metadata = self.write_metadata_store(blocking_write=False).read(
file_path=self.get_metadata_file_path(metadata_class)
)
self.assertEqual(
self.assertMetadataEqual(
self.deserialize_metadata(metadata_class, serialized_metadata),
self.get_metadata(metadata_class, custom={'a': 3}),
)
Expand Down Expand Up @@ -528,37 +544,6 @@ def test_deserialize_wrong_types_step_metadata(
with self.assertRaises(ValueError):
self.deserialize_metadata(StepMetadata, wrong_metadata)

@parameterized.parameters(
(
RootMetadata(custom={'a': None}),
{'custom': {'a': None}}
),
(
RootMetadata(format=_SAMPLE_FORMAT),
{'format': _SAMPLE_FORMAT}
),
(
StepMetadata(format=_SAMPLE_FORMAT),
{'format': _SAMPLE_FORMAT},
),
(
StepMetadata(item_handlers={'a': 'a_handler'}),
{'item_handlers': {'a': 'a_handler'}},
),
(
StepMetadata(custom={'blah': 123}),
{'custom': {'blah': 123}},
),
)
def test_only_serialize_non_default_metadata_values(
self,
metadata: StepMetadata | RootMetadata,
expected_serialized_metadata: dict[str, Any],
):
self.assertEqual(
self.serialize_metadata(metadata), expected_serialized_metadata
)

@parameterized.parameters(StepMetadata, RootMetadata)
def test_unknown_key_in_metadata(
self, metadata_class: type[StepMetadata] | type[RootMetadata],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@

def serialize(metadata: RootMetadata) -> SerializedMetadata:
"""Serializes `metadata` to a dictionary."""
serialized_metadata = {}
if metadata.format is not None:
serialized_metadata['format'] = metadata.format
if metadata.custom:
serialized_metadata['custom'] = metadata.custom
return serialized_metadata
return {
'format': metadata.format,
'custom': metadata.custom,
}


def deserialize(metadata_dict: SerializedMetadata) -> RootMetadata:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,15 @@ def serialize(metadata: StepMetadata) -> SerializedMetadata:
if isinstance(val, float)
}

serialized_metadata = {}
if metadata.format is not None:
serialized_metadata['format'] = metadata.format
if metadata.item_handlers:
serialized_metadata['item_handlers'] = metadata.item_handlers
if metadata.metrics:
serialized_metadata['metrics'] = metadata.metrics
if float_metrics:
serialized_metadata['performance_metrics'] = float_metrics
if metadata.init_timestamp_nsecs is not None:
serialized_metadata['init_timestamp_nsecs'] = metadata.init_timestamp_nsecs
if metadata.commit_timestamp_nsecs is not None:
serialized_metadata['commit_timestamp_nsecs'] = (
metadata.commit_timestamp_nsecs
)
if metadata.custom:
serialized_metadata['custom'] = metadata.custom

return serialized_metadata
return {
'format': metadata.format,
'item_handlers': metadata.item_handlers,
'metrics': metadata.metrics,
'performance_metrics': float_metrics,
'init_timestamp_nsecs': metadata.init_timestamp_nsecs,
'commit_timestamp_nsecs': metadata.commit_timestamp_nsecs,
'custom': metadata.custom,
}


def deserialize(
Expand All @@ -87,6 +77,8 @@ def deserialize(
validated_metadata_dict['item_handlers'] = item_handlers
elif isinstance(item_handlers, CheckpointHandlerTypeStr):
validated_metadata_dict['item_handlers'] = item_handlers
elif item_handlers is None:
validated_metadata_dict['item_handlers'] = None

utils.validate_field(metadata_dict, 'item_metadata', [dict, str])
dict_item_metadata = metadata_dict.get('item_metadata')
Expand Down

0 comments on commit 28e72df

Please sign in to comment.