diff --git a/examples/cookbook/cookbook/recipes/schema.py b/examples/cookbook/cookbook/recipes/schema.py index c0cb13ae..2f2fc032 100644 --- a/examples/cookbook/cookbook/recipes/schema.py +++ b/examples/cookbook/cookbook/recipes/schema.py @@ -1,4 +1,8 @@ -from graphene import Node +import asyncio + +from asgiref.sync import sync_to_async + +from graphene import Field, Node, String from graphene_django.filter import DjangoFilterConnectionField from graphene_django.types import DjangoObjectType @@ -6,12 +10,32 @@ class RecipeNode(DjangoObjectType): + async_field = String() + class Meta: model = Recipe interfaces = (Node,) fields = "__all__" filter_fields = ["title", "amounts"] + async def resolve_async_field(self, info): + await asyncio.sleep(2) + return "success" + + +class RecipeType(DjangoObjectType): + async_field = String() + + class Meta: + model = Recipe + fields = "__all__" + filter_fields = ["title", "amounts"] + skip_registry = True + + async def resolve_async_field(self, info): + await asyncio.sleep(2) + return "success" + class RecipeIngredientNode(DjangoObjectType): class Meta: @@ -28,7 +52,13 @@ class Meta: class Query: recipe = Node.Field(RecipeNode) + raw_recipe = Field(RecipeType) all_recipes = DjangoFilterConnectionField(RecipeNode) recipeingredient = Node.Field(RecipeIngredientNode) all_recipeingredients = DjangoFilterConnectionField(RecipeIngredientNode) + + @staticmethod + @sync_to_async + def resolve_raw_recipe(self, info): + return Recipe.objects.first() diff --git a/examples/cookbook/cookbook/urls.py b/examples/cookbook/cookbook/urls.py index e72b383d..c0f6fdf6 100644 --- a/examples/cookbook/cookbook/urls.py +++ b/examples/cookbook/cookbook/urls.py @@ -1,9 +1,10 @@ -from django.conf.urls import url from django.contrib import admin +from django.urls import re_path +from django.views.decorators.csrf import csrf_exempt -from graphene_django.views import GraphQLView +from graphene_django.views import AsyncGraphQLView urlpatterns = [ - url(r"^admin/", admin.site.urls), - url(r"^graphql$", GraphQLView.as_view(graphiql=True)), + re_path(r"^admin/", admin.site.urls), + re_path(r"^graphql$", csrf_exempt(AsyncGraphQLView.as_view(graphiql=True))), ] diff --git a/graphene_django/debug/middleware.py b/graphene_django/debug/middleware.py index de0d72d1..adc8d040 100644 --- a/graphene_django/debug/middleware.py +++ b/graphene_django/debug/middleware.py @@ -1,5 +1,8 @@ +from asgiref.sync import sync_to_async from django.db import connections +from graphql.type.definition import GraphQLNonNull +from ..utils import is_running_async, is_sync_function from .exception.formating import wrap_exception from .sql.tracking import unwrap_cursor, wrap_cursor from .types import DjangoDebug @@ -67,3 +70,28 @@ def resolve(self, next, root, info, **args): return context.django_debug.on_resolve_error(e) context.django_debug.add_result(result) return result + + +class DjangoSyncRequiredMiddleware: + def resolve(self, next, root, info, **args): + parent_type = info.parent_type + return_type = info.return_type + + if isinstance(parent_type, GraphQLNonNull): + parent_type = parent_type.of_type + if isinstance(return_type, GraphQLNonNull): + return_type = return_type.of_type + + if any( + [ + hasattr(parent_type, "graphene_type") + and hasattr(parent_type.graphene_type._meta, "model"), + hasattr(return_type, "graphene_type") + and hasattr(return_type.graphene_type._meta, "model"), + info.parent_type.name == "Mutation", + ] + ): + if is_sync_function(next) and is_running_async(): + return sync_to_async(next)(root, info, **args) + + return next(root, info, **args) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index a1b9a2cc..678c871e 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -1,5 +1,6 @@ from functools import partial +from asgiref.sync import sync_to_async from django.db.models.query import QuerySet from graphql_relay import ( connection_from_array_slice, @@ -7,7 +8,6 @@ get_offset_with_default, offset_to_cursor, ) -from promise import Promise from graphene import Int, NonNull from graphene.relay import ConnectionField @@ -15,7 +15,7 @@ from graphene.types import Field, List from .settings import graphene_settings -from .utils import maybe_queryset +from .utils import is_running_async, is_sync_function, maybe_queryset class DjangoListField(Field): @@ -49,11 +49,36 @@ def model(self): def get_manager(self): return self.model._default_manager - @staticmethod + @classmethod def list_resolver( - django_object_type, resolver, default_manager, root, info, **args + cls, django_object_type, resolver, default_manager, root, info, **args ): - queryset = maybe_queryset(resolver(root, info, **args)) + if is_running_async(): + if is_sync_function(resolver): + resolver = sync_to_async(resolver) + + iterable = resolver(root, info, **args) + + if info.is_awaitable(iterable): + + async def resolve_list_async(iterable): + queryset = maybe_queryset(await iterable) + if queryset is None: + queryset = maybe_queryset(default_manager) + + if isinstance(queryset, QuerySet): + # Pass queryset to the DjangoObjectType get_queryset method + queryset = maybe_queryset( + await sync_to_async(django_object_type.get_queryset)( + queryset, info + ) + ) + + return await sync_to_async(list)(queryset) + + return resolve_list_async(iterable) + + queryset = maybe_queryset(iterable) if queryset is None: queryset = maybe_queryset(default_manager) @@ -61,7 +86,7 @@ def list_resolver( # Pass queryset to the DjangoObjectType get_queryset method queryset = maybe_queryset(django_object_type.get_queryset(queryset, info)) - return queryset + return list(queryset) def wrap_resolve(self, parent_resolver): resolver = super().wrap_resolve(parent_resolver) @@ -235,20 +260,36 @@ def connection_resolver( # eventually leads to DjangoObjectType's get_queryset (accepts queryset) # or a resolve_foo (does not accept queryset) + + if is_running_async(): + if is_sync_function(resolver): + resolver = sync_to_async(resolver) + iterable = resolver(root, info, **args) + + if info.is_awaitable(iterable): + + async def resolve_connection_async(iterable): + iterable = await iterable + if iterable is None: + iterable = default_manager + + iterable = await sync_to_async(queryset_resolver)( + connection, iterable, info, args + ) + + return await sync_to_async(cls.resolve_connection)( + connection, args, iterable, max_limit=max_limit + ) + + return resolve_connection_async(iterable) + if iterable is None: iterable = default_manager # thus the iterable gets refiltered by resolve_queryset # but iterable might be promise iterable = queryset_resolver(connection, iterable, info, args) - on_resolve = partial( - cls.resolve_connection, connection, args, max_limit=max_limit - ) - - if Promise.is_thenable(iterable): - return Promise.resolve(iterable).then(on_resolve) - - return on_resolve(iterable) + return cls.resolve_connection(connection, args, iterable, max_limit=max_limit) def wrap_resolve(self, parent_resolver): return partial( diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index 2380632d..1a6c55e9 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -1,6 +1,7 @@ from collections import OrderedDict from functools import partial +from asgiref.sync import sync_to_async from django.core.exceptions import ValidationError from graphene.types.argument import to_arguments @@ -92,6 +93,18 @@ def filter_kwargs(): qs = super().resolve_queryset(connection, iterable, info, args) + if info.is_awaitable(qs): + + async def filter_async(qs): + filterset = filterset_class( + data=filter_kwargs(), queryset=await qs, request=info.context + ) + if await sync_to_async(filterset.is_valid)(): + return filterset.qs + raise ValidationError(filterset.form.errors.as_json()) + + return filter_async(qs) + filterset = filterset_class( data=filter_kwargs(), queryset=qs, request=info.context ) diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py index f1f12678..30b9b34f 100644 --- a/graphene_django/rest_framework/mutation.py +++ b/graphene_django/rest_framework/mutation.py @@ -1,6 +1,7 @@ from collections import OrderedDict from enum import Enum +from asgiref.sync import sync_to_async from django.shortcuts import get_object_or_404 from rest_framework import serializers @@ -11,6 +12,7 @@ from graphene.types.objecttype import yank_fields_from_attrs from ..types import ErrorType +from ..utils import is_running_async from .serializer_converter import convert_serializer_field @@ -166,6 +168,17 @@ def mutate_and_get_payload(cls, root, info, **input): kwargs = cls.get_serializer_kwargs(root, info, **input) serializer = cls._meta.serializer_class(**kwargs) + if is_running_async(): + + async def perform_mutate_async(): + if await sync_to_async(serializer.is_valid)(): + return await sync_to_async(cls.perform_mutate)(serializer, info) + else: + errors = ErrorType.from_errors(serializer.errors) + return cls(errors=errors) + + return perform_mutate_async() + if serializer.is_valid(): return cls.perform_mutate(serializer, info) else: diff --git a/graphene_django/tests/async_test_helper.py b/graphene_django/tests/async_test_helper.py new file mode 100644 index 00000000..5785c6c1 --- /dev/null +++ b/graphene_django/tests/async_test_helper.py @@ -0,0 +1,6 @@ +from asgiref.sync import async_to_sync + + +def assert_async_result_equal(schema, query, result, **kwargs): + async_result = async_to_sync(schema.execute_async)(query, **kwargs) + assert async_result == result diff --git a/graphene_django/tests/test_fields.py b/graphene_django/tests/test_fields.py index caaa6ddf..3b63c3aa 100644 --- a/graphene_django/tests/test_fields.py +++ b/graphene_django/tests/test_fields.py @@ -2,6 +2,7 @@ import re import pytest +from asgiref.sync import async_to_sync from django.db.models import Count, Model, Prefetch from graphene import List, NonNull, ObjectType, Schema, String @@ -9,6 +10,7 @@ from ..fields import DjangoConnectionField, DjangoListField from ..types import DjangoObjectType +from .async_test_helper import assert_async_result_equal from .models import ( Article as ArticleModel, Film as FilmModel, @@ -82,6 +84,7 @@ class Query(ObjectType): result = schema.execute(query) + assert_async_result_equal(schema, query, result) assert not result.errors assert result.data == { "reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}] @@ -109,6 +112,7 @@ class Query(ObjectType): result = schema.execute(query) assert not result.errors assert result.data == {"reporters": []} + assert_async_result_equal(schema, query, result) ReporterModel.objects.create(first_name="Tara", last_name="West") ReporterModel.objects.create(first_name="Debra", last_name="Payne") @@ -119,6 +123,7 @@ class Query(ObjectType): assert result.data == { "reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}] } + assert_async_result_equal(schema, query, result) def test_override_resolver(self): class Reporter(DjangoObjectType): @@ -146,6 +151,35 @@ def resolve_reporters(_, info): ReporterModel.objects.create(first_name="Debra", last_name="Payne") result = schema.execute(query) + assert not result.errors + assert result.data == {"reporters": [{"firstName": "Tara"}]} + + def test_override_resolver_async_execution(self): + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name",) + + class Query(ObjectType): + reporters = DjangoListField(Reporter) + + def resolve_reporters(_, info): + return ReporterModel.objects.filter(first_name="Tara") + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + } + } + """ + + ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + result = async_to_sync(schema.execute_async)(query) assert not result.errors assert result.data == {"reporters": [{"firstName": "Tara"}]} @@ -210,6 +244,7 @@ class Query(ObjectType): {"firstName": "Debra", "articles": []}, ] } + assert_async_result_equal(schema, query, result) def test_override_resolver_nested_list_field(self): class Article(DjangoObjectType): @@ -268,6 +303,7 @@ class Query(ObjectType): {"firstName": "Debra", "articles": []}, ] } + assert_async_result_equal(schema, query, result) def test_same_type_nested_list_field(self): class Person(DjangoObjectType): @@ -376,6 +412,7 @@ def resolve_reporters(_, info): assert not result.errors assert result.data == {"reporters": [{"firstName": "Tara"}]} + assert_async_result_equal(schema, query, result) def test_resolve_list(self): """Resolving a plain list should work (and not call get_queryset)""" @@ -424,6 +461,53 @@ def resolve_reporters(_, info): assert not result.errors assert result.data == {"reporters": [{"firstName": "Debra"}]} + def test_resolve_list_async(self): + """Resolving a plain list should work (and not call get_queryset) when running under async""" + + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name", "articles") + + @classmethod + def get_queryset(cls, queryset, info): + # Only get reporters with at least 1 article + return queryset.annotate(article_count=Count("articles")).filter( + article_count__gt=0 + ) + + class Query(ObjectType): + reporters = DjangoListField(Reporter) + + def resolve_reporters(_, info): + return [ReporterModel.objects.get(first_name="Debra")] + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + } + } + """ + + r1 = ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + ArticleModel.objects.create( + headline="Amazing news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) + + result = async_to_sync(schema.execute_async)(query) + + assert not result.errors + assert result.data == {"reporters": [{"firstName": "Debra"}]} + def test_get_queryset_foreign_key(self): class Article(DjangoObjectType): class Meta: @@ -483,6 +567,7 @@ class Query(ObjectType): {"firstName": "Debra", "articles": []}, ] } + assert_async_result_equal(schema, query, result) def test_resolve_list_external_resolver(self): """Resolving a plain list from external resolver should work (and not call get_queryset)""" @@ -531,6 +616,53 @@ class Query(ObjectType): assert not result.errors assert result.data == {"reporters": [{"firstName": "Debra"}]} + def test_resolve_list_external_resolver_async(self): + """Resolving a plain list from external resolver should work (and not call get_queryset)""" + + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name", "articles") + + @classmethod + def get_queryset(cls, queryset, info): + # Only get reporters with at least 1 article + return queryset.annotate(article_count=Count("articles")).filter( + article_count__gt=0 + ) + + def resolve_reporters(_, info): + return [ReporterModel.objects.get(first_name="Debra")] + + class Query(ObjectType): + reporters = DjangoListField(Reporter, resolver=resolve_reporters) + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + } + } + """ + + r1 = ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + ArticleModel.objects.create( + headline="Amazing news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) + + result = async_to_sync(schema.execute_async)(query) + + assert not result.errors + assert result.data == {"reporters": [{"firstName": "Debra"}]} + def test_get_queryset_filter_external_resolver(self): class Reporter(DjangoObjectType): class Meta: @@ -575,6 +707,7 @@ class Query(ObjectType): assert not result.errors assert result.data == {"reporters": [{"firstName": "Tara"}]} + assert_async_result_equal(schema, query, result) def test_select_related_and_prefetch_related_are_respected( self, django_assert_num_queries @@ -717,6 +850,7 @@ def resolve_articles(root, info): r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"', captured.captured_queries[1]["sql"], ) + assert_async_result_equal(schema, query, result) class TestDjangoConnectionField: diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 42394c20..e8aa1f74 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -2,6 +2,7 @@ import datetime import pytest +from asgiref.sync import async_to_sync from django.db import models from django.db.models import Q from django.utils.functional import SimpleLazyObject @@ -15,6 +16,7 @@ from ..fields import DjangoConnectionField from ..types import DjangoObjectType from ..utils import DJANGO_FILTER_INSTALLED +from .async_test_helper import assert_async_result_equal from .models import ( APNewsReporter, Article, @@ -43,6 +45,7 @@ class Meta: """ result = schema.execute(query) assert not result.errors + assert_async_result_equal(schema, query, result) def test_should_query_simplelazy_objects(): @@ -68,6 +71,7 @@ def resolve_reporter(self, info): result = schema.execute(query) assert not result.errors assert result.data == {"reporter": {"id": "1"}} + assert_async_result_equal(schema, query, result) def test_should_query_wrapped_simplelazy_objects(): @@ -93,6 +97,7 @@ def resolve_reporter(self, info): result = schema.execute(query) assert not result.errors assert result.data == {"reporter": {"id": "1"}} + assert_async_result_equal(schema, query, result) def test_should_query_well(): @@ -121,6 +126,7 @@ def resolve_reporter(self, info): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) @pytest.mark.skipif(IntegerRangeField is MissingType, reason="RangeField should exist") @@ -175,6 +181,7 @@ def resolve_event(self, info): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_node(): @@ -256,6 +263,7 @@ def resolve_reporter(self, info): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_onetoone_fields(): @@ -314,6 +322,7 @@ def resolve_film_details(root, info): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_connectionfields(): @@ -352,6 +361,7 @@ def resolve_all_reporters(self, info, **args): "edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}], } } + assert_async_result_equal(schema, query, result) def test_should_keep_annotations(): @@ -411,6 +421,7 @@ def resolve_all_articles(self, info, **args): """ result = schema.execute(query) assert not result.errors + assert_async_result_equal(schema, query, result) @pytest.mark.skipif( @@ -492,6 +503,7 @@ class Query(graphene.ObjectType): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) @pytest.mark.skipif( @@ -537,6 +549,7 @@ def resolve_films(self, info, **args): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) @pytest.mark.skipif( @@ -626,6 +639,7 @@ class Query(graphene.ObjectType): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_enforce_first_or_last(graphene_settings): @@ -666,6 +680,7 @@ class Query(graphene.ObjectType): "paginate the `allReporters` connection.\n" ) assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_error_if_first_is_greater_than_max(graphene_settings): @@ -708,6 +723,7 @@ class Query(graphene.ObjectType): "exceeds the `first` limit of 100 records.\n" ) assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_error_if_last_is_greater_than_max(graphene_settings): @@ -750,6 +766,7 @@ class Query(graphene.ObjectType): "exceeds the `last` limit of 100 records.\n" ) assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_promise_connectionfields(): @@ -785,6 +802,7 @@ def resolve_all_reporters(self, info, **args): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_connectionfields_with_last(): @@ -822,6 +840,7 @@ def resolve_all_reporters(self, info, **args): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_connectionfields_with_manager(): @@ -863,6 +882,7 @@ def resolve_all_reporters(self, info, **args): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_dataloader_fields(): @@ -965,6 +985,106 @@ class Query(graphene.ObjectType): assert result.data == expected +def test_should_query_dataloader_fields_async(): + from promise import Promise + from promise.dataloader import DataLoader + + def article_batch_load_fn(keys): + queryset = Article.objects.filter(reporter_id__in=keys) + return Promise.resolve( + [ + [article for article in queryset if article.reporter_id == id] + for id in keys + ] + ) + + article_loader = DataLoader(article_batch_load_fn) + + class ArticleType(DjangoObjectType): + class Meta: + model = Article + interfaces = (Node,) + fields = "__all__" + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + articles = DjangoConnectionField(ArticleType) + + def resolve_articles(self, info, **args): + return article_loader.load(self.id).get() + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + r = Reporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + + Article.objects.create( + headline="Article Node 1", + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + reporter=r, + editor=r, + lang="es", + ) + Article.objects.create( + headline="Article Node 2", + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + reporter=r, + editor=r, + lang="en", + ) + + schema = graphene.Schema(query=Query) + query = """ + query ReporterPromiseConnectionQuery { + allReporters(first: 1) { + edges { + node { + id + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + } + } + """ + + expected = { + "allReporters": { + "edges": [ + { + "node": { + "id": "UmVwb3J0ZXJUeXBlOjE=", + "articles": { + "edges": [ + {"node": {"headline": "Article Node 1"}}, + {"node": {"headline": "Article Node 2"}}, + ] + }, + } + } + ] + } + } + + result = async_to_sync(schema.execute_async)(query) + assert not result.errors + assert result.data == expected + + def test_should_handle_inherited_choices(): class BaseModel(models.Model): choice_field = models.IntegerField(choices=((0, "zero"), (1, "one"))) @@ -1071,6 +1191,7 @@ class Query(graphene.ObjectType): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_model_inheritance_support_reverse_relationships(): @@ -1411,6 +1532,7 @@ class Query(graphene.ObjectType): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_connection_should_limit_after_to_list_length(): @@ -1448,6 +1570,7 @@ class Query(graphene.ObjectType): expected = {"allReporters": {"edges": []}} assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values={"after": after}) REPORTERS = [ @@ -1491,6 +1614,7 @@ class Query(graphene.ObjectType): result = schema.execute(query) assert not result.errors assert len(result.data["allReporters"]["edges"]) == 4 + assert_async_result_equal(schema, query, result) def test_should_have_next_page(graphene_settings): @@ -1529,6 +1653,7 @@ class Query(graphene.ObjectType): assert not result.errors assert len(result.data["allReporters"]["edges"]) == 4 assert result.data["allReporters"]["pageInfo"]["hasNextPage"] + assert_async_result_equal(schema, query, result, variable_values={}) last_result = result.data["allReporters"]["pageInfo"]["endCursor"] result2 = schema.execute(query, variable_values={"first": 4, "after": last_result}) @@ -1542,6 +1667,9 @@ class Query(graphene.ObjectType): assert {to_global_id("ReporterType", reporter.id) for reporter in db_reporters} == { gql_reporter["node"]["id"] for gql_reporter in gql_reporters } + assert_async_result_equal( + schema, query, result2, variable_values={"first": 4, "after": last_result} + ) @pytest.mark.parametrize("max_limit", [100, 4]) @@ -1565,7 +1693,7 @@ class Query(graphene.ObjectType): def test_query_last(self, graphene_settings, max_limit): schema = self.setup_schema(graphene_settings, max_limit=max_limit) - query_last = """ + query = """ query { allReporters(last: 3) { edges { @@ -1577,16 +1705,17 @@ def test_query_last(self, graphene_settings, max_limit): } """ - result = schema.execute(query_last) + result = schema.execute(query) assert not result.errors assert len(result.data["allReporters"]["edges"]) == 3 assert [ e["node"]["firstName"] for e in result.data["allReporters"]["edges"] ] == ["First 3", "First 4", "First 5"] + assert_async_result_equal(schema, query, result) def test_query_first_and_last(self, graphene_settings, max_limit): schema = self.setup_schema(graphene_settings, max_limit=max_limit) - query_first_and_last = """ + query = """ query { allReporters(first: 4, last: 3) { edges { @@ -1598,12 +1727,13 @@ def test_query_first_and_last(self, graphene_settings, max_limit): } """ - result = schema.execute(query_first_and_last) + result = schema.execute(query) assert not result.errors assert len(result.data["allReporters"]["edges"]) == 3 assert [ e["node"]["firstName"] for e in result.data["allReporters"]["edges"] ] == ["First 1", "First 2", "First 3"] + assert_async_result_equal(schema, query, result) def test_query_first_last_and_after(self, graphene_settings, max_limit): schema = self.setup_schema(graphene_settings, max_limit=max_limit) @@ -1629,6 +1759,9 @@ def test_query_first_last_and_after(self, graphene_settings, max_limit): assert [ e["node"]["firstName"] for e in result.data["allReporters"]["edges"] ] == ["First 2", "First 3", "First 4"] + assert_async_result_equal( + schema, query_first_last_and_after, result, variable_values={"after": after} + ) def test_query_last_and_before(self, graphene_settings, max_limit): schema = self.setup_schema(graphene_settings, max_limit=max_limit) @@ -1650,6 +1783,7 @@ def test_query_last_and_before(self, graphene_settings, max_limit): assert not result.errors assert len(result.data["allReporters"]["edges"]) == 1 assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 5" + assert_async_result_equal(schema, query_first_last_and_after, result) before = base64.b64encode(b"arrayconnection:5").decode() result = schema.execute( @@ -1659,6 +1793,12 @@ def test_query_last_and_before(self, graphene_settings, max_limit): assert not result.errors assert len(result.data["allReporters"]["edges"]) == 1 assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 4" + assert_async_result_equal( + schema, + query_first_last_and_after, + result, + variable_values={"before": before}, + ) def test_should_preserve_prefetch_related(django_assert_num_queries): @@ -1713,6 +1853,7 @@ def resolve_films(root, info, **kwargs): with django_assert_num_queries(3): result = schema.execute(query) assert not result.errors + assert_async_result_equal(schema, query, result) def test_should_preserve_annotations(): @@ -1768,6 +1909,7 @@ def resolve_films(root, info, **kwargs): } assert result.data == expected, str(result.data) assert not result.errors + assert_async_result_equal(schema, query, result) def test_connection_should_enable_offset_filtering(): @@ -1807,6 +1949,7 @@ class Query(graphene.ObjectType): } } assert result.data == expected + assert_async_result_equal(schema, query, result) def test_connection_should_enable_offset_filtering_higher_than_max_limit( @@ -1851,6 +1994,7 @@ class Query(graphene.ObjectType): } } assert result.data == expected + assert_async_result_equal(schema, query, result) def test_connection_should_forbid_offset_filtering_with_before(): @@ -1881,6 +2025,7 @@ class Query(graphene.ObjectType): expected_error = "You can't provide a `before` value at the same time as an `offset` value to properly paginate the `allReporters` connection." assert len(result.errors) == 1 assert result.errors[0].message == expected_error + assert_async_result_equal(schema, query, result, variable_values={"before": before}) def test_connection_should_allow_offset_filtering_with_after(): @@ -1923,6 +2068,7 @@ class Query(graphene.ObjectType): } } assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values={"after": after}) def test_connection_should_succeed_if_last_higher_than_number_of_objects(): @@ -1953,6 +2099,7 @@ class Query(graphene.ObjectType): assert not result.errors expected = {"allReporters": {"edges": []}} assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values={"last": 2}) Reporter.objects.create(first_name="John", last_name="Doe") Reporter.objects.create(first_name="Some", last_name="Guy") @@ -1970,6 +2117,7 @@ class Query(graphene.ObjectType): } } assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values={"last": 2}) result = schema.execute(query, variable_values={"last": 4}) assert not result.errors @@ -1984,6 +2132,7 @@ class Query(graphene.ObjectType): } } assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values={"last": 4}) result = schema.execute(query, variable_values={"last": 20}) assert not result.errors @@ -1998,6 +2147,7 @@ class Query(graphene.ObjectType): } } assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values={"last": 20}) def test_should_query_nullable_foreign_key(): diff --git a/graphene_django/types.py b/graphene_django/types.py index e310fe47..ce46b968 100644 --- a/graphene_django/types.py +++ b/graphene_django/types.py @@ -16,6 +16,7 @@ DJANGO_FILTER_INSTALLED, camelize, get_model_fields, + is_running_async, is_valid_django_model, ) @@ -288,7 +289,11 @@ def get_queryset(cls, queryset, info): def get_node(cls, info, id): queryset = cls.get_queryset(cls._meta.model.objects, info) try: + if is_running_async(): + return queryset.aget(pk=id) + return queryset.get(pk=id) + except cls._meta.model.DoesNotExist: return None diff --git a/graphene_django/utils/__init__.py b/graphene_django/utils/__init__.py index a64ee36a..609da967 100644 --- a/graphene_django/utils/__init__.py +++ b/graphene_django/utils/__init__.py @@ -5,6 +5,8 @@ camelize, get_model_fields, get_reverse_fields, + is_running_async, + is_sync_function, is_valid_django_model, maybe_queryset, ) @@ -17,5 +19,7 @@ "camelize", "is_valid_django_model", "GraphQLTestCase", + "is_sync_function", + "is_running_async", "bypass_get_queryset", ] diff --git a/graphene_django/utils/utils.py b/graphene_django/utils/utils.py index 364eff9b..edac2e33 100644 --- a/graphene_django/utils/utils.py +++ b/graphene_django/utils/utils.py @@ -1,4 +1,5 @@ import inspect +from asyncio import get_running_loop import django from django.db import connection, models, transaction @@ -139,6 +140,21 @@ def set_rollback(): transaction.set_rollback(True) +def is_running_async(): + try: + get_running_loop() + except RuntimeError: + return False + else: + return True + + +def is_sync_function(func): + return not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction( + func + ) + + def bypass_get_queryset(resolver): """ Adds a bypass_get_queryset attribute to the resolver, which is used to diff --git a/graphene_django/views.py b/graphene_django/views.py index 1ec65988..bdc483e1 100644 --- a/graphene_django/views.py +++ b/graphene_django/views.py @@ -1,12 +1,14 @@ import inspect import json import re +import traceback +from asyncio import coroutines, gather from django.db import connection, transaction from django.http import HttpResponse, HttpResponseNotAllowed from django.http.response import HttpResponseBadRequest from django.shortcuts import render -from django.utils.decorators import method_decorator +from django.utils.decorators import classonlymethod, method_decorator from django.views.decorators.csrf import ensure_csrf_cookie from django.views.generic import View from graphql import ( @@ -431,3 +433,336 @@ def get_content_type(request): meta = request.META content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", "")) return content_type.split(";", 1)[0].lower() + + +class AsyncGraphQLView(GraphQLView): + schema = None + graphiql = False + middleware = None + root_value = None + pretty = False + batch = False + subscription_path = None + execution_context_class = None + + def __init__( + self, + schema=None, + middleware=None, + root_value=None, + graphiql=False, + pretty=False, + batch=False, + subscription_path=None, + execution_context_class=None, + ): + if not schema: + schema = graphene_settings.SCHEMA + + if middleware is None: + middleware = graphene_settings.MIDDLEWARE + + self.schema = self.schema or schema + if middleware is not None: + if isinstance(middleware, MiddlewareManager): + self.middleware = middleware + else: + self.middleware = list(instantiate_middleware(middleware)) + self.root_value = root_value + self.pretty = self.pretty or pretty + self.graphiql = self.graphiql or graphiql + self.batch = self.batch or batch + self.execution_context_class = execution_context_class + if subscription_path is None: + self.subscription_path = graphene_settings.SUBSCRIPTION_PATH + + assert isinstance( + self.schema, Schema + ), "A Schema is required to be provided to GraphQLView." + assert not all((graphiql, batch)), "Use either graphiql or batch processing" + + # noinspection PyUnusedLocal + def get_root_value(self, request): + return self.root_value + + def get_middleware(self, request): + return self.middleware + + def get_context(self, request): + return request + + @classonlymethod + def as_view(cls, **initkwargs): + view = super().as_view(**initkwargs) + view._is_coroutine = coroutines._is_coroutine + return view + + async def dispatch(self, request, *args, **kwargs): + try: + if request.method.lower() not in ("get", "post"): + raise HttpError( + HttpResponseNotAllowed( + ["GET", "POST"], "GraphQL only supports GET and POST requests." + ) + ) + + data = self.parse_body(request) + show_graphiql = self.graphiql and self.can_display_graphiql(request, data) + + if show_graphiql: + return self.render_graphiql( + request, + # Dependency parameters. + whatwg_fetch_version=self.whatwg_fetch_version, + whatwg_fetch_sri=self.whatwg_fetch_sri, + react_version=self.react_version, + react_sri=self.react_sri, + react_dom_sri=self.react_dom_sri, + graphiql_version=self.graphiql_version, + graphiql_sri=self.graphiql_sri, + graphiql_css_sri=self.graphiql_css_sri, + subscriptions_transport_ws_version=self.subscriptions_transport_ws_version, + subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri, + graphiql_plugin_explorer_version=self.graphiql_plugin_explorer_version, + graphiql_plugin_explorer_sri=self.graphiql_plugin_explorer_sri, + # The SUBSCRIPTION_PATH setting. + subscription_path=self.subscription_path, + # GraphiQL headers tab, + graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED, + graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS, + ) + + if self.batch: + responses = await gather( + *[self.get_response(request, entry) for entry in data] + ) + result = "[{}]".format( + ",".join([response[0] for response in responses]) + ) + status_code = ( + responses + and max(responses, key=lambda response: response[1])[1] + or 200 + ) + else: + result, status_code = await self.get_response( + request, data, show_graphiql + ) + + return HttpResponse( + status=status_code, content=result, content_type="application/json" + ) + + except HttpError as e: + response = e.response + response["Content-Type"] = "application/json" + response.content = self.json_encode( + request, {"errors": [self.format_error(e)]} + ) + return response + + async def get_response(self, request, data, show_graphiql=False): + query, variables, operation_name, id = self.get_graphql_params(request, data) + + execution_result = await self.execute_graphql_request( + request, data, query, variables, operation_name, show_graphiql + ) + + if getattr(request, MUTATION_ERRORS_FLAG, False) is True: + set_rollback() + + status_code = 200 + if execution_result: + response = {} + + if execution_result.errors: + for e in execution_result.errors: + print(e) + traceback.print_tb(e.__traceback__) + set_rollback() + response["errors"] = [ + self.format_error(e) for e in execution_result.errors + ] + + if execution_result.errors and any( + not getattr(e, "path", None) for e in execution_result.errors + ): + status_code = 400 + else: + response["data"] = execution_result.data + + if self.batch: + response["id"] = id + response["status"] = status_code + + result = self.json_encode(request, response, pretty=show_graphiql) + else: + result = None + + return result, status_code + + def render_graphiql(self, request, **data): + return render(request, self.graphiql_template, data) + + def json_encode(self, request, d, pretty=False): + if not (self.pretty or pretty) and not request.GET.get("pretty"): + return json.dumps(d, separators=(",", ":")) + + return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": ")) + + def parse_body(self, request): + content_type = self.get_content_type(request) + + if content_type == "application/graphql": + return {"query": request.body.decode()} + + elif content_type == "application/json": + # noinspection PyBroadException + try: + body = request.body.decode("utf-8") + except Exception as e: + raise HttpError(HttpResponseBadRequest(str(e))) + + try: + request_json = json.loads(body) + if self.batch: + assert isinstance(request_json, list), ( + "Batch requests should receive a list, but received {}." + ).format(repr(request_json)) + assert ( + len(request_json) > 0 + ), "Received an empty list in the batch request." + else: + assert isinstance( + request_json, dict + ), "The received data is not a valid JSON query." + return request_json + except AssertionError as e: + raise HttpError(HttpResponseBadRequest(str(e))) + except (TypeError, ValueError): + raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON.")) + + elif content_type in [ + "application/x-www-form-urlencoded", + "multipart/form-data", + ]: + return request.POST + + return {} + + async def execute_graphql_request( + self, request, data, query, variables, operation_name, show_graphiql=False + ): + if not query: + if show_graphiql: + return None + raise HttpError(HttpResponseBadRequest("Must provide query string.")) + + try: + document = parse(query) + except Exception as e: + return ExecutionResult(errors=[e]) + + if request.method.lower() == "get": + operation_ast = get_operation_ast(document, operation_name) + if operation_ast and operation_ast.operation != OperationType.QUERY: + if show_graphiql: + return None + + raise HttpError( + HttpResponseNotAllowed( + ["POST"], + "Can only perform a {} operation from a POST request.".format( + operation_ast.operation.value + ), + ) + ) + + try: + extra_options = {} + if self.execution_context_class: + extra_options["execution_context_class"] = self.execution_context_class + + options = { + "source": query, + "root_value": self.get_root_value(request), + "variable_values": variables, + "operation_name": operation_name, + "context_value": self.get_context(request), + "middleware": self.get_middleware(request), + } + options.update(extra_options) + + operation_ast = get_operation_ast(document, operation_name) + if ( + operation_ast + and operation_ast.operation == OperationType.MUTATION + and ( + graphene_settings.ATOMIC_MUTATIONS is True + or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True + ) + ): + with transaction.atomic(): + result = await self.schema.execute_async(**options) + if getattr(request, MUTATION_ERRORS_FLAG, False) is True: + transaction.set_rollback(True) + return result + + return await self.schema.execute_async(**options) + except Exception as e: + return ExecutionResult(errors=[e]) + + @classmethod + def can_display_graphiql(cls, request, data): + raw = "raw" in request.GET or "raw" in data + return not raw and cls.request_wants_html(request) + + @classmethod + def request_wants_html(cls, request): + accepted = get_accepted_content_types(request) + accepted_length = len(accepted) + # the list will be ordered in preferred first - so we have to make + # sure the most preferred gets the highest number + html_priority = ( + accepted_length - accepted.index("text/html") + if "text/html" in accepted + else 0 + ) + json_priority = ( + accepted_length - accepted.index("application/json") + if "application/json" in accepted + else 0 + ) + + return html_priority > json_priority + + @staticmethod + def get_graphql_params(request, data): + query = request.GET.get("query") or data.get("query") + variables = request.GET.get("variables") or data.get("variables") + id = request.GET.get("id") or data.get("id") + + if variables and isinstance(variables, str): + try: + variables = json.loads(variables) + except Exception: + raise HttpError(HttpResponseBadRequest("Variables are invalid JSON.")) + + operation_name = request.GET.get("operationName") or data.get("operationName") + if operation_name == "null": + operation_name = None + + return query, variables, operation_name, id + + @staticmethod + def format_error(error): + if isinstance(error, GraphQLError): + return error.formatted + + return {"message": str(error)} + + @staticmethod + def get_content_type(request): + meta = request.META + content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", "")) + return content_type.split(";", 1)[0].lower() diff --git a/setup.py b/setup.py index 2c07aba2..05466801 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ "pytz", "django-filter>=22.1", "pytest-django>=4.5.2", + "pytest-asyncio>=0.16,<2", ] + rest_framework_require