Skip to content

Commit

Permalink
Graphene Federation v2 Support Added
Browse files Browse the repository at this point in the history
  • Loading branch information
arunsureshkumar authored and arun-sureshkumar committed Oct 6, 2022
2 parents 889eda8 + 4a44e2a commit bb72bbb
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 80 deletions.
6 changes: 6 additions & 0 deletions graphene_mongo/advanced_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class FileFieldType(graphene.ObjectType):
length = graphene.Int()
data = graphene.String()

# Support Graphene Federation v2
_shareable = True

@classmethod
def _resolve_fs_field(cls, field, name, default_value=None):
v = getattr(field.instance, field.key)
Expand Down Expand Up @@ -37,6 +40,9 @@ def resolve_data(self, info):
class _CoordinatesTypeField(graphene.ObjectType):
type = graphene.String()

# Support Graphene Federation v2
_shareable = True

def resolve_type(self, info):
return self["type"]

Expand Down
183 changes: 143 additions & 40 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@
from collections import OrderedDict
from functools import partial, reduce

import bson
import graphene
import mongoengine
from bson import DBRef, ObjectId
from graphene import Context
from graphene.types.utils import get_type
from graphene.utils.str_converters import to_snake_case
from graphql import GraphQLResolveInfo
from mongoengine.base import get_document
from promise import Promise
from graphql_relay import from_global_id
from graphene.relay import ConnectionField
from graphene.types.argument import to_arguments
from graphene.types.dynamic import Dynamic
from graphene.types.structures import Structure
from graphql_relay.connection.array_connection import cursor_to_offset
from graphene.types.utils import get_type
from graphene.utils.str_converters import to_snake_case
from graphql import GraphQLResolveInfo
from graphql_relay import from_global_id
from graphql_relay.connection.arrayconnection import cursor_to_offset
from mongoengine import QuerySet
from mongoengine.base import get_document
from promise import Promise
from pymongo.errors import OperationFailure

from .advanced_types import (
FileFieldType,
Expand All @@ -30,6 +32,9 @@
from .registry import get_global_registry
from .utils import get_model_reference_fields, get_query_fields, find_skip_and_limit, \
connection_from_iterables
import pymongo

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])


class MongoengineConnectionField(ConnectionField):
Expand Down Expand Up @@ -77,9 +82,27 @@ def registry(self):

@property
def args(self):
_field_args = self.field_args
_advance_args = self.advance_args
_filter_args = self.filter_args
_extended_args = self.extended_args
if self._type._meta.non_filter_fields:
for _field in self._type._meta.non_filter_fields:
if _field in _field_args:
_field_args.pop(_field)
if _field in _advance_args:
_advance_args.pop(_field)
if _field in _filter_args:
_filter_args.pop(_field)
if _field in _extended_args:
_filter_args.pop(_field)
extra_args = dict(dict(dict(_field_args, **_advance_args), **_filter_args), **_extended_args)

for key in list(self._base_args.keys()):
extra_args.pop(key, None)
return to_arguments(
self._base_args or OrderedDict(),
dict(dict(dict(self.field_args, **self.advance_args), **self.filter_args), **self.extended_args),
extra_args
)

@args.setter
Expand All @@ -100,6 +123,14 @@ def is_filterable(k):
return False
if not hasattr(self.model, k):
return False
else:
# else section is a patch for federated field error
field_ = self.fields[k]
type_ = field_.type
while hasattr(type_, "of_type"):
type_ = type_.of_type
if hasattr(type_, "_sdl") and "@key" in type_._sdl:
return False
if isinstance(getattr(self.model, k), property):
return False
try:
Expand Down Expand Up @@ -128,6 +159,9 @@ def is_filterable(k):
getattr(converted, "_of_type", None), graphene.Union
):
return False
# below if condition: workaround for DB filterable field redefined as custom graphene type
if hasattr(field_, 'type') and hasattr(converted, 'type') and converted.type != field_.type:
return False
return True

def get_filter_type(_type):
Expand All @@ -150,7 +184,7 @@ def filter_args(self):
if self._type._meta.filter_fields:
for field, filter_collection in self._type._meta.filter_fields.items():
for each in filter_collection:
if str(self._type._meta.fields[field].type) == 'PointFieldType':
if str(self._type._meta.fields[field].type) in ('PointFieldType', 'PointFieldType!'):
if each == 'max_distance':
filter_type = graphene.Int
else:
Expand Down Expand Up @@ -279,17 +313,17 @@ def get_queryset(self, model, info, required_fields=None, skip=None, limit=None,
skip)
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)

def default_resolver(self, _root, info, required_fields=None, **args):
def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
if required_fields is None:
required_fields = list()
args = args or {}
for key, value in dict(args).items():
if value is None:
del args[key]
if _root is not None:
if _root is not None and not resolved:
field_name = to_snake_case(info.field_name)
if not hasattr(_root, "_fields_ordered"):
if getattr(_root, field_name, []) is not None:
if isinstance(getattr(_root, field_name, []), list):
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]
elif field_name in _root._fields_ordered and not (isinstance(_root._fields[field_name].field,
mongoengine.EmbeddedDocumentField) or
Expand All @@ -316,25 +350,33 @@ def default_resolver(self, _root, info, required_fields=None, **args):
before = args.pop("before", None)
if before:
before = cursor_to_offset(before)
if callable(getattr(self.model, "objects", None)):
if "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
count=count)
if limit:
if reverse:
args["pk__in"] = args["pk__in"][::-1][skip:skip + limit]
else:
args["pk__in"] = args["pk__in"][skip:skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
list_length = len(iterables)
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info = info._replace(context=Context())
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
elif _root is None or args:

if resolved is not None:
items = resolved

if isinstance(items, QuerySet):
try:
count = items.count(with_limit_and_skip=True)
except OperationFailure:
count = len(items)
else:
count = len(items)

skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
count=count)

if limit:
if reverse:
items = items[::-1][skip:skip + limit]
else:
items = items[skip:skip + limit]
elif skip:
items = items[skip:]
iterables = items
list_length = len(iterables)

elif callable(getattr(self.model, "objects", None)):
if _root is None or args or isinstance(getattr(_root, field_name, []), MongoengineConnectionField):
args_copy = args.copy()
for key in args.copy():
if key not in self.model._fields_ordered:
Expand All @@ -346,8 +388,20 @@ def default_resolver(self, _root, info, required_fields=None, **args):
mongoengine.fields.LazyReferenceField) or isinstance(getattr(self.model, key),
mongoengine.fields.CachedReferenceField):
if not isinstance(args_copy[key], ObjectId):
args_copy[key] = from_global_id(args_copy[key])[1]
count = mongoengine.get_db()[self.model._get_collection_name()].count_documents(args_copy)
_from_global_id = from_global_id(args_copy[key])[1]
if bson.objectid.ObjectId.is_valid(_from_global_id):
args_copy[key] = ObjectId(_from_global_id)
else:
args_copy[key] = _from_global_id
elif isinstance(getattr(self.model, key),
mongoengine.fields.EnumField):
if getattr(args_copy[key], "value", None):
args_copy[key] = args_copy[key].value

if PYMONGO_VERSION >= (3, 7):
count = (mongoengine.get_db()[self.model._get_collection_name()]).count_documents(args_copy)
else:
count = self.model.objects(args_copy).count()
if count != 0:
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
count=count)
Expand All @@ -358,6 +412,24 @@ def default_resolver(self, _root, info, required_fields=None, **args):
info = info._replace(context=Context())
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)

elif "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
count=count)
if limit:
if reverse:
args["pk__in"] = args["pk__in"][::-1][skip:skip + limit]
else:
args["pk__in"] = args["pk__in"][skip:skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
list_length = len(iterables)
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info = info._replace(context=Context())
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)

elif _root is not None:
field_name = to_snake_case(info.field_name)
items = getattr(_root, field_name, [])
Expand All @@ -373,6 +445,7 @@ def default_resolver(self, _root, info, required_fields=None, **args):
items = items[skip:]
iterables = items
list_length = len(iterables)

has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False
has_previous_page = True if skip else False
if reverse:
Expand All @@ -391,31 +464,42 @@ def default_resolver(self, _root, info, required_fields=None, **args):
return connection

def chained_resolver(self, resolver, is_partial, root, info, **args):

for key, value in dict(args).items():
if value is None:
del args[key]

required_fields = list()

for field in self.required_fields:
if field in self.model._fields_ordered:
required_fields.append(field)

for field in get_query_fields(info):
if to_snake_case(field) in self.model._fields_ordered:
required_fields.append(to_snake_case(field))

args_copy = args.copy()

if not bool(args) or not is_partial:
if isinstance(self.model, mongoengine.Document) or isinstance(self.model,
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):

from itertools import filterfalse
connection_fields = [field for field in self.fields if
type(self.fields[field]) == MongoengineConnectionField]
filterable_args = tuple(filterfalse(connection_fields.__contains__, list(self.model._fields_ordered)))
for arg_name, arg in args.copy().items():
if arg_name not in self.model._fields_ordered + tuple(self.filter_args.keys()):
if arg_name not in filterable_args + tuple(self.filter_args.keys()):
args_copy.pop(arg_name)
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info = info._replace(context=Context())
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args_copy)

# XXX: Filter nested args
resolved = resolver(root, info, **args)

if resolved is not None:
if isinstance(resolved, list):
if resolved == list():
Expand All @@ -428,36 +512,55 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
args.update(resolved._query)
args_copy = args.copy()
for arg_name, arg in args.copy().items():
if arg_name not in self.model._fields_ordered + ('first', 'last', 'before', 'after') + tuple(
self.filter_args.keys()):
if "." in arg_name or arg_name not in self.model._fields_ordered + (
'first', 'last', 'before', 'after') + tuple(
self.filter_args.keys()):
args_copy.pop(arg_name)
if arg_name == '_id' and isinstance(arg, dict):
operation = list(arg.keys())[0]
args_copy['pk' + operation.replace('$', '__')] = arg[operation]
if not isinstance(arg, ObjectId) and '.' in arg_name:
operation = list(arg.keys())[0]
args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[operation]
if type(arg) == dict:
operation = list(arg.keys())[0]
args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[
operation]
else:
args_copy[arg_name.replace('.', '__')] = arg
elif '.' in arg_name and isinstance(arg, ObjectId):
args_copy[arg_name.replace('.', '__')] = arg
else:
operations = ["$lte", "$gte", "$ne", "$in"]
if isinstance(arg, dict) and any(op in arg for op in operations):
operation = list(arg.keys())[0]
args_copy[arg_name + operation.replace('$', '__')] = arg[operation]
del args_copy[arg_name]
return self.default_resolver(root, info, required_fields, **args_copy)
return self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy)
elif isinstance(resolved, Promise):
return resolved.value
else:
return resolved

return self.default_resolver(root, info, required_fields, **args)

@classmethod
def connection_resolver(cls, resolver, connection_type, root, info, **args):
if root:
for key, value in root.__dict__.items():
if value:
try:
setattr(root, key, from_global_id(value)[1])
except Exception as error:
pass
iterable = resolver(root, info, **args)

if isinstance(connection_type, graphene.NonNull):
connection_type = connection_type.of_type

on_resolve = partial(cls.resolve_connection, connection_type, args)

if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)

return on_resolve(iterable)

def get_resolver(self, parent_resolver):
Expand Down
5 changes: 5 additions & 0 deletions graphene_mongo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ def register_enum(self, cls):
assert type(cls) == EnumMeta, 'Only EnumMeta can be registered, received "{}"'.format(
cls.__name__
)
if not cls.__name__.endswith('Enum'):
name = cls.__name__ + 'Enum'
else:
name = cls.__name__
cls.__name__ = name
self._registry_enum[cls] = Enum.from_enum(cls)

def get_type_for_model(self, model):
Expand Down
Loading

0 comments on commit bb72bbb

Please sign in to comment.