Skip to content

Commit

Permalink
fix: fix warning regression during import when launch with strict war…
Browse files Browse the repository at this point in the history
…ning filters (#149)
  • Loading branch information
XuehaiPan authored Jul 6, 2024
1 parent a0d1d62 commit 83e6eff
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 122 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fix warning regression during import when launch with strict warning filters by [@XuehaiPan](https://github.com/XuehaiPan) in [#149](https://github.com/metaopt/optree/pull/149).

### Removed

Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ addlicense-install: go-install

pytest: pytest-install
$(PYTHON) -m pytest --version
cd tests && $(PYTHON) -X dev -c 'import $(PROJECT_PATH)' && \
$(PYTHON) -X dev -c 'import $(PROJECT_PATH)._C; print(f"GLIBCXX_USE_CXX11_ABI={$(PROJECT_PATH)._C.GLIBCXX_USE_CXX11_ABI}")' && \
cd tests && $(PYTHON) -X dev -W 'always' -W 'error' -c 'import $(PROJECT_PATH)' && \
$(PYTHON) -X dev -W 'always' -W 'error' -c 'import $(PROJECT_PATH)._C; print(f"GLIBCXX_USE_CXX11_ABI={$(PROJECT_PATH)._C.GLIBCXX_USE_CXX11_ABI}")' && \
$(PYTHON) -X dev -m pytest --verbose --color=yes --durations=0 --showlocals \
--cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \
$(PYTESTOPTS) .
Expand Down
228 changes: 110 additions & 118 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,132 +671,124 @@ def dict_insertion_ordered(mode: bool, *, namespace: str) -> Generator[None, Non

####################################################################################################

warnings.filterwarnings('ignore', category=FutureWarning, module=__name__, append=True)


@deprecated(
'The function `_sorted_keys` is deprecated and will be removed in a future version.',
category=FutureWarning,
)
def _sorted_keys(dct: dict[KT, VT]) -> list[KT]:
return total_order_sorted(dct)


@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
class KeyPathEntry(NamedTuple):
key: Any

def __add__(self, other: object) -> KeyPath:
if isinstance(other, KeyPathEntry):
return KeyPath((self, other))
if isinstance(other, KeyPath):
return KeyPath((self, *other.keys))
return NotImplemented

def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.key == other.key

def pprint(self) -> str:
"""Pretty name of the key path entry."""
raise NotImplementedError


@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
class KeyPath(NamedTuple):
keys: tuple[KeyPathEntry, ...] = ()

def __add__(self, other: object) -> KeyPath:
if isinstance(other, KeyPathEntry):
return KeyPath((*self.keys, other))
if isinstance(other, KeyPath):
return KeyPath(self.keys + other.keys)
return NotImplemented

def __eq__(self, other: object) -> bool:
return isinstance(other, KeyPath) and self.keys == other.keys

def pprint(self) -> str:
"""Pretty name of the key path."""
if not self.keys:
return ' tree root'
return ''.join(k.pprint() for k in self.keys)


@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
class GetitemKeyPathEntry(KeyPathEntry):
"""The key path entry class for sequences and dictionaries."""

def pprint(self) -> str:
"""Pretty name of the key path entry."""
return f'[{self.key!r}]'


@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
class AttributeKeyPathEntry(KeyPathEntry):
"""The key path entry class for namedtuples."""

def pprint(self) -> str:
"""Pretty name of the key path entry."""
return f'.{self.key}'
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=FutureWarning, module=__name__, append=False)

@deprecated(
'The function `_sorted_keys` is deprecated and will be removed in a future version.',
category=FutureWarning,
)
def _sorted_keys(dct: dict[KT, VT]) -> list[KT]:
return total_order_sorted(dct)

@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
class FlattenedKeyPathEntry(KeyPathEntry): # fallback
"""The fallback key path entry class."""
@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
class KeyPathEntry(NamedTuple):
key: Any

def __add__(self, other: object) -> KeyPath:
if isinstance(other, KeyPathEntry):
return KeyPath((self, other))
if isinstance(other, KeyPath):
return KeyPath((self, *other.keys))
return NotImplemented

def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.key == other.key

def pprint(self) -> str:
"""Pretty name of the key path entry."""
raise NotImplementedError

@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
class KeyPath(NamedTuple):
keys: tuple[KeyPathEntry, ...] = ()

def __add__(self, other: object) -> KeyPath:
if isinstance(other, KeyPathEntry):
return KeyPath((*self.keys, other))
if isinstance(other, KeyPath):
return KeyPath(self.keys + other.keys)
return NotImplemented

def __eq__(self, other: object) -> bool:
return isinstance(other, KeyPath) and self.keys == other.keys

def pprint(self) -> str:
"""Pretty name of the key path."""
if not self.keys:
return ' tree root'
return ''.join(k.pprint() for k in self.keys)

@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
class GetitemKeyPathEntry(KeyPathEntry):
"""The key path entry class for sequences and dictionaries."""

def pprint(self) -> str:
"""Pretty name of the key path entry."""
return f'[<flat index {self.key}>]'
def pprint(self) -> str:
"""Pretty name of the key path entry."""
return f'[{self.key!r}]'

@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
class AttributeKeyPathEntry(KeyPathEntry):
"""The key path entry class for namedtuples."""

KeyPathHandler = Callable[[PyTree], Sequence[KeyPathEntry]]
_KEYPATH_REGISTRY: dict[type[CustomTreeNode], KeyPathHandler] = {}
def pprint(self) -> str:
"""Pretty name of the key path entry."""
return f'.{self.key}'

@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
class FlattenedKeyPathEntry(KeyPathEntry): # fallback
"""The fallback key path entry class."""

@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
def register_keypaths(
cls: type[CustomTreeNode[T]],
handler: KeyPathHandler,
) -> KeyPathHandler:
"""Register a key path handler for a custom pytree node type."""
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if cls in _KEYPATH_REGISTRY:
raise ValueError(f'Key path handler for {cls!r} has already been registered.')
def pprint(self) -> str:
"""Pretty name of the key path entry."""
return f'[<flat index {self.key}>]'

_KEYPATH_REGISTRY[cls] = handler
return handler
KeyPathHandler = Callable[[PyTree], Sequence[KeyPathEntry]]
_KEYPATH_REGISTRY: dict[type[CustomTreeNode], KeyPathHandler] = {}

@deprecated(
'The key path API is deprecated and will be removed in a future version. '
'Please use the accessor API instead.',
category=FutureWarning,
)
def register_keypaths(
cls: type[CustomTreeNode[T]],
handler: KeyPathHandler,
) -> KeyPathHandler:
"""Register a key path handler for a custom pytree node type."""
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if cls in _KEYPATH_REGISTRY:
raise ValueError(f'Key path handler for {cls!r} has already been registered.')

_KEYPATH_REGISTRY[cls] = handler
return handler

register_keypaths(tuple, lambda tup: list(map(GetitemKeyPathEntry, range(len(tup))))) # type: ignore[arg-type]
register_keypaths(list, lambda lst: list(map(GetitemKeyPathEntry, range(len(lst))))) # type: ignore[arg-type]
register_keypaths(dict, lambda dct: list(map(GetitemKeyPathEntry, _sorted_keys(dct)))) # type: ignore[arg-type]
register_keypaths(OrderedDict, lambda odct: list(map(GetitemKeyPathEntry, odct))) # type: ignore[arg-type,call-overload]
register_keypaths(defaultdict, lambda ddct: list(map(GetitemKeyPathEntry, _sorted_keys(ddct)))) # type: ignore[arg-type]
register_keypaths(deque, lambda dq: list(map(GetitemKeyPathEntry, range(len(dq))))) # type: ignore[arg-type]
register_keypaths(tuple, lambda tup: list(map(GetitemKeyPathEntry, range(len(tup))))) # type: ignore[arg-type]
register_keypaths(list, lambda lst: list(map(GetitemKeyPathEntry, range(len(lst))))) # type: ignore[arg-type]
register_keypaths(dict, lambda dct: list(map(GetitemKeyPathEntry, _sorted_keys(dct)))) # type: ignore[arg-type]
register_keypaths(OrderedDict, lambda odct: list(map(GetitemKeyPathEntry, odct))) # type: ignore[arg-type,call-overload]
register_keypaths(defaultdict, lambda ddct: list(map(GetitemKeyPathEntry, _sorted_keys(ddct)))) # type: ignore[arg-type]
register_keypaths(deque, lambda dq: list(map(GetitemKeyPathEntry, range(len(dq))))) # type: ignore[arg-type]

register_keypaths.get = _KEYPATH_REGISTRY.get # type: ignore[attr-defined]
register_keypaths.get = _KEYPATH_REGISTRY.get # type: ignore[attr-defined]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ build-frontend = "build"
build-verbosity = 3
container-engine = "docker"
test-extras = ["test"]
test-command = '''make -C "{project}" test PYTHON=python PYTESTOPTS="--quiet --no-showlocals"'''
test-command = '''make -C "{project}" test PYTHON=python PYTESTOPTS="--quiet --exitfirst --no-showlocals"'''

# Linter tools #################################################################

Expand Down Expand Up @@ -263,6 +263,7 @@ ban-relative-imports = "all"
[tool.pytest.ini_options]
filterwarnings = [
"error",
"always",
"ignore:The class `optree.Partial` is deprecated and will be removed in a future version. Please use `optree.functools.partial` instead.:FutureWarning",
"ignore:The key path API is deprecated and will be removed in a future version. Please use the accessor API instead.:FutureWarning",
]
21 changes: 21 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import operator
import pickle
import re
import subprocess
import sys
from collections import OrderedDict, defaultdict, deque

import pytest
Expand All @@ -46,6 +48,25 @@
)


def test_import_no_warnings():
assert (
subprocess.check_output(
[
sys.executable,
'-W',
'always',
'-W',
'error',
'-c',
'import optree',
],
stderr=subprocess.STDOUT,
text=True,
)
== ''
)


def test_max_depth():
lst = [1]
for _ in range(optree.MAX_RECURSION_DEPTH - 1):
Expand Down

0 comments on commit 83e6eff

Please sign in to comment.