Skip to content

Commit

Permalink
internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700087886
  • Loading branch information
niketkumar authored and Orbax Authors committed Nov 27, 2024
1 parent 8a34b74 commit a3472af
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.tree import utils as tree_utils


Expand Down Expand Up @@ -68,6 +69,9 @@ def __init__(
save_concurrent_gb: int = 96,
restore_concurrent_gb: int = 96,
multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(),
pytree_metadata_options: pytree_metadata_options_lib.PyTreeMetadataOptions = (
pytree_metadata_options_lib.PYTREE_METADATA_OPTIONS
),
):
"""Creates StandardCheckpointHandler.
Expand All @@ -79,12 +83,15 @@ def __init__(
Can help to reduce the possibility of OOM's when large checkpoints are
restored.
multiprocessing_options: See orbax.checkpoint.options.
pytree_metadata_options: Options to control types like tuple and
namedtuple in pytree metadata.
"""
self._supported_types = checkpoint_utils.STANDARD_ARRAY_TYPES
self._impl = pytree_checkpoint_handler.PyTreeCheckpointHandler(
save_concurrent_gb=save_concurrent_gb,
restore_concurrent_gb=restore_concurrent_gb,
multiprocessing_options=multiprocessing_options,
pytree_metadata_options=pytree_metadata_options,
)

def _validate_save_state(
Expand Down
13 changes: 7 additions & 6 deletions checkpoint/orbax/checkpoint/_src/testing/test_tree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
PyTree = Any


def _build_namedtuple(
def create_namedtuple(
cls,
field_value_tuples: list[tuple[str, tree_metadata.ValueMetadataEntry]],
):
) -> type[tuple[Any, ...]]:
"""Returns instance of a new namedtuple type structurally identical to `cls`."""
fields, values = zip(*field_value_tuples)
module_name, class_name = tree_rich_types._module_and_class_name(cls) # pylint: disable=protected-access
new_type = tree_rich_types._new_namedtuple_type(module_name, class_name, fields) # pylint: disable=protected-access
Expand Down Expand Up @@ -491,7 +492,7 @@ def __repr__(self):
}
},
expected_nested_tree_metadata_with_rich_types={
'mu_nu': _build_namedtuple(
'mu_nu': create_namedtuple(
MuNu,
[
(
Expand Down Expand Up @@ -789,7 +790,7 @@ def __repr__(self):
}
},
expected_nested_tree_metadata_with_rich_types={
'default_named_tuple_with_nested_attrs': _build_namedtuple(
'default_named_tuple_with_nested_attrs': create_namedtuple(
NamedTupleWithNestedAttributes,
[
(
Expand Down Expand Up @@ -885,12 +886,12 @@ def __repr__(self):
}
},
expected_nested_tree_metadata_with_rich_types={
'named_tuple_with_nested_attrs': _build_namedtuple(
'named_tuple_with_nested_attrs': create_namedtuple(
NamedTupleWithNestedAttributes,
[
(
'nested_mu_nu',
_build_namedtuple(
create_namedtuple(
MuNu,
[
(
Expand Down

0 comments on commit a3472af

Please sign in to comment.