From 5395f04da121106ab85521aeb586e6bee65235d0 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Mon, 9 Dec 2024 14:49:01 -0800 Subject: [PATCH] Support custom PyTree metadata. PiperOrigin-RevId: 704424560 --- checkpoint/CHANGELOG.md | 3 ++ .../base_pytree_checkpoint_handler.py | 22 ++++++++++++-- .../handlers/pytree_checkpoint_handler.py | 9 ++++++ .../handlers/standard_checkpoint_handler.py | 14 ++++++++- .../standard_checkpoint_handler_test_utils.py | 24 ++++++++------- .../orbax/checkpoint/_src/metadata/tree.py | 20 ++++++++++++- .../checkpoint/_src/metadata/tree_test.py | 30 ++++++++++++++++++- .../orbax/checkpoint/_src/tree/types.py | 5 ++++ 8 files changed, 112 insertions(+), 15 deletions(-) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 5e031e62e..2e8475eec 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 make the change unnoticeable to most users, but also has additional accessible properties not included in any tree mapping operations. +### Added +- User-provided custom PyTree metadata. + ## [0.11.0] - 2024-12-30 diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index 2d3b6f9b2..4b2a45f11 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -49,6 +49,7 @@ from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint._src.serialization import types +from orbax.checkpoint._src.tree import types as tree_types from orbax.checkpoint._src.tree import utils as tree_utils import tensorstore as ts @@ -436,6 +437,7 @@ async def async_save( raise ValueError('Found empty item.') save_args = args.save_args ocdbt_target_data_file_size = args.ocdbt_target_data_file_size + custom_metadata = args.custom_metadata save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save') byte_limiter = serialization.get_byte_limiter(self._save_concurrent_bytes) @@ -476,7 +478,11 @@ async def async_save( if multihost.is_primary_host(self._primary_host): commit_futures.append( self._write_metadata_file( - directory, param_infos, save_args, self._use_zarr3 + directory, + param_infos=param_infos, + save_args=save_args, + custom_metadata=custom_metadata, + use_zarr3=self._use_zarr3, ) ) @@ -728,8 +734,10 @@ class TrainState: def _write_metadata_file( self, directory: epath.Path, + *, param_infos: PyTree, save_args: PyTree, + custom_metadata: tree_types.JsonType | None = None, use_zarr3: bool = False, ) -> future.Future: def _save_fn(): @@ -740,6 +748,7 @@ def _save_fn(): param_infos, save_args=save_args, use_zarr3=use_zarr3, + custom=custom_metadata, pytree_metadata_options=self._pytree_metadata_options, ) logging.vlog( @@ -816,12 +825,14 @@ def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata: tree containing metadata. """ is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory) + internal_tree_metadata = self._read_metadata_file(directory) return tree_metadata.build_default_tree_metadata( - self._read_metadata_file(directory).as_user_metadata( + internal_tree_metadata.as_user_metadata( directory, self._type_handler_registry, use_ocdbt=is_ocdbt_checkpoint, ), + custom=internal_tree_metadata.custom, ) def finalize(self, directory: epath.Path) -> None: @@ -873,12 +884,19 @@ class BasePyTreeSaveArgs(CheckpointArgs): enable_pinned_host_transfer: True by default. If False, disables transfer to pinned host when copying from device to host, regardless of the presence of pinned host memory. + custom_metadata: A JSON-serializable object (typically just a nested + dictionary containing string keys and basic type values) that stores user- + specified metadata. This metadata is stored along with the Orbax-internal + PyTree metadata. This can be used to supplement information about the + PyTree checkpoint with information about e.g. the model used to generate + the checkpoint. """ item: PyTree save_args: Optional[PyTree] = None ocdbt_target_data_file_size: Optional[int] = None enable_pinned_host_transfer: bool = True + custom_metadata: tree_types.JsonType | None = None @register_with_handler(BasePyTreeCheckpointHandler, for_restore=True) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py index a682bd10f..edb8ab51b 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py @@ -46,6 +46,7 @@ from orbax.checkpoint._src.serialization import serialization from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint._src.tree import types as tree_types from orbax.checkpoint._src.tree import utils as tree_utils import tensorstore as ts @@ -428,6 +429,7 @@ def _get_impl_save_args( save_args=args.save_args, ocdbt_target_data_file_size=args.ocdbt_target_data_file_size, enable_pinned_host_transfer=args.enable_pinned_host_transfer, + custom_metadata=args.custom_metadata, ) @@ -1046,12 +1048,19 @@ class PyTreeSaveArgs(CheckpointArgs): enable_pinned_host_transfer: True by default. If False, disables transfer to pinned host when copying from device to host, regardless of the presence of pinned host memory. + custom_metadata: A JSON-serializable object (typically just a nested + dictionary containing string keys and basic type values) that stores user- + specified metadata. This metadata is stored along with the Orbax-internal + PyTree metadata. This can be used to supplement information about the + PyTree checkpoint with information about e.g. the model used to generate + the checkpoint. """ item: PyTree save_args: Optional[PyTree] = None ocdbt_target_data_file_size: Optional[int] = None enable_pinned_host_transfer: bool = True + custom_metadata: tree_types.JsonType | None = None def __post_init__(self): if isinstance(self.item, tree_metadata.TreeMetadata): diff --git a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py index 17fc7d4ea..d94cd0209 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py @@ -32,6 +32,7 @@ from orbax.checkpoint._src.handlers import pytree_checkpoint_handler from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib from orbax.checkpoint._src.metadata import tree as tree_metadata +from orbax.checkpoint._src.tree import types as tree_types from orbax.checkpoint._src.tree import utils as tree_utils @@ -144,15 +145,19 @@ async def async_save( 'Make sure to specify kwarg name `args=` when providing' ' `StandardSaveArgs`.' ) + custom_metadata = None if args is not None: item = args.item save_args = args.save_args + custom_metadata = args.custom_metadata self._validate_save_state(item, save_args=save_args) return await self._impl.async_save( directory, args=pytree_checkpoint_handler.PyTreeSaveArgs( - item=item, save_args=save_args + item=item, + save_args=save_args, + custom_metadata=custom_metadata, ), ) @@ -266,10 +271,17 @@ class StandardSaveArgs(CheckpointArgs): save_args: a PyTree with the same structure of `item`, which consists of `ocp.SaveArgs` objects as values. `None` can be used for values where no `SaveArgs` are specified. + custom_metadata: A JSON-serializable object (typically just a nested + dictionary containing string keys and basic type values) that stores user- + specified metadata. This metadata is stored along with the Orbax-internal + PyTree metadata. This can be used to supplement information about the + PyTree checkpoint with information about e.g. the model used to generate + the checkpoint. """ item: PyTree save_args: Optional[PyTree] = None + custom_metadata: tree_types.JsonType | None = None def __post_init__(self): if isinstance(self.item, tree_metadata.TreeMetadata): diff --git a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py index 86a92d12a..652e2904d 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py @@ -14,6 +14,8 @@ """Test for standard_checkpoint_handler.py.""" +# pylint: disable=protected-access, missing-function-docstring + import functools from typing import Any @@ -127,7 +129,6 @@ def test_basic_no_item_arg(self): test_utils.assert_tree_equal(self, self.pytree, restored) def test_shape_dtype_struct(self): - """Test case.""" self.handler.save( self.directory, args=self.save_args_cls(self.mixed_pytree) ) @@ -162,7 +163,7 @@ def test_custom_layout(self): custom_layout = Layout( device_local_layout=DLL( major_to_minor=arr.layout.device_local_layout.major_to_minor[::-1], # pytype: disable=attribute-error - _tiling=arr.layout.device_local_layout._tiling, # pylint: disable=protected-access # pytype: disable=attribute-error + _tiling=arr.layout.device_local_layout._tiling, # pytype: disable=attribute-error ), sharding=arr.sharding, ) @@ -210,7 +211,6 @@ def test_custom_layout(self): @parameterized.parameters((True,), (False,)) def test_change_shape(self, strict: bool): - """Test case.""" if not hasattr(self.restore_args_cls, 'strict'): self.skipTest('strict option not supported for this handler') mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',)) @@ -255,7 +255,6 @@ def test_restore_unsupported_type(self): self.handler.restore(self.directory, args=self.restore_args_cls(pytree)) def test_cast(self): - """Test case.""" # TODO(dicentra): casting from int dtypes currently doesn't work # in the model surgery context. save_args = jax.tree.map( @@ -289,7 +288,6 @@ def check_dtype(x, dtype): jax.tree.map(lambda x: check_dtype(x, jnp.bfloat16), restored) def test_flax_model(self): - """Test case.""" @flax.struct.dataclass class Params(flax.struct.PyTreeNode): @@ -318,12 +316,10 @@ def make_params(): test_utils.assert_tree_equal(self, params, restored) def test_empty_error(self): - """Test case.""" with self.assertRaises(ValueError): self.handler.save(self.directory, args=self.save_args_cls({})) def test_empty_dict_node(self): - """Test case.""" item = {'a': {}, 'b': 3} self.handler.save(self.directory, args=self.save_args_cls(item)) restored = self.handler.restore( @@ -332,7 +328,6 @@ def test_empty_dict_node(self): self.assertDictEqual(restored, item) def test_empty_none_node(self): - """Test case.""" item = {'c': None, 'd': 2} self.handler.save(self.directory, args=self.save_args_cls(item)) restored = self.handler.restore( @@ -341,7 +336,6 @@ def test_empty_none_node(self): self.assertDictEqual(restored, item) def test_none_node_in_restore_args(self): - """Test case.""" devices = np.asarray(jax.devices()) mesh = jax.sharding.Mesh(devices, ('x',)) mesh_axes = jax.sharding.PartitionSpec( @@ -358,7 +352,6 @@ def test_none_node_in_restore_args(self): test_utils.assert_tree_equal(self, restored, {'b': None}) def test_masked_shape_dtype_struct(self): - """Test case.""" def _should_mask(keypath): return keypath[0].key == 'a' or ( @@ -398,3 +391,14 @@ def _none(keypath, x): # Restore it without any item. restored = self.handler.restore(self.directory) test_utils.assert_tree_equal(self, expected, restored) + + def test_custom_metadata(self): + custom_metadata = {'foo': 1} + self.handler.save( + self.directory, + args=self.save_args_cls( + self.pytree, custom_metadata=custom_metadata + ), + ) + metadata = self.handler.metadata(self.directory) + self.assertEqual(metadata.custom, custom_metadata) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree.py b/checkpoint/orbax/checkpoint/_src/metadata/tree.py index 0bf35fe42..8ec591beb 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree.py @@ -22,6 +22,7 @@ import enum import functools import inspect +import json import operator import typing from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple, TypeAlias, TypeVar, Union @@ -37,6 +38,7 @@ from orbax.checkpoint._src.metadata import value_metadata_entry from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import types +from orbax.checkpoint._src.tree import types as tree_types from orbax.checkpoint._src.tree import utils as tree_utils @@ -58,6 +60,7 @@ _USE_ZARR3 = 'use_zarr3' _STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE = 'store_array_data_equal_to_fill_value' _VALUE_METADATA_TREE = 'value_metadata_tree' +_CUSTOM_FIELD = 'custom' class KeyType(enum.Enum): @@ -200,12 +203,13 @@ def jax_keypath(self) -> KeyPath: return tuple(keypath) -@dataclasses.dataclass +@dataclasses.dataclass(kw_only=True) class InternalTreeMetadata: """Metadata representation of a PyTree.""" tree_metadata_entries: List[InternalTreeMetadataEntry] use_zarr3: bool + custom: tree_types.JsonType | None store_array_data_equal_to_fill_value: bool pytree_metadata_options: PyTreeMetadataOptions value_metadata_tree: PyTree | None = None @@ -219,6 +223,14 @@ def __post_init__(self): len(self.tree_metadata_entries), self.value_metadata_tree is not None, ) + # Validate JSON-serializability of custom metadata. + try: + json.dumps(self.custom) + except TypeError as e: + raise TypeError( + 'Failed to encode `custom` metadata as JSON object. Please ensure' + ' your custom metadata is JSON-serializable.' + ) from e @classmethod def build( @@ -227,6 +239,7 @@ def build( *, save_args: Optional[PyTree] = None, use_zarr3: bool = False, + custom: tree_types.JsonType | None = None, pytree_metadata_options: PyTreeMetadataOptions = ( PYTREE_METADATA_OPTIONS ), @@ -267,6 +280,7 @@ def build( return InternalTreeMetadata( tree_metadata_entries=tree_metadata_entries, use_zarr3=use_zarr3, + custom=custom, store_array_data_equal_to_fill_value=ts_utils.STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE, pytree_metadata_options=pytree_metadata_options, value_metadata_tree=value_metadata_tree, @@ -295,6 +309,7 @@ def to_json(self) -> Dict[str, Any]: }, _USE_ZARR3: True/False, _STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE: True, + _CUSTOM_FIELD: ..., _VALUE_METADATA_TREE: '{ "mu_nu": { "category": "namedtuple", @@ -353,6 +368,7 @@ def to_json(self) -> Dict[str, Any]: _STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE: ( self.store_array_data_equal_to_fill_value ), + _CUSTOM_FIELD: self.custom, } # TODO: b/365169723 - Support versioned evolution of metadata storage. if ( @@ -379,6 +395,7 @@ def from_json( ) -> InternalTreeMetadata: """Returns an InternalTreeMetadata instance from its JSON representation.""" use_zarr3 = json_dict.get(_USE_ZARR3, False) + custom = json_dict.get(_CUSTOM_FIELD, None) store_array_data_equal_to_fill_value = json_dict.get( _STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE, False ) @@ -405,6 +422,7 @@ def from_json( return InternalTreeMetadata( tree_metadata_entries=tree_metadata_entries, use_zarr3=use_zarr3, + custom=custom, pytree_metadata_options=pytree_metadata_options, value_metadata_tree=value_metadata_tree, store_array_data_equal_to_fill_value=store_array_data_equal_to_fill_value, diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py b/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py index ea0a38461..9da4f869b 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py @@ -27,8 +27,13 @@ def _to_param_infos( tree: Any, - pytree_metadata_options: tree_metadata_lib.PyTreeMetadataOptions, + pytree_metadata_options: ( + tree_metadata_lib.PyTreeMetadataOptions | None + ) = None, ): + pytree_metadata_options = pytree_metadata_options or ( + tree_metadata_lib.PYTREE_METADATA_OPTIONS + ) return jax.tree.map( # Other properties are not relevant. lambda x: types.ParamInfo( @@ -132,6 +137,29 @@ def test_switching_between_support_rich_types( restored_tree_metadata = restored_internal_tree_metadata.as_nested_tree() chex.assert_trees_all_equal(restored_tree_metadata, expected_tree_metadata) + @parameterized.parameters( + (test_tree_utils.MyDataClass(),), + ([test_tree_utils.MyDataClass()],), + ({'a': test_tree_utils.MyFlax()},), + ) + def test_invalid_custom_metadata(self, custom): + tree = {'scalar_param': 1} + with self.assertRaisesRegex(TypeError, 'Failed to encode'): + tree_metadata_lib.InternalTreeMetadata.build( + param_infos=_to_param_infos(tree), custom=custom + ) + + @parameterized.parameters( + ({'a': 1, 'b': [{'c': 2}, 1]},), + ([1, [{'c': 2}, 1]],), + ) + def test_custom_metadata(self, custom): + tree = {'scalar_param': 1} + internal_tree_metadata = tree_metadata_lib.InternalTreeMetadata.build( + param_infos=_to_param_infos(tree), custom=custom + ) + self.assertEqual(internal_tree_metadata.custom, custom) + class NestedNamedTuple(NamedTuple): a: int diff --git a/checkpoint/orbax/checkpoint/_src/tree/types.py b/checkpoint/orbax/checkpoint/_src/tree/types.py index f49fb7763..bfb2a07bb 100644 --- a/checkpoint/orbax/checkpoint/_src/tree/types.py +++ b/checkpoint/orbax/checkpoint/_src/tree/types.py @@ -14,6 +14,8 @@ """Public common types to work with pytrees.""" +from __future__ import annotations + from typing import Any, TypeVar, Union from jax import tree_util as jtu @@ -30,3 +32,6 @@ jtu.SequenceKey, jtu.DictKey, jtu.GetAttrKey, jtu.FlattenedIndexKey ] PyTreePath = tuple[PyTreeKey, ...] + +JsonType = list['JsonValue'] | dict[str, 'JsonValue'] +JsonValue = str | int | float | bool | None | JsonType