Skip to content

Commit cefbdf1

Browse files
committed
refactor: subscribe: introduce build_per_event_execution_context
Replicates graphql/graphql-js@c1fe951
1 parent 51b93b8 commit cefbdf1

File tree

1 file changed

+60
-21
lines changed

1 file changed

+60
-21
lines changed

src/graphql/execution/execute.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,23 @@ def build_response(
331331
)
332332
return ExecutionResult(data, errors)
333333

334+
def build_per_event_execution_context(self, payload: Any) -> ExecutionContext:
335+
"""Create a copy of the execution context for usage with subscribe events."""
336+
return self.__class__(
337+
self.schema,
338+
self.fragments,
339+
payload,
340+
self.context_value,
341+
self.operation,
342+
self.variable_values,
343+
self.field_resolver,
344+
self.type_resolver,
345+
self.subscribe_field_resolver,
346+
[],
347+
self.middleware_manager,
348+
self.is_awaitable,
349+
)
350+
334351
def execute_operation(self) -> AwaitableOrValue[Any]:
335352
"""Execute an operation.
336353
@@ -1003,7 +1020,7 @@ def execute(
10031020

10041021
# If a valid execution context cannot be created due to incorrect arguments,
10051022
# a "Response" with only errors is returned.
1006-
exe_context = execution_context_class.build(
1023+
context = execution_context_class.build(
10071024
schema,
10081025
document,
10091026
root_value,
@@ -1018,9 +1035,14 @@ def execute(
10181035
)
10191036

10201037
# Return early errors if execution context failed.
1021-
if isinstance(exe_context, list):
1022-
return ExecutionResult(data=None, errors=exe_context)
1038+
if isinstance(context, list):
1039+
return ExecutionResult(data=None, errors=context)
1040+
1041+
return execute_impl(context)
1042+
10231043

1044+
def execute_impl(context: ExecutionContext) -> AwaitableOrValue[ExecutionResult]:
1045+
"""Execute GraphQL operation (internal implementation)."""
10241046
# Return a possible coroutine object that will eventually yield the data described
10251047
# by the "Response" section of the GraphQL specification.
10261048
#
@@ -1032,12 +1054,12 @@ def execute(
10321054
# Errors from sub-fields of a NonNull type may propagate to the top level,
10331055
# at which point we still log the error and null the parent field, which
10341056
# in this case is the entire response.
1035-
errors = exe_context.errors
1036-
build_response = exe_context.build_response
1057+
errors = context.errors
1058+
build_response = context.build_response
10371059
try:
1038-
result = exe_context.execute_operation()
1060+
result = context.execute_operation()
10391061

1040-
if exe_context.is_awaitable(result):
1062+
if context.is_awaitable(result):
10411063
# noinspection PyShadowingNames
10421064
async def await_result() -> Any:
10431065
try:
@@ -1215,6 +1237,7 @@ def subscribe(
12151237
variable_values: Optional[Dict[str, Any]] = None,
12161238
operation_name: Optional[str] = None,
12171239
field_resolver: Optional[GraphQLFieldResolver] = None,
1240+
type_resolver: Optional[GraphQLTypeResolver] = None,
12181241
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
12191242
execution_context_class: Optional[Type[ExecutionContext]] = None,
12201243
) -> AwaitableOrValue[Union[AsyncIterator[ExecutionResult], ExecutionResult]]:
@@ -1237,17 +1260,31 @@ def subscribe(
12371260
If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
12381261
a stream of ExecutionResults representing the response stream.
12391262
"""
1240-
result_or_stream = create_source_event_stream(
1263+
if execution_context_class is None:
1264+
execution_context_class = ExecutionContext
1265+
1266+
# If a valid context cannot be created due to incorrect arguments,
1267+
# a "Response" with only errors is returned.
1268+
context = execution_context_class.build(
12411269
schema,
12421270
document,
12431271
root_value,
12441272
context_value,
12451273
variable_values,
12461274
operation_name,
1275+
field_resolver,
1276+
type_resolver,
12471277
subscribe_field_resolver,
1248-
execution_context_class,
12491278
)
12501279

1280+
# Return early errors if execution context failed.
1281+
if isinstance(context, list):
1282+
return ExecutionResult(data=None, errors=context)
1283+
1284+
result_or_stream = create_source_event_stream_impl(context)
1285+
1286+
build_context = context.build_per_event_execution_context
1287+
12511288
async def map_source_to_response(payload: Any) -> ExecutionResult:
12521289
"""Map source to response.
12531290
@@ -1258,19 +1295,10 @@ async def map_source_to_response(payload: Any) -> ExecutionResult:
12581295
"ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
12591296
"ExecuteQuery" algorithm, for which :func:`~graphql.execute` is also used.
12601297
"""
1261-
result = execute(
1262-
schema,
1263-
document,
1264-
payload,
1265-
context_value,
1266-
variable_values,
1267-
operation_name,
1268-
field_resolver,
1269-
execution_context_class=execution_context_class,
1270-
)
1298+
result = execute_impl(build_context(payload))
12711299
return await result if isawaitable(result) else result
12721300

1273-
if (execution_context_class or ExecutionContext).is_awaitable(result_or_stream):
1301+
if execution_context_class.is_awaitable(result_or_stream):
12741302
awaitable_result_or_stream = cast(Awaitable, result_or_stream)
12751303

12761304
# noinspection PyShadowingNames
@@ -1298,6 +1326,8 @@ def create_source_event_stream(
12981326
context_value: Any = None,
12991327
variable_values: Optional[Dict[str, Any]] = None,
13001328
operation_name: Optional[str] = None,
1329+
field_resolver: Optional[GraphQLFieldResolver] = None,
1330+
type_resolver: Optional[GraphQLTypeResolver] = None,
13011331
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
13021332
execution_context_class: Optional[Type[ExecutionContext]] = None,
13031333
) -> AwaitableOrValue[Union[AsyncIterable[Any], ExecutionResult]]:
@@ -1336,13 +1366,22 @@ def create_source_event_stream(
13361366
context_value,
13371367
variable_values,
13381368
operation_name,
1339-
subscribe_field_resolver=subscribe_field_resolver,
1369+
field_resolver,
1370+
type_resolver,
1371+
subscribe_field_resolver,
13401372
)
13411373

13421374
# Return early errors if execution context failed.
13431375
if isinstance(context, list):
13441376
return ExecutionResult(data=None, errors=context)
13451377

1378+
return create_source_event_stream_impl(context)
1379+
1380+
1381+
def create_source_event_stream_impl(
1382+
context: ExecutionContext,
1383+
) -> AwaitableOrValue[Union[AsyncIterable[Any], ExecutionResult]]:
1384+
"""Create source event stream (internal implementation)."""
13461385
try:
13471386
event_stream = execute_subscription(context)
13481387
except GraphQLError as error:

0 commit comments

Comments
 (0)