Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit b2f8fe3

Browse files
committed
Cancel remaining list items on exceptions
1 parent 57c083d commit b2f8fe3

File tree

2 files changed

+71
-14
lines changed

2 files changed

+71
-14
lines changed

‎src/graphql/execution/execute.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -466,10 +466,10 @@ async def get_results() -> dict[str, Any]:
466466
field = awaitable_fields[0]
467467
results[field] = await results[field]
468468
else:
469-
tasks = {
470-
create_task(results[field]): field # type: ignore[arg-type]
469+
tasks = [
470+
create_task(results[field]) # type: ignore[arg-type]
471471
for field in awaitable_fields
472-
}
472+
]
473473

474474
try:
475475
awaited_results = await gather(*tasks)
@@ -1014,12 +1014,21 @@ async def get_completed_results() -> list[Any]:
10141014
index = awaitable_indices[0]
10151015
completed_results[index] = await completed_results[index]
10161016
else:
1017-
for index, sub_result in zip(
1018-
awaitable_indices,
1019-
await gather(
1020-
*(completed_results[index] for index in awaitable_indices)
1021-
),
1022-
):
1017+
tasks = [
1018+
create_task(completed_results[index]) for index in awaitable_indices
1019+
]
1020+
1021+
try:
1022+
awaited_results = await gather(*tasks)
1023+
except Exception:
1024+
# Cancel unfinished tasks before raising the exception
1025+
for task in tasks:
1026+
if not task.done():
1027+
task.cancel()
1028+
await gather(*tasks, return_exceptions=True)
1029+
raise
1030+
1031+
for index, sub_result in zip(awaitable_indices, awaited_results):
10231032
completed_results[index] = sub_result
10241033
return completed_results
10251034

‎tests/execution/test_parallel.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ async def resolve(*_args):
7676
# raises TimeoutError if not parallel
7777
awaitable_result = execute(schema, ast)
7878
assert isinstance(awaitable_result, Awaitable)
79-
result = await asyncio.wait_for(awaitable_result, 1.0)
79+
result = await asyncio.wait_for(awaitable_result, 1)
8080

8181
assert result == ({"foo": True, "bar": True}, None)
8282

@@ -125,7 +125,7 @@ async def resolve_list(*args):
125125
# raises TimeoutError if not parallel
126126
awaitable_result = execute(schema, ast)
127127
assert isinstance(awaitable_result, Awaitable)
128-
result = await asyncio.wait_for(awaitable_result, 1.0)
128+
result = await asyncio.wait_for(awaitable_result, 1)
129129

130130
assert result == ({"foo": [True, True]}, None)
131131

@@ -188,15 +188,15 @@ async def is_type_of_baz(obj, *_args):
188188
# raises TimeoutError if not parallel
189189
awaitable_result = execute(schema, ast)
190190
assert isinstance(awaitable_result, Awaitable)
191-
result = await asyncio.wait_for(awaitable_result, 1.0)
191+
result = await asyncio.wait_for(awaitable_result, 1)
192192

193193
assert result == (
194194
{"foo": [{"foo": "bar", "foobar": 1}, {"foo": "baz", "foobaz": 2}]},
195195
None,
196196
)
197197

198198
@pytest.mark.asyncio
199-
async def cancel_on_exception():
199+
async def cancel_selection_sets_on_exception():
200200
barrier = Barrier(2)
201201
completed = False
202202

@@ -222,13 +222,61 @@ async def fail(*_args):
222222

223223
awaitable_result = execute(schema, ast)
224224
assert isinstance(awaitable_result, Awaitable)
225-
result = await asyncio.wait_for(awaitable_result, 1.0)
225+
result = await asyncio.wait_for(awaitable_result, 1)
226226

227227
assert result == (
228228
None,
229229
[{"message": "Oops", "locations": [(1, 2)], "path": ["foo"]}],
230230
)
231231

232+
assert not completed
233+
234+
# Unblock succeed() and check that it does not complete
235+
await barrier.wait()
236+
await asyncio.sleep(0)
237+
assert not completed
238+
239+
@pytest.mark.asyncio
240+
async def cancel_lists_on_exception():
241+
barrier = Barrier(2)
242+
completed = False
243+
244+
async def succeed(*_args):
245+
nonlocal completed
246+
await barrier.wait()
247+
completed = True # pragma: no cover
248+
249+
async def fail(*_args):
250+
raise RuntimeError("Oops")
251+
252+
async def resolve_list(*args):
253+
return [fail(*args), succeed(*args)]
254+
255+
schema = GraphQLSchema(
256+
GraphQLObjectType(
257+
"Query",
258+
{
259+
"foo": GraphQLField(
260+
GraphQLList(GraphQLNonNull(GraphQLBoolean)),
261+
resolve=resolve_list,
262+
)
263+
},
264+
)
265+
)
266+
267+
ast = parse("{foo}")
268+
269+
awaitable_result = execute(schema, ast)
270+
assert isinstance(awaitable_result, Awaitable)
271+
result = await asyncio.wait_for(awaitable_result, 1)
272+
273+
assert result == (
274+
{"foo": None},
275+
[{"message": "Oops", "locations": [(1, 2)], "path": ["foo", 0]}],
276+
)
277+
278+
assert not completed
279+
232280
# Unblock succeed() and check that it does not complete
233281
await barrier.wait()
234282
await asyncio.sleep(0)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /