Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 7a36af6

Browse files
refactor and add streaming functions
1 parent 406eb24 commit 7a36af6

File tree

2 files changed

+286
-59
lines changed

2 files changed

+286
-59
lines changed

‎src/backend/fastapi_app/rag_advanced.py‎

Lines changed: 136 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
import pathlib
21
from collections.abc import AsyncGenerator
3-
from typing import (
4-
Any,
5-
)
2+
from typing import Any
63

7-
from openai import AsyncAzureOpenAI, AsyncOpenAI
8-
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
4+
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream
5+
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
96
from openai_messages_token_helper import build_messages, get_token_limit
107

11-
from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
12-
from .postgres_searcher import PostgresSearcher
13-
from .query_rewriter import build_search_function, extract_search_arguments
8+
from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
9+
from fastapi_app.postgres_searcher import PostgresSearcher
10+
from fastapi_app.query_rewriter import build_search_function, extract_search_arguments
11+
from fastapi_app.rag_simple import RAGChatBase
1412

1513

16-
class AdvancedRAGChat:
14+
class AdvancedRAGChat(RAGChatBase):
1715
def __init__(
1816
self,
1917
*,
@@ -27,29 +25,21 @@ def __init__(
2725
self.chat_model = chat_model
2826
self.chat_deployment = chat_deployment
2927
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
30-
current_dir = pathlib.Path(__file__).parent
31-
self.query_prompt_template = open(current_dir / "prompts/query.txt").read()
32-
self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
3328

3429
async def run(
35-
self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {}
36-
) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]:
37-
text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
38-
vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
39-
top = overrides.get("top", 3)
40-
41-
original_user_query = messages[-1]["content"]
42-
if not isinstance(original_user_query, str):
43-
raise ValueError("The most recent message content must be a string.")
44-
past_messages = messages[:-1]
30+
self,
31+
messages: list[ChatCompletionMessageParam],
32+
overrides: dict[str, Any] = {},
33+
) -> RetrievalResponse:
34+
chat_params = self.get_params(messages, overrides)
4535

4636
# Generate an optimized keyword search query based on the chat history and the last question
4737
query_response_token_limit = 500
4838
query_messages: list[ChatCompletionMessageParam] = build_messages(
4939
model=self.chat_model,
5040
system_prompt=self.query_prompt_template,
51-
new_user_content=original_user_query,
52-
past_messages=past_messages,
41+
new_user_content=chat_params.original_user_query,
42+
past_messages=chat_params.past_messages,
5343
max_tokens=self.chat_token_limit - query_response_token_limit, # TODO: count functions
5444
fallback_to_default=True,
5545
)
@@ -65,14 +55,14 @@ async def run(
6555
tool_choice="auto",
6656
)
6757

68-
query_text, filters = extract_search_arguments(original_user_query, chat_completion)
58+
query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion)
6959

7060
# Retrieve relevant items from the database with the GPT optimized query
7161
results = await self.searcher.search_and_embed(
7262
query_text,
73-
top=top,
74-
enable_vector_search=vector_search,
75-
enable_text_search=text_search,
63+
top=chat_params.top,
64+
enable_vector_search=chat_params.enable_vector_search,
65+
enable_text_search=chat_params.enable_text_search,
7666
filters=filters,
7767
)
7868

@@ -84,8 +74,8 @@ async def run(
8474
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
8575
model=self.chat_model,
8676
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
87-
new_user_content=original_user_query + "\n\nSources:\n" + content,
88-
past_messages=past_messages,
77+
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
78+
past_messages=chat_params.past_messages,
8979
max_tokens=self.chat_token_limit - response_token_limit,
9080
fallback_to_default=True,
9181
)
@@ -99,6 +89,7 @@ async def run(
9989
n=1,
10090
stream=False,
10191
)
92+
10293
first_choice_message = chat_completion_response.choices[0].message
10394

10495
return RetrievalResponse(
@@ -119,9 +110,9 @@ async def run(
119110
title="Search using generated search arguments",
120111
description=query_text,
121112
props={
122-
"top": top,
123-
"vector_search": vector_search,
124-
"text_search": text_search,
113+
"top": chat_params.top,
114+
"vector_search": chat_params.enable_vector_search,
115+
"text_search": chat_params.enable_text_search,
125116
"filters": filters,
126117
},
127118
),
@@ -141,3 +132,114 @@ async def run(
141132
],
142133
),
143134
)
135+
136+
async def run_stream(
137+
self,
138+
messages: list[ChatCompletionMessageParam],
139+
overrides: dict[str, Any] = {},
140+
) -> AsyncGenerator[RetrievalResponse | Message, None]:
141+
chat_params = self.get_params(messages, overrides)
142+
143+
# Generate an optimized keyword search query based on the chat history and the last question
144+
query_response_token_limit = 500
145+
query_messages: list[ChatCompletionMessageParam] = build_messages(
146+
model=self.chat_model,
147+
system_prompt=self.query_prompt_template,
148+
new_user_content=chat_params.original_user_query,
149+
past_messages=chat_params.past_messages,
150+
max_tokens=self.chat_token_limit - query_response_token_limit, # TODO: count functions
151+
fallback_to_default=True,
152+
)
153+
154+
chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create(
155+
messages=query_messages,
156+
# Azure OpenAI takes the deployment name as the model name
157+
model=self.chat_deployment if self.chat_deployment else self.chat_model,
158+
temperature=0.0, # Minimize creativity for search query generation
159+
max_tokens=query_response_token_limit, # Setting too low risks malformed JSON, too high risks performance
160+
n=1,
161+
tools=build_search_function(),
162+
tool_choice="auto",
163+
)
164+
165+
query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion)
166+
167+
# Retrieve relevant items from the database with the GPT optimized query
168+
results = await self.searcher.search_and_embed(
169+
query_text,
170+
top=chat_params.top,
171+
enable_vector_search=chat_params.enable_vector_search,
172+
enable_text_search=chat_params.enable_text_search,
173+
filters=filters,
174+
)
175+
176+
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
177+
content = "\n".join(sources_content)
178+
179+
# Generate a contextual and content specific answer using the search results and chat history
180+
response_token_limit = 1024
181+
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
182+
model=self.chat_model,
183+
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
184+
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
185+
past_messages=chat_params.past_messages,
186+
max_tokens=self.chat_token_limit - response_token_limit,
187+
fallback_to_default=True,
188+
)
189+
190+
chat_completion_async_stream: AsyncStream[
191+
ChatCompletionChunk
192+
] = await self.openai_chat_client.chat.completions.create(
193+
# Azure OpenAI takes the deployment name as the model name
194+
model=self.chat_deployment if self.chat_deployment else self.chat_model,
195+
messages=contextual_messages,
196+
temperature=overrides.get("temperature", 0.3),
197+
max_tokens=response_token_limit,
198+
n=1,
199+
stream=True,
200+
)
201+
202+
yield RetrievalResponse(
203+
message=Message(content="", role="assistant"),
204+
context=RAGContext(
205+
data_points={item.id: item.to_dict() for item in results},
206+
thoughts=[
207+
ThoughtStep(
208+
title="Prompt to generate search arguments",
209+
description=[str(message) for message in query_messages],
210+
props=(
211+
{"model": self.chat_model, "deployment": self.chat_deployment}
212+
if self.chat_deployment
213+
else {"model": self.chat_model}
214+
),
215+
),
216+
ThoughtStep(
217+
title="Search using generated search arguments",
218+
description=query_text,
219+
props={
220+
"top": chat_params.top,
221+
"vector_search": chat_params.enable_vector_search,
222+
"text_search": chat_params.enable_text_search,
223+
"filters": filters,
224+
},
225+
),
226+
ThoughtStep(
227+
title="Search results",
228+
description=[result.to_dict() for result in results],
229+
),
230+
ThoughtStep(
231+
title="Prompt to generate answer",
232+
description=[str(message) for message in contextual_messages],
233+
props=(
234+
{"model": self.chat_model, "deployment": self.chat_deployment}
235+
if self.chat_deployment
236+
else {"model": self.chat_model}
237+
),
238+
),
239+
],
240+
),
241+
)
242+
243+
async for response_chunk in chat_completion_async_stream:
244+
yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant")
245+
return

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /