diff --git a/CHANGELOG.md b/CHANGELOG.md index 73ea6e4c9..b84c358a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 packages, namely `orbax-checkpoint` and `orbax-export`. Imports are unchanged, and still of the form `import orbax.checkpoint` or `import orbax.export`. - Finer scoped jax.monitoring calls on the save path. +- `CheckpointManager.metadata()` now accepts a `step` parameter. If provided, it +will return `StepMetadata`, and will otherwise return `RootMetadata`. +- `CheckpointManager.restore()` will now attempt to initialize checkpoint +handlers using `StepMetadata.item_handlers` and the global `HandlerTypeRegistry` +if no args are provided. +- `CompositeCheckpointHandler.metadata()` now returns `StepMetadata`. ## [0.1.7] - 2022-03-29 diff --git a/checkpoint/orbax/checkpoint/BUILD b/checkpoint/orbax/checkpoint/BUILD index ada9307c9..4f3d5b40b 100644 --- a/checkpoint/orbax/checkpoint/BUILD +++ b/checkpoint/orbax/checkpoint/BUILD @@ -101,7 +101,10 @@ py_library( py_library( name = "abstract_checkpoint_manager", srcs = ["abstract_checkpoint_manager.py"], - deps = [":args"], + deps = [ + ":args", + "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", + ], ) py_library( @@ -125,6 +128,7 @@ py_library( "//checkpoint/orbax/checkpoint/_src/handlers:proto_checkpoint_handler", "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", "//checkpoint/orbax/checkpoint/_src/metadata:root_metadata_serialization", + "//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization", "//checkpoint/orbax/checkpoint/_src/multihost", "//checkpoint/orbax/checkpoint/_src/path:atomicity_types", "//checkpoint/orbax/checkpoint/_src/path:deleter", diff --git a/checkpoint/orbax/checkpoint/_src/handlers/BUILD b/checkpoint/orbax/checkpoint/_src/handlers/BUILD index 50866213e..cc29cfc15 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/BUILD +++ b/checkpoint/orbax/checkpoint/_src/handlers/BUILD @@ -25,6 +25,7 @@ py_library( "//checkpoint/orbax/checkpoint/_src/path:atomicity", "//checkpoint/orbax/checkpoint/_src/path:atomicity_defaults", "//checkpoint/orbax/checkpoint/_src/path:atomicity_types", + "//checkpoint/orbax/checkpoint/_src/path:step", ], ) diff --git a/checkpoint/orbax/checkpoint/abstract_checkpoint_manager.py b/checkpoint/orbax/checkpoint/abstract_checkpoint_manager.py index e493f30c5..1d5ac2627 100644 --- a/checkpoint/orbax/checkpoint/abstract_checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/abstract_checkpoint_manager.py @@ -19,6 +19,7 @@ from etils import epath from orbax.checkpoint import args as args_lib +from orbax.checkpoint._src.metadata import checkpoint PyTree = Any SaveParams = Mapping[str, Any] @@ -290,8 +291,22 @@ def item_metadata( """ @abc.abstractmethod - def metadata(self) -> Mapping[str, Any]: - """Returns CheckpointManager level metadata if present, empty otherwise.""" + def metadata( + self, step: int | None = None, + ) -> checkpoint.StepMetadata | checkpoint.RootMetadata: + """Returns `StepMetadata` for the specified step, or `RootMetadata` all. + + If step is specified, only return `StepMetadata` for that step. + Otherwise, return `RootMetadata`. + + Args: + step: Step for which to retrieve `StepMetadata`. If None, returns + `RootMetadata`. + + Returns: + Metadata for the specified step (`StepMetadata`), or all steps + (`RootMetadata`). + """ @abc.abstractmethod def metrics(self, step: int) -> Optional[PyTree]: diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index 3e7bb05d3..451033072 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -22,7 +22,7 @@ import threading import time import typing -from typing import Any, Callable, Container, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Container, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union, overload from absl import logging from etils import epath @@ -45,6 +45,7 @@ from orbax.checkpoint._src.handlers import proto_checkpoint_handler from orbax.checkpoint._src.metadata import checkpoint from orbax.checkpoint._src.metadata import root_metadata_serialization +from orbax.checkpoint._src.metadata import step_metadata_serialization from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import atomicity_types from orbax.checkpoint._src.path import deleter @@ -65,6 +66,9 @@ AbstractCheckpointManager = ( abstract_checkpoint_manager.AbstractCheckpointManager ) +StepMetadata = checkpoint.StepMetadata +RootMetadata = checkpoint.RootMetadata +ItemMetadata = checkpoint.CompositeItemMetadata | checkpoint.SingleItemMetadata AsyncCheckpointer = async_checkpointer.AsyncCheckpointer Checkpointer = checkpointer_lib.Checkpointer JsonCheckpointHandler = json_checkpoint_handler.JsonCheckpointHandler @@ -769,11 +773,14 @@ def __init__( self._metadata_dir = self.directory / METADATA_ITEM_NAME if self._options.read_only and not self._metadata_dir.exists(): - self._metadata = {} if metadata is None else metadata + custom_metadata = {} if metadata is None else dict(metadata) else: - self._metadata = None + custom_metadata = None + self._root_metadata = RootMetadata( + custom_metadata=custom_metadata, + ) - self._maybe_save_metadata(metadata) + self._maybe_save_root_metadata(metadata) # TODO: b/359854428 - Move Finalize biz logic to a separate class/module. self._finalize_thread_lock = threading.Lock() @@ -1357,7 +1364,9 @@ def save( self._logger.log_entry(dataclasses.asdict(step_stats)) return True - def _maybe_get_default_item(self, composite_result: args_lib.Composite): + def _maybe_get_default_item( + self, composite_result: args_lib.Composite + ) -> Union[Any, args_lib.Composite]: if self._default_item: if DEFAULT_ITEM_NAME not in composite_result: raise ValueError( @@ -1438,9 +1447,14 @@ def restore( return self._maybe_get_default_item(restored) - def item_metadata(self, step: int) -> Union[Any, args_lib.Composite]: + def item_metadata( + self, step: int + ) -> Union[Any, args_lib.Composite, ItemMetadata]: """Retrieves metadata for all known items. + Important note: This method will soon be deprecated in favor of + `metadata().item_metadata`. Please use that method instead. + Note that metadata will only be returned for items that can actually be interpreted. If an item is present in the checkpoint but not registered (using a prior save or restore, or with `handler_registry` at init), the @@ -1453,18 +1467,9 @@ def item_metadata(self, step: int) -> Union[Any, args_lib.Composite]: Either metadata for the item itself, if in default-item mode, or a Composite of metadata for each item. """ - assert isinstance(self._checkpointer.handler, CompositeCheckpointHandler) - read_step_directory = self._get_read_step_directory(step, self.directory) - - result = self._checkpointer.metadata(read_step_directory) - if isinstance(result, checkpoint.StepMetadata): - result = result.item_metadata - if self._default_item is None: - self._default_item = _determine_default_item_mode_from_directory( - read_step_directory - ) - return self._maybe_get_default_item(result) + return self.metadata(step).item_metadata + # TODO(b/370812224): Deprecate in favor of StepMetadata.metrics def metrics(self, step: int) -> Optional[PyTree]: if self._track_best: try: @@ -1563,21 +1568,22 @@ def _metadata_file_path(self, legacy: bool = False) -> epath.Path: self._metadata_dir, legacy=legacy ) - def _maybe_save_metadata(self, metadata: Mapping[str, Any]): + def _maybe_save_root_metadata( + self, custom_metadata: Mapping[str, Any] | None + ): """Saves CheckpointManager level metadata, skips if already present.""" if self._options.save_root_metadata: - logging.info('Saving root metadata') - if (metadata is not None and - not self._options.read_only and - utils.is_primary_host(self._multiprocessing_options.primary_host)): - logging.info('Creating metadata directory') + if ( + custom_metadata is not None + and not self._options.read_only + and utils.is_primary_host(self._multiprocessing_options.primary_host) + ): self._metadata_dir.mkdir(parents=True, exist_ok=True) file_path = self._metadata_file_path() if not file_path.exists(): # May have been created by a previous run. - logging.info('Writing root metadata') - metadata_to_save = checkpoint.RootMetadata( - custom_metadata=dict(metadata), - ) + metadata_to_save = self._root_metadata + if custom_metadata is not None: + metadata_to_save.custom_metadata = dict(custom_metadata) self._blocking_metadata_store.write( file_path, serialize_root_metadata(metadata_to_save) ) @@ -1590,9 +1596,43 @@ def _maybe_save_metadata(self, metadata: Mapping[str, Any]): processes=self._multiprocessing_options.active_processes, ) - def metadata(self) -> Mapping[str, Any]: - """See superclass documentation.""" - if self._metadata is None: + def _get_step_metadata(self, step: int) -> StepMetadata: + infos = [info for info in self._checkpoints if info.step == step] + if not infos: + metrics = None + else: + if len(infos) > 1: + logging.warning( + 'Multiple CheckpointInfos found for step %d. Using the first one.', + step, + ) + metrics = infos[0].metrics + + step_metadata = self._checkpointer.metadata( + self._get_read_step_directory(step, self.directory), + ) + + if self._default_item is None: + self._default_item = _determine_default_item_mode_from_directory( + self._get_read_step_directory(step, self.directory) + ) + step_metadata.item_metadata = self._maybe_get_default_item( + step_metadata.item_metadata + ) + + if metrics is not None: + validated_metrics = step_metadata_serialization.deserialize( + {}, metrics=dict(metrics) + ).metrics + step_metadata = dataclasses.replace( + step_metadata, + metrics=validated_metrics, + ) + + return step_metadata + + def _get_root_metadata(self) -> RootMetadata: + if self._root_metadata.custom_metadata is None: if self._metadata_dir.exists(): file_path = self._metadata_file_path() if not file_path.exists(): @@ -1601,16 +1641,28 @@ def metadata(self) -> Mapping[str, Any]: self._metadata_dir) file_path = self._metadata_file_path(legacy=True) serialized_metadata = self._blocking_metadata_store.read(file_path) - self._metadata = deserialize_root_metadata( + if serialized_metadata is None: + raise IOError(f'Failed to read metadata from {file_path}') + self._root_metadata = root_metadata_serialization.deserialize( serialized_metadata - ).custom_metadata - if self._metadata is None: - raise FileNotFoundError( - f'Failed to read metadata from {file_path}.' - ) + ) else: - self._metadata = {} - return self._metadata + self._root_metadata.custom_metadata = {} + return self._root_metadata + + @overload + def metadata(self, step: None = None) -> RootMetadata: + ... + + @overload + def metadata(self, step: int) -> StepMetadata: + ... + + def metadata(self, step: int | None = None) -> RootMetadata | StepMetadata: + """See superclass documentation.""" + if step is not None: + return self._get_step_metadata(step) + return self._get_root_metadata() def _sort_checkpoints_by_metrics( self, checkpoints: List[CheckpointInfo] diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py index 697cbbd2b..7bbc1d4ed 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py @@ -67,6 +67,8 @@ get_present_and_missing_chunks = ( local_checkpoint_data_debugging.get_present_and_missing_chunks ) +RootMetadata = checkpoint_manager.RootMetadata +StepMetadata = checkpoint_manager.StepMetadata _PRIMARY_REPLICA_ID = 0 _SECONDARY_REPLICA_ID = 1 @@ -1303,7 +1305,7 @@ def item_metadata(self, step: int) -> Any: 'Item metadata not yet implemented for emergency.CheckpointManager.' ) - def metadata(self) -> dict[str, Any]: + def metadata(self, step: int | None = None) -> RootMetadata | StepMetadata: """Returns CheckpointManager level metadata if present, empty otherwise.""" raise NotImplementedError( 'Metadata not yet implemented for emergency.CheckpointManager.' diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager.py index f91d77b6e..9ddae2c02 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager.py @@ -18,7 +18,6 @@ subject to change without notice. """ -from collections.abc import Mapping import dataclasses from typing import Any, Callable, Iterable, Sequence from absl import logging @@ -42,6 +41,8 @@ handler_registration.DefaultCheckpointHandlerRegistry ) PyTreeCheckpointHandler = pytree_checkpoint_handler.PyTreeCheckpointHandler +RootMetadata = checkpoint_manager.RootMetadata +StepMetadata = checkpoint_manager.StepMetadata _UNNAMED_ITEM_NAME = 'state' @@ -319,8 +320,8 @@ def restore( def item_metadata(self, step: int) -> Any: return self._impl.item_metadata(step) - def metadata(self) -> Mapping[str, Any]: - return self._impl.metadata() + def metadata(self, step: int | None = None) -> RootMetadata | StepMetadata: + return self._impl.metadata(step) def metrics(self, step: int) -> PyTree | None: raise NotImplementedError() diff --git a/docs/guides/checkpoint/orbax_checkpoint_announcements.md b/docs/guides/checkpoint/orbax_checkpoint_announcements.md index 7b97742b1..674347a3e 100644 --- a/docs/guides/checkpoint/orbax_checkpoint_announcements.md +++ b/docs/guides/checkpoint/orbax_checkpoint_announcements.md @@ -1,5 +1,10 @@ # Announcements +## 2025-01-28 +`CheckpointManager.metadata()` now accepts a `step` parameter. If provided, it +will return `StepMetadata`, and will otherwise return `RootMetadata`. Subclasses +of `AbstractCheckpointManager` should be updated to incorporate this new kwarg. + ## 2024-12-30 orbax-checkpoint version `0.10.3` and [grain](https://pypi.org/project/grain/) version `0.2.2` are not compatible.