Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use RootMetadata and StepMetadata in CheckpointManager. #1511

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
packages, namely `orbax-checkpoint` and `orbax-export`. Imports are unchanged,
and still of the form `import orbax.checkpoint` or `import orbax.export`.
- Finer scoped jax.monitoring calls on the save path.
- `CheckpointManager.metadata()` now accepts a `step` parameter. If provided, it will return `StepMetadata`, and will otherwise return `RootMetadata`.
- `CompositeCheckpointHandler.metadata()` now returns `StepMetadata`.

## [0.1.7] - 2022-03-29

Expand Down
19 changes: 17 additions & 2 deletions checkpoint/orbax/checkpoint/abstract_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from etils import epath
from orbax.checkpoint import args as args_lib
from orbax.checkpoint._src.metadata import checkpoint

PyTree = Any
SaveParams = Mapping[str, Any]
Expand Down Expand Up @@ -290,8 +291,22 @@ def item_metadata(
"""

@abc.abstractmethod
def metadata(self) -> Mapping[str, Any]:
"""Returns CheckpointManager level metadata if present, empty otherwise."""
def metadata(
self, step: int | None = None,
) -> checkpoint.StepMetadata | checkpoint.RootMetadata:
"""Returns `StepMetadata` for the specified step, or `RootMetadata` all.

If step is specified, only return `StepMetadata` for that step.
Otherwise, return `RootMetadata`.

Args:
step: Step for which to retrieve `StepMetadata`. If None, returns
`RootMetadata`.

Returns:
Metadata for the specified step (`StepMetadata`), or all steps
(`RootMetadata`).
"""

@abc.abstractmethod
def metrics(self, step: int) -> Optional[PyTree]:
Expand Down
97 changes: 67 additions & 30 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import threading
import time
import typing
from typing import Any, Callable, Container, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Container, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union, overload

from absl import logging
from etils import epath
Expand All @@ -44,6 +44,7 @@
from orbax.checkpoint._src.handlers import proto_checkpoint_handler
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.metadata import root_metadata_serialization
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import deleter
Expand All @@ -64,6 +65,9 @@
AbstractCheckpointManager = (
abstract_checkpoint_manager.AbstractCheckpointManager
)
StepMetadata = checkpoint.StepMetadata
RootMetadata = checkpoint.RootMetadata
ItemMetadata = checkpoint.CompositeItemMetadata | checkpoint.SingleItemMetadata
AsyncCheckpointer = async_checkpointer.AsyncCheckpointer
Checkpointer = checkpointer_lib.Checkpointer
JsonCheckpointHandler = json_checkpoint_handler.JsonCheckpointHandler
Expand Down Expand Up @@ -709,11 +713,14 @@ def __init__(

self._metadata_dir = self.directory / METADATA_ITEM_NAME
if self._options.read_only and not self._metadata_dir.exists():
self._metadata = {} if metadata is None else metadata
custom_metadata = {} if metadata is None else dict(metadata)
else:
self._metadata = None
custom_metadata = None
self._root_metadata = RootMetadata(
custom=custom_metadata,
)

self._maybe_save_metadata(metadata)
self._maybe_save_root_metadata(metadata)

# TODO: b/359854428 - Move Finalize biz logic to a separate class/module.
self._finalize_thread_lock = threading.Lock()
Expand Down Expand Up @@ -1298,7 +1305,9 @@ def save(
self._logger.log_entry(dataclasses.asdict(step_stats))
return True

def _maybe_get_default_item(self, composite_result: args_lib.Composite):
def _maybe_get_default_item(
self, composite_result: args_lib.Composite
) -> Union[Any, args_lib.Composite]:
if self._default_item:
if DEFAULT_ITEM_NAME not in composite_result:
raise ValueError(
Expand Down Expand Up @@ -1379,7 +1388,9 @@ def restore(

return self._maybe_get_default_item(restored)

def item_metadata(self, step: int) -> Union[Any, args_lib.Composite]:
def item_metadata(
self, step: int
) -> Union[Any, args_lib.Composite, ItemMetadata]:
"""Retrieves metadata for all known items.

Note that metadata will only be returned for items that can actually be
Expand All @@ -1394,18 +1405,14 @@ def item_metadata(self, step: int) -> Union[Any, args_lib.Composite]:
Either metadata for the item itself, if in default-item mode, or a
Composite of metadata for each item.
"""
assert isinstance(self._checkpointer.handler, CompositeCheckpointHandler)
read_step_directory = self._get_read_step_directory(step, self.directory)

result = self._checkpointer.metadata(read_step_directory)
if isinstance(result, checkpoint.StepMetadata):
result = result.item_metadata
if self._default_item is None:
self._default_item = _determine_default_item_mode_from_directory(
read_step_directory
)
return self._maybe_get_default_item(result)
return self._maybe_get_default_item(self.metadata(step).item_metadata)

# TODO(b/370812224): Deprecate in favor of StepMetadata.metrics
def metrics(self, step: int) -> Optional[PyTree]:
if self._track_best:
try:
Expand Down Expand Up @@ -1504,21 +1511,18 @@ def _metadata_file_path(self, legacy: bool = False) -> epath.Path:
self._metadata_dir, legacy=legacy
)

def _maybe_save_metadata(self, metadata: Mapping[str, Any]):
def _maybe_save_root_metadata(self, custom_metadata: Mapping[str, Any]):
"""Saves CheckpointManager level metadata, skips if already present."""
if self._options.save_root_metadata:
logging.info('Saving root metadata')
if (metadata is not None and
if (custom_metadata is not None and
not self._options.read_only and
utils.is_primary_host(self._multiprocessing_options.primary_host)):
logging.info('Creating metadata directory')
self._metadata_dir.mkdir(parents=True, exist_ok=True)
file_path = self._metadata_file_path()
if not file_path.exists(): # May have been created by a previous run.
logging.info('Writing root metadata')
metadata_to_save = checkpoint.RootMetadata(
custom=dict(metadata),
)
metadata_to_save = self._root_metadata
if custom_metadata is not None:
metadata_to_save.custom = dict(custom_metadata)
self._blocking_metadata_store.write(
file_path, serialize_root_metadata(metadata_to_save)
)
Expand All @@ -1531,9 +1535,28 @@ def _maybe_save_metadata(self, metadata: Mapping[str, Any]):
processes=self._multiprocessing_options.active_processes,
)

def metadata(self) -> Mapping[str, Any]:
"""See superclass documentation."""
if self._metadata is None:
def _get_step_metadata(self, step: int) -> StepMetadata:
infos = [info for info in self._checkpoints if info.step == step]
if not infos or len(infos) > 1:
metrics = None
else:
metrics = infos[0].metrics

step_metadata = self._checkpointer.metadata(
self._get_read_step_directory(step, self.directory),
)
if metrics is not None:
validated_metrics = step_metadata_serialization.deserialize(
{}, metrics=dict(metrics)
).metrics
step_metadata = dataclasses.replace(
step_metadata,
metrics=validated_metrics,
)
return step_metadata

def _get_root_metadata(self) -> RootMetadata:
if self._root_metadata.custom is None:
if self._metadata_dir.exists():
file_path = self._metadata_file_path()
if not file_path.exists():
Expand All @@ -1542,14 +1565,28 @@ def metadata(self) -> Mapping[str, Any]:
self._metadata_dir)
file_path = self._metadata_file_path(legacy=True)
serialized_metadata = self._blocking_metadata_store.read(file_path)
self._metadata = deserialize_root_metadata(serialized_metadata).custom
if self._metadata is None:
raise FileNotFoundError(
f'Failed to read metadata from {file_path}.'
)
if serialized_metadata is None:
raise IOError(f'Failed to read metadata from {file_path}')
self._root_metadata = root_metadata_serialization.deserialize(
serialized_metadata
)
else:
self._metadata = {}
return self._metadata
self._root_metadata.custom = {}
return self._root_metadata

@overload
def metadata(self, step: None = None) -> RootMetadata:
...

@overload
def metadata(self, step: int) -> StepMetadata:
...

def metadata(self, step: int | None = None) -> RootMetadata | StepMetadata:
"""See superclass documentation."""
if step is not None:
return self._get_step_metadata(step)
return self._get_root_metadata()

def _sort_checkpoints_by_metrics(
self, checkpoints: List[CheckpointInfo]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
get_present_and_missing_chunks = (
local_checkpoint_data_debugging.get_present_and_missing_chunks
)
RootMetadata = checkpoint_manager.RootMetadata
StepMetadata = checkpoint_manager.StepMetadata

_PRIMARY_REPLICA_ID = 0
_SECONDARY_REPLICA_ID = 1
Expand Down Expand Up @@ -1303,7 +1305,7 @@ def item_metadata(self, step: int) -> Any:
'Item metadata not yet implemented for emergency.CheckpointManager.'
)

def metadata(self) -> dict[str, Any]:
def metadata(self, step: int | None = None) -> RootMetadata | StepMetadata:
"""Returns CheckpointManager level metadata if present, empty otherwise."""
raise NotImplementedError(
'Metadata not yet implemented for emergency.CheckpointManager.'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
subject to change without notice.
"""

from collections.abc import Mapping
import dataclasses
from typing import Any, Callable, Iterable, Sequence
from absl import logging
Expand All @@ -42,6 +41,8 @@
handler_registration.DefaultCheckpointHandlerRegistry
)
PyTreeCheckpointHandler = pytree_checkpoint_handler.PyTreeCheckpointHandler
RootMetadata = checkpoint_manager.RootMetadata
StepMetadata = checkpoint_manager.StepMetadata


_UNNAMED_ITEM_NAME = 'state'
Expand Down Expand Up @@ -319,8 +320,8 @@ def restore(
def item_metadata(self, step: int) -> Any:
return self._impl.item_metadata(step)

def metadata(self) -> Mapping[str, Any]:
return self._impl.metadata()
def metadata(self, step: int | None = None) -> RootMetadata | StepMetadata:
return self._impl.metadata(step)

def metrics(self, step: int) -> PyTree | None:
raise NotImplementedError()
Expand Down
Loading