diff --git a/changelog/1128.feature.rst b/changelog/1128.feature.rst new file mode 100644 index 0000000000..66c35b1935 --- /dev/null +++ b/changelog/1128.feature.rst @@ -0,0 +1 @@ +|commands| Support Python 3.12's ``type`` statement and :class:`py:typing.TypeAliasType` annotations in command signatures. diff --git a/disnake/utils.py b/disnake/utils.py index 6fa6ae82d7..9061cd0f61 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1122,6 +1122,24 @@ def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: return tuple(p for p in parameters if p is not none_cls) + (none_cls,) +def _resolve_typealiastype( + tp: Any, globals: Dict[str, Any], locals: Dict[str, Any], cache: Dict[str, Any] +): + # Use __module__ to get the (global) namespace in which the type alias was defined. + if mod := sys.modules.get(tp.__module__): + mod_globals = mod.__dict__ + if mod_globals is not globals or mod_globals is not locals: + # if the namespace changed (usually when a TypeAliasType was imported from a different module), + # drop the cache since names can resolve differently now + cache = {} + globals = locals = mod_globals + + # Accessing `__value__` automatically evaluates the type alias in the annotation scope. + # (recurse to resolve possible forwardrefs, aliases, etc.) + return evaluate_annotation(tp.__value__, globals, locals, cache) + + +# FIXME: this should be split up into smaller functions for clarity and easier maintenance def evaluate_annotation( tp: Any, globals: Dict[str, Any], @@ -1147,23 +1165,31 @@ def evaluate_annotation( cache[tp] = evaluated return evaluated + # GenericAlias / UnionType if hasattr(tp, "__args__"): - implicit_str = True - is_literal = False - orig_args = args = tp.__args__ if not hasattr(tp, "__origin__"): if tp.__class__ is UnionType: - converted = Union[args] # type: ignore + converted = Union[tp.__args__] # type: ignore return evaluate_annotation(converted, globals, locals, cache) return tp - if tp.__origin__ is Union: + + implicit_str = True + is_literal = False + orig_args = args = tp.__args__ + orig_origin = origin = tp.__origin__ + + # origin can be a TypeAliasType too, resolve it and continue + if hasattr(origin, "__value__"): + origin = _resolve_typealiastype(origin, globals, locals, cache) + + if origin is Union: try: if args.index(type(None)) != len(args) - 1: args = normalise_optional_params(tp.__args__) except ValueError: pass - if tp.__origin__ is Literal: + if origin is Literal: if not PY_310: args = flatten_literal_params(tp.__args__) implicit_str = False @@ -1179,13 +1205,21 @@ def evaluate_annotation( ): raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") + if origin != orig_origin: + # we can't use `copy_with` in this case, so just skip all of the following logic + return origin[evaluated_args] + if evaluated_args == orig_args: return tp try: return tp.copy_with(evaluated_args) except AttributeError: - return tp.__origin__[evaluated_args] + return origin[evaluated_args] + + # TypeAliasType, 3.12+ + if hasattr(tp, "__value__"): + return _resolve_typealiastype(tp, globals, locals, cache) return tp diff --git a/tests/test_utils.py b/tests/test_utils.py index d767264a95..46237c2019 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,7 +9,7 @@ import warnings from dataclasses import dataclass from datetime import timedelta, timezone -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union from unittest import mock import pytest @@ -18,7 +18,13 @@ import disnake from disnake import utils -from . import helpers +from . import helpers, utils_helper_module + +if TYPE_CHECKING: + from typing_extensions import TypeAliasType +elif sys.version_info >= (3, 12): + # non-3.12 tests shouldn't be using this + from typing import TypeAliasType def test_missing() -> None: @@ -785,6 +791,65 @@ def test_resolve_annotation_literal() -> None: utils.resolve_annotation(Literal[timezone.utc, 3], globals(), locals(), {}) # type: ignore +@pytest.mark.skipif(sys.version_info < (3, 12), reason="syntax requires py3.12") +class TestResolveAnnotationTypeAliasType: + def test_simple(self) -> None: + # this is equivalent to `type CoolList = List[int]` + CoolList = TypeAliasType("CoolList", List[int]) + assert utils.resolve_annotation(CoolList, globals(), locals(), {}) == List[int] + + def test_generic(self) -> None: + # this is equivalent to `type CoolList[T] = List[T]; CoolList[int]` + T = TypeVar("T") + CoolList = TypeAliasType("CoolList", List[T], type_params=(T,)) + + annotation = CoolList[int] + assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[int] + + # alias and arg in local scope + def test_forwardref_local(self) -> None: + T = TypeVar("T") + IntOrStr = Union[int, str] + CoolList = TypeAliasType("CoolList", List[T], type_params=(T,)) + + annotation = CoolList["IntOrStr"] + assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[IntOrStr] + + # alias and arg in other module scope + def test_forwardref_module(self) -> None: + resolved = utils.resolve_annotation( + utils_helper_module.ListWithForwardRefAlias, globals(), locals(), {} + ) + assert resolved == List[Union[int, str]] + + # combination of the previous two, alias in other module scope and arg in local scope + def test_forwardref_mixed(self) -> None: + LocalIntOrStr = Union[int, str] + + annotation = utils_helper_module.GenericListAlias["LocalIntOrStr"] + assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[LocalIntOrStr] + + # two different forwardrefs with same name + def test_forwardref_duplicate(self) -> None: + DuplicateAlias = int + + # first, resolve an annotation where `DuplicateAlias` resolves to the local int + cache = {} + assert ( + utils.resolve_annotation(List["DuplicateAlias"], globals(), locals(), cache) + == List[int] + ) + + # then, resolve an annotation where the globalns changes and `DuplicateAlias` resolves to something else + # (i.e. this should not resolve to `List[int]` despite {"DuplicateAlias": int} in the cache) + assert ( + utils.resolve_annotation( + utils_helper_module.ListWithDuplicateAlias, globals(), locals(), cache + ) + == List[str] + ) + + @pytest.mark.parametrize( ("dt", "style", "expected"), [ diff --git a/tests/utils_helper_module.py b/tests/utils_helper_module.py new file mode 100644 index 0000000000..7711e861b8 --- /dev/null +++ b/tests/utils_helper_module.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: MIT + +"""Separate module file for some test_utils.py type annotation tests.""" + +import sys +from typing import TYPE_CHECKING, List, TypeVar, Union + +version = sys.version_info # assign to variable to trick pyright + +if TYPE_CHECKING: + from typing_extensions import TypeAliasType +elif version >= (3, 12): + # non-3.12 tests shouldn't be using this + from typing import TypeAliasType + +if version >= (3, 12): + CoolUniqueIntOrStrAlias = Union[int, str] + ListWithForwardRefAlias = TypeAliasType( + "ListWithForwardRefAlias", List["CoolUniqueIntOrStrAlias"] + ) + + T = TypeVar("T") + GenericListAlias = TypeAliasType("GenericListAlias", List[T], type_params=(T,)) + + DuplicateAlias = str + ListWithDuplicateAlias = TypeAliasType("ListWithDuplicateAlias", List["DuplicateAlias"])