Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit e86f811

Browse files
Rewrite MapAsyncIterable using async generator semantics
1 parent a9b9568 commit e86f811

File tree

2 files changed

+48
-149
lines changed

2 files changed

+48
-149
lines changed
Lines changed: 22 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from __future__ import annotations # Python < 3.10
22

3-
from asyncio import CancelledError, Event, Task, ensure_future, wait
4-
from concurrent.futures import FIRST_COMPLETED
5-
from inspect import isasyncgen, isawaitable
3+
from inspect import isawaitable
64
from types import TracebackType
7-
from typing import Any, AsyncIterable, Callable, Optional, Set, Type, Union, cast
5+
from typing import Any, AsyncIterable, Callable, Optional, Type, Union
86

97

108
__all__ = ["MapAsyncIterable"]
119

1210

11+
# The following is a class because its type is checked in the code.
12+
# otherwise, it could be implemented as a simple async generator function
13+
1314
# noinspection PyAttributeOutsideInit
1415
class MapAsyncIterable:
1516
"""Map an AsyncIterable over a callback function.
@@ -22,97 +23,39 @@ class MapAsyncIterable:
2223
"""
2324

2425
def __init__(self, iterable: AsyncIterable, callback: Callable) -> None:
25-
self.iterator = iterable.__aiter__()
26+
self.iterable = iterable
2627
self.callback = callback
27-
self._close_event = Event()
28+
self._ageniter = self._agen()
29+
self.is_closed = False # used by unittests
2830

2931
def __aiter__(self) -> MapAsyncIterable:
3032
"""Get the iterator object."""
3133
return self
3234

3335
async def __anext__(self) -> Any:
3436
"""Get the next value of the iterator."""
35-
if self.is_closed:
36-
if not isasyncgen(self.iterator):
37-
raise StopAsyncIteration
38-
value = await self.iterator.__anext__()
39-
else:
40-
aclose = ensure_future(self._close_event.wait())
41-
anext = ensure_future(self.iterator.__anext__())
42-
43-
try:
44-
pending: Set[Task] = (
45-
await wait([aclose, anext], return_when=FIRST_COMPLETED)
46-
)[1]
47-
except CancelledError:
48-
# cancel underlying tasks and close
49-
aclose.cancel()
50-
anext.cancel()
51-
await self.aclose()
52-
raise # re-raise the cancellation
53-
54-
for task in pending:
55-
task.cancel()
56-
57-
if aclose.done():
58-
raise StopAsyncIteration
59-
60-
error = anext.exception()
61-
if error:
62-
raise error
63-
64-
value = anext.result()
65-
66-
result = self.callback(value)
67-
68-
return await result if isawaitable(result) else result
37+
return await self._ageniter.__anext__()
38+
39+
async def _agen(self) -> Any:
40+
try:
41+
async for v in self.iterable:
42+
result = self.callback(v)
43+
yield (await result) if isawaitable(result) else result
44+
finally:
45+
self.is_closed = True
46+
if hasattr(self.iterable, "aclose"):
47+
await self.iterable.aclose()
6948

49+
# This is not a standard method and is only used in unittests. Should be removed.
7050
async def athrow(
7151
self,
7252
type_: Union[BaseException, Type[BaseException]],
7353
value: Optional[BaseException] = None,
7454
traceback: Optional[TracebackType] = None,
7555
) -> None:
7656
"""Throw an exception into the asynchronous iterator."""
77-
if self.is_closed:
78-
return
79-
athrow = getattr(self.iterator, "athrow", None)
80-
if athrow:
81-
await athrow(type_, value, traceback)
82-
else:
83-
await self.aclose()
84-
if value is None:
85-
if traceback is None:
86-
raise type_
87-
value = (
88-
type_
89-
if isinstance(value, BaseException)
90-
else cast(Type[BaseException], type_)()
91-
)
92-
if traceback is not None:
93-
value = value.with_traceback(traceback)
94-
raise value
57+
await self._ageniter.athrow(type_, value, traceback)
9558

9659
async def aclose(self) -> None:
9760
"""Close the iterator."""
98-
if not self.is_closed:
99-
aclose = getattr(self.iterator, "aclose", None)
100-
if aclose:
101-
try:
102-
await aclose()
103-
except RuntimeError:
104-
pass
105-
self.is_closed = True
106-
107-
@property
108-
def is_closed(self) -> bool:
109-
"""Check whether the iterator is closed."""
110-
return self._close_event.is_set()
111-
112-
@is_closed.setter
113-
def is_closed(self, value: bool) -> None:
114-
"""Mark the iterator as closed."""
115-
if value:
116-
self._close_event.set()
117-
else:
118-
self._close_event.clear()
61+
await self._ageniter.aclose()

‎tests/execution/test_map_async_iterable.py

Lines changed: 26 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -133,29 +133,26 @@ async def __anext__(self):
133133
with raises(StopAsyncIteration):
134134
await anext(doubles)
135135

136+
# async iterators must not yield after aclose() is called
136137
@mark.asyncio
137-
async def passes_through_early_return_from_async_values():
138+
async def ignored_generator_exit():
138139
async def source():
139140
try:
140141
yield 1
141142
yield 2
142143
yield 3 # pragma: no cover
143144
finally:
144145
yield "Done"
145-
yield "Last"
146+
yield "Last"# pragma: no cover
146147

147148
doubles = MapAsyncIterable(source(), lambda x: x + x)
148149

149150
assert await anext(doubles) == 2
150151
assert await anext(doubles) == 4
151152

152-
# Early return
153-
await doubles.aclose()
154-
155-
# Subsequent next calls may yield from finally block
156-
assert await anext(doubles) == "LastLast"
157-
with raises(GeneratorExit):
158-
assert await anext(doubles)
153+
with raises(RuntimeError) as exc_info:
154+
await doubles.aclose()
155+
assert str(exc_info.value) == "async generator ignored GeneratorExit"
159156

160157
@mark.asyncio
161158
async def allows_throwing_errors_through_async_iterable():
@@ -256,12 +253,8 @@ async def source():
256253
assert await anext(doubles) == 4
257254

258255
# Throw error
259-
await doubles.athrow(RuntimeError("ouch"))
260-
261-
with raises(StopAsyncIteration):
262-
await anext(doubles)
263-
with raises(StopAsyncIteration):
264-
await anext(doubles)
256+
with raises(RuntimeError):
257+
await doubles.athrow(RuntimeError("ouch"))
265258

266259
@mark.asyncio
267260
async def does_not_normally_map_over_thrown_errors():
@@ -394,65 +387,28 @@ async def source():
394387
await sleep(0.05)
395388
assert not doubles_future.done()
396389

397-
# Unblock and watch StopAsyncIteration propagate
398-
await doubles.aclose()
399-
await sleep(0.05)
400-
assert doubles_future.done()
401-
assert isinstance(doubles_future.exception(), StopAsyncIteration)
390+
# with python 3.8 and higher, close() cannot be used to unblock a generator.
391+
# instead, the task should be killed. AsyncGenerators are not re-entrant.
392+
if sys.version_info[:2] >= (3, 8):
393+
with raises(RuntimeError):
394+
await doubles.aclose()
395+
doubles_future.cancel()
396+
await sleep(0.05)
397+
assert doubles_future.done()
398+
with raises(CancelledError):
399+
doubles_future.exception()
400+
401+
else:
402+
# old behaviour, where aclose() could unblock a Task
403+
# Unblock and watch StopAsyncIteration propagate
404+
await doubles.aclose()
405+
await sleep(0.05)
406+
assert doubles_future.done()
407+
assert isinstance(doubles_future.exception(), StopAsyncIteration)
402408

403409
with raises(StopAsyncIteration):
404410
await anext(singles)
405411

406-
@mark.asyncio
407-
async def can_unset_closed_state_of_async_iterable():
408-
items = [1, 2, 3]
409-
410-
class Iterable:
411-
def __init__(self):
412-
self.is_closed = False
413-
414-
def __aiter__(self):
415-
return self
416-
417-
async def __anext__(self):
418-
if self.is_closed:
419-
raise StopAsyncIteration
420-
try:
421-
return items.pop(0)
422-
except IndexError:
423-
raise StopAsyncIteration
424-
425-
async def aclose(self):
426-
self.is_closed = True
427-
428-
iterable = Iterable()
429-
doubles = MapAsyncIterable(iterable, lambda x: x + x)
430-
431-
assert await anext(doubles) == 2
432-
assert await anext(doubles) == 4
433-
assert not iterable.is_closed
434-
await doubles.aclose()
435-
assert iterable.is_closed
436-
with raises(StopAsyncIteration):
437-
await anext(iterable)
438-
with raises(StopAsyncIteration):
439-
await anext(doubles)
440-
assert doubles.is_closed
441-
442-
iterable.is_closed = False
443-
doubles.is_closed = False
444-
assert not doubles.is_closed
445-
446-
assert await anext(doubles) == 6
447-
assert not doubles.is_closed
448-
assert not iterable.is_closed
449-
with raises(StopAsyncIteration):
450-
await anext(iterable)
451-
with raises(StopAsyncIteration):
452-
await anext(doubles)
453-
assert not doubles.is_closed
454-
assert not iterable.is_closed
455-
456412
@mark.asyncio
457413
async def can_cancel_async_iterable_while_waiting():
458414
class Iterable:

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /