Skip to content

Simplify synchronizers #1153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/+b22c903a.fixed.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
An error that could cause duplicate warnings to be issued
124 changes: 28 additions & 96 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import pluggy
import pytest
from _pytest.fixtures import resolve_fixture_function
from _pytest.scope import Scope
from pytest import (
Config,
Expand All @@ -41,7 +42,6 @@
Function,
Item,
Mark,
Metafunc,
MonkeyPatch,
Parser,
PytestCollectionWarning,
Expand All @@ -50,17 +50,16 @@
)

if sys.version_info >= (3, 10):
from typing import Concatenate, ParamSpec
from typing import ParamSpec
else:
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec

if sys.version_info >= (3, 11):
from asyncio import Runner
else:
from backports.asyncio.runner import Runner

_ScopeName = Literal["session", "package", "module", "class", "function"]
_T = TypeVar("_T")
_R = TypeVar("_R", bound=Union[Awaitable[Any], AsyncIterator[Any]])
_P = ParamSpec("_P")
FixtureFunction = Callable[_P, _R]
Expand Down Expand Up @@ -234,44 +233,19 @@ def pytest_report_header(config: Config) -> list[str]:
]


def _fixture_synchronizer(fixturedef: FixtureDef, runner: Runner) -> Callable:
def _fixture_synchronizer(
fixturedef: FixtureDef, runner: Runner, request: FixtureRequest
) -> Callable:
"""Returns a synchronous function evaluating the specified fixture."""
fixture_function = resolve_fixture_function(fixturedef, request)
if inspect.isasyncgenfunction(fixturedef.func):
return _wrap_asyncgen_fixture(fixturedef.func, runner)
return _wrap_asyncgen_fixture(fixture_function, runner, request) # type: ignore[arg-type]
elif inspect.iscoroutinefunction(fixturedef.func):
return _wrap_async_fixture(fixturedef.func, runner)
return _wrap_async_fixture(fixture_function, runner, request) # type: ignore[arg-type]
else:
return fixturedef.func


def _add_kwargs(
func: Callable[..., Any],
kwargs: dict[str, Any],
request: FixtureRequest,
) -> dict[str, Any]:
sig = inspect.signature(func)
ret = kwargs.copy()
if "request" in sig.parameters:
ret["request"] = request
return ret


def _perhaps_rebind_fixture_func(func: _T, instance: Any | None) -> _T:
if instance is not None:
# The fixture needs to be bound to the actual request.instance
# so it is bound to the same object as the test method.
unbound, cls = func, None
try:
unbound, cls = func.__func__, type(func.__self__) # type: ignore
except AttributeError:
pass
# Only if the fixture was bound before to an instance of
# the same type.
if cls is not None and isinstance(instance, cls):
func = unbound.__get__(instance) # type: ignore
return func


AsyncGenFixtureParams = ParamSpec("AsyncGenFixtureParams")
AsyncGenFixtureYieldType = TypeVar("AsyncGenFixtureYieldType")

Expand All @@ -281,17 +255,14 @@ def _wrap_asyncgen_fixture(
AsyncGenFixtureParams, AsyncGeneratorType[AsyncGenFixtureYieldType, Any]
],
runner: Runner,
) -> Callable[
Concatenate[FixtureRequest, AsyncGenFixtureParams], AsyncGenFixtureYieldType
]:
request: FixtureRequest,
) -> Callable[AsyncGenFixtureParams, AsyncGenFixtureYieldType]:
@functools.wraps(fixture_function)
def _asyncgen_fixture_wrapper(
request: FixtureRequest,
*args: AsyncGenFixtureParams.args,
**kwargs: AsyncGenFixtureParams.kwargs,
):
func = _perhaps_rebind_fixture_func(fixture_function, request.instance)
gen_obj = func(*args, **_add_kwargs(func, kwargs, request))
gen_obj = fixture_function(*args, **kwargs)

async def setup():
res = await gen_obj.__anext__() # type: ignore[union-attr]
Expand Down Expand Up @@ -334,18 +305,16 @@ def _wrap_async_fixture(
AsyncFixtureParams, CoroutineType[Any, Any, AsyncFixtureReturnType]
],
runner: Runner,
) -> Callable[Concatenate[FixtureRequest, AsyncFixtureParams], AsyncFixtureReturnType]:
request: FixtureRequest,
) -> Callable[AsyncFixtureParams, AsyncFixtureReturnType]:

@functools.wraps(fixture_function) # type: ignore[arg-type]
def _async_fixture_wrapper(
request: FixtureRequest,
*args: AsyncFixtureParams.args,
**kwargs: AsyncFixtureParams.kwargs,
):
func = _perhaps_rebind_fixture_func(fixture_function, request.instance)

async def setup():
res = await func(*args, **_add_kwargs(func, kwargs, request))
res = await fixture_function(*args, **kwargs)
return res

context = contextvars.copy_context()
Expand Down Expand Up @@ -451,11 +420,10 @@ def _can_substitute(item: Function) -> bool:
return inspect.iscoroutinefunction(func)

def runtest(self) -> None:
self.obj = wrap_in_sync(
# https://github.com/pytest-dev/pytest-asyncio/issues/596
self.obj, # type: ignore[has-type]
)
super().runtest()
synchronized_obj = wrap_in_sync(self.obj)
with MonkeyPatch.context() as c:
c.setattr(self, "obj", synchronized_obj)
super().runtest()


class AsyncGenerator(PytestAsyncioFunction):
Expand Down Expand Up @@ -494,11 +462,10 @@ def _can_substitute(item: Function) -> bool:
)

def runtest(self) -> None:
self.obj = wrap_in_sync(
# https://github.com/pytest-dev/pytest-asyncio/issues/596
self.obj, # type: ignore[has-type]
)
super().runtest()
synchronized_obj = wrap_in_sync(self.obj)
with MonkeyPatch.context() as c:
c.setattr(self, "obj", synchronized_obj)
super().runtest()


class AsyncHypothesisTest(PytestAsyncioFunction):
Expand All @@ -517,10 +484,10 @@ def _can_substitute(item: Function) -> bool:
)

def runtest(self) -> None:
self.obj.hypothesis.inner_test = wrap_in_sync(
self.obj.hypothesis.inner_test,
)
super().runtest()
synchronized_obj = wrap_in_sync(self.obj.hypothesis.inner_test)
with MonkeyPatch.context() as c:
c.setattr(self.obj.hypothesis, "inner_test", synchronized_obj)
super().runtest()


# The function name needs to start with "pytest_"
Expand Down Expand Up @@ -579,32 +546,6 @@ def _temporary_event_loop_policy(policy: AbstractEventLoopPolicy) -> Iterator[No
_set_event_loop(old_loop)


@pytest.hookimpl(tryfirst=True)
def pytest_generate_tests(metafunc: Metafunc) -> None:
marker = metafunc.definition.get_closest_marker("asyncio")
if not marker:
return
default_loop_scope = _get_default_test_loop_scope(metafunc.config)
loop_scope = _get_marked_loop_scope(marker, default_loop_scope)
runner_fixture_id = f"_{loop_scope}_scoped_runner"
# This specific fixture name may already be in metafunc.argnames, if this
# test indirectly depends on the fixture. For example, this is the case
# when the test depends on an async fixture, both of which share the same
# event loop fixture mark.
if runner_fixture_id in metafunc.fixturenames:
return
fixturemanager = metafunc.config.pluginmanager.get_plugin("funcmanage")
assert fixturemanager is not None
# Add the scoped event loop fixture to Metafunc's list of fixture names and
# fixturedefs and leave the actual parametrization to pytest
# The fixture needs to be appended to avoid messing up the fixture evaluation
# order
metafunc.fixturenames.append(runner_fixture_id)
metafunc._arg2fixturedefs[runner_fixture_id] = fixturemanager._arg2fixturedefs[
runner_fixture_id
]


def _get_event_loop_policy() -> AbstractEventLoopPolicy:
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
Expand Down Expand Up @@ -691,12 +632,6 @@ def wrap_in_sync(
Return a sync wrapper around an async function executing it in the
current event loop.
"""
# if the function is already wrapped, we rewrap using the original one
# not using __wrapped__ because the original function may already be
# a wrapped one
raw_func = getattr(func, "_raw_test_func", None)
if raw_func is not None:
func = raw_func

@functools.wraps(func)
def inner(*args, **kwargs):
Expand All @@ -713,7 +648,6 @@ def inner(*args, **kwargs):
task.exception()
raise

inner._raw_test_func = func # type: ignore[attr-defined]
return inner


Expand Down Expand Up @@ -755,11 +689,9 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
)
runner_fixture_id = f"_{loop_scope}_scoped_runner"
runner = request.getfixturevalue(runner_fixture_id)
synchronizer = _fixture_synchronizer(fixturedef, runner)
synchronizer = _fixture_synchronizer(fixturedef, runner, request)
_make_asyncio_fixture_function(synchronizer, loop_scope)
with MonkeyPatch.context() as c:
if "request" not in fixturedef.argnames:
c.setattr(fixturedef, "argnames", (*fixturedef.argnames, "request"))
c.setattr(fixturedef, "func", synchronizer)
hook_result = yield
return hook_result
Expand Down
2 changes: 1 addition & 1 deletion tests/markers/test_function_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def test_warns():
)
)
result = pytester.runpytest_subprocess("--asyncio-mode=strict")
result.assert_outcomes(passed=1, warnings=2)
result.assert_outcomes(passed=1, warnings=1)
result.stdout.fnmatch_lines("*DeprecationWarning*")


Expand Down
Loading