Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 62502b1

Browse files
Finish refactoring of rag flows
1 parent 076f367 commit 62502b1

File tree

4 files changed

+164
-209
lines changed

4 files changed

+164
-209
lines changed

‎src/backend/fastapi_app/rag_advanced.py

Lines changed: 65 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import os
21
from collections.abc import AsyncGenerator
32
from typing import Optional, TypedDict, Union
43

5-
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream
6-
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
7-
from openai_messages_token_helper import get_token_limit
4+
from openai import AsyncAzureOpenAI, AsyncOpenAI
5+
from openai.types.chat import ChatCompletionMessageParam
86
from pydantic_ai import Agent, RunContext
97
from pydantic_ai.messages import ModelMessagesTypeAdapter
108
from pydantic_ai.models.openai import OpenAIModel
@@ -13,22 +11,17 @@
1311

1412
from fastapi_app.api_models import (
1513
AIChatRoles,
14+
ChatRequestOverrides,
1615
ItemPublic,
1716
Message,
1817
RAGContext,
1918
RetrievalResponse,
2019
RetrievalResponseDelta,
2120
ThoughtStep,
2221
)
23-
from fastapi_app.postgres_models import Item
2422
from fastapi_app.postgres_searcher import PostgresSearcher
2523
from fastapi_app.rag_base import ChatParams, RAGChatBase
2624

27-
# Experiment #1: Annotated did not work!
28-
# Experiment #2: Function-level docstring, Inline docstrings next to attributes
29-
# Function -level docstring leads to XML like this: <summary>Search ...
30-
# Experiment #3: Move the docstrings below the attributes in triple-quoted strings - SUCCESS!!!
31-
3225

3326
class PriceFilter(TypedDict):
3427
column: str = "price"
@@ -64,19 +57,44 @@ class SearchResults(TypedDict):
6457

6558

6659
class AdvancedRAGChat(RAGChatBase):
60+
query_prompt_template = open(RAGChatBase.prompts_dir / "query.txt").read()
61+
query_fewshots = open(RAGChatBase.prompts_dir / "query_fewshots.json").read()
62+
6763
def __init__(
6864
self,
6965
*,
66+
messages: list[ChatCompletionMessageParam],
67+
overrides: ChatRequestOverrides,
7068
searcher: PostgresSearcher,
7169
openai_chat_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
7270
chat_model: str,
7371
chat_deployment: Optional[str], # Not needed for non-Azure OpenAI
7472
):
7573
self.searcher = searcher
76-
self.openai_chat_client = openai_chat_client
77-
self.chat_model = chat_model
78-
self.chat_deployment = chat_deployment
79-
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
74+
self.chat_params = self.get_chat_params(messages, overrides)
75+
self.model_for_thoughts = (
76+
{"model": chat_model, "deployment": chat_deployment} if chat_deployment else {"model": chat_model}
77+
)
78+
pydantic_chat_model = OpenAIModel(
79+
chat_model if chat_deployment is None else chat_deployment,
80+
provider=OpenAIProvider(openai_client=openai_chat_client),
81+
)
82+
self.search_agent = Agent(
83+
pydantic_chat_model,
84+
model_settings=ModelSettings(temperature=0.0, max_tokens=500, seed=self.chat_params.seed),
85+
system_prompt=self.query_prompt_template,
86+
tools=[self.search_database],
87+
output_type=SearchResults,
88+
)
89+
self.answer_agent = Agent(
90+
pydantic_chat_model,
91+
system_prompt=self.answer_prompt_template,
92+
model_settings=ModelSettings(
93+
temperature=self.chat_params.temperature,
94+
max_tokens=self.chat_params.response_token_limit,
95+
seed=self.chat_params.seed,
96+
),
97+
)
8098

8199
async def search_database(
82100
self,
@@ -113,42 +131,28 @@ async def search_database(
113131
query=search_query, items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters
114132
)
115133

116-
async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPublic], list[ThoughtStep]]:
117-
model = OpenAIModel(
118-
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=self.openai_chat_client)
119-
)
120-
agent = Agent(
121-
model,
122-
model_settings=ModelSettings(temperature=0.0, max_tokens=500, seed=chat_params.seed),
123-
system_prompt=self.query_prompt_template,
124-
tools=[self.search_database],
125-
output_type=SearchResults,
126-
)
134+
async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
127135
few_shots = ModelMessagesTypeAdapter.validate_json(self.query_fewshots)
128-
user_query = f"Find search results for user query: {chat_params.original_user_query}"
129-
results = await agent.run(
136+
user_query = f"Find search results for user query: {self.chat_params.original_user_query}"
137+
results = await self.search_agent.run(
130138
user_query,
131-
message_history=few_shots + chat_params.past_messages,
132-
deps=chat_params,
139+
message_history=few_shots + self.chat_params.past_messages,
140+
deps=self.chat_params,
133141
)
134142
items = results.output["items"]
135143
thoughts = [
136144
ThoughtStep(
137145
title="Prompt to generate search arguments",
138146
description=results.all_messages(),
139-
props=(
140-
{"model": self.chat_model, "deployment": self.chat_deployment}
141-
if self.chat_deployment
142-
else {"model": self.chat_model} # TODO
143-
),
147+
props=self.model_for_thoughts,
144148
),
145149
ThoughtStep(
146150
title="Search using generated search arguments",
147151
description=results.output["query"],
148152
props={
149-
"top": chat_params.top,
150-
"vector_search": chat_params.enable_vector_search,
151-
"text_search": chat_params.enable_text_search,
153+
"top": self.chat_params.top,
154+
"vector_search": self.chat_params.enable_vector_search,
155+
"text_search": self.chat_params.enable_text_search,
152156
"filters": results.output["filters"],
153157
},
154158
),
@@ -161,25 +165,12 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPubli
161165

162166
async def answer(
163167
self,
164-
chat_params: ChatParams,
165168
items: list[ItemPublic],
166169
earlier_thoughts: list[ThoughtStep],
167170
) -> RetrievalResponse:
168-
agent = Agent(
169-
OpenAIModel(
170-
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
171-
provider=OpenAIProvider(openai_client=self.openai_chat_client),
172-
),
173-
system_prompt=self.answer_prompt_template,
174-
model_settings=ModelSettings(
175-
temperature=chat_params.temperature, max_tokens=chat_params.response_token_limit, seed=chat_params.seed
176-
),
177-
)
178-
179-
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in items]
180-
response = await agent.run(
181-
user_prompt=chat_params.original_user_query + "Sources:\n" + "\n".join(sources_content),
182-
message_history=chat_params.past_messages,
171+
response = await self.answer_agent.run(
172+
user_prompt=self.prepare_rag_request(self.chat_params.original_user_query, items),
173+
message_history=self.chat_params.past_messages,
183174
)
184175

185176
return RetrievalResponse(
@@ -191,57 +182,35 @@ async def answer(
191182
ThoughtStep(
192183
title="Prompt to generate answer",
193184
description=response.all_messages(),
194-
props=(
195-
{"model": self.chat_model, "deployment": self.chat_deployment}
196-
if self.chat_deployment
197-
else {"model": self.chat_model}
198-
),
185+
props=self.model_for_thoughts,
199186
),
200187
],
201188
),
202189
)
203190

204191
async def answer_stream(
205192
self,
206-
chat_params: ChatParams,
207-
contextual_messages: list[ChatCompletionMessageParam],
208-
results: list[Item],
193+
items: list[ItemPublic],
209194
earlier_thoughts: list[ThoughtStep],
210195
) -> AsyncGenerator[RetrievalResponseDelta, None]:
211-
chat_completion_async_stream: AsyncStream[
212-
ChatCompletionChunk
213-
] = await self.openai_chat_client.chat.completions.create(
214-
# Azure OpenAI takes the deployment name as the model name
215-
model=self.chat_deployment if self.chat_deployment else self.chat_model,
216-
messages=contextual_messages,
217-
temperature=chat_params.temperature,
218-
max_tokens=chat_params.response_token_limit,
219-
n=1,
220-
stream=True,
221-
)
222-
223-
yield RetrievalResponseDelta(
224-
context=RAGContext(
225-
data_points={item.id: item.to_dict() for item in results},
226-
thoughts=earlier_thoughts
227-
+ [
228-
ThoughtStep(
229-
title="Prompt to generate answer",
230-
description=contextual_messages,
231-
props=(
232-
{"model": self.chat_model, "deployment": self.chat_deployment}
233-
if self.chat_deployment
234-
else {"model": self.chat_model}
196+
async with self.answer_agent.run_stream(
197+
self.prepare_rag_request(self.chat_params.original_user_query, items),
198+
message_history=self.chat_params.past_messages,
199+
) as agent_stream_runner:
200+
yield RetrievalResponseDelta(
201+
context=RAGContext(
202+
data_points={item.id: item for item in items},
203+
thoughts=earlier_thoughts
204+
+ [
205+
ThoughtStep(
206+
title="Prompt to generate answer",
207+
description=agent_stream_runner.all_messages(),
208+
props=self.model_for_thoughts,
235209
),
236-
),
237-
],
238-
),
239-
)
210+
],
211+
),
212+
)
240213

241-
async for response_chunk in chat_completion_async_stream:
242-
# first response has empty choices and last response has empty content
243-
if response_chunk.choices and response_chunk.choices[0].delta.content:
244-
yield RetrievalResponseDelta(
245-
delta=Message(content=str(response_chunk.choices[0].delta.content), role=AIChatRoles.ASSISTANT)
246-
)
247-
return
214+
async for message in agent_stream_runner.stream_text(delta=True, debounce_by=None):
215+
yield RetrievalResponseDelta(delta=Message(content=str(message), role=AIChatRoles.ASSISTANT))
216+
return

‎src/backend/fastapi_app/rag_base.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from fastapi_app.api_models import (
88
ChatParams,
99
ChatRequestOverrides,
10+
ItemPublic,
1011
RetrievalResponse,
1112
RetrievalResponseDelta,
1213
ThoughtStep,
@@ -15,12 +16,12 @@
1516

1617

1718
class RAGChatBase(ABC):
18-
current_dir = pathlib.Path(__file__).parent
19-
query_prompt_template = open(current_dir / "prompts/query.txt").read()
20-
query_fewshots = open(current_dir / "prompts/query_fewshots.json").read()
21-
answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
19+
prompts_dir = pathlib.Path(__file__).parent / "prompts/"
20+
answer_prompt_template = open(prompts_dir / "answer.txt").read()
2221

23-
def get_params(self, messages: list[ChatCompletionMessageParam], overrides: ChatRequestOverrides) -> ChatParams:
22+
def get_chat_params(
23+
self, messages: list[ChatCompletionMessageParam], overrides: ChatRequestOverrides
24+
) -> ChatParams:
2425
response_token_limit = 1024
2526
prompt_template = overrides.prompt_template or self.answer_prompt_template
2627

@@ -52,6 +53,10 @@ async def prepare_context(
5253
) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]:
5354
raise NotImplementedError
5455

56+
def prepare_rag_request(self, user_query, items: list[ItemPublic]) -> str:
57+
sources_str = "\n".join([f"[{item.id}]:{item.to_str_for_rag()}" for item in items])
58+
return f"{user_query}Sources:\n{sources_str}"
59+
5560
@abstractmethod
5661
async def answer(
5762
self,

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /