Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit c5ce017

Browse files
Merge pull request #61 from john0isaac/add-streaming
Add streaming
2 parents f76747a + 8b32800 commit c5ce017

File tree

13 files changed

+692
-368
lines changed

13 files changed

+692
-368
lines changed

‎requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
-r src/backend/requirements.txt
22
ruff
3+
mypy
34
pre-commit
45
pip-tools
56
pip-compile-cross-platform
67
pytest
78
pytest-cov
89
pytest-asyncio
10+
pytest-snapshot
911
mypy
1012
locust

‎src/backend/fastapi_app/api_models.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,42 @@
1+
from enum import Enum
12
from typing import Any
23

34
from openai.types.chat import ChatCompletionMessageParam
45
from pydantic import BaseModel
56

67

8+
class AIChatRoles(str, Enum):
9+
USER = "user"
10+
ASSISTANT = "assistant"
11+
SYSTEM = "system"
12+
13+
714
class Message(BaseModel):
815
content: str
9-
role: str = "user"
16+
role: AIChatRoles = AIChatRoles.USER
17+
18+
19+
class RetrievalMode(str, Enum):
20+
TEXT = "text"
21+
VECTORS = "vectors"
22+
HYBRID = "hybrid"
23+
24+
25+
class ChatRequestOverrides(BaseModel):
26+
top: int = 3
27+
temperature: float = 0.3
28+
retrieval_mode: RetrievalMode = RetrievalMode.HYBRID
29+
use_advanced_flow: bool = True
30+
prompt_template: str | None = None
31+
32+
33+
class ChatRequestContext(BaseModel):
34+
overrides: ChatRequestOverrides
1035

1136

1237
class ChatRequest(BaseModel):
1338
messages: list[ChatCompletionMessageParam]
14-
context: dict= {}
39+
context: ChatRequestContext
1540

1641

1742
class ThoughtStep(BaseModel):
@@ -32,6 +57,12 @@ class RetrievalResponse(BaseModel):
3257
session_state: Any | None = None
3358

3459

60+
class RetrievalResponseDelta(BaseModel):
61+
delta: Message | None = None
62+
context: RAGContext | None = None
63+
session_state: Any | None = None
64+
65+
3566
class ItemPublic(BaseModel):
3667
id: int
3768
type: str
@@ -43,3 +74,12 @@ class ItemPublic(BaseModel):
4374

4475
class ItemWithDistance(ItemPublic):
4576
distance: float
77+
78+
79+
class ChatParams(ChatRequestOverrides):
80+
prompt_template: str
81+
response_token_limit: int = 1024
82+
enable_text_search: bool
83+
enable_vector_search: bool
84+
original_user_query: str
85+
past_messages: list[ChatCompletionMessageParam]

‎src/backend/fastapi_app/rag_advanced.py

Lines changed: 118 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
1-
import pathlib
21
from collections.abc import AsyncGenerator
3-
from typing import (
4-
Any,
5-
)
2+
from typing import Any
63

7-
from openai import AsyncAzureOpenAI, AsyncOpenAI
8-
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
4+
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream
5+
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
96
from openai_messages_token_helper import build_messages, get_token_limit
107

11-
from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
12-
from .postgres_searcher import PostgresSearcher
13-
from .query_rewriter import build_search_function, extract_search_arguments
8+
from fastapi_app.api_models import (
9+
AIChatRoles,
10+
Message,
11+
RAGContext,
12+
RetrievalResponse,
13+
RetrievalResponseDelta,
14+
ThoughtStep,
15+
)
16+
from fastapi_app.postgres_models import Item
17+
from fastapi_app.postgres_searcher import PostgresSearcher
18+
from fastapi_app.query_rewriter import build_search_function, extract_search_arguments
19+
from fastapi_app.rag_base import ChatParams, RAGChatBase
1420

1521

16-
class AdvancedRAGChat:
22+
class AdvancedRAGChat(RAGChatBase):
1723
def __init__(
1824
self,
1925
*,
@@ -27,24 +33,11 @@ def __init__(
2733
self.chat_model = chat_model
2834
self.chat_deployment = chat_deployment
2935
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
30-
current_dir = pathlib.Path(__file__).parent
31-
self.query_prompt_template = open(current_dir / "prompts/query.txt").read()
32-
self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
33-
34-
async def run(
35-
self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {}
36-
) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]:
37-
text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
38-
vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
39-
top = overrides.get("top", 3)
40-
41-
original_user_query = messages[-1]["content"]
42-
if not isinstance(original_user_query, str):
43-
raise ValueError("The most recent message content must be a string.")
44-
past_messages = messages[:-1]
45-
46-
# Generate an optimized keyword search query based on the chat history and the last question
47-
query_response_token_limit = 500
36+
37+
async def generate_search_query(
38+
self, original_user_query: str, past_messages: list[ChatCompletionMessageParam], query_response_token_limit: int
39+
) -> tuple[list[ChatCompletionMessageParam], Any | str | None, list]:
40+
"""Generate an optimized keyword search query based on the chat history and the last question"""
4841
query_messages: list[ChatCompletionMessageParam] = build_messages(
4942
model=self.chat_model,
5043
system_prompt=self.query_prompt_template,
@@ -67,68 +60,128 @@ async def run(
6760

6861
query_text, filters = extract_search_arguments(original_user_query, chat_completion)
6962

63+
return query_messages, query_text, filters
64+
65+
async def prepare_context(
66+
self, chat_params: ChatParams
67+
) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]:
68+
query_messages, query_text, filters = await self.generate_search_query(
69+
original_user_query=chat_params.original_user_query,
70+
past_messages=chat_params.past_messages,
71+
query_response_token_limit=500,
72+
)
73+
7074
# Retrieve relevant items from the database with the GPT optimized query
7175
results = await self.searcher.search_and_embed(
7276
query_text,
73-
top=top,
74-
enable_vector_search=vector_search,
75-
enable_text_search=text_search,
77+
top=chat_params.top,
78+
enable_vector_search=chat_params.enable_vector_search,
79+
enable_text_search=chat_params.enable_text_search,
7680
filters=filters,
7781
)
7882

7983
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
8084
content = "\n".join(sources_content)
8185

8286
# Generate a contextual and content specific answer using the search results and chat history
83-
response_token_limit = 1024
8487
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
8588
model=self.chat_model,
86-
system_prompt=overrides.get("prompt_template") orself.answer_prompt_template,
87-
new_user_content=original_user_query + "\n\nSources:\n" + content,
88-
past_messages=past_messages,
89-
max_tokens=self.chat_token_limit - response_token_limit,
89+
system_prompt=chat_params.prompt_template,
90+
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
91+
past_messages=chat_params.past_messages,
92+
max_tokens=self.chat_token_limit - chat_params.response_token_limit,
9093
fallback_to_default=True,
9194
)
9295

96+
thoughts = [
97+
ThoughtStep(
98+
title="Prompt to generate search arguments",
99+
description=[str(message) for message in query_messages],
100+
props=(
101+
{"model": self.chat_model, "deployment": self.chat_deployment}
102+
if self.chat_deployment
103+
else {"model": self.chat_model}
104+
),
105+
),
106+
ThoughtStep(
107+
title="Search using generated search arguments",
108+
description=query_text,
109+
props={
110+
"top": chat_params.top,
111+
"vector_search": chat_params.enable_vector_search,
112+
"text_search": chat_params.enable_text_search,
113+
"filters": filters,
114+
},
115+
),
116+
ThoughtStep(
117+
title="Search results",
118+
description=[result.to_dict() for result in results],
119+
),
120+
]
121+
return contextual_messages, results, thoughts
122+
123+
async def answer(
124+
self,
125+
chat_params: ChatParams,
126+
contextual_messages: list[ChatCompletionMessageParam],
127+
results: list[Item],
128+
earlier_thoughts: list[ThoughtStep],
129+
) -> RetrievalResponse:
93130
chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create(
94131
# Azure OpenAI takes the deployment name as the model name
95132
model=self.chat_deployment if self.chat_deployment else self.chat_model,
96133
messages=contextual_messages,
97-
temperature=overrides.get("temperature", 0.3),
98-
max_tokens=response_token_limit,
134+
temperature=chat_params.temperature,
135+
max_tokens=chat_params.response_token_limit,
99136
n=1,
100137
stream=False,
101138
)
102-
first_choice_message = chat_completion_response.choices[0].message
103139

104140
return RetrievalResponse(
105-
message=Message(content=str(first_choice_message.content), role=first_choice_message.role),
141+
message=Message(
142+
content=str(chat_completion_response.choices[0].message.content), role=AIChatRoles.ASSISTANT
143+
),
106144
context=RAGContext(
107145
data_points={item.id: item.to_dict() for item in results},
108-
thoughts=[
146+
thoughts=earlier_thoughts
147+
+ [
109148
ThoughtStep(
110-
title="Prompt to generate search arguments",
111-
description=[str(message) for message in query_messages],
149+
title="Prompt to generate answer",
150+
description=[str(message) for message in contextual_messages],
112151
props=(
113152
{"model": self.chat_model, "deployment": self.chat_deployment}
114153
if self.chat_deployment
115154
else {"model": self.chat_model}
116155
),
117156
),
118-
ThoughtStep(
119-
title="Search using generated search arguments",
120-
description=query_text,
121-
props={
122-
"top": top,
123-
"vector_search": vector_search,
124-
"text_search": text_search,
125-
"filters": filters,
126-
},
127-
),
128-
ThoughtStep(
129-
title="Search results",
130-
description=[result.to_dict() for result in results],
131-
),
157+
],
158+
),
159+
)
160+
161+
async def answer_stream(
162+
self,
163+
chat_params: ChatParams,
164+
contextual_messages: list[ChatCompletionMessageParam],
165+
results: list[Item],
166+
earlier_thoughts: list[ThoughtStep],
167+
) -> AsyncGenerator[RetrievalResponseDelta, None]:
168+
chat_completion_async_stream: AsyncStream[
169+
ChatCompletionChunk
170+
] = await self.openai_chat_client.chat.completions.create(
171+
# Azure OpenAI takes the deployment name as the model name
172+
model=self.chat_deployment if self.chat_deployment else self.chat_model,
173+
messages=contextual_messages,
174+
temperature=chat_params.temperature,
175+
max_tokens=chat_params.response_token_limit,
176+
n=1,
177+
stream=True,
178+
)
179+
180+
yield RetrievalResponseDelta(
181+
context=RAGContext(
182+
data_points={item.id: item.to_dict() for item in results},
183+
thoughts=earlier_thoughts
184+
+ [
132185
ThoughtStep(
133186
title="Prompt to generate answer",
134187
description=[str(message) for message in contextual_messages],
@@ -141,3 +194,11 @@ async def run(
141194
],
142195
),
143196
)
197+
198+
async for response_chunk in chat_completion_async_stream:
199+
# first response has empty choices and last response has empty content
200+
if response_chunk.choices and response_chunk.choices[0].delta.content:
201+
yield RetrievalResponseDelta(
202+
delta=Message(content=str(response_chunk.choices[0].delta.content), role=AIChatRoles.ASSISTANT)
203+
)
204+
return

‎src/backend/fastapi_app/rag_base.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pathlib
2+
from abc import ABC, abstractmethod
3+
from collections.abc import AsyncGenerator
4+
5+
from openai.types.chat import ChatCompletionMessageParam
6+
7+
from fastapi_app.api_models import (
8+
ChatParams,
9+
ChatRequestOverrides,
10+
RetrievalResponse,
11+
RetrievalResponseDelta,
12+
ThoughtStep,
13+
)
14+
from fastapi_app.postgres_models import Item
15+
16+
17+
class RAGChatBase(ABC):
18+
current_dir = pathlib.Path(__file__).parent
19+
query_prompt_template = open(current_dir / "prompts/query.txt").read()
20+
answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
21+
22+
def get_params(self, messages: list[ChatCompletionMessageParam], overrides: ChatRequestOverrides) -> ChatParams:
23+
response_token_limit = 1024
24+
prompt_template = overrides.prompt_template or self.answer_prompt_template
25+
26+
enable_text_search = overrides.retrieval_mode in ["text", "hybrid", None]
27+
enable_vector_search = overrides.retrieval_mode in ["vectors", "hybrid", None]
28+
29+
original_user_query = messages[-1]["content"]
30+
if not isinstance(original_user_query, str):
31+
raise ValueError("The most recent message content must be a string.")
32+
past_messages = messages[:-1]
33+
34+
return ChatParams(
35+
top=overrides.top,
36+
temperature=overrides.temperature,
37+
retrieval_mode=overrides.retrieval_mode,
38+
use_advanced_flow=overrides.use_advanced_flow,
39+
response_token_limit=response_token_limit,
40+
prompt_template=prompt_template,
41+
enable_text_search=enable_text_search,
42+
enable_vector_search=enable_vector_search,
43+
original_user_query=original_user_query,
44+
past_messages=past_messages,
45+
)
46+
47+
@abstractmethod
48+
async def prepare_context(
49+
self, chat_params: ChatParams
50+
) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]:
51+
raise NotImplementedError
52+
53+
@abstractmethod
54+
async def answer(
55+
self,
56+
chat_params: ChatParams,
57+
contextual_messages: list[ChatCompletionMessageParam],
58+
results: list[Item],
59+
earlier_thoughts: list[ThoughtStep],
60+
) -> RetrievalResponse:
61+
raise NotImplementedError
62+
63+
@abstractmethod
64+
async def answer_stream(
65+
self,
66+
chat_params: ChatParams,
67+
contextual_messages: list[ChatCompletionMessageParam],
68+
results: list[Item],
69+
earlier_thoughts: list[ThoughtStep],
70+
) -> AsyncGenerator[RetrievalResponseDelta, None]:
71+
raise NotImplementedError
72+
if False:
73+
yield 0

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /