Skip to content

Add new strategy register_extra_types #645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ for line in sys.stdin:
endef
export PRINT_HELP_PYSCRIPT
BROWSER := python -c "$$BROWSER_PYSCRIPT"
TESTS ?= tests

help:
@python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)

clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts


clean-build: ## remove build artifacts
rm -fr build/
rm -fr dist/
Expand All @@ -51,9 +51,8 @@ lint: ## check style with ruff and black
pdm run ruff check src/ tests bench
pdm run black --check src tests docs/conf.py

test: ## run tests quickly with the default Python
pdm run pytest -x --ff -n auto tests

test: ## run tests quickly with the default Python; pass TESTS= for specific path
pdm run pytest -x --ff $(if $(filter $(TESTS),tests),-n auto ,)$(TESTS)

test-all: ## run tests on every Python version with tox
tox
Expand Down
23 changes: 12 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,18 @@ select = [
"I", # isort
]
ignore = [
"E501", # line length is handled by black
"RUF001", # leave my smart characters alone
"S101", # assert
"S307", # hands off my eval
"SIM300", # Yoda rocks in asserts
"PGH003", # leave my type: ignores alone
"B006", # mutable argument defaults
"DTZ001", # datetimes in tests
"DTZ006", # datetimes in tests
"UP006", # We support old typing constructs at runtime
"UP035", # We support old typing constructs at runtime
"B006", # mutable argument defaults
"DTZ001", # datetimes in tests
"DTZ006", # datetimes in tests
"E501", # line length is handled by black
"PGH003", # leave my type: ignores alone
"PLC0414", # redundant import aliases indicate exported names
"RUF001", # leave my smart characters alone
"S101", # assert
"S307", # hands off my eval
"SIM300", # Yoda rocks in asserts
"UP006", # We support old typing constructs at runtime
"UP035", # We support old typing constructs at runtime
]

[tool.ruff.lint.pyupgrade]
Expand Down
3 changes: 3 additions & 0 deletions src/cattrs/preconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from .._compat import is_subclass
from ..converters import Converter, UnstructureHook
from ..fns import identity
from ._all import ConverterFormat as ConverterFormat
from ._all import PreconfiguredConverter as PreconfiguredConverter
from ._all import has_format as has_format

if sys.version_info[:2] < (3, 10):
from typing_extensions import ParamSpec
Expand Down
151 changes: 151 additions & 0 deletions src/cattrs/preconf/_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal, TypeAlias, TypeIs, Union, overload

from ..converters import Converter
from ..types import Unavailable

if TYPE_CHECKING:
try:
from cattrs.preconf.bson import BsonConverter
except ModuleNotFoundError:
BsonConverter = Unavailable

try:
from cattrs.preconf.cbor2 import Cbor2Converter
except ModuleNotFoundError:
Cbor2Converter = Unavailable

from cattrs.preconf.json import JsonConverter

try:
from cattrs.preconf.msgpack import MsgpackConverter
except ModuleNotFoundError:
MsgpackConverter = Unavailable

try:
from cattrs.preconf.msgspec import MsgspecJsonConverter
except ModuleNotFoundError:
MsgspecJsonConverter = Unavailable

try:
from cattrs.preconf.orjson import OrjsonConverter
except ModuleNotFoundError:
OrjsonConverter = Unavailable

try:
from cattrs.preconf.pyyaml import PyyamlConverter
except ModuleNotFoundError:
PyyamlConverter = Unavailable

try:
from cattrs.preconf.tomlkit import TomlkitConverter
except ModuleNotFoundError:
TomlkitConverter = Unavailable

try:
from cattrs.preconf.ujson import UjsonConverter
except ModuleNotFoundError:
UjsonConverter = Unavailable

PreconfiguredConverter: TypeAlias = Union[
BsonConverter,
Cbor2Converter,
JsonConverter,
MsgpackConverter,
MsgspecJsonConverter,
OrjsonConverter,
PyyamlConverter,
TomlkitConverter,
UjsonConverter,
]

else:
PreconfiguredConverter: TypeAlias = Converter

ConverterFormat: TypeAlias = Literal[
"bson",
"cbor2",
"json",
"msgpack",
"msgspec-json",
"orjson",
"pyyaml",
"tomlkit",
"ujson",
]

C: TypeAlias = Converter | Unavailable


@overload
def has_format(converter: C, fmt: Literal["bson"]) -> TypeIs["BsonConverter"]: ...
@overload
def has_format(converter: C, fmt: Literal["cbor2"]) -> TypeIs["Cbor2Converter"]: ...
@overload
def has_format(converter: C, fmt: Literal["json"]) -> TypeIs["JsonConverter"]: ...
@overload
def has_format(converter: C, fmt: Literal["msgpack"]) -> TypeIs["MsgpackConverter"]: ...
@overload
def has_format(
converter: C, fmt: Literal["msgspec-json"]
) -> TypeIs["MsgspecJsonConverter"]: ...
@overload
def has_format(converter: C, fmt: Literal["orjson"]) -> TypeIs["OrjsonConverter"]: ...
@overload
def has_format(converter: C, fmt: Literal["pyyaml"]) -> TypeIs["PyyamlConverter"]: ...
@overload
def has_format(converter: C, fmt: Literal["tomlkit"]) -> TypeIs["TomlkitConverter"]: ...
@overload
def has_format(converter: C, fmt: Literal["ujson"]) -> TypeIs["UjsonConverter"]: ...
def has_format(
converter: C, fmt: ConverterFormat | str | Sequence[ConverterFormat]
) -> bool:
if isinstance(fmt, str):
fmt = (fmt,)

if "bson" in fmt and converter.__class__.__name__ == "BsonConverter":
from .bson import BsonConverter

return isinstance(converter, BsonConverter)

if "cbor2" in fmt and converter.__class__.__name__ == "Cbor2Converter":
from .cbor2 import Cbor2Converter

return isinstance(converter, Cbor2Converter)

if "json" in fmt and converter.__class__.__name__ == "JsonConverter":
from .json import JsonConverter

return isinstance(converter, JsonConverter)

if "msgpack" in fmt and converter.__class__.__name__ == "MsgpackConverter":
from .msgpack import MsgpackConverter

return isinstance(converter, MsgpackConverter)

if "msgspec-json" in fmt and converter.__class__.__name__ == "MsgspecJsonConverter":
from .msgspec import MsgspecJsonConverter

return isinstance(converter, MsgspecJsonConverter)

if "orjson" in fmt and converter.__class__.__name__ == "OrjsonConverter":
from .orjson import OrjsonConverter

return isinstance(converter, OrjsonConverter)

if "pyyaml" in fmt and converter.__class__.__name__ == "PyyamlConverter":
from .pyyaml import PyyamlConverter

return isinstance(converter, PyyamlConverter)

if "tomlkit" in fmt and converter.__class__.__name__ == "TomlkitConverter":
from .tomlkit import TomlkitConverter

return isinstance(converter, TomlkitConverter)

if "ujson" in fmt and converter.__class__.__name__ == "UjsonConverter":
from .ujson import UjsonConverter

return isinstance(converter, UjsonConverter)

return False
2 changes: 2 additions & 0 deletions src/cattrs/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""High level strategies for converters."""

from ._class_methods import use_class_methods
from ._extra_types import register_extra_types
from ._subclasses import include_subclasses
from ._unions import configure_tagged_union, configure_union_passthrough

__all__ = [
"configure_tagged_union",
"configure_union_passthrough",
"include_subclasses",
"register_extra_types",
"use_class_methods",
]
53 changes: 53 additions & 0 deletions src/cattrs/strategies/_extra_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from functools import cache, wraps
from importlib import import_module
from types import ModuleType
from typing import NoReturn

from ...converters import Converter
from ...dispatch import StructuredValue, StructureHook, TargetType, UnstructuredValue


def register_extra_types(converter: Converter, *classes: type) -> None:
"""
TODO: Add docs
"""
for cl in classes:
if not isinstance(cl, type):
raise TypeError("Type required instead of object")

struct_hook = get_module(cl).gen_structure_hook(cl, converter)
if struct_hook is None:
raise_unsupported(cl)
converter.register_structure_hook(cl, bypass(cl, struct_hook))

unstruct_hook = get_module(cl).gen_unstructure_hook(cl, converter)
if unstruct_hook is None:
raise_unsupported(cl)
converter.register_unstructure_hook(cl, unstruct_hook)


def bypass(target: type, structure_hook: StructureHook) -> StructureHook:
"""Bypass structure hook when given object of target type."""

@wraps(structure_hook)
def wrapper(obj: UnstructuredValue, cl: TargetType) -> StructuredValue:
return obj if type(obj) is target else structure_hook(obj, cl)

return wrapper


@cache
def get_module(cl: type) -> ModuleType:
modname = getattr(cl, "__module__", "builtins")
try:
return import_module(f"cattrs.strategies._extra_types._{modname}")
except ModuleNotFoundError:
raise_unsupported(cl)


def raise_unexpected_structure(target: type, cl: type) -> NoReturn:
raise TypeError(f"Unable to structure registered extra type {target} from {cl}")


def raise_unsupported(cl: type) -> NoReturn:
raise ValueError(f"Type {cl} is not supported by register_extra_types strategy")
54 changes: 54 additions & 0 deletions src/cattrs/strategies/_extra_types/_builtins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from collections.abc import Sequence
from contextlib import suppress
from functools import cache, partial
from numbers import Real

from ...converters import Converter
from ...dispatch import StructureHook, UnstructureHook
from ...preconf import has_format
from . import raise_unexpected_structure

MISSING_SPECIAL_FLOATS = ("msgspec-json", "orjson")

SPECIAL = (float("inf"), float("-inf"), float("nan"))
SPECIAL_STR = ("inf", "+inf", "-inf", "infinity", "+infinity", "-infinity", "nan")


@cache
def gen_structure_hook(cl: type, _) -> StructureHook | None:
if cl is complex:
return structure_complex
return None


@cache
def gen_unstructure_hook(cl: type, converter: Converter) -> UnstructureHook | None:
if cl is complex:
if has_format(converter, MISSING_SPECIAL_FLOATS):
return partial(unstructure_complex, special_as_string=True)
return unstructure_complex
return None


def structure_complex(obj: object, _) -> complex:
if (
isinstance(obj, Sequence)
and len(obj) == 2
and all(isinstance(x, (Real, str)) for x in obj)
):
with suppress(ValueError):
obj = [ # for all converters, string inf and nan are allowed
float(x) if (isinstance(x, str) and x.lower() in SPECIAL_STR) else x
for x in obj
]
return complex(*obj)
raise_unexpected_structure(complex, type(obj)) # noqa: RET503 # NoReturn


def unstructure_complex(
value: complex, special_as_string: bool = False
) -> list[float | str]:
return [
str(x) if (x in SPECIAL and special_as_string) else x
for x in [value.real, value.imag]
]
34 changes: 34 additions & 0 deletions src/cattrs/strategies/_extra_types/_uuid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from functools import cache
from uuid import UUID

from ...converters import Converter
from ...dispatch import StructureHook, UnstructureHook
from ...fns import identity
from ...preconf import has_format
from . import raise_unexpected_structure

SUPPORTS_UUID = ("bson", "cbor", "msgspec-json", "orjson")


@cache
def gen_structure_hook(cl: type, _) -> StructureHook | None:
if issubclass(cl, UUID):
return structure_uuid
return None


@cache
def gen_unstructure_hook(cl: type, converter: Converter) -> UnstructureHook | None:
if issubclass(cl, UUID):
return identity if has_format(converter, SUPPORTS_UUID) else lambda v: str(v)
return None


def structure_uuid(value: bytes | int | str, _) -> UUID:
if isinstance(value, bytes):
return UUID(bytes=value)
if isinstance(value, int):
return UUID(int=value)
if isinstance(value, str):
return UUID(value)
raise_unexpected_structure(UUID, type(value)) # noqa: RET503 # NoReturn
18 changes: 18 additions & 0 deletions src/cattrs/strategies/_extra_types/_zoneinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from functools import cache
from zoneinfo import ZoneInfo

from ...dispatch import StructureHook, UnstructureHook


@cache
def gen_structure_hook(cl: type, _) -> StructureHook | None:
if issubclass(cl, ZoneInfo):
return lambda v, _: ZoneInfo(v)
return None


@cache
def gen_unstructure_hook(cl: type, _) -> UnstructureHook | None:
if issubclass(cl, ZoneInfo):
return lambda v: str(v)
return None
Loading