Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add local ollama models support #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
leichangqing wants to merge 4 commits into microsoft:dev
base: dev
Choose a base branch
Loading
from leichangqing:main
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions data_process/open_benchmarks/main.py → data_process/main.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

import yaml

from data_process.reformat_dataset import reformat_dataset
from data_process.sample_dataset import sample_datasets
from data_process.utils.filepaths import get_dataset_dir, get_document_dir, get_split_filepath
from data_process.utils.stats import check_dataset_split
from open_benchmarks.reformat_dataset import reformat_dataset
from open_benchmarks.sample_dataset import sample_datasets
from open_benchmarks.utils.filepaths import get_dataset_dir, get_document_dir, get_split_filepath
from open_benchmarks.utils.stats import check_dataset_split


def load_yaml_config(config_path: str, args: argparse.Namespace) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion data_process/open_benchmarks/README.md
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Assume that you are in the directory `data_process/`:

```sh
python main.py config/datasets.yaml
python main.py open_benchmarks/config/datasets.yaml
```

## QA Protocol Overview
Expand Down
2 changes: 1 addition & 1 deletion data_process/open_benchmarks/dataset_utils/hotpotqa.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import uuid

from data_process.utils.question_type import infer_question_type
from open_benchmarks.utils.question_type import infer_question_type


split2url: Dict[str, str] = {
Expand Down
2 changes: 1 addition & 1 deletion data_process/open_benchmarks/dataset_utils/musique.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import uuid
import jsonlines

from data_process.utils.question_type import infer_question_type
from open_benchmarks.utils.question_type import infer_question_type


zipfile_id = "1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h"
Expand Down
2 changes: 1 addition & 1 deletion data_process/open_benchmarks/dataset_utils/nq.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bs4 import BeautifulSoup
from datasets import Dataset, load_dataset

from data_process.utils.question_type import infer_nq_question_type
from open_benchmarks.utils.question_type import infer_nq_question_type


def clean_text(text: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion data_process/open_benchmarks/dataset_utils/popqa.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datasets import Dataset, load_dataset
from tqdm import tqdm

from data_process.utils.question_type import infer_question_type
from open_benchmarks.utils.question_type import infer_question_type


def load_raw_data(dataset_dir: str, split: str) -> Dataset:
Expand Down
2 changes: 1 addition & 1 deletion data_process/open_benchmarks/dataset_utils/triviaqa.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from datasets import Dataset, load_dataset

from data_process.utils.question_type import infer_question_type
from open_benchmarks.utils.question_type import infer_question_type


def load_raw_data(dataset_dir: str, split: str) -> Dataset:
Expand Down
4 changes: 2 additions & 2 deletions data_process/open_benchmarks/dataset_utils/two_wiki.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import uuid
import jsonlines

from data_process.dataset_utils.hotpotqa import get_supporting_facts
from data_process.utils.question_type import infer_question_type
from open_benchmarks.dataset_utils.hotpotqa import get_supporting_facts
from open_benchmarks.utils.question_type import infer_question_type


default_name: str = "data_ids_april7.zip?rlkey=u868q6h0jojw4djjg7ea65j46"
Expand Down
2 changes: 1 addition & 1 deletion data_process/open_benchmarks/dataset_utils/webqa.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from datasets import Dataset, load_dataset

from data_process.utils.question_type import infer_question_type
from open_benchmarks.utils.question_type import infer_question_type


def load_raw_data(dataset_dir: str, split: str) -> Dataset:
Expand Down
2 changes: 1 addition & 1 deletion data_process/open_benchmarks/reformat_dataset.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def get_dataset_utils_module(dataset: str):
module = importlib.import_module(f"data_process.dataset_utils.{dataset}")
module = importlib.import_module(f"open_benchmarks.dataset_utils.{dataset}")
return module


Expand Down
8 changes: 4 additions & 4 deletions data_process/open_benchmarks/sample_dataset.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import numpy as np
from tqdm import tqdm

from data_process.utils.filepaths import get_doc_location_filepath, get_download_filepaths, get_title_status_filepath
from data_process.utils.io import dump_to_json_file, load_from_json_file, dump_to_jsonlines, load_from_jsonlines
from data_process.utils.stats import SOURCE_TYPES_TO_DOWNLOAD
from data_process.utils import wikidata, wikipedia
from open_benchmarks.utils.filepaths import get_doc_location_filepath, get_download_filepaths, get_title_status_filepath
from open_benchmarks.utils.io import dump_to_json_file, load_from_json_file, dump_to_jsonlines, load_from_jsonlines
from open_benchmarks.utils.stats import SOURCE_TYPES_TO_DOWNLOAD
from open_benchmarks.utils import wikidata, wikipedia


def load_caches(document_dir: str) -> Tuple[Dict[str, Dict[str, Dict[str, str]]], Dict[str, Dict[str, bool]]]:
Expand Down
2 changes: 1 addition & 1 deletion data_process/open_benchmarks/utils/filepaths.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from typing import Dict, Literal, Optional

from data_process.utils.stats import FILE_TYPES_TO_DOWNLOAD, SOURCE_TYPES_TO_DOWNLOAD
from open_benchmarks.utils.stats import FILE_TYPES_TO_DOWNLOAD, SOURCE_TYPES_TO_DOWNLOAD


def get_dataset_dir(root_dir: str, dataset: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion data_process/open_benchmarks/utils/wikidata.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import aiohttp
from bs4 import BeautifulSoup

from data_process.utils.io import dump_bytes_to_file, dump_texts_to_file, async_dump_bytes_to_file
from open_benchmarks.utils.io import dump_bytes_to_file, dump_texts_to_file, async_dump_bytes_to_file


def parse_contents(html_content: str) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion data_process/open_benchmarks/utils/wikipedia.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import wikipediaapi
from wikipediaapi import WikipediaPage, WikipediaPageSection

from data_process.utils.io import dump_bytes_to_file, dump_texts_to_file, async_dump_bytes_to_file
from open_benchmarks.utils.io import dump_bytes_to_file, dump_texts_to_file, async_dump_bytes_to_file


WIKI_WIKI = wikipediaapi.Wikipedia('Microsoft Research Asia PIKE-RAG', 'en')
Expand Down
3 changes: 2 additions & 1 deletion pikerag/llm_client/__init__.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pikerag.llm_client.azure_open_ai_client import AzureOpenAIClient
from pikerag.llm_client.base import BaseLLMClient
from pikerag.llm_client.hf_meta_llama_client import HFMetaLlamaClient
from pikerag.llm_client.ollama_llm_client import OllamaLLMClient


__all__ = ["AzureMetaLlamaClient", "AzureOpenAIClient", "BaseLLMClient", "HFMetaLlamaClient"]
__all__ = ["AzureMetaLlamaClient", "AzureOpenAIClient", "BaseLLMClient", "HFMetaLlamaClient", "OllamaLLMClient"]
9 changes: 6 additions & 3 deletions pikerag/llm_client/azure_open_ai_client.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,13 @@ def __init__(
super().__init__(location, auto_dump, logger, max_attempt, exponential_backoff_factor, unit_wait_time, **kwargs)

client_configs = kwargs.get("client_config", {})
if client_configs.get("api_key", None) is None and os.environ.get("AZURE_OPENAI_API_KEY", None) is None:
client_configs["azure_ad_token_provider"] = get_azure_active_directory_token_provider()
# masked for local Ollama
#if client_configs.get("api_key", None) is None and os.environ.get("AZURE_OPENAI_API_KEY", None) is None:
# client_configs["azure_ad_token_provider"] = get_azure_active_directory_token_provider()

self._client = AzureOpenAI(**client_configs)
#self._client = AzureOpenAI(**client_configs)
from openai import OpenAI
self._client = OpenAI(**client_configs)

def _get_response_with_messages(self, messages: List[dict], **llm_config) -> ChatCompletion:
response: ChatCompletion = None
Expand Down
111 changes: 111 additions & 0 deletions pikerag/llm_client/ollama_llm_client.py
View file Open in desktop
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import json
import os
import re
import time
from typing import Callable, List, Literal, Optional, Union

from langchain_core.embeddings import Embeddings
from ollama import Client as OllamaClient # Assuming an equivalent Ollama client exists
from pickledb import PickleDB

from pikerag.llm_client.base import BaseLLMClient
from pikerag.utils.logger import Logger

def parse_wait_time_from_error(error: Exception) -> Optional[int]:
try:
info_str: str = str(error)
matches = re.search(r"Try again in (\d+) seconds", info_str)
if matches:
return int(matches.group(1)) + 3 # Wait an additional 3 seconds
except Exception:
pass
return None

class OllamaLLMClient(BaseLLMClient):
NAME = "OllamaLLMClient"

def __init__(
self, location: str = None, auto_dump: bool = True, logger: Logger = None,
max_attempt: int = 5, exponential_backoff_factor: int = None, unit_wait_time: int = 60, **kwargs,
) -> None:
super().__init__(location, auto_dump, logger, max_attempt, exponential_backoff_factor, unit_wait_time, **kwargs)

# client_configs = kwargs.get("client_config", {})
# print(client_configs.get("OLLAMA_HOST", None) )
# if client_configs.get("OLLAMA_HOST", None) is None and os.environ.get("OLLAMA_HOST", None) is None:
# client_configs["base_url"] = "http://10.5.108.210:11434"
base_url = os.environ.get("OLLAMA_HOST")

self._client = OllamaClient(**kwargs.get("client_config", {}))

#self._client = OllamaClient(**kwargs.get("client_config", {}))

def _get_response_with_messages(self, messages: List[dict], **llm_config) -> dict:
response = None
num_attempt = 0
while num_attempt < self._max_attempt:
try:
response = self._client.chat(messages=messages, **llm_config)
break
except Exception as e:
self.warning(f"Request failed due to: {e}")
num_attempt += 1
wait_time = parse_wait_time_from_error(e) or (self._unit_wait_time * num_attempt)
time.sleep(wait_time)
self.warning("Retrying...")
return response

def _get_content_from_response(self, response: dict, messages: List[dict] = None) -> str:
try:
resp = response.get("message", {}).get("content", "")
print(resp)
#resp = response.get("message", {}).get("content", "")[0]
return resp
except Exception as e:
self.warning(f"Error extracting content: {e}")
return ""

def close(self):
super().close()

class OllamaEmbedding(Embeddings):
def __init__(self, **kwargs) -> None:
client_configs = kwargs.get("client_config", {})
base_url = client_configs.get("OLLAMA_URL")
#model = client_configs.get("OLLAMA_MODEL")
embed_model = client_configs.get("OLLAMA_EMBED_MODEL")

self._client = OllamaClient(base_url=base_url, model=embed_model)
self._model = kwargs.get("model", "nomic-embed-text:latest")
cache_config = kwargs.get("cache_config", {})
self._cache = PickleDB(location=cache_config.get("location")) if cache_config.get("location") else None

def _save_cache(self, query: str, embedding: List[float]) -> None:
if self._cache:
self._cache.set(query, embedding)

def _get_cache(self, query: str) -> Union[List[float], Literal[False]]:
return self._cache.get(query) if self._cache else False

def _get_response(self, texts: Union[str, List[str]]) -> dict:
while True:
try:
return self._client.embeddings(input=texts, model=self._model)
except Exception as e:
wait_time = parse_wait_time_from_error(e) or 30
self.warning(f"Embedding request failed: {e}, waiting {wait_time} seconds...")
time.sleep(wait_time)

def embed_documents(self, texts: List[str], batch_call: bool = False) -> List[List[float]]:
if batch_call:
response = self._get_response(texts)
return [res["embedding"] for res in response["data"]]
return [self.embed_query(text) for text in texts]

def embed_query(self, text: str) -> List[float]:
embedding = self._get_cache(text)
if embedding is False:
response = self._get_response(text)
embedding = response["data"][0]["embedding"]
self._save_cache(text, embedding)
return embedding

AltStyle によって変換されたページ (->オリジナル) /