Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 3818653

Browse files
committed
Allow injecting custom data to custom execution context (#226)
1 parent c685d84 commit 3818653

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

‎src/graphql/execution/execute.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from asyncio import ensure_future, gather, shield, wait_for
66
from contextlib import suppress
7+
from copy import copy
78
from typing import (
89
Any,
910
AsyncGenerator,
@@ -219,6 +220,7 @@ def build(
219220
subscribe_field_resolver: GraphQLFieldResolver | None = None,
220221
middleware: Middleware | None = None,
221222
is_awaitable: Callable[[Any], bool] | None = None,
223+
**custom_args: Any,
222224
) -> list[GraphQLError] | ExecutionContext:
223225
"""Build an execution context
224226
@@ -292,24 +294,14 @@ def build(
292294
IncrementalPublisher(),
293295
middleware_manager,
294296
is_awaitable,
297+
**custom_args,
295298
)
296299

297300
def build_per_event_execution_context(self, payload: Any) -> ExecutionContext:
298301
"""Create a copy of the execution context for usage with subscribe events."""
299-
return self.__class__(
300-
self.schema,
301-
self.fragments,
302-
payload,
303-
self.context_value,
304-
self.operation,
305-
self.variable_values,
306-
self.field_resolver,
307-
self.type_resolver,
308-
self.subscribe_field_resolver,
309-
self.incremental_publisher,
310-
self.middleware_manager,
311-
self.is_awaitable,
312-
)
302+
context = copy(self)
303+
context.root_value = payload
304+
return context
313305

314306
def execute_operation(
315307
self, initial_result_record: InitialResultRecord
@@ -1709,6 +1701,7 @@ def execute(
17091701
middleware: Middleware | None = None,
17101702
execution_context_class: type[ExecutionContext] | None = None,
17111703
is_awaitable: Callable[[Any], bool] | None = None,
1704+
**custom_context_args: Any,
17121705
) -> AwaitableOrValue[ExecutionResult]:
17131706
"""Execute a GraphQL operation.
17141707
@@ -1741,6 +1734,7 @@ def execute(
17411734
middleware,
17421735
execution_context_class,
17431736
is_awaitable,
1737+
**custom_context_args,
17441738
)
17451739
if isinstance(result, ExecutionResult):
17461740
return result
@@ -1769,6 +1763,7 @@ def experimental_execute_incrementally(
17691763
middleware: Middleware | None = None,
17701764
execution_context_class: type[ExecutionContext] | None = None,
17711765
is_awaitable: Callable[[Any], bool] | None = None,
1766+
**custom_context_args: Any,
17721767
) -> AwaitableOrValue[ExecutionResult | ExperimentalIncrementalExecutionResults]:
17731768
"""Execute GraphQL operation incrementally (internal implementation).
17741769
@@ -1797,6 +1792,7 @@ def experimental_execute_incrementally(
17971792
subscribe_field_resolver,
17981793
middleware,
17991794
is_awaitable,
1795+
**custom_context_args,
18001796
)
18011797

18021798
# Return early errors if execution context failed.
@@ -2127,6 +2123,7 @@ def subscribe(
21272123
subscribe_field_resolver: GraphQLFieldResolver | None = None,
21282124
execution_context_class: type[ExecutionContext] | None = None,
21292125
middleware: MiddlewareManager | None = None,
2126+
**custom_context_args: Any,
21302127
) -> AwaitableOrValue[AsyncIterator[ExecutionResult] | ExecutionResult]:
21312128
"""Create a GraphQL subscription.
21322129
@@ -2167,6 +2164,7 @@ def subscribe(
21672164
type_resolver,
21682165
subscribe_field_resolver,
21692166
middleware=middleware,
2167+
**custom_context_args,
21702168
)
21712169

21722170
# Return early errors if execution context failed.
@@ -2202,6 +2200,7 @@ def create_source_event_stream(
22022200
type_resolver: GraphQLTypeResolver | None = None,
22032201
subscribe_field_resolver: GraphQLFieldResolver | None = None,
22042202
execution_context_class: type[ExecutionContext] | None = None,
2203+
**custom_context_args: Any,
22052204
) -> AwaitableOrValue[AsyncIterable[Any] | ExecutionResult]:
22062205
"""Create source event stream
22072206
@@ -2238,6 +2237,7 @@ def create_source_event_stream(
22382237
field_resolver,
22392238
type_resolver,
22402239
subscribe_field_resolver,
2240+
**custom_context_args,
22412241
)
22422242

22432243
# Return early errors if execution context failed.

‎tests/execution/test_customize.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def uses_a_custom_execution_context_class():
4343
)
4444

4545
class TestExecutionContext(ExecutionContext):
46+
def __init__(self, *args, **kwargs):
47+
assert kwargs.pop("custom_arg", None) == "baz"
48+
super().__init__(*args, **kwargs)
49+
4650
def execute_field(
4751
self,
4852
parent_type,
@@ -62,7 +66,12 @@ def execute_field(
6266
)
6367
return result * 2 # type: ignore
6468

65-
assert execute(schema, query, execution_context_class=TestExecutionContext) == (
69+
assert execute(
70+
schema,
71+
query,
72+
execution_context_class=TestExecutionContext,
73+
custom_arg="baz",
74+
) == (
6675
{"foo": "barbar"},
6776
None,
6877
)
@@ -101,6 +110,10 @@ async def custom_foo():
101110
@pytest.mark.asyncio
102111
async def uses_a_custom_execution_context_class():
103112
class TestExecutionContext(ExecutionContext):
113+
def __init__(self, *args, **kwargs):
114+
assert kwargs.pop("custom_arg", None) == "baz"
115+
super().__init__(*args, **kwargs)
116+
104117
def build_resolve_info(self, *args, **kwargs):
105118
resolve_info = super().build_resolve_info(*args, **kwargs)
106119
resolve_info.context["foo"] = "bar"
@@ -132,6 +145,7 @@ def resolve_foo(message, _info):
132145
document,
133146
context_value={},
134147
execution_context_class=TestExecutionContext,
148+
custom_arg="baz",
135149
)
136150
assert isasyncgen(subscription)
137151

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /