diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 367ad6397..c2a2a8fd4 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -6,6 +6,7 @@ from graphene.relay import ConnectionField, PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice +from .settings import graphene_settings from .utils import DJANGO_FILTER_INSTALLED, maybe_queryset @@ -30,6 +31,14 @@ class DjangoConnectionField(ConnectionField): def __init__(self, *args, **kwargs): self.on = kwargs.pop('on', False) + self.max_limit = kwargs.pop( + 'max_limit', + graphene_settings.RELAY_CONNECTION_MAX_LIMIT + ) + self.enforce_first_or_last = kwargs.pop( + 'enforce_first_or_last', + graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST + ) super(DjangoConnectionField, self).__init__(*args, **kwargs) @property @@ -51,7 +60,29 @@ def merge_querysets(cls, default_queryset, queryset): return default_queryset & queryset @classmethod - def connection_resolver(cls, resolver, connection, default_manager, root, args, context, info): + def connection_resolver(cls, resolver, connection, default_manager, max_limit, + enforce_first_or_last, root, args, context, info): + first = args.get('first') + last = args.get('last') + + if enforce_first_or_last: + assert first or last, ( + 'You must provide a `first` or `last` value to properly paginate the `{}` connection.' + ).format(info.field_name) + + if max_limit: + if first: + assert first <= max_limit, ( + 'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' + ).format(first, info.field_name, max_limit) + args['first'] = min(first, max_limit) + + if last: + assert last <= max_limit, ( + 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' + ).format(first, info.field_name, max_limit) + args['last'] = min(last, max_limit) + iterable = resolver(root, args, context, info) if iterable is None: iterable = default_manager @@ -78,7 +109,14 @@ def connection_resolver(cls, resolver, connection, default_manager, root, args, return connection def get_resolver(self, parent_resolver): - return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager()) + return partial( + self.connection_resolver, + parent_resolver, + self.type, + self.get_manager(), + self.max_limit, + self.enforce_first_or_last + ) def get_connection_field(*args, **kwargs): diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index 061b2c652..fc414bf51 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -67,16 +67,35 @@ def merge_querysets(default_queryset, queryset): return queryset @classmethod - def connection_resolver(cls, resolver, connection, default_manager, filterset_class, filtering_args, + def connection_resolver(cls, resolver, connection, default_manager, max_limit, + enforce_first_or_last, filterset_class, filtering_args, root, args, context, info): filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} qs = filterset_class( data=filter_kwargs, queryset=default_manager.get_queryset() ).qs + return super(DjangoFilterConnectionField, cls).connection_resolver( - resolver, connection, qs, root, args, context, info) + resolver, + connection, + qs, + max_limit, + enforce_first_or_last, + root, + args, + context, + info + ) def get_resolver(self, parent_resolver): - return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager(), - self.filterset_class, self.filtering_args) + return partial( + self.connection_resolver, + parent_resolver, + self.type, + self.get_manager(), + self.max_limit, + self.enforce_first_or_last, + self.filterset_class, + self.filtering_args + ) diff --git a/graphene_django/settings.py b/graphene_django/settings.py index d83642a7b..46d70ee15 100644 --- a/graphene_django/settings.py +++ b/graphene_django/settings.py @@ -30,6 +30,11 @@ 'SCHEMA_OUTPUT': 'schema.json', 'SCHEMA_INDENT': None, 'MIDDLEWARE': (), + # Set to True if the connection fields must have + # either the first or last argument + 'RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST': False, + # Max items returned in ConnectionFields / FilterConnectionFields + 'RELAY_CONNECTION_MAX_LIMIT': 100, } if settings.DEBUG: diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 3594000a9..c1deebbcd 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -12,6 +12,7 @@ from ..compat import MissingType, JSONField from ..fields import DjangoConnectionField from ..types import DjangoObjectType +from ..settings import graphene_settings from .models import Article, Reporter pytestmark = pytest.mark.django_db @@ -452,3 +453,95 @@ class Query(graphene.ObjectType): result = schema.execute(query) assert not result.errors assert result.data == expected + + +def test_should_enforce_first_or_last(): + graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = True + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + r = Reporter.objects.create( + first_name='John', + last_name='Doe', + email='johndoe@example.com', + a_choice=1 + ) + + schema = graphene.Schema(query=Query) + query = ''' + query NodeFilteringQuery { + allReporters { + edges { + node { + id + } + } + } + } + ''' + + expected = { + 'allReporters': None + } + + result = schema.execute(query) + assert len(result.errors) == 1 + assert str(result.errors[0]) == ( + 'You must provide a `first` or `last` value to properly ' + 'paginate the `allReporters` connection.' + ) + assert result.data == expected + + +def test_should_error_if_first_is_greater_than_max(): + graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 100 + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + r = Reporter.objects.create( + first_name='John', + last_name='Doe', + email='johndoe@example.com', + a_choice=1 + ) + + schema = graphene.Schema(query=Query) + query = ''' + query NodeFilteringQuery { + allReporters(first: 101) { + edges { + node { + id + } + } + } + } + ''' + + expected = { + 'allReporters': None + } + + result = schema.execute(query) + assert len(result.errors) == 1 + assert str(result.errors[0]) == ( + 'Requesting 101 records on the `allReporters` connection ' + 'exceeds the `first` limit of 100 records.' + ) + assert result.data == expected + + graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False