Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 292517b

Browse files
Refactor, please mypy
1 parent 33b5295 commit 292517b

File tree

19 files changed

+283
-93
lines changed

19 files changed

+283
-93
lines changed

‎.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@
3636
"htmlcov": true,
3737
".mypy_cache": true,
3838
".coverage": true
39-
}
39+
},
40+
"python.REPL.enableREPLSmartSend": false
4041
}

‎requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ pytest-snapshot
1414
locust
1515
psycopg2
1616
dotenv-azd
17+
freezegun

‎src/backend/fastapi_app/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def lifespan(app: fastapi.FastAPI) -> AsyncIterator[State]:
3838
if (
3939
os.getenv("OPENAI_CHAT_HOST") == "azure"
4040
or os.getenv("OPENAI_EMBED_HOST") == "azure"
41-
or os.getenv("POSTGRES_HOST").endswith(".database.azure.com")
41+
or os.getenv("POSTGRES_HOST", "").endswith(".database.azure.com")
4242
):
4343
azure_credential = await get_azure_credential()
4444
engine = await create_postgres_engine_from_env(azure_credential)

‎src/backend/fastapi_app/api_models.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from enum import Enum
2-
from typing import Any, Optional
2+
from typing import Any, Optional, Union
33

44
from openai.types.chat import ChatCompletionMessageParam
5-
from pydantic import BaseModel
5+
from pydantic import BaseModel, Field
6+
from pydantic_ai.messages import ModelRequest, ModelResponse
67

78

89
class AIChatRoles(str, Enum):
@@ -95,4 +96,33 @@ class ChatParams(ChatRequestOverrides):
9596
enable_text_search: bool
9697
enable_vector_search: bool
9798
original_user_query: str
98-
past_messages: list[ChatCompletionMessageParam]
99+
past_messages: list[Union[ModelRequest, ModelResponse]]
100+
101+
102+
class Filter(BaseModel):
103+
column: str
104+
comparison_operator: str
105+
value: Any
106+
107+
108+
class PriceFilter(Filter):
109+
column: str = Field(default="price", description="The column to filter on (always 'price' for this filter)")
110+
comparison_operator: str = Field(description="The operator for price comparison ('>', '<', '>=', '<=', '=')")
111+
value: float = Field(description="The price value to compare against (e.g., 30.00)")
112+
113+
114+
class BrandFilter(Filter):
115+
column: str = Field(default="brand", description="The column to filter on (always 'brand' for this filter)")
116+
comparison_operator: str = Field(description="The operator for brand comparison ('=' or '!=')")
117+
value: str = Field(description="The brand name to compare against (e.g., 'AirStrider')")
118+
119+
120+
class SearchResults(BaseModel):
121+
query: str
122+
"""The original search query"""
123+
124+
items: list[ItemPublic]
125+
"""List of items that match the search query and filters"""
126+
127+
filters: list[Filter]
128+
"""List of filters applied to the search results"""

‎src/backend/fastapi_app/openai_clients.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
async def create_openai_chat_client(
12-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
12+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
1313
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
1414
openai_chat_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
1515
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
@@ -29,7 +29,7 @@ async def create_openai_chat_client(
2929
azure_deployment=azure_deployment,
3030
api_key=api_key,
3131
)
32-
else:
32+
elifazure_credential:
3333
logger.info(
3434
"Setting up Azure OpenAI client for chat completions using Azure Identity, endpoint %s, deployment %s",
3535
azure_endpoint,
@@ -44,6 +44,8 @@ async def create_openai_chat_client(
4444
azure_deployment=azure_deployment,
4545
azure_ad_token_provider=token_provider,
4646
)
47+
else:
48+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
4749
elif OPENAI_CHAT_HOST == "ollama":
4850
logger.info("Setting up OpenAI client for chat completions using Ollama")
4951
openai_chat_client = openai.AsyncOpenAI(
@@ -67,7 +69,7 @@ async def create_openai_chat_client(
6769

6870

6971
async def create_openai_embed_client(
70-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
72+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
7173
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
7274
openai_embed_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
7375
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
@@ -87,7 +89,7 @@ async def create_openai_embed_client(
8789
azure_deployment=azure_deployment,
8890
api_key=api_key,
8991
)
90-
else:
92+
elifazure_credential:
9193
logger.info(
9294
"Setting up Azure OpenAI client for embeddings using Azure Identity, endpoint %s, deployment %s",
9395
azure_endpoint,
@@ -102,6 +104,8 @@ async def create_openai_embed_client(
102104
azure_deployment=azure_deployment,
103105
azure_ad_token_provider=token_provider,
104106
)
107+
else:
108+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
105109
elif OPENAI_EMBED_HOST == "ollama":
106110
logger.info("Setting up OpenAI client for embeddings using Ollama")
107111
openai_embed_client = openai.AsyncOpenAI(

‎src/backend/fastapi_app/postgres_searcher.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy import Float, Integer, column, select, text
66
from sqlalchemy.ext.asyncio import AsyncSession
77

8+
from fastapi_app.api_models import Filter
89
from fastapi_app.embeddings import compute_text_embedding
910
from fastapi_app.postgres_models import Item
1011

@@ -26,21 +27,24 @@ def __init__(
2627
self.embed_dimensions = embed_dimensions
2728
self.embedding_column = embedding_column
2829

29-
def build_filter_clause(self, filters) -> tuple[str, str]:
30+
def build_filter_clause(self, filters: Optional[list[Filter]]) -> tuple[str, str]:
3031
if filters is None:
3132
return "", ""
3233
filter_clauses = []
3334
for filter in filters:
34-
if isinstance(filter["value"], str):
35-
filter["value"] = f"'{filter['value']}'"
36-
filter_clauses.append(f"{filter['column']} {filter['comparison_operator']} {filter['value']}")
35+
filter_value = f"'{filter.value}'" if isinstance(filter.value, str) else filter.value
36+
filter_clauses.append(f"{filter.column} {filter.comparison_operator} {filter_value}")
3737
filter_clause = " AND ".join(filter_clauses)
3838
if len(filter_clause) > 0:
3939
return f"WHERE {filter_clause}", f"AND {filter_clause}"
4040
return "", ""
4141

4242
async def search(
43-
self, query_text: Optional[str], query_vector: list[float], top: int = 5, filters: Optional[list[dict]] = None
43+
self,
44+
query_text: Optional[str],
45+
query_vector: list[float],
46+
top: int = 5,
47+
filters: Optional[list[Filter]] = None,
4448
):
4549
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
4650
table_name = Item.__tablename__
@@ -106,7 +110,7 @@ async def search_and_embed(
106110
top: int = 5,
107111
enable_vector_search: bool = False,
108112
enable_text_search: bool = False,
109-
filters: Optional[list[dict]] = None,
113+
filters: Optional[list[Filter]] = None,
110114
) -> list[Item]:
111115
"""
112116
Search rows by query text. Optionally converts the query text to a vector if enable_vector_search is True.

‎src/backend/fastapi_app/rag_advanced.py

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import AsyncGenerator
2-
from typing import Optional, TypedDict, Union
2+
from typing import Optional, Union
33

44
from openai import AsyncAzureOpenAI, AsyncOpenAI
55
from openai.types.chat import ChatCompletionMessageParam
@@ -11,51 +11,22 @@
1111

1212
from fastapi_app.api_models import (
1313
AIChatRoles,
14+
BrandFilter,
1415
ChatRequestOverrides,
16+
Filter,
1517
ItemPublic,
1618
Message,
19+
PriceFilter,
1720
RAGContext,
1821
RetrievalResponse,
1922
RetrievalResponseDelta,
23+
SearchResults,
2024
ThoughtStep,
2125
)
2226
from fastapi_app.postgres_searcher import PostgresSearcher
2327
from fastapi_app.rag_base import ChatParams, RAGChatBase
2428

2529

26-
class PriceFilter(TypedDict):
27-
column: str = "price"
28-
"""The column to filter on (always 'price' for this filter)"""
29-
30-
comparison_operator: str
31-
"""The operator for price comparison ('>', '<', '>=', '<=', '=')"""
32-
33-
value: float
34-
""" The price value to compare against (e.g., 30.00) """
35-
36-
37-
class BrandFilter(TypedDict):
38-
column: str = "brand"
39-
"""The column to filter on (always 'brand' for this filter)"""
40-
41-
comparison_operator: str
42-
"""The operator for brand comparison ('=' or '!=')"""
43-
44-
value: str
45-
"""The brand name to compare against (e.g., 'AirStrider')"""
46-
47-
48-
class SearchResults(TypedDict):
49-
query: str
50-
"""The original search query"""
51-
52-
items: list[ItemPublic]
53-
"""List of items that match the search query and filters"""
54-
55-
filters: list[Union[PriceFilter, BrandFilter]]
56-
"""List of filters applied to the search results"""
57-
58-
5930
class AdvancedRAGChat(RAGChatBase):
6031
query_prompt_template = open(RAGChatBase.prompts_dir / "query.txt").read()
6132
query_fewshots = open(RAGChatBase.prompts_dir / "query_fewshots.json").read()
@@ -79,9 +50,13 @@ def __init__(
7950
chat_model if chat_deployment is None else chat_deployment,
8051
provider=OpenAIProvider(openai_client=openai_chat_client),
8152
)
82-
self.search_agent = Agent(
53+
self.search_agent = Agent[ChatParams, SearchResults](
8354
pydantic_chat_model,
84-
model_settings=ModelSettings(temperature=0.0, max_tokens=500, seed=self.chat_params.seed),
55+
model_settings=ModelSettings(
56+
temperature=0.0,
57+
max_tokens=500,
58+
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
59+
),
8560
system_prompt=self.query_prompt_template,
8661
tools=[self.search_database],
8762
output_type=SearchResults,
@@ -92,7 +67,7 @@ def __init__(
9267
model_settings=ModelSettings(
9368
temperature=self.chat_params.temperature,
9469
max_tokens=self.chat_params.response_token_limit,
95-
seed=self.chat_params.seed,
70+
**({"seed": self.chat_params.seed} ifself.chat_params.seedisnotNoneelse {}),
9671
),
9772
)
9873

@@ -115,7 +90,7 @@ async def search_database(
11590
List of formatted items that match the search query and filters
11691
"""
11792
# Only send non-None filters
118-
filters = []
93+
filters: list[Filter] = []
11994
if price_filter:
12095
filters.append(price_filter)
12196
if brand_filter:
@@ -134,12 +109,12 @@ async def search_database(
134109
async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
135110
few_shots = ModelMessagesTypeAdapter.validate_json(self.query_fewshots)
136111
user_query = f"Find search results for user query: {self.chat_params.original_user_query}"
137-
results = await self.search_agent.run(
112+
results = await self.search_agent.run(# type: ignore[call-overload]
138113
user_query,
139114
message_history=few_shots + self.chat_params.past_messages,
140115
deps=self.chat_params,
141116
)
142-
items = results.output["items"]
117+
items = results.output.items
143118
thoughts = [
144119
ThoughtStep(
145120
title="Prompt to generate search arguments",
@@ -148,12 +123,12 @@ async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
148123
),
149124
ThoughtStep(
150125
title="Search using generated search arguments",
151-
description=results.output["query"],
126+
description=results.output.query,
152127
props={
153128
"top": self.chat_params.top,
154129
"vector_search": self.chat_params.enable_vector_search,
155130
"text_search": self.chat_params.enable_text_search,
156-
"filters": results.output["filters"],
131+
"filters": results.output.filters,
157132
},
158133
),
159134
ThoughtStep(

‎src/backend/fastapi_app/rag_base.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pathlib
22
from abc import ABC, abstractmethod
33
from collections.abc import AsyncGenerator
4+
from typing import Union
45

56
from openai.types.chat import ChatCompletionMessageParam
7+
from pydantic_ai.messages import ModelRequest, ModelResponse, TextPart, UserPromptPart
68

79
from fastapi_app.api_models import (
810
ChatParams,
@@ -12,7 +14,6 @@
1214
RetrievalResponseDelta,
1315
ThoughtStep,
1416
)
15-
from fastapi_app.postgres_models import Item
1617

1718

1819
class RAGChatBase(ABC):
@@ -31,7 +32,19 @@ def get_chat_params(
3132
original_user_query = messages[-1]["content"]
3233
if not isinstance(original_user_query, str):
3334
raise ValueError("The most recent message content must be a string.")
34-
past_messages = messages[:-1]
35+
36+
# Convert to PydanticAI format:
37+
past_messages: list[Union[ModelRequest, ModelResponse]] = []
38+
for message in messages[:-1]:
39+
content = message["content"]
40+
if not isinstance(content, str):
41+
raise ValueError("All messages must have string content.")
42+
if message["role"] == "user":
43+
past_messages.append(ModelRequest(parts=[UserPromptPart(content=content)]))
44+
elif message["role"] == "assistant":
45+
past_messages.append(ModelResponse(parts=[TextPart(content=content)]))
46+
else:
47+
raise ValueError(f"Cannot convert message: {message}")
3548

3649
return ChatParams(
3750
top=overrides.top,
@@ -48,9 +61,7 @@ def get_chat_params(
4861
)
4962

5063
@abstractmethod
51-
async def prepare_context(
52-
self, chat_params: ChatParams
53-
) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]:
64+
async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
5465
raise NotImplementedError
5566

5667
def prepare_rag_request(self, user_query, items: list[ItemPublic]) -> str:
@@ -60,19 +71,15 @@ def prepare_rag_request(self, user_query, items: list[ItemPublic]) -> str:
6071
@abstractmethod
6172
async def answer(
6273
self,
63-
chat_params: ChatParams,
64-
contextual_messages: list[ChatCompletionMessageParam],
65-
results: list[Item],
74+
items: list[ItemPublic],
6675
earlier_thoughts: list[ThoughtStep],
6776
) -> RetrievalResponse:
6877
raise NotImplementedError
6978

7079
@abstractmethod
7180
async def answer_stream(
7281
self,
73-
chat_params: ChatParams,
74-
contextual_messages: list[ChatCompletionMessageParam],
75-
results: list[Item],
82+
items: list[ItemPublic],
7683
earlier_thoughts: list[ThoughtStep],
7784
) -> AsyncGenerator[RetrievalResponseDelta, None]:
7885
raise NotImplementedError

‎src/backend/fastapi_app/rag_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
model_settings=ModelSettings(
4949
temperature=self.chat_params.temperature,
5050
max_tokens=self.chat_params.response_token_limit,
51-
seed=self.chat_params.seed,
51+
**({"seed": self.chat_params.seed} ifself.chat_params.seedisnotNoneelse {}),
5252
),
5353
)
5454

‎src/backend/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
"opentelemetry-instrumentation-sqlalchemy",
2020
"opentelemetry-instrumentation-aiohttp-client",
2121
"opentelemetry-instrumentation-openai",
22-
"pydantic-ai"
22+
"pydantic-ai-slim[openai]"
2323
]
2424

2525
[build-system]

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /