Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit adb2b2e

Browse files
Make mypy happy
1 parent 31e7cc1 commit adb2b2e

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

‎evals/generate_ground_truth.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from azure.identity import AzureDeveloperCliCredential, get_bearer_token_provider
88
from dotenv_azd import load_azd_env
99
from openai import AzureOpenAI, OpenAI
10+
from openai.types.chat import ChatCompletionToolParam
1011
from sqlalchemy import create_engine, select
1112
from sqlalchemy.orm import Session
1213

@@ -15,7 +16,7 @@
1516
logger = logging.getLogger("ragapp")
1617

1718

18-
def qa_pairs_tool(num_questions: int = 1) -> dict:
19+
def qa_pairs_tool(num_questions: int = 1) -> ChatCompletionToolParam:
1920
return {
2021
"type": "function",
2122
"function": {
@@ -45,7 +46,7 @@ def qa_pairs_tool(num_questions: int = 1) -> dict:
4546
}
4647

4748

48-
def source_retriever() -> Generator[dict, None, None]:
49+
def source_retriever() -> Generator[str, None, None]:
4950
# Connect to the database
5051
DBHOST = os.environ["POSTGRES_HOST"]
5152
DBUSER = os.environ["POSTGRES_USERNAME"]
@@ -76,8 +77,9 @@ def answer_formatter(answer, source) -> str:
7677
return f"{answer} [{source['id']}]"
7778

7879

79-
def get_openai_client() -> AzureOpenAI | OpenAI:
80+
def get_openai_client() -> tuple[AzureOpenAI | OpenAI, str]:
8081
"""Return an OpenAI client based on the environment variables"""
82+
openai_client: AzureOpenAI | OpenAI
8183
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
8284
if OPENAI_CHAT_HOST == "azure":
8385
if api_key := os.getenv("AZURE_OPENAI_KEY"):
@@ -101,8 +103,7 @@ def get_openai_client() -> AzureOpenAI | OpenAI:
101103
raise NotImplementedError("Ollama OpenAI Service is not supported. Switch to Azure or OpenAI.com")
102104
else:
103105
logger.info("Using OpenAI Service with API Key from OPENAICOM_KEY")
104-
openai_config = {"api_type": "openai", "api_key": os.environ["OPENAICOM_KEY"]}
105-
openai_client = OpenAI(**openai_config)
106+
openai_client = OpenAI(api_key=os.environ["OPENAICOM_KEY"])
106107
model = os.environ["OPENAICOM_CHAT_MODEL"]
107108
return openai_client, model
108109

@@ -127,6 +128,9 @@ def generate_ground_truth_data(num_questions_total: int, num_questions_per_sourc
127128
],
128129
tools=[qa_pairs_tool(num_questions=2)],
129130
)
131+
if not result.choices[0].message.tool_calls:
132+
logger.warning("No tool calls found in response, skipping")
133+
continue
130134
qa_pairs = json.loads(result.choices[0].message.tool_calls[0].function.arguments)["qa_list"]
131135
qa_pairs = [{"question": qa_pair["question"], "truth": qa_pair["answer"]} for qa_pair in qa_pairs]
132136
qa.extend(qa_pairs)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /