diff --git a/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py b/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py index d3ae79884..5e33e0405 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py @@ -140,6 +140,22 @@ def get_metadata_file_path( elif metadata_type == RootMetadata: return checkpoint.root_metadata_file_path(path or self.directory) + def assertMetadataEqual( + self, a: StepMetadata | RootMetadata, b: StepMetadata | RootMetadata, + ): + if isinstance(a, StepMetadata): + self.assertEqual(a.format, b.format) + self.assertEqual(a.item_handlers, b.item_handlers) + # ignore item_metadata + self.assertEqual(a.metrics, b.metrics) + self.assertEqual(a.performance_metrics, b.performance_metrics) + self.assertEqual(a.init_timestamp_nsecs, b.init_timestamp_nsecs) + self.assertEqual(a.commit_timestamp_nsecs, b.commit_timestamp_nsecs) + self.assertEqual(a.custom, b.custom) + elif isinstance(a, RootMetadata): + self.assertEqual(a.format, b.format) + self.assertEqual(a.custom, b.custom) + @parameterized.parameters(True, False) def test_read_unknown_path(self, blocking_write: bool): self.assertIsNone( @@ -221,7 +237,7 @@ def test_read_default_values( serialized_metadata = self.write_metadata_store(blocking_write).read( file_path=self.get_metadata_file_path(metadata_class), ) - self.assertEqual( + self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), metadata, ) @@ -249,7 +265,7 @@ def test_read_with_values( serialized_metadata = self.write_metadata_store(blocking_write).read( file_path=self.get_metadata_file_path(metadata_class), ) - self.assertEqual( + self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), metadata, ) @@ -296,7 +312,7 @@ def test_update_without_prior_data( serialized_metadata = self.write_metadata_store(blocking_write).read( file_path=self.get_metadata_file_path(metadata_class), ) - self.assertEqual( + self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), metadata_class( format=_SAMPLE_FORMAT, @@ -331,7 +347,7 @@ def test_update_with_prior_data( serialized_metadata = self.write_metadata_store(blocking_write).read( file_path=self.get_metadata_file_path(metadata_class) ) - self.assertEqual( + self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), metadata_class( format=_SAMPLE_FORMAT, @@ -366,7 +382,7 @@ def test_update_with_unknown_kwargs( serialized_metadata = self.write_metadata_store(blocking_write).read( file_path=self.get_metadata_file_path(metadata_class) ) - self.assertEqual( + self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), metadata_class( format=_SAMPLE_FORMAT, @@ -417,7 +433,7 @@ def test_non_blocking_write_request_enables_writes( serialized_metadata = self.read_metadata_store(blocking_write=True).read( file_path=self.get_metadata_file_path(metadata_class) ) - self.assertEqual( + self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), self.get_metadata(metadata_class), ) @@ -436,14 +452,14 @@ def test_non_blocking_write_request_enables_writes( serialized_metadata = self.read_metadata_store(blocking_write=False).read( file_path=self.get_metadata_file_path(metadata_class) ) - self.assertEqual( + self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), self.get_metadata(metadata_class, custom={'a': 2}), ) serialized_metadata = self.write_metadata_store(blocking_write=False).read( file_path=self.get_metadata_file_path(metadata_class) ) - self.assertEqual( + self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), self.get_metadata(metadata_class, custom={'a': 2}), ) @@ -457,14 +473,14 @@ def test_non_blocking_write_request_enables_writes( serialized_metadata = self.read_metadata_store(blocking_write=False).read( file_path=self.get_metadata_file_path(metadata_class) ) - self.assertEqual( + self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), self.get_metadata(metadata_class, custom={'a': 3}), ) serialized_metadata = self.write_metadata_store(blocking_write=False).read( file_path=self.get_metadata_file_path(metadata_class) ) - self.assertEqual( + self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), self.get_metadata(metadata_class, custom={'a': 3}), ) @@ -528,37 +544,6 @@ def test_deserialize_wrong_types_step_metadata( with self.assertRaises(ValueError): self.deserialize_metadata(StepMetadata, wrong_metadata) - @parameterized.parameters( - ( - RootMetadata(custom={'a': None}), - {'custom': {'a': None}} - ), - ( - RootMetadata(format=_SAMPLE_FORMAT), - {'format': _SAMPLE_FORMAT} - ), - ( - StepMetadata(format=_SAMPLE_FORMAT), - {'format': _SAMPLE_FORMAT}, - ), - ( - StepMetadata(item_handlers={'a': 'a_handler'}), - {'item_handlers': {'a': 'a_handler'}}, - ), - ( - StepMetadata(custom={'blah': 123}), - {'custom': {'blah': 123}}, - ), - ) - def test_only_serialize_non_default_metadata_values( - self, - metadata: StepMetadata | RootMetadata, - expected_serialized_metadata: dict[str, Any], - ): - self.assertEqual( - self.serialize_metadata(metadata), expected_serialized_metadata - ) - @parameterized.parameters(StepMetadata, RootMetadata) def test_unknown_key_in_metadata( self, metadata_class: type[StepMetadata] | type[RootMetadata], diff --git a/checkpoint/orbax/checkpoint/_src/metadata/root_metadata_serialization.py b/checkpoint/orbax/checkpoint/_src/metadata/root_metadata_serialization.py index 5117a4fbc..05f0e4c84 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/root_metadata_serialization.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/root_metadata_serialization.py @@ -24,12 +24,10 @@ def serialize(metadata: RootMetadata) -> SerializedMetadata: """Serializes `metadata` to a dictionary.""" - serialized_metadata = {} - if metadata.format is not None: - serialized_metadata['format'] = metadata.format - if metadata.custom: - serialized_metadata['custom'] = metadata.custom - return serialized_metadata + return { + 'format': metadata.format, + 'custom': metadata.custom, + } def deserialize(metadata_dict: SerializedMetadata) -> RootMetadata: diff --git a/checkpoint/orbax/checkpoint/_src/metadata/step_metadata_serialization.py b/checkpoint/orbax/checkpoint/_src/metadata/step_metadata_serialization.py index c5c5e4dc4..78d987133 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/step_metadata_serialization.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/step_metadata_serialization.py @@ -47,25 +47,15 @@ def serialize(metadata: StepMetadata) -> SerializedMetadata: if isinstance(val, float) } - serialized_metadata = {} - if metadata.format is not None: - serialized_metadata['format'] = metadata.format - if metadata.item_handlers: - serialized_metadata['item_handlers'] = metadata.item_handlers - if metadata.metrics: - serialized_metadata['metrics'] = metadata.metrics - if float_metrics: - serialized_metadata['performance_metrics'] = float_metrics - if metadata.init_timestamp_nsecs is not None: - serialized_metadata['init_timestamp_nsecs'] = metadata.init_timestamp_nsecs - if metadata.commit_timestamp_nsecs is not None: - serialized_metadata['commit_timestamp_nsecs'] = ( - metadata.commit_timestamp_nsecs - ) - if metadata.custom: - serialized_metadata['custom'] = metadata.custom - - return serialized_metadata + return { + 'format': metadata.format, + 'item_handlers': metadata.item_handlers, + 'metrics': metadata.metrics, + 'performance_metrics': float_metrics, + 'init_timestamp_nsecs': metadata.init_timestamp_nsecs, + 'commit_timestamp_nsecs': metadata.commit_timestamp_nsecs, + 'custom': metadata.custom, + } def deserialize( @@ -87,6 +77,8 @@ def deserialize( validated_metadata_dict['item_handlers'] = item_handlers elif isinstance(item_handlers, CheckpointHandlerTypeStr): validated_metadata_dict['item_handlers'] = item_handlers + elif item_handlers is None: + validated_metadata_dict['item_handlers'] = None utils.validate_field(metadata_dict, 'item_metadata', [dict, str]) dict_item_metadata = metadata_dict.get('item_metadata')