-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add create, update and delete subscriptions
Use graphene-luna for testing purposes. Also add a generic "DjangoSignalSubscription" type that allows you to subscribe to any Django signal. Signed-off-by: Tormod Haugland <[email protected]>
- Loading branch information
Showing
9 changed files
with
694 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import graphene | ||
from graphql import GraphQLError | ||
|
||
|
||
class SubscriptionField(graphene.Field): | ||
""" | ||
This is an extension of the graphene.Field class that exists | ||
to allow our DjangoCudSubscriptionBase classes to pass a subscribe | ||
method to the Field instantiation, which we use here in the | ||
`wrap_subscribe` method. `wrap_subscribe` is called internally in graphene | ||
to figure out which resolver to use for a subscription field. | ||
""" | ||
|
||
def __init__(self, *args, subscribe=None, **kwargs): | ||
self.subscribe = subscribe | ||
super().__init__(*args, **kwargs) | ||
|
||
def wrap_subscribe(self, parent_subscribe): | ||
return self.subscribe | ||
|
||
|
||
class DjangoCudSubscriptionBase(graphene.ObjectType): | ||
"""Base class for DjangoCud subscriptions""" | ||
|
||
@classmethod | ||
def get_permissions(cls, root, info, *args, **kwargs): | ||
return cls._meta.permissions | ||
|
||
@classmethod | ||
def check_permissions(cls, root, info, *args, **kwargs) -> None: | ||
get_permissions = getattr(cls, "get_permissions", None) | ||
if not callable(get_permissions): | ||
raise TypeError("The `get_permissions` attribute of a subscription must be callable.") | ||
|
||
permissions = cls.get_permissions(root, info, *args, **kwargs) | ||
|
||
if permissions and len(permissions) > 0: | ||
if not info.context.user.has_perms(permissions): | ||
raise GraphQLError("Not permitted to access this subscription.") | ||
|
||
@classmethod | ||
def Field(cls, name=None, description=None, deprecation_reason=None, required=False): | ||
"""Create a field for the subscription that automatically creates a subscription resolver""" | ||
return SubscriptionField( | ||
cls._meta.output, | ||
resolver=cls._meta.resolver, | ||
subscribe=cls._meta.subscribe, | ||
name=name, | ||
description=description or cls._meta.description, | ||
deprecation_reason=deprecation_reason, | ||
required=required, | ||
) | ||
|
||
@classmethod | ||
async def subscribe(cls, *args, **kwargs): | ||
"""Dummy subscribe method. Must be implemented by subclasses""" | ||
raise NotImplementedError("`subscribe` must be implemented by the implementing subclass. " | ||
"This is likely a bug in graphene-django-cud.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import asyncio | ||
from collections import OrderedDict | ||
from typing import Optional | ||
|
||
import graphene | ||
from asgiref.sync import async_to_sync | ||
from django.conf import settings | ||
from django.db.models.signals import post_save | ||
from django.dispatch import Signal | ||
from graphene.types.objecttype import ObjectTypeOptions | ||
from graphene_django.registry import get_global_registry | ||
|
||
from graphene_django_cud.consts import USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS_KEY | ||
from graphene_django_cud.signals import post_create_mutation | ||
from graphene_django_cud.subscriptions.core import DjangoCudSubscriptionBase | ||
from graphene_django_cud.util import to_snake_case | ||
|
||
|
||
class DjangoCreateSubscriptionOptions(ObjectTypeOptions): | ||
model = None | ||
return_field_name = None | ||
permissions = None | ||
signal: Optional[Signal] = None | ||
|
||
|
||
class DjangoCreateSubscription(DjangoCudSubscriptionBase): | ||
# All active subscriptions are stored in this centralized dictionary. | ||
# We need to do this to keep track of which subscriptions are listening to | ||
# which signals. | ||
subscribers = {} | ||
|
||
@classmethod | ||
def __init_subclass_with_meta__( | ||
cls, | ||
_meta=None, | ||
model=None, | ||
permissions=None, | ||
return_field_name=None, | ||
signal=post_create_mutation if getattr( | ||
settings, | ||
USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS_KEY, | ||
False | ||
) else post_save, | ||
**kwargs, | ||
): | ||
registry = get_global_registry() | ||
model_type = registry.get_type_for_model(model) | ||
|
||
if not _meta: | ||
_meta = DjangoCreateSubscriptionOptions(cls) | ||
|
||
if not return_field_name: | ||
return_field_name = to_snake_case(model.__name__) | ||
|
||
output_fields = OrderedDict() | ||
output_fields[return_field_name] = graphene.Field(model_type) | ||
|
||
_meta.model = model | ||
_meta.model_type = model_type | ||
_meta.fields = output_fields | ||
_meta.output = cls | ||
_meta.permissions = permissions | ||
|
||
# Importantly, this needs to be set to either nothing or the identity. | ||
# Internally in graphene it will be defaulted to the identity function. If it | ||
# isn't, graphene will try to pass the value resolve from the "subscribe" method | ||
# through this resolver. If it is also set to "subscribe", we will get an issue with | ||
# graphene trying to return an AsyncIterator. | ||
_meta.resolver = None | ||
|
||
# This is set to be the subscription resolver in the SubscriptionField class. | ||
_meta.subscribe = cls.subscribe | ||
_meta.return_field_name = return_field_name | ||
|
||
# Connect to the model's post_save (or your custom) signal | ||
signal.connect(cls._model_created_handler, sender=model) | ||
|
||
super().__init_subclass_with_meta__(_meta=_meta, **kwargs) | ||
|
||
@classmethod | ||
def _model_created_handler(cls, sender, instance, created=None, **kwargs): | ||
"""Handle model creation and notify subscribers""" | ||
if created or created is None: | ||
print(sender, instance, created, kwargs) | ||
new_instance = cls.handle_object_created(sender, instance, **kwargs) | ||
|
||
assert new_instance is None or isinstance(new_instance, cls._meta.model) | ||
|
||
if new_instance: | ||
instance = new_instance | ||
|
||
# Notify all subscribers for the model | ||
for subscriber in cls.subscribers.get(sender, []): | ||
async_to_sync(subscriber)(instance) | ||
|
||
@classmethod | ||
def handle_object_created(cls, sender, instance, **kwargs): | ||
"""Handle and modify any instance created""" | ||
pass | ||
|
||
@classmethod | ||
def check_permissions(cls, root, info, *args, **kwargs) -> None: | ||
return super().check_permissions(root, info, *args, **kwargs) | ||
|
||
@classmethod | ||
async def subscribe(cls, root, info, *args, **kwargs): | ||
"""Subscribe to the model creation events asynchronously""" | ||
|
||
cls.check_permissions(root, info, *args, **kwargs) | ||
|
||
model = cls._meta.model | ||
queue = asyncio.Queue() | ||
|
||
# Ensure there's a list of subscribers for the model | ||
if model not in cls.subscribers: | ||
cls.subscribers[model] = [] | ||
|
||
# Add the queue's put method to the subscribers for this model | ||
cls.subscribers[model].append(queue.put) | ||
|
||
try: | ||
while True: | ||
# Wait for the next model instance to be created | ||
instance = await queue.get() | ||
data = {cls._meta.return_field_name: instance} | ||
yield cls(**data) | ||
finally: | ||
# Clean up the subscriber when the subscription ends | ||
cls.subscribers[model].remove(queue.put) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import asyncio | ||
from collections import OrderedDict | ||
from typing import Optional | ||
|
||
import graphene | ||
from asgiref.sync import async_to_sync | ||
from django.db.models.signals import post_save, post_delete | ||
from graphene.types.objecttype import ObjectTypeOptions | ||
from graphene.types.utils import yank_fields_from_attrs | ||
from graphene_django.registry import get_global_registry | ||
from requests import delete | ||
|
||
from graphene_django_cud.subscriptions.core import DjangoCudSubscriptionBase | ||
from graphene_django_cud.util import to_snake_case | ||
|
||
from graphene_django_cud.util.dict import get_any_of | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DjangoDeleteSubscriptionOptions(ObjectTypeOptions): | ||
model = None | ||
return_field_name = None | ||
permissions = None | ||
signal = None | ||
|
||
|
||
class DjangoDeleteSubscription(DjangoCudSubscriptionBase): | ||
# All active subscriptions are stored in this centralized dictionary. | ||
# We need to do this to keep track of which subscriptions are listening to | ||
# which signals. | ||
subscribers = {} | ||
|
||
@classmethod | ||
def __init_subclass_with_meta__( | ||
cls, | ||
_meta=None, | ||
model=None, | ||
permissions=None, | ||
return_field_name=None, | ||
signal=post_delete, | ||
**kwargs, | ||
): | ||
registry = get_global_registry() | ||
model_type = registry.get_type_for_model(model) | ||
|
||
if not _meta: | ||
_meta = DjangoDeleteSubscriptionOptions(cls) | ||
|
||
if not return_field_name: | ||
return_field_name = to_snake_case(model.__name__) | ||
|
||
output_fields = OrderedDict() | ||
output_fields["id"] = graphene.String() | ||
|
||
_meta.model = model | ||
_meta.model_type = model_type | ||
_meta.fields = yank_fields_from_attrs(output_fields, _as=graphene.Field) | ||
_meta.output = cls | ||
_meta.permissions = permissions | ||
|
||
# Importantly, this needs to be set to either nothing or the identity. | ||
# Internally in graphene it will be defaulted to the identity function. | ||
_meta.resolver = None | ||
|
||
# This is set to be the subscription resolver in the SubscriptionField class. | ||
_meta.subscribe = cls.subscribe | ||
_meta.return_field_name = return_field_name | ||
|
||
# Connect to the model's post_save signal | ||
signal.connect(cls._model_deleted_handler, sender=model) | ||
|
||
super().__init_subclass_with_meta__(_meta=_meta, **kwargs) | ||
|
||
@classmethod | ||
def _model_deleted_handler(cls, sender, *args, **kwargs): | ||
"""Handle model updating and notify subscribers""" | ||
|
||
Model = cls._meta.model | ||
|
||
instance: Optional[Model] = kwargs.get("instance", None) or next(filter( | ||
lambda x: isinstance(x, Model), args | ||
), None) | ||
|
||
deleted_id = get_any_of( | ||
kwargs, | ||
[ | ||
"pk", | ||
"raw_id", | ||
"input_id", | ||
"id" | ||
] | ||
) if not instance else get_any_of( | ||
instance, | ||
[ | ||
"pk", | ||
"id", | ||
] | ||
) | ||
|
||
print(kwargs, args, deleted_id) | ||
|
||
if deleted_id is None: | ||
logger.warning("Received a delete signal for a model without an instance or an id being passed to the " | ||
"signal handler. Are you using a compatible signal? Read the documentation for " | ||
"graphene-django-cud for more information.") | ||
return | ||
|
||
new_deleted_id = cls.handle_object_deleted(sender, deleted_id, **kwargs) | ||
|
||
if new_deleted_id is not None: | ||
deleted_id = new_deleted_id | ||
|
||
# Notify all subscribers for the model | ||
for subscriber in cls.subscribers.get(sender, []): | ||
async_to_sync(subscriber)(deleted_id) | ||
|
||
@classmethod | ||
def handle_object_deleted(cls, sender, deleted_id, **kwargs): | ||
"""Handle and modify any instance created""" | ||
pass | ||
|
||
@classmethod | ||
def check_permissions(cls, root, info, *args, **kwargs) -> None: | ||
return super().check_permissions(root, info, *args, **kwargs) | ||
|
||
@classmethod | ||
async def subscribe(cls, root, info, *args, **kwargs): | ||
"""Subscribe to the model creation events asynchronously""" | ||
|
||
cls.check_permissions(root, info, *args, **kwargs) | ||
|
||
model = cls._meta.model | ||
queue = asyncio.Queue() | ||
|
||
# Ensure there's a list of subscribers for the model | ||
if model not in cls.subscribers: | ||
cls.subscribers[model] = [] | ||
|
||
# Add the queue's put method to the subscribers for this model | ||
cls.subscribers[model].append(queue.put) | ||
|
||
try: | ||
while True: | ||
# Wait for the next model instance to be deleted | ||
_id = await queue.get() | ||
|
||
yield cls(id=_id) | ||
finally: | ||
# Clean up the subscriber when the subscription ends | ||
cls.subscribers[model].remove(queue.put) |
Oops, something went wrong.