Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit e30f166

Browse files
committed
subscribe: stay synchronous when possible
This (breaking!) change aligns the return types of `execute` and `subscribe` (as well as `create_source_event_stream`) with respect to returning values or awaitables. Replicates graphql/graphql-js@6d42ced
1 parent 5950470 commit e30f166

File tree

4 files changed

+213
-72
lines changed

4 files changed

+213
-72
lines changed

‎src/graphql/execution/subscribe.py‎

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
from inspect import isawaitable
2-
from typing import Any, AsyncIterable, AsyncIterator, Dict, Optional, Type, Union
2+
from typing import (
3+
Any,
4+
AsyncIterable,
5+
AsyncIterator,
6+
Awaitable,
7+
Dict,
8+
Optional,
9+
Type,
10+
Union,
11+
cast,
12+
)
313

414
from ..error import GraphQLError, located_error
515
from ..execution.collect_fields import collect_fields
@@ -11,15 +21,15 @@
1121
)
1222
from ..execution.values import get_argument_values
1323
from ..language import DocumentNode
14-
from ..pyutils import Path, inspect
24+
from ..pyutils import AwaitableOrValue, Path, inspect
1525
from ..type import GraphQLFieldResolver, GraphQLSchema
1626
from .map_async_iterator import MapAsyncIterator
1727

1828

1929
__all__ = ["subscribe", "create_source_event_stream"]
2030

2131

22-
asyncdef subscribe(
32+
def subscribe(
2333
schema: GraphQLSchema,
2434
document: DocumentNode,
2535
root_value: Any = None,
@@ -29,7 +39,7 @@ async def subscribe(
2939
field_resolver: Optional[GraphQLFieldResolver] = None,
3040
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
3141
execution_context_class: Optional[Type[ExecutionContext]] = None,
32-
) -> Union[AsyncIterator[ExecutionResult], ExecutionResult]:
42+
) -> AwaitableOrValue[Union[AsyncIterator[ExecutionResult], ExecutionResult]]:
3343
"""Create a GraphQL subscription.
3444
3545
Implements the "Subscribe" algorithm described in the GraphQL spec.
@@ -49,7 +59,7 @@ async def subscribe(
4959
If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
5060
a stream of ExecutionResults representing the response stream.
5161
"""
52-
result_or_stream = awaitcreate_source_event_stream(
62+
result_or_stream = create_source_event_stream(
5363
schema,
5464
document,
5565
root_value,
@@ -59,8 +69,6 @@ async def subscribe(
5969
subscribe_field_resolver,
6070
execution_context_class,
6171
)
62-
if isinstance(result_or_stream, ExecutionResult):
63-
return result_or_stream
6472

6573
async def map_source_to_response(payload: Any) -> ExecutionResult:
6674
"""Map source to response.
@@ -84,11 +92,28 @@ async def map_source_to_response(payload: Any) -> ExecutionResult:
8492
)
8593
return await result if isawaitable(result) else result
8694

95+
if (execution_context_class or ExecutionContext).is_awaitable(result_or_stream):
96+
awaitable_result_or_stream = cast(Awaitable, result_or_stream)
97+
98+
# noinspection PyShadowingNames
99+
async def await_result() -> Any:
100+
result_or_stream = await awaitable_result_or_stream
101+
if isinstance(result_or_stream, ExecutionResult):
102+
return result_or_stream
103+
return MapAsyncIterator(result_or_stream, map_source_to_response)
104+
105+
return await_result()
106+
107+
if isinstance(result_or_stream, ExecutionResult):
108+
return result_or_stream
109+
87110
# Map every source value to a ExecutionResult value as described above.
88-
return MapAsyncIterator(result_or_stream, map_source_to_response)
111+
return MapAsyncIterator(
112+
cast(AsyncIterable[Any], result_or_stream), map_source_to_response
113+
)
89114

90115

91-
asyncdef create_source_event_stream(
116+
def create_source_event_stream(
92117
schema: GraphQLSchema,
93118
document: DocumentNode,
94119
root_value: Any = None,
@@ -97,7 +122,7 @@ async def create_source_event_stream(
97122
operation_name: Optional[str] = None,
98123
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
99124
execution_context_class: Optional[Type[ExecutionContext]] = None,
100-
) -> Union[AsyncIterable[Any], ExecutionResult]:
125+
) -> AwaitableOrValue[Union[AsyncIterable[Any], ExecutionResult]]:
101126
"""Create source event stream
102127
103128
Implements the "CreateSourceEventStream" algorithm described in the GraphQL
@@ -145,12 +170,28 @@ async def create_source_event_stream(
145170
return ExecutionResult(data=None, errors=context)
146171

147172
try:
148-
returnawait execute_subscription(context)
173+
event_stream= execute_subscription(context)
149174
except GraphQLError as error:
150175
return ExecutionResult(data=None, errors=[error])
151176

177+
if context.is_awaitable(event_stream):
178+
awaitable_event_stream = cast(Awaitable, event_stream)
179+
180+
# noinspection PyShadowingNames
181+
async def await_event_stream() -> Union[AsyncIterable[Any], ExecutionResult]:
182+
try:
183+
return await awaitable_event_stream
184+
except GraphQLError as error:
185+
return ExecutionResult(data=None, errors=[error])
152186

153-
async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
187+
return await_event_stream()
188+
189+
return event_stream
190+
191+
192+
def execute_subscription(
193+
context: ExecutionContext,
194+
) -> AwaitableOrValue[AsyncIterable[Any]]:
154195
schema = context.schema
155196

156197
root_type = schema.subscription_type
@@ -191,19 +232,33 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
191232
# AsyncIterable yielding raw payloads.
192233
resolve_fn = field_def.subscribe or context.subscribe_field_resolver
193234

194-
event_stream = resolve_fn(context.root_value, info, **args)
195-
if context.is_awaitable(event_stream):
196-
event_stream = await event_stream
197-
if isinstance(event_stream, Exception):
198-
raise event_stream
235+
result = resolve_fn(context.root_value, info, **args)
236+
if context.is_awaitable(result):
199237

200-
# Assert field returned an event stream, otherwise yield an error.
201-
if not isinstance(event_stream, AsyncIterable):
202-
raise GraphQLError(
203-
"Subscription field must return AsyncIterable."
204-
f" Received: {inspect(event_stream)}."
205-
)
238+
# noinspection PyShadowingNames
239+
async def await_result() -> AsyncIterable[Any]:
240+
try:
241+
return assert_event_stream(await result)
242+
except Exception as error:
243+
raise located_error(error, field_nodes, path.as_list())
244+
245+
return await_result()
246+
247+
return assert_event_stream(result)
206248

207-
return event_stream
208249
except Exception as error:
209250
raise located_error(error, field_nodes, path.as_list())
251+
252+
253+
def assert_event_stream(result: Any) -> AsyncIterable:
254+
if isinstance(result, Exception):
255+
raise result
256+
257+
# Assert field returned an event stream, otherwise yield an error.
258+
if not isinstance(result, AsyncIterable):
259+
raise GraphQLError(
260+
"Subscription field must return AsyncIterable."
261+
f" Received: {inspect(result)}."
262+
)
263+
264+
return result

‎tests/execution/test_customize.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class Root:
6767
async def custom_foo():
6868
yield {"foo": "FooValue"}
6969

70-
subscription = awaitsubscribe(
70+
subscription = subscribe(
7171
schema,
7272
document=parse("subscription { foo }"),
7373
root_value=Root(),
@@ -111,7 +111,7 @@ def resolve_foo(message, _info):
111111
)
112112

113113
document = parse("subscription { foo }")
114-
subscription = awaitsubscribe(
114+
subscription = subscribe(
115115
schema,
116116
document,
117117
context_value={},

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /