diff --git a/CHANGELOG.md b/CHANGELOG.md index 7aa6c159..7c6d0cfd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fix warning regression during import when launch with strict warning filters by [@XuehaiPan](https://github.com/XuehaiPan) in [#149](https://github.com/metaopt/optree/pull/149). ### Removed diff --git a/Makefile b/Makefile index fd909c36..cffbe4f7 100644 --- a/Makefile +++ b/Makefile @@ -114,8 +114,8 @@ addlicense-install: go-install pytest: pytest-install $(PYTHON) -m pytest --version - cd tests && $(PYTHON) -X dev -c 'import $(PROJECT_PATH)' && \ - $(PYTHON) -X dev -c 'import $(PROJECT_PATH)._C; print(f"GLIBCXX_USE_CXX11_ABI={$(PROJECT_PATH)._C.GLIBCXX_USE_CXX11_ABI}")' && \ + cd tests && $(PYTHON) -X dev -W 'always' -W 'error' -c 'import $(PROJECT_PATH)' && \ + $(PYTHON) -X dev -W 'always' -W 'error' -c 'import $(PROJECT_PATH)._C; print(f"GLIBCXX_USE_CXX11_ABI={$(PROJECT_PATH)._C.GLIBCXX_USE_CXX11_ABI}")' && \ $(PYTHON) -X dev -m pytest --verbose --color=yes --durations=0 --showlocals \ --cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . diff --git a/optree/registry.py b/optree/registry.py index ab7a38ea..1e0d9825 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -671,132 +671,124 @@ def dict_insertion_ordered(mode: bool, *, namespace: str) -> Generator[None, Non #################################################################################################### -warnings.filterwarnings('ignore', category=FutureWarning, module=__name__, append=True) - - -@deprecated( - 'The function `_sorted_keys` is deprecated and will be removed in a future version.', - category=FutureWarning, -) -def _sorted_keys(dct: dict[KT, VT]) -> list[KT]: - return total_order_sorted(dct) - - -@deprecated( - 'The key path API is deprecated and will be removed in a future version. ' - 'Please use the accessor API instead.', - category=FutureWarning, -) -class KeyPathEntry(NamedTuple): - key: Any - - def __add__(self, other: object) -> KeyPath: - if isinstance(other, KeyPathEntry): - return KeyPath((self, other)) - if isinstance(other, KeyPath): - return KeyPath((self, *other.keys)) - return NotImplemented - - def __eq__(self, other: object) -> bool: - return isinstance(other, self.__class__) and self.key == other.key - - def pprint(self) -> str: - """Pretty name of the key path entry.""" - raise NotImplementedError - - -@deprecated( - 'The key path API is deprecated and will be removed in a future version. ' - 'Please use the accessor API instead.', - category=FutureWarning, -) -class KeyPath(NamedTuple): - keys: tuple[KeyPathEntry, ...] = () - - def __add__(self, other: object) -> KeyPath: - if isinstance(other, KeyPathEntry): - return KeyPath((*self.keys, other)) - if isinstance(other, KeyPath): - return KeyPath(self.keys + other.keys) - return NotImplemented - - def __eq__(self, other: object) -> bool: - return isinstance(other, KeyPath) and self.keys == other.keys - - def pprint(self) -> str: - """Pretty name of the key path.""" - if not self.keys: - return ' tree root' - return ''.join(k.pprint() for k in self.keys) - - -@deprecated( - 'The key path API is deprecated and will be removed in a future version. ' - 'Please use the accessor API instead.', - category=FutureWarning, -) -class GetitemKeyPathEntry(KeyPathEntry): - """The key path entry class for sequences and dictionaries.""" - - def pprint(self) -> str: - """Pretty name of the key path entry.""" - return f'[{self.key!r}]' - - -@deprecated( - 'The key path API is deprecated and will be removed in a future version. ' - 'Please use the accessor API instead.', - category=FutureWarning, -) -class AttributeKeyPathEntry(KeyPathEntry): - """The key path entry class for namedtuples.""" - - def pprint(self) -> str: - """Pretty name of the key path entry.""" - return f'.{self.key}' +with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=FutureWarning, module=__name__, append=False) + @deprecated( + 'The function `_sorted_keys` is deprecated and will be removed in a future version.', + category=FutureWarning, + ) + def _sorted_keys(dct: dict[KT, VT]) -> list[KT]: + return total_order_sorted(dct) -@deprecated( - 'The key path API is deprecated and will be removed in a future version. ' - 'Please use the accessor API instead.', - category=FutureWarning, -) -class FlattenedKeyPathEntry(KeyPathEntry): # fallback - """The fallback key path entry class.""" + @deprecated( + 'The key path API is deprecated and will be removed in a future version. ' + 'Please use the accessor API instead.', + category=FutureWarning, + ) + class KeyPathEntry(NamedTuple): + key: Any + + def __add__(self, other: object) -> KeyPath: + if isinstance(other, KeyPathEntry): + return KeyPath((self, other)) + if isinstance(other, KeyPath): + return KeyPath((self, *other.keys)) + return NotImplemented + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) and self.key == other.key + + def pprint(self) -> str: + """Pretty name of the key path entry.""" + raise NotImplementedError + + @deprecated( + 'The key path API is deprecated and will be removed in a future version. ' + 'Please use the accessor API instead.', + category=FutureWarning, + ) + class KeyPath(NamedTuple): + keys: tuple[KeyPathEntry, ...] = () + + def __add__(self, other: object) -> KeyPath: + if isinstance(other, KeyPathEntry): + return KeyPath((*self.keys, other)) + if isinstance(other, KeyPath): + return KeyPath(self.keys + other.keys) + return NotImplemented + + def __eq__(self, other: object) -> bool: + return isinstance(other, KeyPath) and self.keys == other.keys + + def pprint(self) -> str: + """Pretty name of the key path.""" + if not self.keys: + return ' tree root' + return ''.join(k.pprint() for k in self.keys) + + @deprecated( + 'The key path API is deprecated and will be removed in a future version. ' + 'Please use the accessor API instead.', + category=FutureWarning, + ) + class GetitemKeyPathEntry(KeyPathEntry): + """The key path entry class for sequences and dictionaries.""" - def pprint(self) -> str: - """Pretty name of the key path entry.""" - return f'[]' + def pprint(self) -> str: + """Pretty name of the key path entry.""" + return f'[{self.key!r}]' + @deprecated( + 'The key path API is deprecated and will be removed in a future version. ' + 'Please use the accessor API instead.', + category=FutureWarning, + ) + class AttributeKeyPathEntry(KeyPathEntry): + """The key path entry class for namedtuples.""" -KeyPathHandler = Callable[[PyTree], Sequence[KeyPathEntry]] -_KEYPATH_REGISTRY: dict[type[CustomTreeNode], KeyPathHandler] = {} + def pprint(self) -> str: + """Pretty name of the key path entry.""" + return f'.{self.key}' + @deprecated( + 'The key path API is deprecated and will be removed in a future version. ' + 'Please use the accessor API instead.', + category=FutureWarning, + ) + class FlattenedKeyPathEntry(KeyPathEntry): # fallback + """The fallback key path entry class.""" -@deprecated( - 'The key path API is deprecated and will be removed in a future version. ' - 'Please use the accessor API instead.', - category=FutureWarning, -) -def register_keypaths( - cls: type[CustomTreeNode[T]], - handler: KeyPathHandler, -) -> KeyPathHandler: - """Register a key path handler for a custom pytree node type.""" - if not inspect.isclass(cls): - raise TypeError(f'Expected a class, got {cls!r}.') - if cls in _KEYPATH_REGISTRY: - raise ValueError(f'Key path handler for {cls!r} has already been registered.') + def pprint(self) -> str: + """Pretty name of the key path entry.""" + return f'[]' - _KEYPATH_REGISTRY[cls] = handler - return handler + KeyPathHandler = Callable[[PyTree], Sequence[KeyPathEntry]] + _KEYPATH_REGISTRY: dict[type[CustomTreeNode], KeyPathHandler] = {} + @deprecated( + 'The key path API is deprecated and will be removed in a future version. ' + 'Please use the accessor API instead.', + category=FutureWarning, + ) + def register_keypaths( + cls: type[CustomTreeNode[T]], + handler: KeyPathHandler, + ) -> KeyPathHandler: + """Register a key path handler for a custom pytree node type.""" + if not inspect.isclass(cls): + raise TypeError(f'Expected a class, got {cls!r}.') + if cls in _KEYPATH_REGISTRY: + raise ValueError(f'Key path handler for {cls!r} has already been registered.') + + _KEYPATH_REGISTRY[cls] = handler + return handler -register_keypaths(tuple, lambda tup: list(map(GetitemKeyPathEntry, range(len(tup))))) # type: ignore[arg-type] -register_keypaths(list, lambda lst: list(map(GetitemKeyPathEntry, range(len(lst))))) # type: ignore[arg-type] -register_keypaths(dict, lambda dct: list(map(GetitemKeyPathEntry, _sorted_keys(dct)))) # type: ignore[arg-type] -register_keypaths(OrderedDict, lambda odct: list(map(GetitemKeyPathEntry, odct))) # type: ignore[arg-type,call-overload] -register_keypaths(defaultdict, lambda ddct: list(map(GetitemKeyPathEntry, _sorted_keys(ddct)))) # type: ignore[arg-type] -register_keypaths(deque, lambda dq: list(map(GetitemKeyPathEntry, range(len(dq))))) # type: ignore[arg-type] + register_keypaths(tuple, lambda tup: list(map(GetitemKeyPathEntry, range(len(tup))))) # type: ignore[arg-type] + register_keypaths(list, lambda lst: list(map(GetitemKeyPathEntry, range(len(lst))))) # type: ignore[arg-type] + register_keypaths(dict, lambda dct: list(map(GetitemKeyPathEntry, _sorted_keys(dct)))) # type: ignore[arg-type] + register_keypaths(OrderedDict, lambda odct: list(map(GetitemKeyPathEntry, odct))) # type: ignore[arg-type,call-overload] + register_keypaths(defaultdict, lambda ddct: list(map(GetitemKeyPathEntry, _sorted_keys(ddct)))) # type: ignore[arg-type] + register_keypaths(deque, lambda dq: list(map(GetitemKeyPathEntry, range(len(dq))))) # type: ignore[arg-type] -register_keypaths.get = _KEYPATH_REGISTRY.get # type: ignore[attr-defined] + register_keypaths.get = _KEYPATH_REGISTRY.get # type: ignore[attr-defined] diff --git a/pyproject.toml b/pyproject.toml index d0253f5d..a19c911c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,7 +116,7 @@ build-frontend = "build" build-verbosity = 3 container-engine = "docker" test-extras = ["test"] -test-command = '''make -C "{project}" test PYTHON=python PYTESTOPTS="--quiet --no-showlocals"''' +test-command = '''make -C "{project}" test PYTHON=python PYTESTOPTS="--quiet --exitfirst --no-showlocals"''' # Linter tools ################################################################# @@ -263,6 +263,7 @@ ban-relative-imports = "all" [tool.pytest.ini_options] filterwarnings = [ "error", + "always", "ignore:The class `optree.Partial` is deprecated and will be removed in a future version. Please use `optree.functools.partial` instead.:FutureWarning", "ignore:The key path API is deprecated and will be removed in a future version. Please use the accessor API instead.:FutureWarning", ] diff --git a/tests/test_ops.py b/tests/test_ops.py index 421b3eac..fc66aec0 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -21,6 +21,8 @@ import operator import pickle import re +import subprocess +import sys from collections import OrderedDict, defaultdict, deque import pytest @@ -46,6 +48,25 @@ ) +def test_import_no_warnings(): + assert ( + subprocess.check_output( + [ + sys.executable, + '-W', + 'always', + '-W', + 'error', + '-c', + 'import optree', + ], + stderr=subprocess.STDOUT, + text=True, + ) + == '' + ) + + def test_max_depth(): lst = [1] for _ in range(optree.MAX_RECURSION_DEPTH - 1):