From 8b9ea73b3bdd185aa8d0880b932f66b1ab9339ba Mon Sep 17 00:00:00 2001 From: ouyangyouzhang Date: Fri, 15 Aug 2025 10:34:30 +0800 Subject: [PATCH] =?UTF-8?q?refactor(embedding/reranker):=20=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E6=A8=A1=E5=9E=8B=E5=B7=A5=E5=85=B7=E7=B1=BB=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E7=8E=AF=E5=A2=83=E5=8F=98=E9=87=8F=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 更新.gitignore文件,添加新的数据库文件 在.env中添加EMBEDDING_MODEL_NAME和XINFERENCE_URL配置 重构SiliconFlowEmbeddings和XinferenceReRankerModel使用环境变量 优化DifyCompareTest异常处理和输入验证 修改测试文件路径和并发工作数 --- .env | 3 + .gitignore | 4 +- rag2_0/dify/DifyCompareTest.py | 36 ++++++++---- rag2_0/tool/ModelTool.py | 101 +++++---------------------------- 4 files changed, 43 insertions(+), 101 deletions(-) diff --git a/.env b/.env index daaf080..a4f0aa4 100644 --- a/.env +++ b/.env @@ -3,6 +3,9 @@ OPENAI_API_BASE=https://api.siliconflow.cn/v1/ MODEL_NAME=deepseek-ai/DeepSeek-V3 MINI_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct-128K RERANKER_MODEL_NAME=bge-reranker-v2-m3 +EMBEDDING_MODEL_NAME=bge-m3 + +XINFERENCE_URL=http://10.1.16.39:9995 DIFY_BSAE_URL=http://10.1.16.39/v1 DIFY_APP_KEY=app-CPoOMaGDsLRPAe9TW7Xjhszy diff --git a/.gitignore b/.gitignore index c324184..ed766f7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,10 +8,10 @@ data/excel/* rag2_0/demo/Test* rag2_0/demo/test* data/excel/*.xlsx -rag2_0/demo/ProfessionalTermAnalyzer.py data/logs/* rag2_0/dify/Test.py data/query_logs/* data/conversations/* data/test* -data/temp* \ No newline at end of file +data/temp* +data/db/answer_logs.db diff --git a/rag2_0/dify/DifyCompareTest.py b/rag2_0/dify/DifyCompareTest.py index cf3cd0d..d60b0f2 100755 --- a/rag2_0/dify/DifyCompareTest.py +++ b/rag2_0/dify/DifyCompareTest.py @@ -121,26 +121,38 @@ class DifyCompareTest: time.sleep(10) # 等待1秒后重试 def get_wiki_list_by_msgid(self,msg_id): - if msg_id is None or pd.isna(msg_id): + try: + if msg_id is None or pd.isna(msg_id): + return "" + msg_debug_info = self.exporter.dify_tool.get_message_debug_info_by_id(msg_id) + if not msg_debug_info: + return "" + wiki_list = self.exporter.get_wiki_list(msg_debug_info) + if len(wiki_list) == 0: + return "" + else: + return "\n".join(list(set(wiki_list))) + except Exception as e: + logging.error(f"获取词条列表失败: {e}") return "" - msg_debug_info = self.exporter.dify_tool.get_message_debug_info_by_id(msg_id) - if not msg_debug_info: - return "" - wiki_list = self.exporter.get_wiki_list(msg_debug_info) - if len(wiki_list) == 0: - return "" - else: - return "\n".join(list(set(wiki_list))) def process_single_row(self, index, row): """处理单行数据的方法""" try: query = row["提问"] + current_software = row["当前软件"] + if pd.isna(query) or len(query) == 0 or pd.isna(current_software) or len(current_software) == 0: + result_row = row.copy() + result_row["message_id"] = '' + result_row["本次回答"] = '' + result_row["回答对比"] = '' + result_row["检索到的词条"] = '' + return index, result_row + if "回答" in row: old_answer = row["回答"] else: old_answer = "" - current_software = row["当前软件"] inputs = { "current_softname": current_software, @@ -247,14 +259,14 @@ if __name__ == "__main__": # 处理第一个文件 excel_files = [ - ("data/excel/300专业提问.xlsx", "data/excel/300专业提问_问答测试.xlsx"), + ("data/excel/第一轮的专业问题.xlsx", "data/excel/第一轮的专业问题_dify.xlsx"), # ("data/excel/有知识的.xlsx", "data/excel/有知识的_问答测试.xlsx") ] for excel_path, save_path in excel_files: logging.info(f"开始处理文件: {excel_path}") try: - dify_compare_test.run(excel_path=excel_path, save_path=save_path, max_workers=10) + dify_compare_test.run(excel_path=excel_path, save_path=save_path, max_workers=5) logging.info(f"文件处理完成: {excel_path}") except Exception as e: logging.error(f"处理文件 {excel_path} 时出错: {e}") diff --git a/rag2_0/tool/ModelTool.py b/rag2_0/tool/ModelTool.py index c197f6b..0d0c665 100755 --- a/rag2_0/tool/ModelTool.py +++ b/rag2_0/tool/ModelTool.py @@ -23,10 +23,11 @@ from rag2_0.tool.APIKeyManager import APIKeyManager class SiliconFlowEmbeddings(Embeddings): """SiliconFlow嵌入模型封装""" - def __init__(self, api_key: str, model: str = "bge-m3"): + def __init__(self, api_key: str, model: str = os.getenv("EMBEDDING_MODEL_NAME", "bge-m3")): self.api_key = api_key self.model = model - self.url = "http://10.1.16.39:9995/v1/embeddings" + base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995") + self.url = urljoin(base_url.rstrip('/') + '/', 'v1/embeddings') self.headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" @@ -70,83 +71,7 @@ class SiliconFlowEmbeddings(Embeddings): """异步嵌入单个查询""" result = await self._embed_async([text]) return result[0] - -class SiliconFlowReRankerModel: - @staticmethod - def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]: - """ - 使用硅流重排模型对文档进行重新排序 - - Args: - query: 用户查询文本 - documents: 需要重新排序的文档列表 - top_k: 返回排序后的前k个文档 - - Returns: - List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 - """ - url = "https://api.siliconflow.cn/v1/rerank" - payload = { - "model": "BAAI/bge-reranker-v2-m3", - "query": query, - "documents": documents, - "top_n": top_k, - "max_chunks_per_doc": 1024, - "overlap_tokens": 80, - "return_documents": True - } - api_key = APIKeyManager.get_api_key() - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json" - } - try: - response = requests.post(url, json=payload, headers=headers, timeout=300) - response.raise_for_status() - results = response.json() - return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]] - except requests.exceptions.RequestException as e: - logging.error(f"重排序请求失败: {str(e)}", exc_info=True) - return [] - - @staticmethod - async def rerank_async(query: str, documents: List[str], top_k: int = 10) -> List[str]: - """ - 使用硅流重排模型对文档进行异步重新排序 - - Args: - query: 用户查询文本 - documents: 需要重新排序的文档列表 - top_k: 返回排序后的前k个文档 - - Returns: - List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 - """ - url = "https://api.siliconflow.cn/v1/rerank" - payload = { - "model": "BAAI/bge-reranker-v2-m3", - "query": query, - "documents": documents, - "top_n": top_k, - "max_chunks_per_doc": 1024, - "overlap_tokens": 80, - "return_documents": True - } - api_key = APIKeyManager.get_api_key() - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json" - } - try: - async with httpx.AsyncClient(timeout=300) as client: - response = await client.post(url, json=payload, headers=headers) - response.raise_for_status() - results = response.json() - return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]] - except httpx.RequestError as e: - logging.error(f"异步重排序请求失败: {str(e)}", exc_info=True) - return [] - + class XinferenceReRankerModel: """重排模型封装""" @@ -163,17 +88,18 @@ class XinferenceReRankerModel: Returns: List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 """ - url = "http://10.1.16.39:9995/v1/rerank" - - params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": "bge-reranker-v2-m3"} + base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995") + model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3") + rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank') + params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name} headers = { "Authorization": "Bearer ", # 这里需要替换为实际的token "Content-Type": "application/json" } try: - response = requests.post(url, json=params, headers=headers, timeout=300) + response = requests.post(rerank_url, json=params, headers=headers, timeout=300) response.raise_for_status() # 检查响应状态 results = response.json() @@ -197,9 +123,10 @@ class XinferenceReRankerModel: Returns: List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 """ - url = "http://10.1.16.39:9995/v1/rerank" - - params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": "bge-reranker-v2-m3"} + base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995") + rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank') + model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3") + params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name} headers = { "Authorization": "Bearer ", # 这里需要替换为实际的token "Content-Type": "application/json" @@ -207,7 +134,7 @@ class XinferenceReRankerModel: try: async with httpx.AsyncClient(timeout=300) as client: - response = await client.post(url, json=params, headers=headers) + response = await client.post(rerank_url, json=params, headers=headers) response.raise_for_status() # 检查响应状态 results = response.json()