@@ -331,6 +331,23 @@ def build_response(
331
331
)
332
332
return ExecutionResult (data , errors )
333
333
334
+ def build_per_event_execution_context (self , payload : Any ) -> ExecutionContext :
335
+ """Create a copy of the execution context for usage with subscribe events."""
336
+ return self .__class__ (
337
+ self .schema ,
338
+ self .fragments ,
339
+ payload ,
340
+ self .context_value ,
341
+ self .operation ,
342
+ self .variable_values ,
343
+ self .field_resolver ,
344
+ self .type_resolver ,
345
+ self .subscribe_field_resolver ,
346
+ [],
347
+ self .middleware_manager ,
348
+ self .is_awaitable ,
349
+ )
350
+
334
351
def execute_operation (self ) -> AwaitableOrValue [Any ]:
335
352
"""Execute an operation.
336
353
@@ -1003,7 +1020,7 @@ def execute(
1003
1020
1004
1021
# If a valid execution context cannot be created due to incorrect arguments,
1005
1022
# a "Response" with only errors is returned.
1006
- exe_context = execution_context_class .build (
1023
+ context = execution_context_class .build (
1007
1024
schema ,
1008
1025
document ,
1009
1026
root_value ,
@@ -1018,9 +1035,14 @@ def execute(
1018
1035
)
1019
1036
1020
1037
# Return early errors if execution context failed.
1021
- if isinstance (exe_context , list ):
1022
- return ExecutionResult (data = None , errors = exe_context )
1038
+ if isinstance (context , list ):
1039
+ return ExecutionResult (data = None , errors = context )
1040
+
1041
+ return execute_impl (context )
1042
+
1023
1043
1044
+ def execute_impl (context : ExecutionContext ) -> AwaitableOrValue [ExecutionResult ]:
1045
+ """Execute GraphQL operation (internal implementation)."""
1024
1046
# Return a possible coroutine object that will eventually yield the data described
1025
1047
# by the "Response" section of the GraphQL specification.
1026
1048
#
@@ -1032,12 +1054,12 @@ def execute(
1032
1054
# Errors from sub-fields of a NonNull type may propagate to the top level,
1033
1055
# at which point we still log the error and null the parent field, which
1034
1056
# in this case is the entire response.
1035
- errors = exe_context .errors
1036
- build_response = exe_context .build_response
1057
+ errors = context .errors
1058
+ build_response = context .build_response
1037
1059
try :
1038
- result = exe_context .execute_operation ()
1060
+ result = context .execute_operation ()
1039
1061
1040
- if exe_context .is_awaitable (result ):
1062
+ if context .is_awaitable (result ):
1041
1063
# noinspection PyShadowingNames
1042
1064
async def await_result () -> Any :
1043
1065
try :
@@ -1215,6 +1237,7 @@ def subscribe(
1215
1237
variable_values : Optional [Dict [str , Any ]] = None ,
1216
1238
operation_name : Optional [str ] = None ,
1217
1239
field_resolver : Optional [GraphQLFieldResolver ] = None ,
1240
+ type_resolver : Optional [GraphQLTypeResolver ] = None ,
1218
1241
subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
1219
1242
execution_context_class : Optional [Type [ExecutionContext ]] = None ,
1220
1243
) -> AwaitableOrValue [Union [AsyncIterator [ExecutionResult ], ExecutionResult ]]:
@@ -1237,17 +1260,31 @@ def subscribe(
1237
1260
If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
1238
1261
a stream of ExecutionResults representing the response stream.
1239
1262
"""
1240
- result_or_stream = create_source_event_stream (
1263
+ if execution_context_class is None :
1264
+ execution_context_class = ExecutionContext
1265
+
1266
+ # If a valid context cannot be created due to incorrect arguments,
1267
+ # a "Response" with only errors is returned.
1268
+ context = execution_context_class .build (
1241
1269
schema ,
1242
1270
document ,
1243
1271
root_value ,
1244
1272
context_value ,
1245
1273
variable_values ,
1246
1274
operation_name ,
1275
+ field_resolver ,
1276
+ type_resolver ,
1247
1277
subscribe_field_resolver ,
1248
- execution_context_class ,
1249
1278
)
1250
1279
1280
+ # Return early errors if execution context failed.
1281
+ if isinstance (context , list ):
1282
+ return ExecutionResult (data = None , errors = context )
1283
+
1284
+ result_or_stream = create_source_event_stream_impl (context )
1285
+
1286
+ build_context = context .build_per_event_execution_context
1287
+
1251
1288
async def map_source_to_response (payload : Any ) -> ExecutionResult :
1252
1289
"""Map source to response.
1253
1290
@@ -1258,19 +1295,10 @@ async def map_source_to_response(payload: Any) -> ExecutionResult:
1258
1295
"ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
1259
1296
"ExecuteQuery" algorithm, for which :func:`~graphql.execute` is also used.
1260
1297
"""
1261
- result = execute (
1262
- schema ,
1263
- document ,
1264
- payload ,
1265
- context_value ,
1266
- variable_values ,
1267
- operation_name ,
1268
- field_resolver ,
1269
- execution_context_class = execution_context_class ,
1270
- )
1298
+ result = execute_impl (build_context (payload ))
1271
1299
return await result if isawaitable (result ) else result
1272
1300
1273
- if ( execution_context_class or ExecutionContext ) .is_awaitable (result_or_stream ):
1301
+ if execution_context_class .is_awaitable (result_or_stream ):
1274
1302
awaitable_result_or_stream = cast (Awaitable , result_or_stream )
1275
1303
1276
1304
# noinspection PyShadowingNames
@@ -1298,6 +1326,8 @@ def create_source_event_stream(
1298
1326
context_value : Any = None ,
1299
1327
variable_values : Optional [Dict [str , Any ]] = None ,
1300
1328
operation_name : Optional [str ] = None ,
1329
+ field_resolver : Optional [GraphQLFieldResolver ] = None ,
1330
+ type_resolver : Optional [GraphQLTypeResolver ] = None ,
1301
1331
subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
1302
1332
execution_context_class : Optional [Type [ExecutionContext ]] = None ,
1303
1333
) -> AwaitableOrValue [Union [AsyncIterable [Any ], ExecutionResult ]]:
@@ -1336,13 +1366,22 @@ def create_source_event_stream(
1336
1366
context_value ,
1337
1367
variable_values ,
1338
1368
operation_name ,
1339
- subscribe_field_resolver = subscribe_field_resolver ,
1369
+ field_resolver ,
1370
+ type_resolver ,
1371
+ subscribe_field_resolver ,
1340
1372
)
1341
1373
1342
1374
# Return early errors if execution context failed.
1343
1375
if isinstance (context , list ):
1344
1376
return ExecutionResult (data = None , errors = context )
1345
1377
1378
+ return create_source_event_stream_impl (context )
1379
+
1380
+
1381
+ def create_source_event_stream_impl (
1382
+ context : ExecutionContext ,
1383
+ ) -> AwaitableOrValue [Union [AsyncIterable [Any ], ExecutionResult ]]:
1384
+ """Create source event stream (internal implementation)."""
1346
1385
try :
1347
1386
event_stream = execute_subscription (context )
1348
1387
except GraphQLError as error :
0 commit comments