Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8c052dc

Browse files
fix type for streaming to conform with Microsoft Chat Protocol
1 parent 9af3296 commit 8c052dc

File tree

5 files changed

+138
-105
lines changed

5 files changed

+138
-105
lines changed

‎src/backend/fastapi_app/api_models.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
from enum import Enum
12
from typing import Any
23

34
from openai.types.chat import ChatCompletionMessageParam
45
from pydantic import BaseModel
56

67

8+
class AIChatRoles(str, Enum):
9+
USER = "user"
10+
ASSISTANT = "assistant"
11+
SYSTEM = "system"
12+
13+
714
class Message(BaseModel):
815
content: str
9-
role: str = "user"
16+
role: AIChatRoles = AIChatRoles.USER
1017

1118

1219
class ChatRequest(BaseModel):
@@ -32,6 +39,12 @@ class RetrievalResponse(BaseModel):
3239
session_state: Any | None = None
3340

3441

42+
class RetrievalResponseDelta(BaseModel):
43+
delta: Message | None = None
44+
context: RAGContext | None = None
45+
session_state: Any | None = None
46+
47+
3548
class ItemPublic(BaseModel):
3649
id: int
3750
type: str

‎src/backend/fastapi_app/rag_advanced.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
66
from openai_messages_token_helper import build_messages, get_token_limit
77

8-
from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
8+
from fastapi_app.api_models import (
9+
AIChatRoles,
10+
Message,
11+
RAGContext,
12+
RetrievalResponse,
13+
RetrievalResponseDelta,
14+
ThoughtStep,
15+
)
916
from fastapi_app.postgres_models import Item
1017
from fastapi_app.postgres_searcher import PostgresSearcher
1118
from fastapi_app.query_rewriter import build_search_function, extract_search_arguments
@@ -110,10 +117,10 @@ async def run(
110117
stream=False,
111118
)
112119

113-
first_choice_message = chat_completion_response.choices[0].message
114-
115120
return RetrievalResponse(
116-
message=Message(content=str(first_choice_message.content), role=first_choice_message.role),
121+
message=Message(
122+
content=str(chat_completion_response.choices[0].message.content), role=AIChatRoles.ASSISTANT
123+
),
117124
context=RAGContext(
118125
data_points={item.id: item.to_dict() for item in results},
119126
thoughts=[
@@ -157,7 +164,7 @@ async def run_stream(
157164
self,
158165
messages: list[ChatCompletionMessageParam],
159166
overrides: dict[str, Any] = {},
160-
) -> AsyncGenerator[RetrievalResponse|Message, None]:
167+
) -> AsyncGenerator[RetrievalResponseDelta, None]:
161168
chat_params = self.get_params(messages, overrides)
162169

163170
# Generate an optimized keyword search query based on the chat history and the last question
@@ -188,8 +195,7 @@ async def run_stream(
188195
# The connection closes when it returns back to the context manger in the dependencies
189196
await self.searcher.db_session.close()
190197

191-
yield RetrievalResponse(
192-
message=Message(content="", role="assistant"),
198+
yield RetrievalResponseDelta(
193199
context=RAGContext(
194200
data_points={item.id: item.to_dict() for item in results},
195201
thoughts=[
@@ -230,7 +236,9 @@ async def run_stream(
230236
)
231237

232238
async for response_chunk in chat_completion_async_stream:
233-
# first response has empty choices
234-
if response_chunk.choices:
235-
yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant")
239+
# first response has empty choices and last response has empty content
240+
if response_chunk.choices and response_chunk.choices[0].delta.content:
241+
yield RetrievalResponseDelta(
242+
delta=Message(content=str(response_chunk.choices[0].delta.content), role=AIChatRoles.ASSISTANT)
243+
)
236244
return

‎src/backend/fastapi_app/rag_simple.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
from openai_messages_token_helper import build_messages, get_token_limit
99
from pydantic import BaseModel
1010

11-
from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
11+
from fastapi_app.api_models import (
12+
AIChatRoles,
13+
Message,
14+
RAGContext,
15+
RetrievalResponse,
16+
RetrievalResponseDelta,
17+
ThoughtStep,
18+
)
1219
from fastapi_app.postgres_models import Item
1320
from fastapi_app.postgres_searcher import PostgresSearcher
1421

@@ -76,7 +83,7 @@ async def run_stream(
7683
self,
7784
messages: list[ChatCompletionMessageParam],
7885
overrides: dict[str, Any] = {},
79-
) -> AsyncGenerator[RetrievalResponse|Message, None]:
86+
) -> AsyncGenerator[RetrievalResponseDelta, None]:
8087
raise NotImplementedError
8188
if False:
8289
yield 0
@@ -145,10 +152,10 @@ async def run(
145152
stream=False,
146153
)
147154

148-
first_choice_message = chat_completion_response.choices[0].message
149-
150155
return RetrievalResponse(
151-
message=Message(content=str(first_choice_message.content), role=first_choice_message.role),
156+
message=Message(
157+
content=str(chat_completion_response.choices[0].message.content), role=AIChatRoles.ASSISTANT
158+
),
152159
context=RAGContext(
153160
data_points={item.id: item.to_dict() for item in results},
154161
thoughts=[
@@ -182,7 +189,7 @@ async def run_stream(
182189
self,
183190
messages: list[ChatCompletionMessageParam],
184191
overrides: dict[str, Any] = {},
185-
) -> AsyncGenerator[RetrievalResponse|Message, None]:
192+
) -> AsyncGenerator[RetrievalResponseDelta, None]:
186193
chat_params = self.get_params(messages, overrides)
187194

188195
# Retrieve relevant items from the database
@@ -206,8 +213,7 @@ async def run_stream(
206213
# The connection closes when it returns back to the context manger in the dependencies
207214
await self.searcher.db_session.close()
208215

209-
yield RetrievalResponse(
210-
message=Message(content="", role="assistant"),
216+
yield RetrievalResponseDelta(
211217
context=RAGContext(
212218
data_points={item.id: item.to_dict() for item in results},
213219
thoughts=[
@@ -237,7 +243,9 @@ async def run_stream(
237243
),
238244
)
239245
async for response_chunk in chat_completion_async_stream:
240-
# first response has empty choices
241-
if response_chunk.choices:
242-
yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant")
246+
# first response has empty choices and last response has empty content
247+
if response_chunk.choices and response_chunk.choices[0].delta.content:
248+
yield RetrievalResponseDelta(
249+
delta=Message(content=str(response_chunk.choices[0].delta.content), role=AIChatRoles.ASSISTANT)
250+
)
243251
return

‎src/backend/fastapi_app/routes/api_routes.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
from fastapi.responses import StreamingResponse
88
from sqlalchemy import select
99

10-
from fastapi_app.api_models import ChatRequest, ItemPublic, ItemWithDistance, Message, RetrievalResponse
10+
from fastapi_app.api_models import (
11+
ChatRequest,
12+
ItemPublic,
13+
ItemWithDistance,
14+
RetrievalResponse,
15+
RetrievalResponseDelta,
16+
)
1117
from fastapi_app.dependencies import ChatClient, CommonDeps, DBSession, EmbeddingsClient
1218
from fastapi_app.postgres_models import Item
1319
from fastapi_app.postgres_searcher import PostgresSearcher
@@ -17,13 +23,13 @@
1723
router = fastapi.APIRouter()
1824

1925

20-
async def format_as_ndjson(r: AsyncGenerator[RetrievalResponse|Message, None]) -> AsyncGenerator[str, None]:
26+
async def format_as_ndjson(r: AsyncGenerator[RetrievalResponseDelta, None]) -> AsyncGenerator[str, None]:
2127
"""
2228
Format the response as NDJSON
2329
"""
2430
try:
2531
async for event in r:
26-
yield json.dumps(event.model_dump(), ensure_ascii=False) + "\n"
32+
yield event.model_dump_json() + "\n"
2733
except Exception as error:
2834
logging.exception("Exception while generating response stream: %s", error)
2935
yield json.dumps({"error": str(error)}, ensure_ascii=False) + "\n"

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /