Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 172104c

Browse files
refactor code
1 parent 7a36af6 commit 172104c

File tree

2 files changed

+86
-92
lines changed

2 files changed

+86
-92
lines changed

‎src/backend/fastapi_app/rag_advanced.py‎

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from openai_messages_token_helper import build_messages, get_token_limit
77

88
from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
9+
from fastapi_app.postgres_models import Item
910
from fastapi_app.postgres_searcher import PostgresSearcher
1011
from fastapi_app.query_rewriter import build_search_function, extract_search_arguments
11-
from fastapi_app.rag_simple import RAGChatBase
12+
from fastapi_app.rag_simple import ChatParams, RAGChatBase
1213

1314

1415
class AdvancedRAGChat(RAGChatBase):
@@ -26,15 +27,10 @@ def __init__(
2627
self.chat_deployment = chat_deployment
2728
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
2829

29-
async def run(
30-
self,
31-
messages: list[ChatCompletionMessageParam],
32-
overrides: dict[str, Any] = {},
33-
) -> RetrievalResponse:
34-
chat_params = self.get_params(messages, overrides)
35-
36-
# Generate an optimized keyword search query based on the chat history and the last question
37-
query_response_token_limit = 500
30+
async def generate_search_query(
31+
self, chat_params: ChatParams, query_response_token_limit: int
32+
) -> tuple[list[ChatCompletionMessageParam], Any | str | None, list]:
33+
"""Generate an optimized keyword search query based on the chat history and the last question"""
3834
query_messages: list[ChatCompletionMessageParam] = build_messages(
3935
model=self.chat_model,
4036
system_prompt=self.query_prompt_template,
@@ -57,6 +53,12 @@ async def run(
5753

5854
query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion)
5955

56+
return query_messages, query_text, filters
57+
58+
async def retreive_and_build_context(
59+
self, chat_params: ChatParams, query_text: str | Any | None, filters: list
60+
) -> tuple[list[ChatCompletionMessageParam], list[Item]]:
61+
"""Retrieve relevant items from the database and build a context for the chat model."""
6062
# Retrieve relevant items from the database with the GPT optimized query
6163
results = await self.searcher.search_and_embed(
6264
query_text,
@@ -70,22 +72,40 @@ async def run(
7072
content = "\n".join(sources_content)
7173

7274
# Generate a contextual and content specific answer using the search results and chat history
73-
response_token_limit = 1024
7475
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
7576
model=self.chat_model,
76-
system_prompt=overrides.get("prompt_template") orself.answer_prompt_template,
77+
system_prompt=chat_params.prompt_template,
7778
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
7879
past_messages=chat_params.past_messages,
79-
max_tokens=self.chat_token_limit - response_token_limit,
80+
max_tokens=self.chat_token_limit - chat_params.response_token_limit,
8081
fallback_to_default=True,
8182
)
83+
return contextual_messages, results
84+
85+
async def run(
86+
self,
87+
messages: list[ChatCompletionMessageParam],
88+
overrides: dict[str, Any] = {},
89+
) -> RetrievalResponse:
90+
chat_params = self.get_params(messages, overrides)
91+
92+
# Generate an optimized keyword search query based on the chat history and the last question
93+
query_messages, query_text, filters = await self.generate_search_query(
94+
chat_params=chat_params, query_response_token_limit=500
95+
)
96+
97+
# Retrieve relevant items from the database with the GPT optimized query
98+
# Generate a contextual and content specific answer using the search results and chat history
99+
contextual_messages, results = await self.retreive_and_build_context(
100+
chat_params=chat_params, query_text=query_text, filters=filters
101+
)
82102

83103
chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create(
84104
# Azure OpenAI takes the deployment name as the model name
85105
model=self.chat_deployment if self.chat_deployment else self.chat_model,
86106
messages=contextual_messages,
87-
temperature=overrides.get("temperature", 0.3),
88-
max_tokens=response_token_limit,
107+
temperature=chat_params.temperature,
108+
max_tokens=chat_params.response_token_limit,
89109
n=1,
90110
stream=False,
91111
)
@@ -141,50 +161,14 @@ async def run_stream(
141161
chat_params = self.get_params(messages, overrides)
142162

143163
# Generate an optimized keyword search query based on the chat history and the last question
144-
query_response_token_limit = 500
145-
query_messages: list[ChatCompletionMessageParam] = build_messages(
146-
model=self.chat_model,
147-
system_prompt=self.query_prompt_template,
148-
new_user_content=chat_params.original_user_query,
149-
past_messages=chat_params.past_messages,
150-
max_tokens=self.chat_token_limit - query_response_token_limit, # TODO: count functions
151-
fallback_to_default=True,
164+
query_messages, query_text, filters = await self.generate_search_query(
165+
chat_params=chat_params, query_response_token_limit=500
152166
)
153167

154-
chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create(
155-
messages=query_messages,
156-
# Azure OpenAI takes the deployment name as the model name
157-
model=self.chat_deployment if self.chat_deployment else self.chat_model,
158-
temperature=0.0, # Minimize creativity for search query generation
159-
max_tokens=query_response_token_limit, # Setting too low risks malformed JSON, too high risks performance
160-
n=1,
161-
tools=build_search_function(),
162-
tool_choice="auto",
163-
)
164-
165-
query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion)
166-
167168
# Retrieve relevant items from the database with the GPT optimized query
168-
results = await self.searcher.search_and_embed(
169-
query_text,
170-
top=chat_params.top,
171-
enable_vector_search=chat_params.enable_vector_search,
172-
enable_text_search=chat_params.enable_text_search,
173-
filters=filters,
174-
)
175-
176-
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
177-
content = "\n".join(sources_content)
178-
179169
# Generate a contextual and content specific answer using the search results and chat history
180-
response_token_limit = 1024
181-
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
182-
model=self.chat_model,
183-
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
184-
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
185-
past_messages=chat_params.past_messages,
186-
max_tokens=self.chat_token_limit - response_token_limit,
187-
fallback_to_default=True,
170+
contextual_messages, results = await self.retreive_and_build_context(
171+
chat_params=chat_params, query_text=query_text, filters=filters
188172
)
189173

190174
chat_completion_async_stream: AsyncStream[
@@ -193,8 +177,8 @@ async def run_stream(
193177
# Azure OpenAI takes the deployment name as the model name
194178
model=self.chat_deployment if self.chat_deployment else self.chat_model,
195179
messages=contextual_messages,
196-
temperature=overrides.get("temperature", 0.3),
197-
max_tokens=response_token_limit,
180+
temperature=chat_params.temperature,
181+
max_tokens=chat_params.response_token_limit,
198182
n=1,
199183
stream=True,
200184
)

‎src/backend/fastapi_app/rag_simple.py‎

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@
99
from pydantic import BaseModel
1010

1111
from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
12+
from fastapi_app.postgres_models import Item
1213
from fastapi_app.postgres_searcher import PostgresSearcher
1314

1415

1516
class ChatParams(BaseModel):
16-
top: int
17-
temperature: float
17+
top: int = 3
18+
temperature: float = 0.3
19+
response_token_limit: int = 1024
1820
enable_text_search: bool
1921
enable_vector_search: bool
2022
original_user_query: str
2123
past_messages: list[ChatCompletionMessageParam]
24+
prompt_template: str
2225

2326

2427
class RAGChatBase(ABC):
@@ -27,17 +30,24 @@ class RAGChatBase(ABC):
2730
answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
2831

2932
def get_params(self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any]) -> ChatParams:
30-
enable_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
31-
enable_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
3233
top: int = overrides.get("top", 3)
3334
temperature: float = overrides.get("temperature", 0.3)
35+
response_token_limit = 1024
36+
prompt_template = overrides.get("prompt_template") or self.answer_prompt_template
37+
38+
enable_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
39+
enable_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
40+
3441
original_user_query = messages[-1]["content"]
3542
if not isinstance(original_user_query, str):
3643
raise ValueError("The most recent message content must be a string.")
3744
past_messages = messages[:-1]
45+
3846
return ChatParams(
3947
top=top,
4048
temperature=temperature,
49+
response_token_limit=response_token_limit,
50+
prompt_template=prompt_template,
4151
enable_text_search=enable_text_search,
4252
enable_vector_search=enable_vector_search,
4353
original_user_query=original_user_query,
@@ -52,6 +62,15 @@ async def run(
5262
) -> RetrievalResponse:
5363
raise NotImplementedError
5464

65+
@abstractmethod
66+
async def retreive_and_build_context(
67+
self,
68+
chat_params: ChatParams,
69+
*args,
70+
**kwargs,
71+
) -> tuple[list[ChatCompletionMessageParam], list[Item]]:
72+
raise NotImplementedError
73+
5574
@abstractmethod
5675
async def run_stream(
5776
self,
@@ -78,12 +97,10 @@ def __init__(
7897
self.chat_deployment = chat_deployment
7998
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
8099

81-
async def run(
82-
self,
83-
messages: list[ChatCompletionMessageParam],
84-
overrides: dict[str, Any] = {},
85-
) -> RetrievalResponse:
86-
chat_params = self.get_params(messages, overrides)
100+
async def retreive_and_build_context(
101+
self, chat_params: ChatParams
102+
) -> tuple[list[ChatCompletionMessageParam], list[Item]]:
103+
"""Retrieve relevant items from the database and build a context for the chat model."""
87104

88105
# Retrieve relevant items from the database
89106
results = await self.searcher.search_and_embed(
@@ -97,22 +114,33 @@ async def run(
97114
content = "\n".join(sources_content)
98115

99116
# Generate a contextual and content specific answer using the search results and chat history
100-
response_token_limit = 1024
101117
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
102118
model=self.chat_model,
103-
system_prompt=overrides.get("prompt_template") orself.answer_prompt_template,
119+
system_prompt=chat_params.prompt_template,
104120
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
105121
past_messages=chat_params.past_messages,
106-
max_tokens=self.chat_token_limit - response_token_limit,
122+
max_tokens=self.chat_token_limit - chat_params.response_token_limit,
107123
fallback_to_default=True,
108124
)
125+
return contextual_messages, results
126+
127+
async def run(
128+
self,
129+
messages: list[ChatCompletionMessageParam],
130+
overrides: dict[str, Any] = {},
131+
) -> RetrievalResponse:
132+
chat_params = self.get_params(messages, overrides)
133+
134+
# Retrieve relevant items from the database
135+
# Generate a contextual and content specific answer using the search results and chat history
136+
contextual_messages, results = await self.retreive_and_build_context(chat_params=chat_params)
109137

110138
chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create(
111139
# Azure OpenAI takes the deployment name as the model name
112140
model=self.chat_deployment if self.chat_deployment else self.chat_model,
113141
messages=contextual_messages,
114142
temperature=chat_params.temperature,
115-
max_tokens=response_token_limit,
143+
max_tokens=chat_params.response_token_limit,
116144
n=1,
117145
stream=False,
118146
)
@@ -158,35 +186,17 @@ async def run_stream(
158186
chat_params = self.get_params(messages, overrides)
159187

160188
# Retrieve relevant items from the database
161-
results = await self.searcher.search_and_embed(
162-
chat_params.original_user_query,
163-
top=chat_params.top,
164-
enable_vector_search=chat_params.enable_vector_search,
165-
enable_text_search=chat_params.enable_text_search,
166-
)
167-
168-
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
169-
content = "\n".join(sources_content)
170-
171189
# Generate a contextual and content specific answer using the search results and chat history
172-
response_token_limit = 1024
173-
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
174-
model=self.chat_model,
175-
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
176-
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
177-
past_messages=chat_params.past_messages,
178-
max_tokens=self.chat_token_limit - response_token_limit,
179-
fallback_to_default=True,
180-
)
190+
contextual_messages, results = await self.retreive_and_build_context(chat_params=chat_params)
181191

182192
chat_completion_async_stream: AsyncStream[
183193
ChatCompletionChunk
184194
] = await self.openai_chat_client.chat.completions.create(
185195
# Azure OpenAI takes the deployment name as the model name
186196
model=self.chat_deployment if self.chat_deployment else self.chat_model,
187197
messages=contextual_messages,
188-
temperature=overrides.get("temperature", 0.3),
189-
max_tokens=response_token_limit,
198+
temperature=chat_params.temperature,
199+
max_tokens=chat_params.response_token_limit,
190200
n=1,
191201
stream=True,
192202
)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /