Skip to content

Commit

Permalink
Support custom PyTree metadata. Standardize naming of the "custom met…
Browse files Browse the repository at this point in the history
…adata" field (user-supplied metadata) as `custom_metadata`.

PiperOrigin-RevId: 718050751
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Jan 21, 2025
1 parent 3a8acf7 commit 9dedd19
Show file tree
Hide file tree
Showing 19 changed files with 230 additions and 86 deletions.
9 changes: 6 additions & 3 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
make the change unnoticeable to most users, but also has additional accessible
properties not included in any tree mapping operations.
- `Checkpointer.save()`, `AsyncCheckpointer.save()` also saves `StepMetadata`.
- Added github actions CI testing using Python versions 3.10-3.13
- Added github actions CI testing using Python versions 3.10-3.13.
- Standardize naming of the "custom metadata" field (user-supplied metadata) as
`custom_metadata`.

### Added
- The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`.
- `CommitFuture` and `HandlerAwaitableSignal` for signalling between Checkpointing layers to enable async
directory creation.
- `CommitFuture` and `HandlerAwaitableSignal` for signalling between
Checkpointing layers to enable async directory creation.
- User-provided custom PyTree metadata.

### Fixed
- Fix a bug where snapshots are not released by `wait_for_new_checkpoint`
Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ py_library(
name = "tree",
srcs = ["tree.py"],
deps = [
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
"//orbax/checkpoint/_src/tree:types",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _save_step_metadata(
):
"""Saves StepMetadata to the checkpoint directory."""
update_dict = {
'custom': custom_metadata,
'custom_metadata': custom_metadata,
}
if isinstance(
self._handler, composite_checkpoint_handler.CompositeCheckpointHandler
Expand Down
3 changes: 3 additions & 0 deletions checkpoint/orbax/checkpoint/_src/handlers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/serialization",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
"//orbax/checkpoint:utils",
],
Expand All @@ -91,6 +92,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
"//orbax/checkpoint:utils",
"//orbax/checkpoint/_src/metadata:array_metadata_store",
Expand Down Expand Up @@ -170,6 +172,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/metadata:tree",
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

Expand Down Expand Up @@ -444,6 +445,7 @@ async def async_save(
raise ValueError('Found empty item.')
save_args = args.save_args
ocdbt_target_data_file_size = args.ocdbt_target_data_file_size
custom_metadata = args.custom_metadata

save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save')
byte_limiter = serialization.get_byte_limiter(self._save_concurrent_bytes)
Expand Down Expand Up @@ -491,6 +493,7 @@ async def async_save(
checkpoint_dir=directory,
param_infos=param_infos,
save_args=save_args,
custom_metadata=custom_metadata,
use_zarr3=self._use_zarr3,
)
)
Expand Down Expand Up @@ -799,8 +802,10 @@ def update_param_info(param_info: types.ParamInfo) -> types.ParamInfo:
def _write_metadata_file(
self,
directory: epath.Path,
*,
param_infos: PyTree,
save_args: PyTree,
custom_metadata: tree_types.JsonType | None = None,
use_zarr3: bool = False,
) -> future.Future:
def _save_fn(param_infos):
Expand All @@ -811,6 +816,7 @@ def _save_fn(param_infos):
param_infos,
save_args=save_args,
use_zarr3=use_zarr3,
custom_metadata=custom_metadata,
pytree_metadata_options=self._pytree_metadata_options,
)
logging.vlog(
Expand All @@ -832,8 +838,10 @@ def _write_metadata_after_commits(
self,
commit_futures: List[future.Future],
checkpoint_dir: epath.Path,
*,
param_infos: PyTree,
save_args: PyTree,
custom_metadata: tree_types.JsonType | None = None,
use_zarr3: bool,
) -> None:
if not utils.is_primary_host(self._primary_host):
Expand All @@ -853,7 +861,11 @@ def _write_metadata_after_commits(
param_infos, checkpoint_dir, self._array_metadata_store
)
self._write_metadata_file(
checkpoint_dir, param_infos, save_args, use_zarr3
checkpoint_dir,
param_infos=param_infos,
save_args=save_args,
custom_metadata=custom_metadata,
use_zarr3=use_zarr3,
).result()

def _read_metadata_file(
Expand Down Expand Up @@ -915,12 +927,14 @@ def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata:
tree containing metadata.
"""
is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory)
internal_tree_metadata = self._read_metadata_file(directory)
return tree_metadata.build_default_tree_metadata(
self._read_metadata_file(directory).as_user_metadata(
internal_tree_metadata.as_custom_metadata(
directory,
self._type_handler_registry,
use_ocdbt=is_ocdbt_checkpoint,
),
custom_metadata=internal_tree_metadata.custom_metadata,
)

def finalize(self, directory: epath.Path) -> None:
Expand Down Expand Up @@ -972,12 +986,16 @@ class BasePyTreeSaveArgs(CheckpointArgs):
enable_pinned_host_transfer: True by default. If False, disables transfer to
pinned host when copying from device to host, regardless of the presence
of pinned host memory.
custom_metadata: User-provided custom metadata. An arbitrary
JSON-serializable dictionary the user can use to store additional
information. The field is treated as opaque by Orbax.
"""

item: PyTree
save_args: Optional[PyTree] = None
ocdbt_target_data_file_size: Optional[int] = None
enable_pinned_host_transfer: bool = True
custom_metadata: tree_types.JsonType | None = None


@register_with_handler(BasePyTreeCheckpointHandler, for_restore=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def test_metadata_no_save(self, use_handler_registry):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.custom_metadata)

def test_metadata_handler_registry(self):
registry = handler_registration.DefaultCheckpointHandlerRegistry()
Expand Down Expand Up @@ -779,7 +779,7 @@ def test_metadata_handler_registry(self):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.custom_metadata)

def test_metadata_after_step_metadata_write(self):
handler = CompositeCheckpointHandler(
Expand All @@ -795,7 +795,7 @@ def test_metadata_after_step_metadata_write(self):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.custom_metadata)

metadata_to_write = checkpoint.StepMetadata(
item_handlers={
Expand All @@ -813,7 +813,7 @@ def test_metadata_after_step_metadata_write(self):
),
init_timestamp_nsecs=1000,
commit_timestamp_nsecs=2000,
custom={
custom_metadata={
'custom_key': 'custom_value',
},
)
Expand All @@ -837,7 +837,9 @@ def test_metadata_after_step_metadata_write(self):
)
self.assertEqual(step_metadata.init_timestamp_nsecs, 1000)
self.assertEqual(step_metadata.commit_timestamp_nsecs, 2000)
self.assertEqual(step_metadata.custom, {'custom_key': 'custom_value'})
self.assertEqual(
step_metadata.custom_metadata, {'custom_key': 'custom_value'}
)

def test_metadata_existing_items_updates_step_metadata(self):
handler = CompositeCheckpointHandler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from orbax.checkpoint._src.serialization import serialization
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

Expand Down Expand Up @@ -431,6 +432,7 @@ def _get_impl_save_args(
save_args=args.save_args,
ocdbt_target_data_file_size=args.ocdbt_target_data_file_size,
enable_pinned_host_transfer=args.enable_pinned_host_transfer,
custom_metadata=args.custom_metadata,
)


Expand Down Expand Up @@ -1052,12 +1054,16 @@ class PyTreeSaveArgs(CheckpointArgs):
enable_pinned_host_transfer: True by default. If False, disables transfer to
pinned host when copying from device to host, regardless of the presence
of pinned host memory.
custom_metadata: User-provided custom metadata. An arbitrary
JSON-serializable dictionary the user can use to store additional
information. The field is treated as opaque by Orbax.
"""

item: PyTree
save_args: Optional[PyTree] = None
ocdbt_target_data_file_size: Optional[int] = None
enable_pinned_host_transfer: bool = True
custom_metadata: tree_types.JsonType | None = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils


Expand Down Expand Up @@ -144,15 +145,19 @@ async def async_save(
'Make sure to specify kwarg name `args=` when providing'
' `StandardSaveArgs`.'
)
custom_metadata = None
if args is not None:
item = args.item
save_args = args.save_args
custom_metadata = args.custom_metadata

self._validate_save_state(item, save_args=save_args)
return await self._impl.async_save(
directory,
args=pytree_checkpoint_handler.PyTreeSaveArgs(
item=item, save_args=save_args
item=item,
save_args=save_args,
custom_metadata=custom_metadata,
),
)

Expand Down Expand Up @@ -266,10 +271,14 @@ class StandardSaveArgs(CheckpointArgs):
save_args: a PyTree with the same structure of `item`, which consists of
`ocp.SaveArgs` objects as values. `None` can be used for values where no
`SaveArgs` are specified.
custom_metadata: User-provided custom metadata. An arbitrary
JSON-serializable dictionary the user can use to store additional
information. The field is treated as opaque by Orbax.
"""

item: PyTree
save_args: Optional[PyTree] = None
custom_metadata: tree_types.JsonType | None = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Test for standard_checkpoint_handler.py."""

# pylint: disable=protected-access, missing-function-docstring

import functools
from typing import Any

Expand Down Expand Up @@ -127,7 +129,6 @@ def test_basic_no_item_arg(self):
test_utils.assert_tree_equal(self, self.pytree, restored)

def test_shape_dtype_struct(self):
"""Test case."""
self.handler.save(
self.directory, args=self.save_args_cls(self.mixed_pytree)
)
Expand Down Expand Up @@ -162,7 +163,7 @@ def test_custom_layout(self):
custom_layout = Layout(
device_local_layout=DLL(
major_to_minor=arr.layout.device_local_layout.major_to_minor[::-1], # pytype: disable=attribute-error
_tiling=arr.layout.device_local_layout._tiling, # pylint: disable=protected-access # pytype: disable=attribute-error
_tiling=arr.layout.device_local_layout._tiling, # pytype: disable=attribute-error
),
sharding=arr.sharding,
)
Expand Down Expand Up @@ -210,7 +211,6 @@ def test_custom_layout(self):

@parameterized.parameters((True,), (False,))
def test_change_shape(self, strict: bool):
"""Test case."""
if not hasattr(self.restore_args_cls, 'strict'):
self.skipTest('strict option not supported for this handler')
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',))
Expand Down Expand Up @@ -255,7 +255,6 @@ def test_restore_unsupported_type(self):
self.handler.restore(self.directory, args=self.restore_args_cls(pytree))

def test_cast(self):
"""Test case."""
# TODO(dicentra): casting from int dtypes currently doesn't work
# in the model surgery context.
save_args = jax.tree.map(
Expand Down Expand Up @@ -289,7 +288,6 @@ def check_dtype(x, dtype):
jax.tree.map(lambda x: check_dtype(x, jnp.bfloat16), restored)

def test_flax_model(self):
"""Test case."""

@flax.struct.dataclass
class Params(flax.struct.PyTreeNode):
Expand Down Expand Up @@ -318,12 +316,10 @@ def make_params():
test_utils.assert_tree_equal(self, params, restored)

def test_empty_error(self):
"""Test case."""
with self.assertRaises(ValueError):
self.handler.save(self.directory, args=self.save_args_cls({}))

def test_empty_dict_node(self):
"""Test case."""
item = {'a': {}, 'b': 3}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand All @@ -332,7 +328,6 @@ def test_empty_dict_node(self):
self.assertDictEqual(restored, item)

def test_empty_none_node(self):
"""Test case."""
item = {'c': None, 'd': 2}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand All @@ -341,7 +336,6 @@ def test_empty_none_node(self):
self.assertDictEqual(restored, item)

def test_none_node_in_restore_args(self):
"""Test case."""
devices = np.asarray(jax.devices())
mesh = jax.sharding.Mesh(devices, ('x',))
mesh_axes = jax.sharding.PartitionSpec(
Expand All @@ -358,7 +352,6 @@ def test_none_node_in_restore_args(self):
test_utils.assert_tree_equal(self, restored, {'b': None})

def test_masked_shape_dtype_struct(self):
"""Test case."""

def _should_mask(keypath):
return keypath[0].key == 'a' or (
Expand Down Expand Up @@ -398,3 +391,12 @@ def _none(keypath, x):
# Restore it without any item.
restored = self.handler.restore(self.directory)
test_utils.assert_tree_equal(self, expected, restored)

def test_custom_metadata(self):
custom_metadata = {'foo': 1}
self.handler.save(
self.directory,
args=self.save_args_cls(self.pytree, custom_metadata=custom_metadata),
)
metadata = self.handler.metadata(self.directory)
self.assertEqual(metadata.custom_metadata, custom_metadata)
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
],
)
Expand Down
Loading

0 comments on commit 9dedd19

Please sign in to comment.