From 1bec14b678618e1671f6c0ac4916eb9ba01c57d4 Mon Sep 17 00:00:00 2001 From: jiangmin Date: 2025年3月13日 18:51:31 +0800 Subject: [PATCH 1/3] =?UTF-8?q?bug=E4=BF=AE=E5=A4=8D:=E4=BF=AE=E8=AE=A2?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=92=8C=E5=AF=BC=E5=8C=85=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_process/{open_benchmarks => }/main.py | 8 ++++---- data_process/open_benchmarks/README.md | 2 +- data_process/open_benchmarks/dataset_utils/hotpotqa.py | 2 +- data_process/open_benchmarks/dataset_utils/musique.py | 2 +- data_process/open_benchmarks/dataset_utils/nq.py | 2 +- data_process/open_benchmarks/dataset_utils/popqa.py | 2 +- data_process/open_benchmarks/dataset_utils/triviaqa.py | 2 +- data_process/open_benchmarks/dataset_utils/two_wiki.py | 4 ++-- data_process/open_benchmarks/dataset_utils/webqa.py | 2 +- data_process/open_benchmarks/reformat_dataset.py | 2 +- data_process/open_benchmarks/sample_dataset.py | 8 ++++---- data_process/open_benchmarks/utils/filepaths.py | 2 +- data_process/open_benchmarks/utils/wikidata.py | 2 +- data_process/open_benchmarks/utils/wikipedia.py | 2 +- 14 files changed, 21 insertions(+), 21 deletions(-) rename data_process/{open_benchmarks => }/main.py (91%) diff --git a/data_process/open_benchmarks/main.py b/data_process/main.py similarity index 91% rename from data_process/open_benchmarks/main.py rename to data_process/main.py index 9e56a35..c7671d3 100644 --- a/data_process/open_benchmarks/main.py +++ b/data_process/main.py @@ -8,10 +8,10 @@ import yaml -from data_process.reformat_dataset import reformat_dataset -from data_process.sample_dataset import sample_datasets -from data_process.utils.filepaths import get_dataset_dir, get_document_dir, get_split_filepath -from data_process.utils.stats import check_dataset_split +from open_benchmarks.reformat_dataset import reformat_dataset +from open_benchmarks.sample_dataset import sample_datasets +from open_benchmarks.utils.filepaths import get_dataset_dir, get_document_dir, get_split_filepath +from open_benchmarks.utils.stats import check_dataset_split def load_yaml_config(config_path: str, args: argparse.Namespace) -> dict: diff --git a/data_process/open_benchmarks/README.md b/data_process/open_benchmarks/README.md index 3af91fb..c5050b2 100644 --- a/data_process/open_benchmarks/README.md +++ b/data_process/open_benchmarks/README.md @@ -5,7 +5,7 @@ Assume that you are in the directory `data_process/`: ```sh -python main.py config/datasets.yaml +python main.py open_benchmarks/config/datasets.yaml ``` ## QA Protocol Overview diff --git a/data_process/open_benchmarks/dataset_utils/hotpotqa.py b/data_process/open_benchmarks/dataset_utils/hotpotqa.py index 227b5ec..f8471a1 100644 --- a/data_process/open_benchmarks/dataset_utils/hotpotqa.py +++ b/data_process/open_benchmarks/dataset_utils/hotpotqa.py @@ -8,7 +8,7 @@ import uuid -from data_process.utils.question_type import infer_question_type +from open_benchmarks.utils.question_type import infer_question_type split2url: Dict[str, str] = { diff --git a/data_process/open_benchmarks/dataset_utils/musique.py b/data_process/open_benchmarks/dataset_utils/musique.py index 8537518..237ba17 100644 --- a/data_process/open_benchmarks/dataset_utils/musique.py +++ b/data_process/open_benchmarks/dataset_utils/musique.py @@ -9,7 +9,7 @@ import uuid import jsonlines -from data_process.utils.question_type import infer_question_type +from open_benchmarks.utils.question_type import infer_question_type zipfile_id = "1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h" diff --git a/data_process/open_benchmarks/dataset_utils/nq.py b/data_process/open_benchmarks/dataset_utils/nq.py index a353633..b225c49 100644 --- a/data_process/open_benchmarks/dataset_utils/nq.py +++ b/data_process/open_benchmarks/dataset_utils/nq.py @@ -8,7 +8,7 @@ from bs4 import BeautifulSoup from datasets import Dataset, load_dataset -from data_process.utils.question_type import infer_nq_question_type +from open_benchmarks.utils.question_type import infer_nq_question_type def clean_text(text: str) -> str: diff --git a/data_process/open_benchmarks/dataset_utils/popqa.py b/data_process/open_benchmarks/dataset_utils/popqa.py index 2600d66..be00388 100644 --- a/data_process/open_benchmarks/dataset_utils/popqa.py +++ b/data_process/open_benchmarks/dataset_utils/popqa.py @@ -8,7 +8,7 @@ from datasets import Dataset, load_dataset from tqdm import tqdm -from data_process.utils.question_type import infer_question_type +from open_benchmarks.utils.question_type import infer_question_type def load_raw_data(dataset_dir: str, split: str) -> Dataset: diff --git a/data_process/open_benchmarks/dataset_utils/triviaqa.py b/data_process/open_benchmarks/dataset_utils/triviaqa.py index 5c76efe..025199e 100644 --- a/data_process/open_benchmarks/dataset_utils/triviaqa.py +++ b/data_process/open_benchmarks/dataset_utils/triviaqa.py @@ -6,7 +6,7 @@ import uuid from datasets import Dataset, load_dataset -from data_process.utils.question_type import infer_question_type +from open_benchmarks.utils.question_type import infer_question_type def load_raw_data(dataset_dir: str, split: str) -> Dataset: diff --git a/data_process/open_benchmarks/dataset_utils/two_wiki.py b/data_process/open_benchmarks/dataset_utils/two_wiki.py index e988ab0..821c4ec 100644 --- a/data_process/open_benchmarks/dataset_utils/two_wiki.py +++ b/data_process/open_benchmarks/dataset_utils/two_wiki.py @@ -10,8 +10,8 @@ import uuid import jsonlines -from data_process.dataset_utils.hotpotqa import get_supporting_facts -from data_process.utils.question_type import infer_question_type +from open_benchmarks.dataset_utils.hotpotqa import get_supporting_facts +from open_benchmarks.utils.question_type import infer_question_type default_name: str = "data_ids_april7.zip?rlkey=u868q6h0jojw4djjg7ea65j46" diff --git a/data_process/open_benchmarks/dataset_utils/webqa.py b/data_process/open_benchmarks/dataset_utils/webqa.py index 541721e..95467ed 100644 --- a/data_process/open_benchmarks/dataset_utils/webqa.py +++ b/data_process/open_benchmarks/dataset_utils/webqa.py @@ -6,7 +6,7 @@ import uuid from datasets import Dataset, load_dataset -from data_process.utils.question_type import infer_question_type +from open_benchmarks.utils.question_type import infer_question_type def load_raw_data(dataset_dir: str, split: str) -> Dataset: diff --git a/data_process/open_benchmarks/reformat_dataset.py b/data_process/open_benchmarks/reformat_dataset.py index 6ebde11..6d3f297 100644 --- a/data_process/open_benchmarks/reformat_dataset.py +++ b/data_process/open_benchmarks/reformat_dataset.py @@ -9,7 +9,7 @@ def get_dataset_utils_module(dataset: str): - module = importlib.import_module(f"data_process.dataset_utils.{dataset}") + module = importlib.import_module(f"open_benchmarks.dataset_utils.{dataset}") return module diff --git a/data_process/open_benchmarks/sample_dataset.py b/data_process/open_benchmarks/sample_dataset.py index d282b0a..6cd77b9 100644 --- a/data_process/open_benchmarks/sample_dataset.py +++ b/data_process/open_benchmarks/sample_dataset.py @@ -8,10 +8,10 @@ import numpy as np from tqdm import tqdm -from data_process.utils.filepaths import get_doc_location_filepath, get_download_filepaths, get_title_status_filepath -from data_process.utils.io import dump_to_json_file, load_from_json_file, dump_to_jsonlines, load_from_jsonlines -from data_process.utils.stats import SOURCE_TYPES_TO_DOWNLOAD -from data_process.utils import wikidata, wikipedia +from open_benchmarks.utils.filepaths import get_doc_location_filepath, get_download_filepaths, get_title_status_filepath +from open_benchmarks.utils.io import dump_to_json_file, load_from_json_file, dump_to_jsonlines, load_from_jsonlines +from open_benchmarks.utils.stats import SOURCE_TYPES_TO_DOWNLOAD +from open_benchmarks.utils import wikidata, wikipedia def load_caches(document_dir: str) -> Tuple[Dict[str, Dict[str, Dict[str, str]]], Dict[str, Dict[str, bool]]]: diff --git a/data_process/open_benchmarks/utils/filepaths.py b/data_process/open_benchmarks/utils/filepaths.py index 0f9f9bd..b0214b2 100644 --- a/data_process/open_benchmarks/utils/filepaths.py +++ b/data_process/open_benchmarks/utils/filepaths.py @@ -4,7 +4,7 @@ import os from typing import Dict, Literal, Optional -from data_process.utils.stats import FILE_TYPES_TO_DOWNLOAD, SOURCE_TYPES_TO_DOWNLOAD +from open_benchmarks.utils.stats import FILE_TYPES_TO_DOWNLOAD, SOURCE_TYPES_TO_DOWNLOAD def get_dataset_dir(root_dir: str, dataset: str) -> str: diff --git a/data_process/open_benchmarks/utils/wikidata.py b/data_process/open_benchmarks/utils/wikidata.py index 1a95175..a5a4a1b 100644 --- a/data_process/open_benchmarks/utils/wikidata.py +++ b/data_process/open_benchmarks/utils/wikidata.py @@ -9,7 +9,7 @@ import aiohttp from bs4 import BeautifulSoup -from data_process.utils.io import dump_bytes_to_file, dump_texts_to_file, async_dump_bytes_to_file +from open_benchmarks.utils.io import dump_bytes_to_file, dump_texts_to_file, async_dump_bytes_to_file def parse_contents(html_content: str) -> dict: diff --git a/data_process/open_benchmarks/utils/wikipedia.py b/data_process/open_benchmarks/utils/wikipedia.py index dd1921a..194ea2e 100644 --- a/data_process/open_benchmarks/utils/wikipedia.py +++ b/data_process/open_benchmarks/utils/wikipedia.py @@ -10,7 +10,7 @@ import wikipediaapi from wikipediaapi import WikipediaPage, WikipediaPageSection -from data_process.utils.io import dump_bytes_to_file, dump_texts_to_file, async_dump_bytes_to_file +from open_benchmarks.utils.io import dump_bytes_to_file, dump_texts_to_file, async_dump_bytes_to_file WIKI_WIKI = wikipediaapi.Wikipedia('Microsoft Research Asia PIKE-RAG', 'en') From 9b0f4c21b8ca67c7dbc6c436ea4baf31a1b11b15 Mon Sep 17 00:00:00 2001 From: Thunder <1639908@qq.com> Date: 2025年3月27日 12:07:28 +0800 Subject: [PATCH 2/3] add ollama LLM client --- pikerag/llm_client/__init__.py | 3 +- pikerag/llm_client/ollama_llm_client.py | 111 ++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 pikerag/llm_client/ollama_llm_client.py diff --git a/pikerag/llm_client/__init__.py b/pikerag/llm_client/__init__.py index e09d590..83e81fb 100644 --- a/pikerag/llm_client/__init__.py +++ b/pikerag/llm_client/__init__.py @@ -5,6 +5,7 @@ from pikerag.llm_client.azure_open_ai_client import AzureOpenAIClient from pikerag.llm_client.base import BaseLLMClient from pikerag.llm_client.hf_meta_llama_client import HFMetaLlamaClient +from pikerag.llm_client.ollama_llm_client import OllamaLLMClient -__all__ = ["AzureMetaLlamaClient", "AzureOpenAIClient", "BaseLLMClient", "HFMetaLlamaClient"] +__all__ = ["AzureMetaLlamaClient", "AzureOpenAIClient", "BaseLLMClient", "HFMetaLlamaClient", "OllamaLLMClient"] diff --git a/pikerag/llm_client/ollama_llm_client.py b/pikerag/llm_client/ollama_llm_client.py new file mode 100644 index 0000000..5a3569b --- /dev/null +++ b/pikerag/llm_client/ollama_llm_client.py @@ -0,0 +1,111 @@ +import json +import os +import re +import time +from typing import Callable, List, Literal, Optional, Union + +from langchain_core.embeddings import Embeddings +from ollama import Client as OllamaClient # Assuming an equivalent Ollama client exists +from pickledb import PickleDB + +from pikerag.llm_client.base import BaseLLMClient +from pikerag.utils.logger import Logger + +def parse_wait_time_from_error(error: Exception) -> Optional[int]: + try: + info_str: str = str(error) + matches = re.search(r"Try again in (\d+) seconds", info_str) + if matches: + return int(matches.group(1)) + 3 # Wait an additional 3 seconds + except Exception: + pass + return None + +class OllamaLLMClient(BaseLLMClient): + NAME = "OllamaLLMClient" + + def __init__( + self, location: str = None, auto_dump: bool = True, logger: Logger = None, + max_attempt: int = 5, exponential_backoff_factor: int = None, unit_wait_time: int = 60, **kwargs, + ) -> None: + super().__init__(location, auto_dump, logger, max_attempt, exponential_backoff_factor, unit_wait_time, **kwargs) + + # client_configs = kwargs.get("client_config", {}) + # print(client_configs.get("OLLAMA_HOST", None) ) + # if client_configs.get("OLLAMA_HOST", None) is None and os.environ.get("OLLAMA_HOST", None) is None: + # client_configs["base_url"] = "http://10.5.108.210:11434" + base_url = os.environ.get("OLLAMA_HOST") + + self._client = OllamaClient(**kwargs.get("client_config", {})) + + #self._client = OllamaClient(**kwargs.get("client_config", {})) + + def _get_response_with_messages(self, messages: List[dict], **llm_config) -> dict: + response = None + num_attempt = 0 + while num_attempt < self._max_attempt: + try: + response = self._client.chat(messages=messages, **llm_config) + break + except Exception as e: + self.warning(f"Request failed due to: {e}") + num_attempt += 1 + wait_time = parse_wait_time_from_error(e) or (self._unit_wait_time * num_attempt) + time.sleep(wait_time) + self.warning("Retrying...") + return response + + def _get_content_from_response(self, response: dict, messages: List[dict] = None) -> str: + try: + resp = response.get("message", {}).get("content", "") + print(resp) + #resp = response.get("message", {}).get("content", "")[0] + return resp + except Exception as e: + self.warning(f"Error extracting content: {e}") + return "" + + def close(self): + super().close() + +class OllamaEmbedding(Embeddings): + def __init__(self, **kwargs) -> None: + client_configs = kwargs.get("client_config", {}) + base_url = client_configs.get("OLLAMA_URL") + #model = client_configs.get("OLLAMA_MODEL") + embed_model = client_configs.get("OLLAMA_EMBED_MODEL") + + self._client = OllamaClient(base_url=base_url, model=embed_model) + self._model = kwargs.get("model", "nomic-embed-text:latest") + cache_config = kwargs.get("cache_config", {}) + self._cache = PickleDB(location=cache_config.get("location")) if cache_config.get("location") else None + + def _save_cache(self, query: str, embedding: List[float]) -> None: + if self._cache: + self._cache.set(query, embedding) + + def _get_cache(self, query: str) -> Union[List[float], Literal[False]]: + return self._cache.get(query) if self._cache else False + + def _get_response(self, texts: Union[str, List[str]]) -> dict: + while True: + try: + return self._client.embeddings(input=texts, model=self._model) + except Exception as e: + wait_time = parse_wait_time_from_error(e) or 30 + self.warning(f"Embedding request failed: {e}, waiting {wait_time} seconds...") + time.sleep(wait_time) + + def embed_documents(self, texts: List[str], batch_call: bool = False) -> List[List[float]]: + if batch_call: + response = self._get_response(texts) + return [res["embedding"] for res in response["data"]] + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + embedding = self._get_cache(text) + if embedding is False: + response = self._get_response(text) + embedding = response["data"][0]["embedding"] + self._save_cache(text, embedding) + return embedding From e80af9cd69584717390b3fcfa719232a6447c655 Mon Sep 17 00:00:00 2001 From: Thunder <1639908@qq.com> Date: 2025年3月27日 15:07:59 +0800 Subject: [PATCH 3/3] 20250327 --- pikerag/llm_client/azure_open_ai_client.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pikerag/llm_client/azure_open_ai_client.py b/pikerag/llm_client/azure_open_ai_client.py index a920fb1..80b0350 100644 --- a/pikerag/llm_client/azure_open_ai_client.py +++ b/pikerag/llm_client/azure_open_ai_client.py @@ -66,10 +66,13 @@ def __init__( super().__init__(location, auto_dump, logger, max_attempt, exponential_backoff_factor, unit_wait_time, **kwargs) client_configs = kwargs.get("client_config", {}) - if client_configs.get("api_key", None) is None and os.environ.get("AZURE_OPENAI_API_KEY", None) is None: - client_configs["azure_ad_token_provider"] = get_azure_active_directory_token_provider() + # masked for local Ollama + #if client_configs.get("api_key", None) is None and os.environ.get("AZURE_OPENAI_API_KEY", None) is None: + # client_configs["azure_ad_token_provider"] = get_azure_active_directory_token_provider() - self._client = AzureOpenAI(**client_configs) + #self._client = AzureOpenAI(**client_configs) + from openai import OpenAI + self._client = OpenAI(**client_configs) def _get_response_with_messages(self, messages: List[dict], **llm_config) -> ChatCompletion: response: ChatCompletion = None

AltStyle によって変換されたページ (->オリジナル) /