Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 48ed790

Browse files
committed
Cancel remaining iterator items on exceptions
1 parent b2f8fe3 commit 48ed790

File tree

2 files changed

+63
-7
lines changed

2 files changed

+63
-7
lines changed

‎src/graphql/execution/execute.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -911,13 +911,22 @@ async def complete_async_iterator_value(
911911
index = awaitable_indices[0]
912912
completed_results[index] = await completed_results[index]
913913
else:
914-
for index, result in zip(
915-
awaitable_indices,
916-
await gather(
917-
*(completed_results[index] for index in awaitable_indices)
918-
),
919-
):
920-
completed_results[index] = result
914+
tasks = [
915+
create_task(completed_results[index]) for index in awaitable_indices
916+
]
917+
918+
try:
919+
awaited_results = await gather(*tasks)
920+
except Exception:
921+
# Cancel unfinished tasks before raising the exception
922+
for task in tasks:
923+
if not task.done():
924+
task.cancel()
925+
await gather(*tasks, return_exceptions=True)
926+
raise
927+
928+
for index, sub_result in zip(awaitable_indices, awaited_results):
929+
completed_results[index] = sub_result
921930
return completed_results
922931

923932
def complete_list_value(

‎tests/execution/test_parallel.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,50 @@ async def resolve_list(*args):
281281
await barrier.wait()
282282
await asyncio.sleep(0)
283283
assert not completed
284+
285+
@pytest.mark.asyncio
286+
async def cancel_async_iterator_on_exception():
287+
barrier = Barrier(2)
288+
completed = False
289+
290+
async def succeed(*_args):
291+
nonlocal completed
292+
await barrier.wait()
293+
completed = True # pragma: no cover
294+
295+
async def fail(*_args):
296+
raise RuntimeError("Oops")
297+
298+
async def resolve_iterator(*args):
299+
yield fail(*args)
300+
yield succeed(*args)
301+
302+
schema = GraphQLSchema(
303+
GraphQLObjectType(
304+
"Query",
305+
{
306+
"foo": GraphQLField(
307+
GraphQLList(GraphQLNonNull(GraphQLBoolean)),
308+
resolve=resolve_iterator,
309+
)
310+
},
311+
)
312+
)
313+
314+
ast = parse("{foo}")
315+
316+
awaitable_result = execute(schema, ast)
317+
assert isinstance(awaitable_result, Awaitable)
318+
result = await asyncio.wait_for(awaitable_result, 1)
319+
320+
assert result == (
321+
{"foo": None},
322+
[{"message": "Oops", "locations": [(1, 2)], "path": ["foo", 0]}],
323+
)
324+
325+
assert not completed
326+
327+
# Unblock succeed() and check that it does not complete
328+
await barrier.wait()
329+
await asyncio.sleep(0)
330+
assert not completed

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /