1
1
from enum import Enum
2
- from typing import Any , Optional
2
+ from typing import Any , Optional , Union
3
3
4
4
from openai .types .chat import ChatCompletionMessageParam
5
- from pydantic import BaseModel
5
+ from pydantic import BaseModel , Field
6
+ from pydantic_ai .messages import ModelRequest , ModelResponse
6
7
7
8
8
9
class AIChatRoles (str , Enum ):
@@ -41,14 +42,34 @@ class ChatRequest(BaseModel):
41
42
sessionState : Optional [Any ] = None
42
43
43
44
45
+ class ItemPublic (BaseModel ):
46
+ id : int
47
+ type : str
48
+ brand : str
49
+ name : str
50
+ description : str
51
+ price : float
52
+
53
+ def to_str_for_rag (self ):
54
+ return f"Name:{ self .name } Description:{ self .description } Price:{ self .price } Brand:{ self .brand } Type:{ self .type } "
55
+
56
+
57
+ class ItemWithDistance (ItemPublic ):
58
+ distance : float
59
+
60
+ def __init__ (self , ** data ):
61
+ super ().__init__ (** data )
62
+ self .distance = round (self .distance , 2 )
63
+
64
+
44
65
class ThoughtStep (BaseModel ):
45
66
title : str
46
67
description : Any
47
68
props : dict = {}
48
69
49
70
50
71
class RAGContext (BaseModel ):
51
- data_points : dict [int , dict [ str , Any ] ]
72
+ data_points : dict [int , ItemPublic ]
52
73
thoughts : list [ThoughtStep ]
53
74
followup_questions : Optional [list [str ]] = None
54
75
@@ -69,27 +90,39 @@ class RetrievalResponseDelta(BaseModel):
69
90
sessionState : Optional [Any ] = None
70
91
71
92
72
- class ItemPublic (BaseModel ):
73
- id : int
74
- type : str
75
- brand : str
76
- name : str
77
- description : str
78
- price : float
79
-
80
-
81
- class ItemWithDistance (ItemPublic ):
82
- distance : float
83
-
84
- def __init__ (self , ** data ):
85
- super ().__init__ (** data )
86
- self .distance = round (self .distance , 2 )
87
-
88
-
89
93
class ChatParams (ChatRequestOverrides ):
90
94
prompt_template : str
91
95
response_token_limit : int = 1024
92
96
enable_text_search : bool
93
97
enable_vector_search : bool
94
98
original_user_query : str
95
- past_messages : list [ChatCompletionMessageParam ]
99
+ past_messages : list [Union [ModelRequest , ModelResponse ]]
100
+
101
+
102
+ class Filter (BaseModel ):
103
+ column : str
104
+ comparison_operator : str
105
+ value : Any
106
+
107
+
108
+ class PriceFilter (Filter ):
109
+ column : str = Field (default = "price" , description = "The column to filter on (always 'price' for this filter)" )
110
+ comparison_operator : str = Field (description = "The operator for price comparison ('>', '<', '>=', '<=', '=')" )
111
+ value : float = Field (description = "The price value to compare against (e.g., 30.00)" )
112
+
113
+
114
+ class BrandFilter (Filter ):
115
+ column : str = Field (default = "brand" , description = "The column to filter on (always 'brand' for this filter)" )
116
+ comparison_operator : str = Field (description = "The operator for brand comparison ('=' or '!=')" )
117
+ value : str = Field (description = "The brand name to compare against (e.g., 'AirStrider')" )
118
+
119
+
120
+ class SearchResults (BaseModel ):
121
+ query : str
122
+ """The original search query"""
123
+
124
+ items : list [ItemPublic ]
125
+ """List of items that match the search query and filters"""
126
+
127
+ filters : list [Filter ]
128
+ """List of filters applied to the search results"""
0 commit comments