Skip to content

Commit

Permalink
Fix namedtuple empty value typestr when experimental support_rich_typ…
Browse files Browse the repository at this point in the history
…es is disabled again after enabling it.

PiperOrigin-RevId: 700012734
  • Loading branch information
niketkumar authored and Orbax Authors committed Nov 25, 2024
1 parent e4a1239 commit 3a32a23
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 22 deletions.
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- [Experimental Feature] Support `empty NamedTuple` leaf when
`PyTreeMetadataOptions.support_rich_types=true`.

### Fixed
- Fix namedtuple empty value typestr when experimental support_rich_types is
disabled again after enabling it.


## [0.10.1] - 2024-11-22

Expand Down
10 changes: 10 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/empty_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def get_empty_value_typestr(
)


def override_empty_value_typestr(
typestr: str, pytree_metadata_options: PyTreeMetadataOptions
) -> str:
"""Returns updated typestr based on pytree_metadata_options."""
if not pytree_metadata_options.support_rich_types:
if typestr == RESTORE_TYPE_NAMED_TUPLE:
return RESTORE_TYPE_NONE
return typestr


def is_empty_typestr(typestr: str) -> bool:
return (
typestr == RESTORE_TYPE_LIST
Expand Down
7 changes: 6 additions & 1 deletion checkpoint/orbax/checkpoint/_src/metadata/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,12 @@ def from_json(
return InternalTreeMetadataEntry(
keypath,
KeyMetadataEntry.from_json(json_dict[_KEY_METADATA_KEY]),
ValueMetadataEntry.from_json(json_dict[_VALUE_METADATA_KEY]),
ValueMetadataEntry.from_json(
json_dict[_VALUE_METADATA_KEY],
pytree_metadata_options=PyTreeMetadataOptions(
support_rich_types=False # Always in legacy mode.
),
),
)

@classmethod
Expand Down
8 changes: 7 additions & 1 deletion checkpoint/orbax/checkpoint/_src/metadata/tree_rich_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import functools
from typing import Any, Iterable, Mapping, Sequence, Type, TypeAlias

from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.metadata import value_metadata_entry
from orbax.checkpoint._src.tree import utils as tree_utils
import simplejson
Expand Down Expand Up @@ -134,7 +135,12 @@ def _value_metadata_tree_for_json_loads(obj):
if 'category' in obj:
if obj['category'] == 'custom':
if obj['clazz'] == _VALUE_METADATA_ENTRY_CLAZZ:
return value_metadata_entry.ValueMetadataEntry.from_json(obj['data'])
return value_metadata_entry.ValueMetadataEntry.from_json(
obj['data'],
pytree_metadata_options_lib.PyTreeMetadataOptions(
support_rich_types=True, # Always in rich types mode.
),
)
if obj['clazz'] == 'tuple':
return tuple(
[(_value_metadata_tree_for_json_loads(v)) for v in obj['entries']]
Expand Down
19 changes: 3 additions & 16 deletions checkpoint/orbax/checkpoint/_src/metadata/tree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.serialization import type_handlers
Expand Down Expand Up @@ -75,14 +76,7 @@ def test_as_nested_tree(
else:
expected_tree_metadata = test_pytree.expected_nested_tree_metadata
restored_tree_metadata = restored_internal_tree_metadata.as_nested_tree()
self.assertEqual(
jax.tree.structure(
expected_tree_metadata, is_leaf=tree_utils.is_empty_or_leaf
),
jax.tree.structure(
restored_tree_metadata, is_leaf=tree_utils.is_empty_or_leaf
),
)
chex.assert_trees_all_equal(restored_tree_metadata, expected_tree_metadata)

@parameterized.product(
test_pytree=test_tree_utils.TEST_PYTREES,
Expand Down Expand Up @@ -132,14 +126,7 @@ def test_switching_between_support_rich_types(
)

restored_tree_metadata = restored_internal_tree_metadata.as_nested_tree()
self.assertEqual(
jax.tree.structure(
expected_tree_metadata, is_leaf=tree_utils.is_empty_or_leaf
),
jax.tree.structure(
restored_tree_metadata, is_leaf=tree_utils.is_empty_or_leaf
),
)
chex.assert_trees_all_equal(restored_tree_metadata, expected_tree_metadata)


if __name__ == '__main__':
Expand Down
12 changes: 10 additions & 2 deletions checkpoint/orbax/checkpoint/_src/metadata/value_metadata_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Any, Dict

from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.serialization import types


Expand Down Expand Up @@ -48,9 +49,16 @@ def to_json(self) -> Dict[str, Any]:
}

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> ValueMetadataEntry:
def from_json(
cls,
json_dict: Dict[str, Any],
pytree_metadata_options: pytree_metadata_options_lib.PyTreeMetadataOptions,
) -> ValueMetadataEntry:
return ValueMetadataEntry(
value_type=json_dict[_VALUE_TYPE],
value_type=empty_values.override_empty_value_typestr(
json_dict[_VALUE_TYPE],
pytree_metadata_options,
),
skip_deserialize=json_dict[_SKIP_DESERIALIZE],
)

Expand Down
4 changes: 2 additions & 2 deletions checkpoint/orbax/checkpoint/_src/testing/test_tree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def __repr__(self):
skip_deserialize=False,
),
'b': tree_metadata.ValueMetadataEntry(
value_type='None',
value_type='Dict',
skip_deserialize=True,
),
},
Expand Down Expand Up @@ -395,7 +395,7 @@ def __repr__(self):
skip_deserialize=False,
),
'b': tree_metadata.ValueMetadataEntry(
value_type='None',
value_type='Dict',
skip_deserialize=True,
),
},
Expand Down

0 comments on commit 3a32a23

Please sign in to comment.