From 3a8acf7a81ec62b74b4b69f176fdfe7822a93079 Mon Sep 17 00:00:00 2001 From: Yaning Liang Date: Tue, 21 Jan 2025 10:04:12 -0800 Subject: [PATCH] Add checkpoint/BUILD update dependencies on checkpoint/BUILD to these targets. PiperOrigin-RevId: 717965180 --- checkpoint/orbax/checkpoint/BUILD | 355 ++++++++++++++++++ .../orbax/checkpoint/_src/checkpointers/BUILD | 10 + .../orbax/checkpoint/_src/handlers/BUILD | 48 ++- .../orbax/checkpoint/_src/metadata/BUILD | 1 + checkpoint/orbax/checkpoint/_src/path/BUILD | 17 +- .../orbax/checkpoint/_src/serialization/BUILD | 4 + checkpoint/orbax/checkpoint/_src/tree/BUILD | 1 + 7 files changed, 432 insertions(+), 4 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/BUILD diff --git a/checkpoint/orbax/checkpoint/BUILD b/checkpoint/orbax/checkpoint/BUILD new file mode 100644 index 000000000..9a7809fe9 --- /dev/null +++ b/checkpoint/orbax/checkpoint/BUILD @@ -0,0 +1,355 @@ +package( + default_applicable_licenses = [":package_license"], + default_visibility = ["//visibility:public"], +) + +license( + name = "package_license", + package_name = "orbax-checkpoint", +) + +py_library( + name = "checkpoint", + srcs = ["__init__.py"], + lib_rule = pytype_strict_library, + visibility = ["//visibility:public"], + deps = [ + ":abstract_checkpoint_manager", + ":aggregate_handlers", + ":args", + ":arrays", + ":checkpoint_manager", + ":checkpoint_utils", + ":checkpointers", + ":future", + ":handlers", + ":logging", + ":msgpack_utils", + ":options", + ":path", + ":test_utils", + ":transform_utils", + ":tree", + ":type_handlers", + ":utils", + ":version", + "//checkpoint/orbax/checkpoint/_src/handlers:async_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:composite_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:standard_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/path:step", + "//orbax/checkpoint/_src/handlers:array_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:base_pytree_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:handler_registration", + "//orbax/checkpoint/_src/handlers:json_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:proto_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:random_key_checkpoint_handler", + "//orbax/checkpoint/metadata", + "//orbax/checkpoint/serialization", + ], +) + +py_library( + name = "handlers", + srcs = ["handlers.py"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/handlers:async_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:composite_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:standard_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:array_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:handler_registration", + "//orbax/checkpoint/_src/handlers:handler_type_registry", + "//orbax/checkpoint/_src/handlers:json_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:proto_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:random_key_checkpoint_handler", + ], +) + +py_library( + name = "checkpoint_args", + srcs = ["checkpoint_args.py"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/handlers:checkpoint_handler", + "//orbax/checkpoint/_src/handlers:handler_type_registry", + ], +) + +py_test( + name = "checkpoint_args_test", + srcs = ["checkpoint_args_test.py"], + python_version = "PY3", + deps = [ + ":checkpoint_args", + "//checkpoint/orbax/checkpoint/_src/handlers:checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:standard_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:handler_type_registry", + ], +) + +py_library( + name = "args", + srcs = ["args.py"], + deps = [ + ":checkpoint_args", + "//checkpoint/orbax/checkpoint/_src/handlers:composite_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:standard_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:array_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:json_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:proto_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:random_key_checkpoint_handler", + ], +) + +py_library( + name = "abstract_checkpoint_manager", + srcs = ["abstract_checkpoint_manager.py"], + deps = [":args"], +) + +py_library( + name = "checkpoint_manager", + srcs = ["checkpoint_manager.py"], + srcs_version = "PY3", + tags = ["ignore_for_dep=//orbax/checkpoint/google:storage_configuration_alerter"], + deps = [ + ":abstract_checkpoint_manager", + ":args", + ":checkpoint_args", + ":logging", + ":options", + ":utils", + "//checkpoint/orbax/checkpoint/_src/checkpointers:checkpointer", + "//checkpoint/orbax/checkpoint/_src/handlers:checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:composite_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/path:atomicity_types", + "//checkpoint/orbax/checkpoint/_src/path:deleter", + "//checkpoint/orbax/checkpoint/_src/path:step", + "//checkpoint/orbax/checkpoint/_src/path:utils", + "//third_party/py/jax/experimental/array_serialization:serialization", + "//orbax/checkpoint/_src/checkpointers:abstract_checkpointer", + "//orbax/checkpoint/_src/checkpointers:async_checkpointer", + "//orbax/checkpoint/_src/handlers:handler_registration", + "//orbax/checkpoint/_src/handlers:json_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:proto_checkpoint_handler", + "//orbax/checkpoint/_src/metadata:root_metadata_serialization", + "//orbax/checkpoint/google:storage_configuration_alerter", + ], +) + +py_library( + name = "test_utils", + srcs = ["test_utils.py"], + srcs_version = "PY3", + deps = [ + ":checkpoint_args", + "//checkpoint/orbax/checkpoint/_src/handlers:async_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", + "//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/multihost:multislice", + "//checkpoint/orbax/checkpoint/_src/path:atomicity", + "//checkpoint/orbax/checkpoint/_src/path:step", + "//checkpoint/orbax/checkpoint/_src/serialization", + "//checkpoint/orbax/checkpoint/_src/serialization:replica_slices", + "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//checkpoint/orbax/checkpoint/_src/tree:utils", + ], +) + +py_test( + name = "test_utils_test", + srcs = ["test_utils_test.py"], + deps = [ + ":test_utils", + "//checkpoint/orbax/checkpoint/_src/multihost", + ], +) + +py_library( + name = "utils", + srcs = ["utils.py"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/path:async_utils", + "//checkpoint/orbax/checkpoint/_src/path:step", + "//checkpoint/orbax/checkpoint/_src/tree:utils", + ], +) + +py_library( + name = "transform_utils", + srcs = ["transform_utils.py"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//checkpoint/orbax/checkpoint/_src/tree:utils", + ], +) + +py_library( + name = "future", + srcs = ["future.py"], + deps = ["//orbax/checkpoint/_src/futures:future"], +) + +py_library( + name = "aggregate_handlers", + srcs = ["aggregate_handlers.py"], + deps = [ + ":future", + ":msgpack_utils", + ":utils", + "//checkpoint/orbax/checkpoint/_src/metadata:tree", + ], +) + +py_library( + name = "checkpoint_utils", + srcs = ["checkpoint_utils.py"], + deps = [ + ":utils", + "//checkpoint/orbax/checkpoint/_src/metadata:tree", + "//checkpoint/orbax/checkpoint/_src/metadata:value", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/path:step", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//orbax/checkpoint/_src/path/snapshot", + ], +) + +py_library( + name = "msgpack_utils", + srcs = ["msgpack_utils.py"], + deps = ["//third_party/py/msgpack"], +) + +py_test( + name = "msgpack_utils_test", + srcs = ["msgpack_utils_test.py"], + deps = [":msgpack_utils"], +) + +py_test( + name = "checkpoint_utils_test", + srcs = ["checkpoint_utils_test.py"], + python_version = "PY3", + deps = [ + ":args", + ":checkpoint_manager", + ":checkpoint_utils", + ":test_utils", + ":utils", + "//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/metadata:value", + "//checkpoint/orbax/checkpoint/_src/path:step", + "//orbax/checkpoint/_src/checkpointers:pytree_checkpointer", + ], +) + +py_test( + name = "transform_utils_test", + srcs = ["transform_utils_test.py"], + deps = [ + ":test_utils", + ":transform_utils", + "//checkpoint/orbax/checkpoint/_src/tree:utils", + ], +) + +py_test( + name = "single_host_test", + srcs = ["single_host_test.py"], + deps = [ + ":test_utils", + "//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//third_party/py/ml_dtypes", + "//orbax/checkpoint/_src/handlers:standard_checkpoint_handler_test_utils", + ], +) + +py_library( + name = "conftest", + srcs = ["conftest.py"], +) + +py_library( + name = "options", + srcs = ["options.py"], + deps = ["//checkpoint/orbax/checkpoint/_src/multihost"], +) + +py_library( + name = "version", + srcs = ["version.py"], +) + +py_library( + name = "logging", + srcs = ["logging.py"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/logging:step_statistics", + "//orbax/checkpoint/_src/logging:abstract_logger", + "//orbax/checkpoint/_src/logging:cloud_logger", # buildcleaner:keep + "//orbax/checkpoint/_src/logging:composite_logger", + "//orbax/checkpoint/_src/logging:standard_logger", + ], +) + +py_library( + name = "tree", + srcs = ["tree.py"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/tree:utils", + "//orbax/checkpoint/_src/tree:types", + ], +) + +py_library( + name = "path", + srcs = ["path.py"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/path:async_utils", + "//checkpoint/orbax/checkpoint/_src/path:atomicity", + "//checkpoint/orbax/checkpoint/_src/path:atomicity_defaults", + "//checkpoint/orbax/checkpoint/_src/path:atomicity_types", + "//checkpoint/orbax/checkpoint/_src/path:deleter", + "//checkpoint/orbax/checkpoint/_src/path:format_utils", + "//checkpoint/orbax/checkpoint/_src/path:step", + ], +) + +py_library( + name = "checkpointers", + srcs = ["checkpointers.py"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/checkpointers:checkpointer", + "//orbax/checkpoint/_src/checkpointers:abstract_checkpointer", + "//orbax/checkpoint/_src/checkpointers:async_checkpointer", + "//orbax/checkpoint/_src/checkpointers:pytree_checkpointer", + "//orbax/checkpoint/_src/checkpointers:standard_checkpointer", + ], +) + +py_library( + name = "type_handlers", + srcs = ["type_handlers.py"], + deps = ["//checkpoint/orbax/checkpoint/_src/serialization:type_handlers"], +) + +py_library( + name = "arrays", + srcs = ["arrays.py"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/arrays:abstract_arrays", + "//checkpoint/orbax/checkpoint/_src/arrays:types", + ], +) diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/BUILD b/checkpoint/orbax/checkpoint/_src/checkpointers/BUILD index 496d31b6f..8c9ca656a 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/BUILD +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/BUILD @@ -6,6 +6,7 @@ package( py_library( name = "abstract_checkpointer", srcs = ["abstract_checkpointer.py"], + deps = ["//checkpoint/orbax/checkpoint:version"], ) py_library( @@ -13,6 +14,8 @@ py_library( srcs = ["checkpointer.py"], deps = [ ":abstract_checkpointer", + "//checkpoint/orbax/checkpoint:checkpoint_args", + "//checkpoint/orbax/checkpoint:options", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/handlers:checkpoint_handler", "//checkpoint/orbax/checkpoint/_src/handlers:composite_checkpoint_handler", @@ -22,6 +25,7 @@ py_library( "//checkpoint/orbax/checkpoint/_src/path:atomicity", "//checkpoint/orbax/checkpoint/_src/path:atomicity_defaults", "//checkpoint/orbax/checkpoint/_src/path:atomicity_types", + "//orbax/checkpoint:utils", ], ) @@ -30,6 +34,7 @@ py_library( srcs = ["pytree_checkpointer.py"], deps = [ ":checkpointer", + "//checkpoint/orbax/checkpoint:options", "//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler", ], ) @@ -39,6 +44,7 @@ py_library( srcs = ["standard_checkpointer.py"], deps = [ ":async_checkpointer", + "//checkpoint/orbax/checkpoint:options", "//checkpoint/orbax/checkpoint/_src/handlers:standard_checkpoint_handler", "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", "//checkpoint/orbax/checkpoint/_src/path:atomicity_types", @@ -50,6 +56,9 @@ py_library( srcs = ["async_checkpointer.py"], deps = [ ":checkpointer", + "//checkpoint/orbax/checkpoint:checkpoint_args", + "//checkpoint/orbax/checkpoint:future", + "//checkpoint/orbax/checkpoint:options", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/handlers:async_checkpoint_handler", "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", @@ -57,5 +66,6 @@ py_library( "//checkpoint/orbax/checkpoint/_src/path:async_utils", "//checkpoint/orbax/checkpoint/_src/path:atomicity", "//checkpoint/orbax/checkpoint/_src/path:atomicity_types", + "//orbax/checkpoint:utils", ], ) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/BUILD b/checkpoint/orbax/checkpoint/_src/handlers/BUILD index 163dee32e..b9d31bd8d 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/BUILD +++ b/checkpoint/orbax/checkpoint/_src/handlers/BUILD @@ -16,6 +16,9 @@ py_library( ":checkpoint_handler", ":handler_registration", ":proto_checkpoint_handler", + "//checkpoint/orbax/checkpoint:checkpoint_args", + "//checkpoint/orbax/checkpoint:future", + "//checkpoint/orbax/checkpoint:options", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src:composite", "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", @@ -36,6 +39,9 @@ py_test( ":json_checkpoint_handler", ":proto_checkpoint_handler", ":standard_checkpoint_handler", + "//checkpoint/orbax/checkpoint:args", + "//checkpoint/orbax/checkpoint:logging", + "//checkpoint/orbax/checkpoint:test_utils", "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", "//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization", "//checkpoint/orbax/checkpoint/_src/metadata:value", @@ -51,6 +57,11 @@ py_library( deps = [ ":async_checkpoint_handler", ":base_pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint:aggregate_handlers", + "//checkpoint/orbax/checkpoint:checkpoint_args", + "//checkpoint/orbax/checkpoint:future", + "//checkpoint/orbax/checkpoint:options", + "//checkpoint/orbax/checkpoint:transform_utils", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", "//checkpoint/orbax/checkpoint/_src/metadata:tree", @@ -58,6 +69,7 @@ py_library( "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", "//checkpoint/orbax/checkpoint/_src/tree:utils", + "//orbax/checkpoint:utils", ], ) @@ -67,6 +79,9 @@ py_library( srcs_version = "PY3", deps = [ ":async_checkpoint_handler", + "//checkpoint/orbax/checkpoint:checkpoint_args", + "//checkpoint/orbax/checkpoint:future", + "//checkpoint/orbax/checkpoint:options", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", "//checkpoint/orbax/checkpoint/_src/metadata:tree", @@ -77,6 +92,7 @@ py_library( "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", "//checkpoint/orbax/checkpoint/_src/serialization:types", "//checkpoint/orbax/checkpoint/_src/tree:utils", + "//orbax/checkpoint:utils", "//orbax/checkpoint/_src/metadata:array_metadata_store", ], ) @@ -86,14 +102,21 @@ py_library( srcs = ["json_checkpoint_handler.py"], deps = [ ":async_checkpoint_handler", + "//checkpoint/orbax/checkpoint:checkpoint_args", + "//checkpoint/orbax/checkpoint:future", + "//checkpoint/orbax/checkpoint:options", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + "//orbax/checkpoint:utils", ], ) py_library( name = "async_checkpoint_handler", srcs = ["async_checkpoint_handler.py"], - deps = [":checkpoint_handler"], + deps = [ + ":checkpoint_handler", + "//checkpoint/orbax/checkpoint:future", + ], ) py_library( @@ -101,9 +124,13 @@ py_library( srcs = ["array_checkpoint_handler.py"], deps = [ ":async_checkpoint_handler", + "//checkpoint/orbax/checkpoint:aggregate_handlers", + "//checkpoint/orbax/checkpoint:checkpoint_args", + "//checkpoint/orbax/checkpoint:future", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//orbax/checkpoint:utils", ], ) @@ -122,7 +149,11 @@ py_library( srcs = ["proto_checkpoint_handler.py"], deps = [ ":async_checkpoint_handler", + "//checkpoint/orbax/checkpoint:checkpoint_args", + "//checkpoint/orbax/checkpoint:future", + "//checkpoint/orbax/checkpoint:options", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + "//orbax/checkpoint:utils", ], ) @@ -132,6 +163,10 @@ py_library( deps = [ ":async_checkpoint_handler", ":pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint:checkpoint_args", + "//checkpoint/orbax/checkpoint:checkpoint_utils", + "//checkpoint/orbax/checkpoint:future", + "//checkpoint/orbax/checkpoint:options", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options", "//checkpoint/orbax/checkpoint/_src/metadata:tree", @@ -153,8 +188,10 @@ py_library( srcs = ["standard_checkpoint_handler_test_utils.py"], deps = [ ":standard_checkpoint_handler", + "//checkpoint/orbax/checkpoint:test_utils", "//checkpoint/orbax/checkpoint/_src/multihost", "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//orbax/checkpoint:utils", ], ) @@ -167,6 +204,8 @@ py_library( ":composite_checkpoint_handler", ":json_checkpoint_handler", ":pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint:checkpoint_args", + "//checkpoint/orbax/checkpoint:future", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", ], @@ -179,13 +218,17 @@ py_test( deps = [ ":composite_checkpoint_handler", ":random_key_checkpoint_handler", + "//checkpoint/orbax/checkpoint:args", ], ) py_library( name = "handler_registration", srcs = ["handler_registration.py"], - deps = [":checkpoint_handler"], + deps = [ + ":checkpoint_handler", + "//checkpoint/orbax/checkpoint:checkpoint_args", + ], ) py_test( @@ -195,6 +238,7 @@ py_test( ":checkpoint_handler", ":handler_registration", ":standard_checkpoint_handler", + "//checkpoint/orbax/checkpoint:checkpoint_args", ], ) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/BUILD b/checkpoint/orbax/checkpoint/_src/metadata/BUILD index 4d6d49a0d..f2386850a 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/BUILD +++ b/checkpoint/orbax/checkpoint/_src/metadata/BUILD @@ -174,6 +174,7 @@ py_test( deps = [ ":array_metadata", ":array_metadata_store", + "//checkpoint/orbax/checkpoint:test_utils", "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", ], ) diff --git a/checkpoint/orbax/checkpoint/_src/path/BUILD b/checkpoint/orbax/checkpoint/_src/path/BUILD index 2e54703da..17b22a5b3 100644 --- a/checkpoint/orbax/checkpoint/_src/path/BUILD +++ b/checkpoint/orbax/checkpoint/_src/path/BUILD @@ -7,6 +7,7 @@ py_library( name = "utils", srcs = ["utils.py"], srcs_version = "PY3", + deps = ["//checkpoint/orbax/checkpoint:options"], ) py_library( @@ -30,6 +31,7 @@ py_test( deps = [ ":atomicity", ":step", + "//checkpoint/orbax/checkpoint:test_utils", "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", "//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization", ], @@ -38,7 +40,10 @@ py_test( py_library( name = "deleter", srcs = ["deleter.py"], - deps = [":step"], + deps = [ + ":step", + "//orbax/checkpoint:utils", + ], ) py_test( @@ -68,6 +73,7 @@ py_library( ":path", ":step", ":utils", + "//checkpoint/orbax/checkpoint:options", "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", "//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization", "//checkpoint/orbax/checkpoint/_src/multihost", @@ -81,6 +87,8 @@ py_test( ":atomicity", ":atomicity_types", ":step", + "//checkpoint/orbax/checkpoint:options", + "//checkpoint/orbax/checkpoint:test_utils", "//checkpoint/orbax/checkpoint/_src/multihost", ], ) @@ -88,7 +96,10 @@ py_test( py_library( name = "atomicity_types", srcs = ["atomicity_types.py"], - deps = ["//checkpoint/orbax/checkpoint/_src/metadata:checkpoint"], + deps = [ + "//checkpoint/orbax/checkpoint:options", + "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", + ], ) py_library( @@ -112,6 +123,8 @@ py_test( srcs = ["format_utils_test.py"], deps = [ ":format_utils", + "//checkpoint/orbax/checkpoint:args", + "//checkpoint/orbax/checkpoint:checkpoint_manager", "//checkpoint/orbax/checkpoint/_src/checkpointers:checkpointer", "//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler", "//checkpoint/orbax/checkpoint/_src/handlers:standard_checkpoint_handler", diff --git a/checkpoint/orbax/checkpoint/_src/serialization/BUILD b/checkpoint/orbax/checkpoint/_src/serialization/BUILD index f9e0b2732..48d20466b 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/BUILD +++ b/checkpoint/orbax/checkpoint/_src/serialization/BUILD @@ -19,6 +19,7 @@ py_library( srcs = ["types.py"], deps = [ ":serialization", + "//checkpoint/orbax/checkpoint:future", "//checkpoint/orbax/checkpoint/_src/arrays:types", "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", "//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options", @@ -34,6 +35,7 @@ py_library( ":serialization", ":tensorstore_utils", ":types", + "//checkpoint/orbax/checkpoint:future", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/arrays:subchunking", "//checkpoint/orbax/checkpoint/_src/arrays:types", @@ -91,6 +93,8 @@ py_test( deps = [ ":serialization", ":tensorstore_utils", + "//checkpoint/orbax/checkpoint:future", + "//checkpoint/orbax/checkpoint:test_utils", "//checkpoint/orbax/checkpoint/_src:asyncio_utils", ], ) diff --git a/checkpoint/orbax/checkpoint/_src/tree/BUILD b/checkpoint/orbax/checkpoint/_src/tree/BUILD index e2ddcc372..dd5d3b34c 100644 --- a/checkpoint/orbax/checkpoint/_src/tree/BUILD +++ b/checkpoint/orbax/checkpoint/_src/tree/BUILD @@ -23,6 +23,7 @@ py_test( srcs = ["utils_test.py"], deps = [ ":utils", + "//checkpoint/orbax/checkpoint:test_utils", "//checkpoint/orbax/checkpoint/_src/testing:test_tree_utils", ], )