From 9449383699b4e978dc00b9c952b26a71070c681d Mon Sep 17 00:00:00 2001 From: Lpepsi <846179345@qq.com> Date: 2024年2月10日 20:56:13 +0800 Subject: [PATCH] =?UTF-8?q?AI=E6=A8=A1=E5=9E=8B=E5=A2=9E=E5=8A=A0GLM?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AI.py | 35 ----------------------------------- DownloadEmail.py | 1 - ParseInvoice.py | 1 - ai/AI.py | 15 +++++++++++++++ ai/AIRegister.py | 25 +++++++++++++++++++++++++ ai/GLM.py | 36 ++++++++++++++++++++++++++++++++++++ ai/GPT.py | 37 +++++++++++++++++++++++++++++++++++++ main.py | 24 +++++++++++------------- system_prompt.py | 7 ++++++- 9 files changed, 130 insertions(+), 51 deletions(-) delete mode 100644 AI.py create mode 100644 ai/AI.py create mode 100644 ai/AIRegister.py create mode 100644 ai/GLM.py create mode 100644 ai/GPT.py diff --git a/AI.py b/AI.py deleted file mode 100644 index 4c174f2..0000000 --- a/AI.py +++ /dev/null @@ -1,35 +0,0 @@ -import json - -from openai import OpenAI - -import system_prompt - -client = OpenAI(api_key='') - - -def ai_summary(user_input): - response = client.chat.completions.create( - model="gpt-3.5-turbo-0125", - response_format={"type": "json_object"}, - messages=[ - {"role": "system", - "content": "你是一名AI助手,根据用户输入的项目名称进行总结,只能从餐饮,交通,其他三个词中,选出一个词来总结该项目名称属于哪种类型。只需要返回类型即可,不需要其他内容。结果输出为JSON:{'type':'xxx'}"}, - {"role": "user", "content": user_input}, - ] - ) - return json.loads(response.choices[0].message.content)['type'] - - -def find_text(user_input): - response = client.chat.completions.create( - model="gpt-3.5-turbo-0125", - response_format={"type": "json_object"}, - messages=[ - {"role": "system", - "content": system_prompt.prompt}, - {"role": "user", "content": user_input}, - ] - ) - result = json.loads(response.choices[0].message.content) - print(result['text']) - return result['text'] diff --git a/DownloadEmail.py b/DownloadEmail.py index 4dce638..728a23d 100644 --- a/DownloadEmail.py +++ b/DownloadEmail.py @@ -5,7 +5,6 @@ from selenium.webdriver.support import expected_conditions as EC from watchdog.observers import Observer -import AI import Constant import ParseInvoice import config diff --git a/ParseInvoice.py b/ParseInvoice.py index 59f8a8e..ef7cf7c 100644 --- a/ParseInvoice.py +++ b/ParseInvoice.py @@ -1,7 +1,6 @@ import json import os.path -import AI import OcrUtil import config import util diff --git a/ai/AI.py b/ai/AI.py new file mode 100644 index 0000000..0aede1d --- /dev/null +++ b/ai/AI.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod + + +class AbsAI(ABC): + + def __init__(self): + self.client = None + + @abstractmethod + def ai_summary(self, user_input): + pass + + @abstractmethod + def find_text(self, user_input): + pass diff --git a/ai/AIRegister.py b/ai/AIRegister.py new file mode 100644 index 0000000..7a1c4f7 --- /dev/null +++ b/ai/AIRegister.py @@ -0,0 +1,25 @@ +class AIRegister: + _registry = {} + + @classmethod + def register(cls, ai_type, ai_class): + if ai_type in cls._registry: + raise ValueError(f"AI type {ai_type} already registered.") + cls._registry[ai_type] = ai_class + + @classmethod + def get_ai(cls, ai_type): + if ai_type not in cls._registry: + raise ValueError(f"AI type {ai_type} not found.") + ai_class = cls._registry[ai_type] + clazz = ai_class() + print(clazz) + return clazz + + +def register_ai(ai_type): + def decorator(ai_class): + print(type(ai_class)) + AIRegister.register(ai_type, ai_class) + return ai_class + return decorator diff --git a/ai/GLM.py b/ai/GLM.py new file mode 100644 index 0000000..b44441e --- /dev/null +++ b/ai/GLM.py @@ -0,0 +1,36 @@ +from zhipuai import ZhipuAI + +import system_prompt +from ai.AI import AbsAI +from ai.AIRegister import register_ai + + +@register_ai('GLM') +class GLM(AbsAI): + + def __init__(self): + super().__init__() + self.client = ZhipuAI(api_key="") + + def ai_summary(self, user_input): + response = self.client.chat.completions.create( + model="glm-4", # 填写需要调用的模型名称 + messages=[ + {"role": "system", "content": system_prompt.summary_prompt}, + {"role": "user", "content": user_input}, + ] + ) + print(response.choices[0].message) + + def find_text(self, user_input): + response = self.client.chat.completions.create( + model="glm-4", # 填写需要调用的模型名称 + messages=[ + {"role": "system", "content": system_prompt.find_text_prompt}, + {"role": "user", "content": user_input}, + ] + ) + print(response.choices[0].message) + + + diff --git a/ai/GPT.py b/ai/GPT.py new file mode 100644 index 0000000..07c8bdd --- /dev/null +++ b/ai/GPT.py @@ -0,0 +1,37 @@ +import json +from openai import OpenAI +import system_prompt +from ai.AI import AbsAI +from ai.AIRegister import register_ai + +@register_ai('GPT') +class GPT(AbsAI): + + def __init__(self): + super().__init__() + self.client = OpenAI(api_key='') + + @register_ai('GPT') + def ai_summary(self, user_input): + response = self.client.chat.completions.create( + model="gpt-3.5-turbo-0125", + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": system_prompt.summary_prompt}, + {"role": "user", "content": user_input}, + ] + ) + return json.loads(response.choices[0].message.content)['type'] + + def find_text(self, user_input): + response = self.client.chat.completions.create( + model="gpt-3.5-turbo-0125", + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": system_prompt.find_text_prompt}, + {"role": "user", "content": user_input}, + ] + ) + result = json.loads(response.choices[0].message.content) + print(result['text']) + return result['text'] diff --git a/main.py b/main.py index 3485477..bf6d37c 100644 --- a/main.py +++ b/main.py @@ -1,23 +1,21 @@ -import os -import pathlib import time -import AI -import DownloadEmail -import OcrUtil -import ParseInvoice -import config -import user_input +from ai.AIRegister import AIRegister +from ai.GPT import GPT +# import DownloadEmail if __name__ == '__main__': # OcrUtil.ocr_invoice('./dzfp_24442000000000377692_20240101190922.pdf') # ParseInvoice.parse_invoice() # AI.ai_summary('汽车92号汽油费') - start = time.time() - DownloadEmail.download_email() - end = time.time() - print(f'执行时长: {end - start}') + # start = time.time() + # DownloadEmail.download_email() + # end = time.time() + # print(f'执行时长: {end - start}') # AI.find_url(user_input.input) # setting.parse_setting() # print(pathlib.Path('./dzfp_24442000000000377692_20240101190922.pdf').suffix) - # os.rename('./dzfp_24442000000000377692_20240101190922.pdf','./newfile.pdf') \ No newline at end of file + # os.rename('./dzfp_24442000000000377692_20240101190922.pdf','./newfile.pdf') + AIRegister.get_ai('GPT') + print(type(GPT())) + diff --git a/system_prompt.py b/system_prompt.py index a3d4abc..39b4030 100644 --- a/system_prompt.py +++ b/system_prompt.py @@ -1,4 +1,4 @@ -prompt = ''' +find_text_prompt = ''' 你是一名HTML解析助手,你需要解析用户上传的HTML片段。 1.解析出片段中带有发票下载链接的超链接标签文本。 2.如果有多个下载链接,则找出下载为PDF格式的超链接标签文本即可。 @@ -12,4 +12,9 @@ 输出: {"text":"HelloWorld"} 注意只需要返回对应的标签文本即可,不需要其他内容。结果输出为JSON:{'text':'xxx'},完成之后,我将会给你10美元"} +''' + +summary_prompt = ''' +你是一名AI助手,根据用户输入的项目名称进行总结。 +只能从餐饮,交通,其他三个词中,选出一个词来总结该项目名称属于哪种类型。只需要返回类型即可,不需要其他内容。结果输出为JSON:{'type':'xxx'} ''' \ No newline at end of file

AltStyle によって変換されたページ (->オリジナル) /