Skip to content
104 changes: 60 additions & 44 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import partial

import pytest
from django.utils.module_loading import import_string

from . import live_server_helper
from .django_compat import is_django_unittest
Expand All @@ -19,7 +20,6 @@
_DjangoDbDatabases = Optional[Union["Literal['__all__']", Iterable[str]]]
_DjangoDb = Tuple[bool, bool, _DjangoDbDatabases]


__all__ = [
"django_db_setup",
"db",
Expand All @@ -42,6 +42,18 @@
]


def import_from_string(val, setting_name):
"""
Attempt to import a class from a string representation.
"""
try:
return import_string(val)
except ImportError as e:
msg = "Could not import '%s' for API setting '%s'. %s: %s." \
% (val, setting_name, e.__class__.__name__, e)
raise ImportError(msg)


@pytest.fixture(scope="session")
def django_db_modify_db_settings_tox_suffix() -> None:
skip_if_no_django()
Expand All @@ -64,15 +76,15 @@ def django_db_modify_db_settings_xdist_suffix(request) -> None:

@pytest.fixture(scope="session")
def django_db_modify_db_settings_parallel_suffix(
django_db_modify_db_settings_tox_suffix: None,
django_db_modify_db_settings_xdist_suffix: None,
django_db_modify_db_settings_tox_suffix: None,
django_db_modify_db_settings_xdist_suffix: None,
) -> None:
skip_if_no_django()


@pytest.fixture(scope="session")
def django_db_modify_db_settings(
django_db_modify_db_settings_parallel_suffix: None,
django_db_modify_db_settings_parallel_suffix: None,
) -> None:
skip_if_no_django()

Expand All @@ -94,13 +106,13 @@ def django_db_createdb(request) -> bool:

@pytest.fixture(scope="session")
def django_db_setup(
request,
django_test_environment: None,
django_db_blocker,
django_db_use_migrations: bool,
django_db_keepdb: bool,
django_db_createdb: bool,
django_db_modify_db_settings: None,
request,
django_test_environment: None,
django_db_blocker,
django_db_use_migrations: bool,
django_db_keepdb: bool,
django_db_createdb: bool,
django_db_modify_db_settings: None,
) -> None:
"""Top level fixture to ensure test databases are available"""
from django.test.utils import setup_databases, teardown_databases
Expand Down Expand Up @@ -136,11 +148,12 @@ def teardown_database() -> None:


def _django_db_fixture_helper(
request,
django_db_blocker,
transactional: bool = False,
reset_sequences: bool = False,
request,
django_db_blocker,
transactional: bool = False,
reset_sequences: bool = False,
) -> None:

if is_django_unittest(request):
return

Expand All @@ -155,13 +168,16 @@ def _django_db_fixture_helper(
django_db_blocker.unblock()
request.addfinalizer(django_db_blocker.restore)

import django.test
import django.db

if transactional:
test_case_class = django.test.TransactionTestCase
test_case_classname = request.config.getvalue("transaction_testcase_class") or os.getenv(
"DJANGO_TRANSACTION_TEST_CASE_CLASS"
) or "django.test.TransactionTestCase"
else:
test_case_class = django.test.TestCase
test_case_classname = request.config.getvalue("testcase_class") or os.getenv(
"DJANGO_TEST_CASE_CLASS"
) or "django.test.TestCase"

test_case_class = import_string(test_case_classname)

_reset_sequences = reset_sequences

Expand Down Expand Up @@ -223,9 +239,9 @@ def _set_suffix_to_test_databases(suffix: str) -> None:

@pytest.fixture(scope="function")
def db(
request,
django_db_setup: None,
django_db_blocker,
request,
django_db_setup: None,
django_db_blocker,
) -> None:
"""Require a django test database.

Expand All @@ -243,8 +259,8 @@ def db(
if "django_db_reset_sequences" in request.fixturenames:
request.getfixturevalue("django_db_reset_sequences")
if (
"transactional_db" in request.fixturenames
or "live_server" in request.fixturenames
"transactional_db" in request.fixturenames
or "live_server" in request.fixturenames
):
request.getfixturevalue("transactional_db")
else:
Expand All @@ -253,9 +269,9 @@ def db(

@pytest.fixture(scope="function")
def transactional_db(
request,
django_db_setup: None,
django_db_blocker,
request,
django_db_setup: None,
django_db_blocker,
) -> None:
"""Require a django test database with transaction support.

Expand All @@ -276,9 +292,9 @@ def transactional_db(

@pytest.fixture(scope="function")
def django_db_reset_sequences(
request,
django_db_setup: None,
django_db_blocker,
request,
django_db_setup: None,
django_db_blocker,
) -> None:
"""Require a transactional test database with sequence reset support.

Expand Down Expand Up @@ -332,9 +348,9 @@ def django_username_field(django_user_model) -> str:

@pytest.fixture()
def admin_user(
db: None,
django_user_model,
django_username_field: str,
db: None,
django_user_model,
django_username_field: str,
):
"""A Django admin user.

Expand Down Expand Up @@ -363,8 +379,8 @@ def admin_user(

@pytest.fixture()
def admin_client(
db: None,
admin_user,
db: None,
admin_user,
) -> "django.test.client.Client":
"""A Django test client logged in as an admin user."""
from django.test.client import Client
Expand Down Expand Up @@ -496,11 +512,11 @@ def _live_server_helper(request) -> None:

@contextmanager
def _assert_num_queries(
config,
num: int,
exact: bool = True,
connection=None,
info=None,
config,
num: int,
exact: bool = True,
connection=None,
info=None,
) -> Generator["django.test.utils.CaptureQueriesContext", None, None]:
from django.test.utils import CaptureQueriesContext

Expand Down Expand Up @@ -547,9 +563,9 @@ def django_assert_max_num_queries(pytestconfig):

@contextmanager
def _capture_on_commit_callbacks(
*,
using: Optional[str] = None,
execute: bool = False
*,
using: Optional[str] = None,
execute: bool = False
):
from django.db import DEFAULT_DB_ALIAS, connections
from django.test import TestCase
Expand Down
11 changes: 11 additions & 0 deletions pytest_django/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ def pytest_addoption(parser) -> None:
default=None,
help="Address and port for the live_server fixture.",
)
group.addoption(
"--testcase-class",
default=None,
help="The base TestCase class to patch for use with django. Useful for hypothesis users",
)
group.addoption(
"--transaction-testcase-class",
default=None,
help="The base TransactionTestCase class to patch for use with django. "
"Useful for hypothesis users",
)
parser.addini(
SETTINGS_MODULE_ENV, "Django settings module to use by pytest-django."
)
Expand Down