88 TYPE_CHECKING ,
99 Any ,
1010 AsyncGenerator ,
11+ NamedTuple ,
1112 Sequence ,
1213 cast ,
1314)
3738 DeferredFragmentRecord ,
3839 DeferredGroupedFieldSetResult ,
3940 IncrementalDataRecord ,
41+ IncrementalDataRecordResult ,
4042 IncrementalResult ,
4143 ReconcilableDeferredGroupedFieldSetResult ,
4244 StreamItemsResult ,
@@ -60,6 +62,14 @@ class IncrementalPublisherContext(Protocol):
6062 cancellable_streams : set [CancellableStreamRecord ] | None
6163
6264
65+ class SubsequentIncrementalExecutionResultContext (NamedTuple ):
66+ """The context for subsequent incremental execution results."""
67+ 68+ pending : list [PendingResult ]
69+ incremental : list [IncrementalResult ]
70+ completed : list [CompletedResult ]
71+ 72+ 6373class IncrementalPublisher :
6474 """Publish incremental results.
6575
@@ -72,15 +82,11 @@ class IncrementalPublisher:
7282 _context : IncrementalPublisherContext
7383 _next_id : int
7484 _incremental_graph : IncrementalGraph
75- _incremental : list [IncrementalResult ]
76- _completed : list [CompletedResult ]
7785
7886 def __init__ (self , context : IncrementalPublisherContext ) -> None :
7987 self ._context = context
8088 self ._next_id = 0
8189 self ._incremental_graph = IncrementalGraph ()
82- self ._incremental = []
83- self ._completed = []
8490
8591 def build_response (
8692 self ,
@@ -131,36 +137,26 @@ async def _subscribe(
131137 self ,
132138 ) -> AsyncGenerator [SubsequentIncrementalExecutionResult , None ]:
133139 """Subscribe to the incremental results."""
140+ incremental_graph = self ._incremental_graph
141+ check_has_next = incremental_graph .has_next
142+ handle_completed_incremental_data = self ._handle_completed_incremental_data
143+ completed_incremental_data = incremental_graph .completed_incremental_data ()
144+ # use the raw iterator rather than 'async for' so as not to end the iterator
145+ # when exiting the loop with the next value
146+ get_next_results = completed_incremental_data .__aiter__ ().__anext__
147+ is_done = False
134148 try :
135- incremental_graph = self ._incremental_graph
136- get_new_pending = incremental_graph .get_new_pending
137- check_has_next = incremental_graph .has_next
138- pending_sources_to_results = self ._pending_sources_to_results
139- completed_incremental_data = incremental_graph .completed_incremental_data ()
140- # use the raw iterator rather than 'async for' so as not to end the iterator
141- # when exiting the loop with the next value
142- get_next_results = completed_incremental_data .__aiter__ ().__anext__
143- is_done = False
144149 while not is_done :
145150 try :
146151 completed_results = await get_next_results ()
147152 except StopAsyncIteration : # pragma: no cover
148153 break
149- pending : list [PendingResult ] = []
150154
155+ context = SubsequentIncrementalExecutionResultContext ([], [], [])
151156 for completed_result in completed_results :
152- if is_deferred_grouped_field_set_result (completed_result ):
153- self ._handle_completed_deferred_grouped_field_set (
154- completed_result
155- )
156- else :
157- completed_result = cast ("StreamItemsResult" , completed_result )
158- await self ._handle_completed_stream_items (completed_result )
157+ await handle_completed_incremental_data (completed_result , context )
159158
160- new_pending = get_new_pending ()
161- pending .extend (pending_sources_to_results (new_pending ))
162- 163- if self ._incremental or self ._completed :
159+ if context .incremental or context .completed :
164160 has_next = check_has_next ()
165161
166162 if not has_next :
@@ -169,15 +165,12 @@ async def _subscribe(
169165 subsequent_incremental_execution_result = (
170166 SubsequentIncrementalExecutionResult (
171167 has_next = has_next ,
172- pending = pending or None ,
173- incremental = self . _incremental or None ,
174- completed = self . _completed or None ,
168+ pending = context . pending or None ,
169+ incremental = context . incremental or None ,
170+ completed = context . completed or None ,
175171 )
176172 )
177173
178- self ._incremental = []
179- self ._completed = []
180- 181174 yield subsequent_incremental_execution_result
182175 finally :
183176 await self ._stop_async_iterators ()
@@ -194,12 +187,34 @@ async def _stop_async_iterators(self) -> None:
194187 if early_returns :
195188 await gather (* early_returns , return_exceptions = True )
196189
190+ async def _handle_completed_incremental_data (
191+ self ,
192+ completed_incremental_data : IncrementalDataRecordResult ,
193+ context : SubsequentIncrementalExecutionResultContext ,
194+ ) -> None :
195+ if is_deferred_grouped_field_set_result (completed_incremental_data ):
196+ self ._handle_completed_deferred_grouped_field_set (
197+ completed_incremental_data , context
198+ )
199+ else :
200+ completed_incremental_data = cast (
201+ "StreamItemsResult" , completed_incremental_data
202+ )
203+ await self ._handle_completed_stream_items (
204+ completed_incremental_data , context
205+ )
206+ 207+ new_pending = self ._incremental_graph .get_new_pending ()
208+ context .pending .extend (self ._pending_sources_to_results (new_pending ))
209+ 197210 def _handle_completed_deferred_grouped_field_set (
198- self , deferred_grouped_field_set_result : DeferredGroupedFieldSetResult
211+ self ,
212+ deferred_grouped_field_set_result : DeferredGroupedFieldSetResult ,
213+ context : SubsequentIncrementalExecutionResultContext ,
199214 ) -> None :
200215 """Handle completed deferred grouped field set result."""
201- append_completed = self . _completed .append
202- append_incremental = self . _incremental .append
216+ append_completed = context . completed .append
217+ append_incremental = context . incremental .append
203218 if is_non_reconcilable_deferred_grouped_field_set_result (
204219 deferred_grouped_field_set_result
205220 ):
@@ -260,7 +275,9 @@ def _handle_completed_deferred_grouped_field_set(
260275 append_completed (CompletedResult (id_ ))
261276
262277 async def _handle_completed_stream_items (
263- self , stream_items_result : StreamItemsResult
278+ self ,
279+ stream_items_result : StreamItemsResult ,
280+ context : SubsequentIncrementalExecutionResultContext ,
264281 ) -> None :
265282 """Handle completed stream."""
266283 stream_record = stream_items_result .stream_record
@@ -269,7 +286,7 @@ async def _handle_completed_stream_items(
269286 return # pragma: no cover
270287 incremental_graph = self ._incremental_graph
271288 if stream_items_result .errors is not None :
272- self . _completed .append (CompletedResult (id_ , stream_items_result .errors ))
289+ context . completed .append (CompletedResult (id_ , stream_items_result .errors ))
273290 incremental_graph .remove_subsequent_result_record (stream_record )
274291 if is_cancellable_stream_record (stream_record ):
275292 cancellable_streams = self ._context .cancellable_streams
@@ -278,7 +295,7 @@ async def _handle_completed_stream_items(
278295 with suppress (Exception ):
279296 await stream_record .early_return
280297 elif stream_items_result .result is None :
281- self . _completed .append (CompletedResult (id_ ))
298+ context . completed .append (CompletedResult (id_ ))
282299 incremental_graph .remove_subsequent_result_record (stream_record )
283300 if is_cancellable_stream_record (stream_record ):
284301 cancellable_streams = self ._context .cancellable_streams
@@ -289,7 +306,7 @@ async def _handle_completed_stream_items(
289306 incremental_entry = IncrementalStreamResult (
290307 items = result .items , id = id_ , errors = result .errors
291308 )
292- self . _incremental .append (incremental_entry )
309+ context . incremental .append (incremental_entry )
293310 if stream_items_result .incremental_data_records : # pragma: no branch
294311 incremental_graph .add_incremental_data_records (
295312 stream_items_result .incremental_data_records
0 commit comments