From 9dedd1949c6c7ff3e2611bea4efbcffc96689c9c Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Tue, 21 Jan 2025 13:35:26 -0800 Subject: [PATCH] Support custom PyTree metadata. Standardize naming of the "custom metadata" field (user-supplied metadata) as `custom_metadata`. PiperOrigin-RevId: 718050751 --- checkpoint/CHANGELOG.md | 9 ++- checkpoint/orbax/checkpoint/BUILD | 2 +- .../_src/checkpointers/checkpointer.py | 2 +- .../orbax/checkpoint/_src/handlers/BUILD | 3 + .../base_pytree_checkpoint_handler.py | 22 +++++- .../composite_checkpoint_handler_test.py | 12 ++-- .../handlers/pytree_checkpoint_handler.py | 6 ++ .../handlers/standard_checkpoint_handler.py | 11 ++- .../standard_checkpoint_handler_test_utils.py | 22 +++--- .../orbax/checkpoint/_src/metadata/BUILD | 1 + .../checkpoint/_src/metadata/checkpoint.py | 14 ++-- .../_src/metadata/checkpoint_test.py | 52 ++++++++------ .../metadata/root_metadata_serialization.py | 6 +- .../metadata/step_metadata_serialization.py | 24 ++++--- .../orbax/checkpoint/_src/metadata/tree.py | 71 ++++++++++++++----- .../checkpoint/_src/metadata/tree_test.py | 46 ++++++++++-- .../orbax/checkpoint/_src/tree/types.py | 5 ++ .../orbax/checkpoint/checkpoint_manager.py | 6 +- .../orbax/checkpoint/checkpoint_utils.py | 2 +- 19 files changed, 230 insertions(+), 86 deletions(-) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 3bc43280..a494e0c5 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -13,12 +13,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 make the change unnoticeable to most users, but also has additional accessible properties not included in any tree mapping operations. - `Checkpointer.save()`, `AsyncCheckpointer.save()` also saves `StepMetadata`. -- Added github actions CI testing using Python versions 3.10-3.13 +- Added github actions CI testing using Python versions 3.10-3.13. +- Standardize naming of the "custom metadata" field (user-supplied metadata) as +`custom_metadata`. ### Added - The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`. -- `CommitFuture` and `HandlerAwaitableSignal` for signalling between Checkpointing layers to enable async -directory creation. +- `CommitFuture` and `HandlerAwaitableSignal` for signalling between +Checkpointing layers to enable async directory creation. +- User-provided custom PyTree metadata. ### Fixed - Fix a bug where snapshots are not released by `wait_for_new_checkpoint` diff --git a/checkpoint/orbax/checkpoint/BUILD b/checkpoint/orbax/checkpoint/BUILD index 9a7809fe..a8370591 100644 --- a/checkpoint/orbax/checkpoint/BUILD +++ b/checkpoint/orbax/checkpoint/BUILD @@ -308,8 +308,8 @@ py_library( name = "tree", srcs = ["tree.py"], deps = [ + "//checkpoint/orbax/checkpoint/_src/tree:types", "//checkpoint/orbax/checkpoint/_src/tree:utils", - "//orbax/checkpoint/_src/tree:types", ], ) diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py index f733bddd..c271c11c 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py @@ -275,7 +275,7 @@ def _save_step_metadata( ): """Saves StepMetadata to the checkpoint directory.""" update_dict = { - 'custom': custom_metadata, + 'custom_metadata': custom_metadata, } if isinstance( self._handler, composite_checkpoint_handler.CompositeCheckpointHandler diff --git a/checkpoint/orbax/checkpoint/_src/handlers/BUILD b/checkpoint/orbax/checkpoint/_src/handlers/BUILD index b9d31bd8..41eea7f0 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/BUILD +++ b/checkpoint/orbax/checkpoint/_src/handlers/BUILD @@ -68,6 +68,7 @@ py_library( "//checkpoint/orbax/checkpoint/_src/serialization", "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//checkpoint/orbax/checkpoint/_src/tree:types", "//checkpoint/orbax/checkpoint/_src/tree:utils", "//orbax/checkpoint:utils", ], @@ -91,6 +92,7 @@ py_library( "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", "//checkpoint/orbax/checkpoint/_src/serialization:types", + "//checkpoint/orbax/checkpoint/_src/tree:types", "//checkpoint/orbax/checkpoint/_src/tree:utils", "//orbax/checkpoint:utils", "//orbax/checkpoint/_src/metadata:array_metadata_store", @@ -170,6 +172,7 @@ py_library( "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options", "//checkpoint/orbax/checkpoint/_src/metadata:tree", + "//checkpoint/orbax/checkpoint/_src/tree:types", "//checkpoint/orbax/checkpoint/_src/tree:utils", ], ) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index f4f1f1e7..8b9cd955 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -50,6 +50,7 @@ from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint._src.serialization import types +from orbax.checkpoint._src.tree import types as tree_types from orbax.checkpoint._src.tree import utils as tree_utils import tensorstore as ts @@ -444,6 +445,7 @@ async def async_save( raise ValueError('Found empty item.') save_args = args.save_args ocdbt_target_data_file_size = args.ocdbt_target_data_file_size + custom_metadata = args.custom_metadata save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save') byte_limiter = serialization.get_byte_limiter(self._save_concurrent_bytes) @@ -491,6 +493,7 @@ async def async_save( checkpoint_dir=directory, param_infos=param_infos, save_args=save_args, + custom_metadata=custom_metadata, use_zarr3=self._use_zarr3, ) ) @@ -799,8 +802,10 @@ def update_param_info(param_info: types.ParamInfo) -> types.ParamInfo: def _write_metadata_file( self, directory: epath.Path, + *, param_infos: PyTree, save_args: PyTree, + custom_metadata: tree_types.JsonType | None = None, use_zarr3: bool = False, ) -> future.Future: def _save_fn(param_infos): @@ -811,6 +816,7 @@ def _save_fn(param_infos): param_infos, save_args=save_args, use_zarr3=use_zarr3, + custom_metadata=custom_metadata, pytree_metadata_options=self._pytree_metadata_options, ) logging.vlog( @@ -832,8 +838,10 @@ def _write_metadata_after_commits( self, commit_futures: List[future.Future], checkpoint_dir: epath.Path, + *, param_infos: PyTree, save_args: PyTree, + custom_metadata: tree_types.JsonType | None = None, use_zarr3: bool, ) -> None: if not utils.is_primary_host(self._primary_host): @@ -853,7 +861,11 @@ def _write_metadata_after_commits( param_infos, checkpoint_dir, self._array_metadata_store ) self._write_metadata_file( - checkpoint_dir, param_infos, save_args, use_zarr3 + checkpoint_dir, + param_infos=param_infos, + save_args=save_args, + custom_metadata=custom_metadata, + use_zarr3=use_zarr3, ).result() def _read_metadata_file( @@ -915,12 +927,14 @@ def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata: tree containing metadata. """ is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory) + internal_tree_metadata = self._read_metadata_file(directory) return tree_metadata.build_default_tree_metadata( - self._read_metadata_file(directory).as_user_metadata( + internal_tree_metadata.as_custom_metadata( directory, self._type_handler_registry, use_ocdbt=is_ocdbt_checkpoint, ), + custom_metadata=internal_tree_metadata.custom_metadata, ) def finalize(self, directory: epath.Path) -> None: @@ -972,12 +986,16 @@ class BasePyTreeSaveArgs(CheckpointArgs): enable_pinned_host_transfer: True by default. If False, disables transfer to pinned host when copying from device to host, regardless of the presence of pinned host memory. + custom_metadata: User-provided custom metadata. An arbitrary + JSON-serializable dictionary the user can use to store additional + information. The field is treated as opaque by Orbax. """ item: PyTree save_args: Optional[PyTree] = None ocdbt_target_data_file_size: Optional[int] = None enable_pinned_host_transfer: bool = True + custom_metadata: tree_types.JsonType | None = None @register_with_handler(BasePyTreeCheckpointHandler, for_restore=True) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py index a2c5c414..9032ba1e 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py @@ -728,7 +728,7 @@ def test_metadata_no_save(self, use_handler_registry): ) self.assertIsNone(step_metadata.init_timestamp_nsecs) self.assertIsNone(step_metadata.commit_timestamp_nsecs) - self.assertEmpty(step_metadata.custom) + self.assertEmpty(step_metadata.custom_metadata) def test_metadata_handler_registry(self): registry = handler_registration.DefaultCheckpointHandlerRegistry() @@ -779,7 +779,7 @@ def test_metadata_handler_registry(self): ) self.assertIsNone(step_metadata.init_timestamp_nsecs) self.assertIsNone(step_metadata.commit_timestamp_nsecs) - self.assertEmpty(step_metadata.custom) + self.assertEmpty(step_metadata.custom_metadata) def test_metadata_after_step_metadata_write(self): handler = CompositeCheckpointHandler( @@ -795,7 +795,7 @@ def test_metadata_after_step_metadata_write(self): ) self.assertIsNone(step_metadata.init_timestamp_nsecs) self.assertIsNone(step_metadata.commit_timestamp_nsecs) - self.assertEmpty(step_metadata.custom) + self.assertEmpty(step_metadata.custom_metadata) metadata_to_write = checkpoint.StepMetadata( item_handlers={ @@ -813,7 +813,7 @@ def test_metadata_after_step_metadata_write(self): ), init_timestamp_nsecs=1000, commit_timestamp_nsecs=2000, - custom={ + custom_metadata={ 'custom_key': 'custom_value', }, ) @@ -837,7 +837,9 @@ def test_metadata_after_step_metadata_write(self): ) self.assertEqual(step_metadata.init_timestamp_nsecs, 1000) self.assertEqual(step_metadata.commit_timestamp_nsecs, 2000) - self.assertEqual(step_metadata.custom, {'custom_key': 'custom_value'}) + self.assertEqual( + step_metadata.custom_metadata, {'custom_key': 'custom_value'} + ) def test_metadata_existing_items_updates_step_metadata(self): handler = CompositeCheckpointHandler( diff --git a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py index c525ed89..fcc0d35a 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py @@ -46,6 +46,7 @@ from orbax.checkpoint._src.serialization import serialization from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint._src.tree import types as tree_types from orbax.checkpoint._src.tree import utils as tree_utils import tensorstore as ts @@ -431,6 +432,7 @@ def _get_impl_save_args( save_args=args.save_args, ocdbt_target_data_file_size=args.ocdbt_target_data_file_size, enable_pinned_host_transfer=args.enable_pinned_host_transfer, + custom_metadata=args.custom_metadata, ) @@ -1052,12 +1054,16 @@ class PyTreeSaveArgs(CheckpointArgs): enable_pinned_host_transfer: True by default. If False, disables transfer to pinned host when copying from device to host, regardless of the presence of pinned host memory. + custom_metadata: User-provided custom metadata. An arbitrary + JSON-serializable dictionary the user can use to store additional + information. The field is treated as opaque by Orbax. """ item: PyTree save_args: Optional[PyTree] = None ocdbt_target_data_file_size: Optional[int] = None enable_pinned_host_transfer: bool = True + custom_metadata: tree_types.JsonType | None = None def __post_init__(self): if isinstance(self.item, tree_metadata.TreeMetadata): diff --git a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py index 17fc7d4e..35e26151 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py @@ -32,6 +32,7 @@ from orbax.checkpoint._src.handlers import pytree_checkpoint_handler from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib from orbax.checkpoint._src.metadata import tree as tree_metadata +from orbax.checkpoint._src.tree import types as tree_types from orbax.checkpoint._src.tree import utils as tree_utils @@ -144,15 +145,19 @@ async def async_save( 'Make sure to specify kwarg name `args=` when providing' ' `StandardSaveArgs`.' ) + custom_metadata = None if args is not None: item = args.item save_args = args.save_args + custom_metadata = args.custom_metadata self._validate_save_state(item, save_args=save_args) return await self._impl.async_save( directory, args=pytree_checkpoint_handler.PyTreeSaveArgs( - item=item, save_args=save_args + item=item, + save_args=save_args, + custom_metadata=custom_metadata, ), ) @@ -266,10 +271,14 @@ class StandardSaveArgs(CheckpointArgs): save_args: a PyTree with the same structure of `item`, which consists of `ocp.SaveArgs` objects as values. `None` can be used for values where no `SaveArgs` are specified. + custom_metadata: User-provided custom metadata. An arbitrary + JSON-serializable dictionary the user can use to store additional + information. The field is treated as opaque by Orbax. """ item: PyTree save_args: Optional[PyTree] = None + custom_metadata: tree_types.JsonType | None = None def __post_init__(self): if isinstance(self.item, tree_metadata.TreeMetadata): diff --git a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py index 86a92d12..d6f39c72 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py @@ -14,6 +14,8 @@ """Test for standard_checkpoint_handler.py.""" +# pylint: disable=protected-access, missing-function-docstring + import functools from typing import Any @@ -127,7 +129,6 @@ def test_basic_no_item_arg(self): test_utils.assert_tree_equal(self, self.pytree, restored) def test_shape_dtype_struct(self): - """Test case.""" self.handler.save( self.directory, args=self.save_args_cls(self.mixed_pytree) ) @@ -162,7 +163,7 @@ def test_custom_layout(self): custom_layout = Layout( device_local_layout=DLL( major_to_minor=arr.layout.device_local_layout.major_to_minor[::-1], # pytype: disable=attribute-error - _tiling=arr.layout.device_local_layout._tiling, # pylint: disable=protected-access # pytype: disable=attribute-error + _tiling=arr.layout.device_local_layout._tiling, # pytype: disable=attribute-error ), sharding=arr.sharding, ) @@ -210,7 +211,6 @@ def test_custom_layout(self): @parameterized.parameters((True,), (False,)) def test_change_shape(self, strict: bool): - """Test case.""" if not hasattr(self.restore_args_cls, 'strict'): self.skipTest('strict option not supported for this handler') mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',)) @@ -255,7 +255,6 @@ def test_restore_unsupported_type(self): self.handler.restore(self.directory, args=self.restore_args_cls(pytree)) def test_cast(self): - """Test case.""" # TODO(dicentra): casting from int dtypes currently doesn't work # in the model surgery context. save_args = jax.tree.map( @@ -289,7 +288,6 @@ def check_dtype(x, dtype): jax.tree.map(lambda x: check_dtype(x, jnp.bfloat16), restored) def test_flax_model(self): - """Test case.""" @flax.struct.dataclass class Params(flax.struct.PyTreeNode): @@ -318,12 +316,10 @@ def make_params(): test_utils.assert_tree_equal(self, params, restored) def test_empty_error(self): - """Test case.""" with self.assertRaises(ValueError): self.handler.save(self.directory, args=self.save_args_cls({})) def test_empty_dict_node(self): - """Test case.""" item = {'a': {}, 'b': 3} self.handler.save(self.directory, args=self.save_args_cls(item)) restored = self.handler.restore( @@ -332,7 +328,6 @@ def test_empty_dict_node(self): self.assertDictEqual(restored, item) def test_empty_none_node(self): - """Test case.""" item = {'c': None, 'd': 2} self.handler.save(self.directory, args=self.save_args_cls(item)) restored = self.handler.restore( @@ -341,7 +336,6 @@ def test_empty_none_node(self): self.assertDictEqual(restored, item) def test_none_node_in_restore_args(self): - """Test case.""" devices = np.asarray(jax.devices()) mesh = jax.sharding.Mesh(devices, ('x',)) mesh_axes = jax.sharding.PartitionSpec( @@ -358,7 +352,6 @@ def test_none_node_in_restore_args(self): test_utils.assert_tree_equal(self, restored, {'b': None}) def test_masked_shape_dtype_struct(self): - """Test case.""" def _should_mask(keypath): return keypath[0].key == 'a' or ( @@ -398,3 +391,12 @@ def _none(keypath, x): # Restore it without any item. restored = self.handler.restore(self.directory) test_utils.assert_tree_equal(self, expected, restored) + + def test_custom_metadata(self): + custom_metadata = {'foo': 1} + self.handler.save( + self.directory, + args=self.save_args_cls(self.pytree, custom_metadata=custom_metadata), + ) + metadata = self.handler.metadata(self.directory) + self.assertEqual(metadata.custom_metadata, custom_metadata) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/BUILD b/checkpoint/orbax/checkpoint/_src/metadata/BUILD index f2386850..be709c4e 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/BUILD +++ b/checkpoint/orbax/checkpoint/_src/metadata/BUILD @@ -34,6 +34,7 @@ py_library( "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", "//checkpoint/orbax/checkpoint/_src/serialization:types", + "//checkpoint/orbax/checkpoint/_src/tree:types", "//checkpoint/orbax/checkpoint/_src/tree:utils", ], ) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py b/checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py index 4d4506df..379e23d5 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py @@ -78,7 +78,9 @@ class StepMetadata: Specified as nano seconds since epoch. default=None. commit_timestamp_nsecs: commit timestamp of a checkpoint, specified as nano seconds since epoch. default=None. - custom: User-provided custom metadata. + custom_metadata: User-provided custom metadata. An arbitrary + JSON-serializable dictionary the user can use to store additional + information. The field is treated as opaque by Orbax. """ item_handlers: ( @@ -91,7 +93,7 @@ class StepMetadata: ) init_timestamp_nsecs: int | None = None commit_timestamp_nsecs: int | None = None - custom: dict[str, Any] = dataclasses.field(default_factory=dict) + custom_metadata: dict[str, Any] = dataclasses.field(default_factory=dict) @dataclasses.dataclass @@ -99,10 +101,14 @@ class RootMetadata: """Metadata of a checkpoint at root level (contains all steps). Attributes: - custom: User-provided custom metadata. + custom_metadata: User-provided custom metadata. An arbitrary + JSON-serializable dictionary the user can use to store additional + information. The field is treated as opaque by Orbax. """ - custom: dict[str, Any] | None = dataclasses.field(default_factory=dict) + custom_metadata: dict[str, Any] | None = dataclasses.field( + default_factory=dict + ) class MetadataStore(Protocol): diff --git a/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py b/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py index a9689ebc..5b85a2ee 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py @@ -93,11 +93,11 @@ def serialize_metadata( def get_metadata( self, metadata_type: type[StepMetadata] | type[RootMetadata], - custom: dict[str, Any] = None, + custom_metadata: dict[str, Any] = None, ) -> StepMetadata | RootMetadata: """Returns a sample metadata object of `metadata_class`.""" - if custom is None: - custom = {'a': 1} + if custom_metadata is None: + custom_metadata = {'a': 1} if metadata_type == StepMetadata: return StepMetadata( item_handlers={'a': 'b'}, @@ -111,11 +111,11 @@ def get_metadata( ), init_timestamp_nsecs=1, commit_timestamp_nsecs=1, - custom=custom, + custom_metadata=custom_metadata, ) elif metadata_type == RootMetadata: return RootMetadata( - custom=custom, + custom_metadata=custom_metadata, ) def get_metadata_filename( @@ -146,9 +146,9 @@ def assertMetadataEqual( self.assertEqual(a.performance_metrics, b.performance_metrics) self.assertEqual(a.init_timestamp_nsecs, b.init_timestamp_nsecs) self.assertEqual(a.commit_timestamp_nsecs, b.commit_timestamp_nsecs) - self.assertEqual(a.custom, b.custom) + self.assertEqual(a.custom_metadata, b.custom_metadata) elif isinstance(a, RootMetadata): - self.assertEqual(a.custom, b.custom) + self.assertEqual(a.custom_metadata, b.custom_metadata) @parameterized.parameters(True, False) def test_read_unknown_path(self, blocking_write: bool): @@ -308,7 +308,7 @@ def test_update_without_prior_data( self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), metadata_class( - custom={'a': 1}, + custom_metadata={'a': 1}, ), ) @@ -341,7 +341,7 @@ def test_update_with_prior_data( self.deserialize_metadata(StepMetadata, serialized_metadata), StepMetadata( init_timestamp_nsecs=1, - custom={'a': 1}, + custom_metadata={'a': 1}, ), ) @@ -374,7 +374,7 @@ def test_update_with_unknown_kwargs( self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), metadata_class( - custom={'blah': 2}, + custom_metadata={'blah': 2}, ), ) @@ -428,7 +428,7 @@ def test_non_blocking_write_request_enables_writes( # write validations serialized_metadata = self.serialize_metadata( - self.get_metadata(metadata_class, custom={'a': 2}) + self.get_metadata(metadata_class, custom_metadata={'a': 2}) ) self.write_metadata_store(blocking_write=False).write( file_path=self.get_metadata_file_path(metadata_class), @@ -442,14 +442,14 @@ def test_non_blocking_write_request_enables_writes( ) self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), - self.get_metadata(metadata_class, custom={'a': 2}), + self.get_metadata(metadata_class, custom_metadata={'a': 2}), ) serialized_metadata = self.write_metadata_store(blocking_write=False).read( file_path=self.get_metadata_file_path(metadata_class) ) self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), - self.get_metadata(metadata_class, custom={'a': 2}), + self.get_metadata(metadata_class, custom_metadata={'a': 2}), ) # update validations @@ -463,14 +463,14 @@ def test_non_blocking_write_request_enables_writes( ) self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), - self.get_metadata(metadata_class, custom={'a': 3}), + self.get_metadata(metadata_class, custom_metadata={'a': 3}), ) serialized_metadata = self.write_metadata_store(blocking_write=False).read( file_path=self.get_metadata_file_path(metadata_class) ) self.assertMetadataEqual( self.deserialize_metadata(metadata_class, serialized_metadata), - self.get_metadata(metadata_class, custom={'a': 3}), + self.get_metadata(metadata_class, custom_metadata={'a': 3}), ) @parameterized.parameters(StepMetadata, RootMetadata) @@ -533,7 +533,7 @@ def test_unknown_key_in_metadata( self, metadata_class: type[StepMetadata] | type[RootMetadata], ): metadata = metadata_class( - custom={'a': 1}, + custom_metadata={'a': 1}, ) serialized_metadata = self.serialize_metadata(metadata) serialized_metadata['blah'] = 2 @@ -635,14 +635,17 @@ def test_validate_dict_entry_wrong_value_type( @parameterized.parameters( ({'item_handlers': {'a': 'b'}},), ({'performance_metrics': {'a': 1.0}},), - ({'custom': {'a': 1}, 'init_timestamp_nsecs': 1},), + ({'custom_metadata': {'a': 1}, 'init_timestamp_nsecs': 1},), ) def test_serialize_for_update_valid_kwargs( self, kwargs: dict[str, Any] ): + expected_kwargs = kwargs.copy() + if 'custom_metadata' in expected_kwargs: + expected_kwargs['custom'] = expected_kwargs.pop('custom_metadata') self.assertEqual( step_metadata_serialization.serialize_for_update(**kwargs), - kwargs, + expected_kwargs, ) @parameterized.parameters( @@ -653,8 +656,8 @@ def test_serialize_for_update_valid_kwargs( ({'performance_metrics': list()},), ({'init_timestamp_nsecs': float()},), ({'commit_timestamp_nsecs': float()},), - ({'custom': list()},), - ({'custom': {int(): None}},), + ({'custom_metadata': list()},), + ({'custom_metadata': {int(): None}},), ) def test_serialize_for_update_wrong_types( self, kwargs: dict[str, Any] @@ -667,10 +670,17 @@ def test_serialize_for_update_with_unknown_kwargs(self): ValueError, 'Provided metadata contains unknown key blah' ): step_metadata_serialization.serialize_for_update( - custom={'a': 1}, + custom_metadata={'a': 1}, blah=123, ) + with self.assertRaisesRegex( + ValueError, 'Provided metadata contains unknown key custom' + ): + step_metadata_serialization.serialize_for_update( + custom={'a': 1}, + ) + def test_serialize_for_update_performance_metrics_only_float(self): self.assertEqual( step_metadata_serialization.serialize_for_update( diff --git a/checkpoint/orbax/checkpoint/_src/metadata/root_metadata_serialization.py b/checkpoint/orbax/checkpoint/_src/metadata/root_metadata_serialization.py index 2fa755b5..6252c2da 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/root_metadata_serialization.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/root_metadata_serialization.py @@ -25,7 +25,7 @@ def serialize(metadata: RootMetadata) -> SerializedMetadata: """Serializes `metadata` to a dictionary.""" return { - 'custom': metadata.custom, + 'custom': metadata.custom_metadata, } @@ -51,4 +51,8 @@ def deserialize(metadata_dict: SerializedMetadata) -> RootMetadata: ) validated_metadata_dict['custom'][k] = metadata_dict[k] + # Rename to `custom_metadata`. + validated_metadata_dict['custom_metadata'] = validated_metadata_dict.pop( + 'custom' + ) return RootMetadata(**validated_metadata_dict) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/step_metadata_serialization.py b/checkpoint/orbax/checkpoint/_src/metadata/step_metadata_serialization.py index 1fbc62c3..7489097e 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/step_metadata_serialization.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/step_metadata_serialization.py @@ -53,7 +53,8 @@ def serialize(metadata: StepMetadata) -> SerializedMetadata: 'performance_metrics': float_metrics, 'init_timestamp_nsecs': metadata.init_timestamp_nsecs, 'commit_timestamp_nsecs': metadata.commit_timestamp_nsecs, - 'custom': metadata.custom, + # Change name from `custom_metadata` to `custom` for serialization. + 'custom': metadata.custom_metadata, } @@ -111,19 +112,23 @@ def serialize_for_update(**kwargs) -> SerializedMetadata: kwargs.get('commit_timestamp_nsecs', None) ) - if 'custom' in kwargs: - if kwargs['custom'] is None: - validated_kwargs['custom'] = {} + if 'custom_metadata' in kwargs: + if kwargs['custom_metadata'] is None: + validated_kwargs['custom_metadata'] = {} else: - utils.validate_type(kwargs['custom'], dict) - for k in kwargs.get('custom', {}) or {}: + utils.validate_type(kwargs['custom_metadata'], dict) + for k in kwargs.get('custom_metadata', {}) or {}: utils.validate_type(k, str) - validated_kwargs['custom'] = kwargs.get('custom', {}) + validated_kwargs['custom_metadata'] = kwargs.get('custom_metadata', {}) for k in kwargs: if k not in validated_kwargs: raise ValueError('Provided metadata contains unknown key %s.' % k) + # Change name from `custom_metadata` to `custom` for serialization. + if 'custom_metadata' in validated_kwargs: + validated_kwargs['custom'] = validated_kwargs.pop('custom_metadata') + return validated_kwargs @@ -193,7 +198,6 @@ def deserialize( for k in custom: utils.validate_type(k, str) validated_metadata_dict['custom'] = custom or {} - for k in metadata_dict: if k not in validated_metadata_dict: if 'custom' in metadata_dict and metadata_dict['custom']: @@ -205,5 +209,9 @@ def deserialize( 'Provided metadata contains unknown key %s. Adding it to custom.', k ) validated_metadata_dict['custom'][k] = metadata_dict[k] + # Change name from `custom` to `custom_metadata` for deserialization. + validated_metadata_dict['custom_metadata'] = validated_metadata_dict.pop( + 'custom' + ) return StepMetadata(**validated_metadata_dict) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree.py b/checkpoint/orbax/checkpoint/_src/metadata/tree.py index ea11f617..e3a906d8 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree.py @@ -22,6 +22,7 @@ import enum import functools import inspect +import json import operator import typing from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple, TypeAlias, TypeVar, Union @@ -37,6 +38,7 @@ from orbax.checkpoint._src.metadata import value_metadata_entry from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import types +from orbax.checkpoint._src.tree import types as tree_types from orbax.checkpoint._src.tree import utils as tree_utils @@ -58,6 +60,7 @@ _USE_ZARR3 = 'use_zarr3' _STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE = 'store_array_data_equal_to_fill_value' _VALUE_METADATA_TREE = 'value_metadata_tree' +_CUSTOM_METADATA = 'custom_metadata' class KeyType(enum.Enum): @@ -200,12 +203,21 @@ def jax_keypath(self) -> KeyPath: return tuple(keypath) -@dataclasses.dataclass +@dataclasses.dataclass(kw_only=True) class InternalTreeMetadata: - """Metadata representation of a PyTree.""" + """Metadata representation of a PyTree. + + Corresponds to the metadata of a PyTree checkpoint (e.g. that saved by + `PyTreeCheckpointHandler`, `StandardCheckpointHandler`, + `StandardCheckpointer`, etc.). + + This class is the internal / on-disk representation of metadata that is + presented to the user as a `TreeMetadata` object. + """ tree_metadata_entries: List[InternalTreeMetadataEntry] use_zarr3: bool + custom_metadata: tree_types.JsonType | None store_array_data_equal_to_fill_value: bool pytree_metadata_options: PyTreeMetadataOptions value_metadata_tree: PyTree | None = None @@ -219,6 +231,14 @@ def __post_init__(self): len(self.tree_metadata_entries), self.value_metadata_tree is not None, ) + # Validate JSON-serializability of custom_metadata. + try: + json.dumps(self.custom_metadata) + except TypeError as e: + raise TypeError( + 'Failed to encode `custom_metadata` metadata as JSON object. Please' + ' ensure your `custom_metadata` is JSON-serializable.' + ) from e @classmethod def build( @@ -227,6 +247,7 @@ def build( *, save_args: Optional[PyTree] = None, use_zarr3: bool = False, + custom_metadata: tree_types.JsonType | None = None, pytree_metadata_options: PyTreeMetadataOptions = ( PYTREE_METADATA_OPTIONS ), @@ -267,6 +288,7 @@ def build( return InternalTreeMetadata( tree_metadata_entries=tree_metadata_entries, use_zarr3=use_zarr3, + custom_metadata=custom_metadata, store_array_data_equal_to_fill_value=ts_utils.STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE, pytree_metadata_options=pytree_metadata_options, value_metadata_tree=value_metadata_tree, @@ -295,6 +317,7 @@ def to_json(self) -> Dict[str, Any]: }, _USE_ZARR3: True/False, _STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE: True, + _CUSTOM_METADATA: ..., _VALUE_METADATA_TREE: '{ "mu_nu": { "category": "namedtuple", @@ -353,6 +376,7 @@ def to_json(self) -> Dict[str, Any]: _STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE: ( self.store_array_data_equal_to_fill_value ), + _CUSTOM_METADATA: self.custom_metadata, } # TODO: b/365169723 - Support versioned evolution of metadata storage. if ( @@ -379,6 +403,7 @@ def from_json( ) -> InternalTreeMetadata: """Returns an InternalTreeMetadata instance from its JSON representation.""" use_zarr3 = json_dict.get(_USE_ZARR3, False) + custom_metadata = json_dict.get(_CUSTOM_METADATA, None) store_array_data_equal_to_fill_value = json_dict.get( _STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE, False ) @@ -405,6 +430,7 @@ def from_json( return InternalTreeMetadata( tree_metadata_entries=tree_metadata_entries, use_zarr3=use_zarr3, + custom_metadata=custom_metadata, pytree_metadata_options=pytree_metadata_options, value_metadata_tree=value_metadata_tree, store_array_data_equal_to_fill_value=store_array_data_equal_to_fill_value, @@ -427,7 +453,7 @@ def as_nested_tree(self) -> Dict[str, Any]: for entry in self.tree_metadata_entries ]) - def as_user_metadata( + def as_custom_metadata( self, directory: epath.Path, type_handler_registry: types.TypeHandlerRegistry, @@ -543,22 +569,28 @@ def serialize_tree( class TreeMetadata(Protocol): """User-facing metadata representation of a PyTree. + Corresponds to the metadata of a PyTree checkpoint (e.g. that saved by + `PyTreeCheckpointHandler`, `StandardCheckpointHandler`, + `StandardCheckpointer`, etc.). + Implementations must register themselves as PyTrees, e.g. with `jax.tree_util.register_pytree_with_keys_class`. - The object should be treated as a regular PyTree that can be mapped over. + The object should be treated as a regular PyTree that can be mapped over. Leaf values are typically of type `ocp.metadata.value.Metadata` (when the object is obtained from a `metadata()` function call). Note that the user may subsequently modify these leaves to be of any type. Additional properties - (e.g. `custom`) may be accessed directly, and are independent from the tree - structure. To directly access the underlying PyTree, which matches the + (e.g. `custom_metadata`) may be accessed directly, and are independent from + the + tree structure. To directly access the underlying PyTree, which matches the checkpoint structure, use the `tree` property. Here is a typical example usage:: with ocp.StandardCheckpointer() as ckptr: # `metadata` is a `TreeMetadata` object, but can be treated as a regular # PyTree. In this case, it corresponds to a "serialized" representation of - # the checkpoint tree. This means that all custom nodes are converted to + # the checkpoint tree. This means that all custom_metadata nodes are + converted to # standardized containers like list, tuple, and dict. (See also # `support_rich_types` for further details on how other types are # handled.) @@ -611,11 +643,12 @@ class TreeMetadata(Protocol): jax.tree.map(lambda x, y: foo(x, y), metadata.tree, tree) - Properties of the `TreeMetadata` object, such as `custom` and `tree`, can be + Properties of the `TreeMetadata` object, such as `custom_metadata` and `tree`, + can be accessed directly:: with ocp.StandardCheckpointer() as ckptr: metadata = ckptr.metadata('/path/to/existing/checkpoint') - metadata.custom + metadata.custom_metadata metadata.tree The metadata can be used directly to restore a checkpoint. Restoration code @@ -636,7 +669,7 @@ def tree(self) -> PyTree: ... @property - def custom(self) -> PyTree | None: + def custom_metadata(self) -> PyTree | None: ... @@ -701,7 +734,7 @@ def build( cls, tree: PyTree, *, - custom: PyTree | None = None, + custom_metadata: PyTree | None = None, ) -> TreeMetadata: """Builds the TreeMetadata.""" ... @@ -718,10 +751,10 @@ def __init__( self, *, tree: PyTree, - custom: PyTree | None = None, + custom_metadata: PyTree | None = None, ): self._tree = tree - self._custom = custom + self._custom_metadata = custom_metadata self._validate_tree_type(tree) def _validate_tree_type(self, tree: PyTree): @@ -740,8 +773,8 @@ def tree(self) -> PyTree: return self._tree @property - def custom(self) -> PyTree | None: - return self._custom + def custom_metadata(self) -> PyTree | None: + return self._custom_metadata def tree_flatten(self): @@ -866,22 +899,22 @@ def build( cls, tree: PyTree, *, - custom: PyTree | None = None, + custom_metadata: PyTree | None = None, ) -> TreeMetadata: """Builds the TreeMetadata.""" return cls( tree=tree, - custom=custom, + custom_metadata=custom_metadata, ) def build_default_tree_metadata( tree: PyTree, *, - custom: PyTree | None = None, + custom_metadata: PyTree | None = None, ) -> TreeMetadata: """Builds the TreeMetadata using a default implementation.""" return _TreeMetadataImpl.build( tree, - custom=custom, + custom_metadata=custom_metadata, ) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py b/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py index ea0a3846..cc965d4a 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py @@ -27,8 +27,13 @@ def _to_param_infos( tree: Any, - pytree_metadata_options: tree_metadata_lib.PyTreeMetadataOptions, + pytree_metadata_options: ( + tree_metadata_lib.PyTreeMetadataOptions | None + ) = None, ): + pytree_metadata_options = pytree_metadata_options or ( + tree_metadata_lib.PYTREE_METADATA_OPTIONS + ) return jax.tree.map( # Other properties are not relevant. lambda x: types.ParamInfo( @@ -132,6 +137,29 @@ def test_switching_between_support_rich_types( restored_tree_metadata = restored_internal_tree_metadata.as_nested_tree() chex.assert_trees_all_equal(restored_tree_metadata, expected_tree_metadata) + @parameterized.parameters( + (test_tree_utils.MyDataClass(),), + ([test_tree_utils.MyDataClass()],), + ({'a': test_tree_utils.MyFlax()},), + ) + def test_invalid_custom_metadata(self, custom_metadata): + tree = {'scalar_param': 1} + with self.assertRaisesRegex(TypeError, 'Failed to encode'): + tree_metadata_lib.InternalTreeMetadata.build( + param_infos=_to_param_infos(tree), custom_metadata=custom_metadata + ) + + @parameterized.parameters( + ({'a': 1, 'b': [{'c': 2}, 1]},), + ([1, [{'c': 2}, 1]],), + ) + def test_custom_metadata(self, custom_metadata): + tree = {'scalar_param': 1} + internal_tree_metadata = tree_metadata_lib.InternalTreeMetadata.build( + param_infos=_to_param_infos(tree), custom_metadata=custom_metadata + ) + self.assertEqual(internal_tree_metadata.custom_metadata, custom_metadata) + class NestedNamedTuple(NamedTuple): a: int @@ -157,9 +185,11 @@ def _check_tree_property( @parameterized.parameters(({'a': 1, 'b': 2},), ([1, 2],), ((1, 2),)) def test_properties(self, tree): - custom = {'foo': 1} - metadata = tree_metadata_lib._TreeMetadataImpl(tree=tree, custom=custom) - self.assertDictEqual(metadata.custom, custom) + custom_metadata = {'foo': 1} + metadata = tree_metadata_lib._TreeMetadataImpl( + tree=tree, custom_metadata=custom_metadata + ) + self.assertDictEqual(metadata.custom_metadata, custom_metadata) self._check_tree_property(tree, metadata) @parameterized.parameters( @@ -229,10 +259,12 @@ def test_sequence_accessors(self, tree): (test_tree_utils.EmptyNamedTuple(),), ) def test_tree_map(self, tree): - custom = {'foo': 1} - metadata = tree_metadata_lib._TreeMetadataImpl(tree=tree, custom=custom) + custom_metadata = {'foo': 1} + metadata = tree_metadata_lib._TreeMetadataImpl( + tree=tree, custom_metadata=custom_metadata + ) metadata = jax.tree.map(lambda x: x + 1, metadata) - self.assertDictEqual(metadata.custom, custom) + self.assertDictEqual(metadata.custom_metadata, custom_metadata) self._check_tree_property(jax.tree.map(lambda x: x + 1, tree), metadata) @parameterized.parameters( diff --git a/checkpoint/orbax/checkpoint/_src/tree/types.py b/checkpoint/orbax/checkpoint/_src/tree/types.py index f49fb776..bfb2a07b 100644 --- a/checkpoint/orbax/checkpoint/_src/tree/types.py +++ b/checkpoint/orbax/checkpoint/_src/tree/types.py @@ -14,6 +14,8 @@ """Public common types to work with pytrees.""" +from __future__ import annotations + from typing import Any, TypeVar, Union from jax import tree_util as jtu @@ -30,3 +32,6 @@ jtu.SequenceKey, jtu.DictKey, jtu.GetAttrKey, jtu.FlattenedIndexKey ] PyTreePath = tuple[PyTreeKey, ...] + +JsonType = list['JsonValue'] | dict[str, 'JsonValue'] +JsonValue = str | int | float | bool | None | JsonType diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index 606435fe..4da56152 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -1517,7 +1517,7 @@ def _maybe_save_metadata(self, metadata: Mapping[str, Any]): if not file_path.exists(): # May have been created by a previous run. logging.info('Writing root metadata') metadata_to_save = checkpoint.RootMetadata( - custom=dict(metadata), + custom_metadata=dict(metadata), ) self._blocking_metadata_store.write( file_path, serialize_root_metadata(metadata_to_save) @@ -1542,7 +1542,9 @@ def metadata(self) -> Mapping[str, Any]: self._metadata_dir) file_path = self._metadata_file_path(legacy=True) serialized_metadata = self._blocking_metadata_store.read(file_path) - self._metadata = deserialize_root_metadata(serialized_metadata).custom + self._metadata = deserialize_root_metadata( + serialized_metadata + ).custom_metadata if self._metadata is None: raise FileNotFoundError( f'Failed to read metadata from {file_path}.' diff --git a/checkpoint/orbax/checkpoint/checkpoint_utils.py b/checkpoint/orbax/checkpoint/checkpoint_utils.py index b6201d8a..1308ae65 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_utils.py +++ b/checkpoint/orbax/checkpoint/checkpoint_utils.py @@ -491,7 +491,7 @@ def _get_sharding_or_layout(value): if isinstance(target, tree_metadata.TreeMetadata): return tree_metadata.build_default_tree_metadata( jax.tree.map(_restore_args, target.tree, sharding_tree.tree), - custom=target.custom, + custom_metadata=target.custom_metadata, ) else: return jax.tree.map(_restore_args, target, sharding_tree)