From 5a8d042360d6e986a4c855d3ce3f5e87fbc87746 Mon Sep 17 00:00:00 2001 From: ouyangyouzhang Date: Tue, 8 Jul 2025 08:38:14 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84DifyQueryRetrieval=5Fapi.py?= =?UTF-8?q?=E4=B8=BAFastAPI=E5=BA=94=E7=94=A8=EF=BC=8C=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E5=BC=82=E6=AD=A5=E6=A3=80=E7=B4=A2API=E5=92=8C=E5=81=A5?= =?UTF-8?q?=E5=BA=B7=E6=A3=80=E6=9F=A5=E7=AB=AF=E7=82=B9=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86=E5=92=8C=E6=97=A5?= =?UTF-8?q?=E5=BF=97=E8=AE=B0=E5=BD=95=EF=BC=8C=E5=90=8C=E6=97=B6=E6=9B=B4?= =?UTF-8?q?=E6=96=B0DifyQueryRetrieval=E7=B1=BB=E4=BB=A5=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=BC=82=E6=AD=A5=E6=A3=80=E7=B4=A2=E5=8A=9F=E8=83=BD=EF=BC=8C?= =?UTF-8?q?=E6=8F=90=E5=8D=87=E6=95=B4=E4=BD=93=E6=80=A7=E8=83=BD=E5=92=8C?= =?UTF-8?q?=E5=8F=AF=E7=BB=B4=E6=8A=A4=E6=80=A7=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rag2_0/dify/DifyQueryRetrieval.py | 170 +++++++++++++++++++++++++- rag2_0/dify/DifyQueryRetrieval_api.py | 114 +++++++++++++---- 2 files changed, 256 insertions(+), 28 deletions(-) diff --git a/rag2_0/dify/DifyQueryRetrieval.py b/rag2_0/dify/DifyQueryRetrieval.py index 20bf6c7..e2593d7 100644 --- a/rag2_0/dify/DifyQueryRetrieval.py +++ b/rag2_0/dify/DifyQueryRetrieval.py @@ -6,6 +6,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Dict, Any, Optional import logging import time +import asyncio +import httpx sys.path.append(os.getcwd()) from rag2_0.intent_recognition.DataModels import Classification @@ -52,6 +54,28 @@ class DifyQueryRetrieval: logging.error(f"检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True) return [] + async def retrieve_by_dataset_async(self, query: str, dataset_name: str) -> List[Dict[str, Any]]: + """ + 异步版本的retrieve_by_dataset方法 + + Args: + query: 查询字符串 + dataset_name: 数据集名称 + + Returns: + 检索到的文档列表 + """ + try: + # 使用asyncio.to_thread包装同步方法 + return await asyncio.to_thread( + self.retrieve_by_dataset, + query, + dataset_name + ) + except Exception as e: + logging.error(f"异步检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True) + return [] + def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]: datasets = self.get_datasets_by_classification(classification, software_name) if len(datasets) == 0: @@ -59,6 +83,25 @@ class DifyQueryRetrieval: return self.retrieve_api(original_query, query_list, datasets) + async def retrieve_async(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]: + """ + 异步版本的retrieve方法 + + Args: + original_query: 原始查询 + query_list: 查询列表 + classification: 分类信息 + software_name: 软件名称 + + Returns: + 检索到的文档列表 + """ + datasets = self.get_datasets_by_classification(classification, software_name) + if len(datasets) == 0: + return None + + return await self.retrieve_api_async(original_query, query_list, datasets) + def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]: all_documents=[] # 使用线程池替代无限制创建线程 @@ -103,6 +146,65 @@ class DifyQueryRetrieval: return processed_documents + async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str])->List[Dict[str, Any]]: + """ + 异步版本的retrieve_api方法,使用asyncio代替线程池 + + Args: + original_query: 原始查询 + query_list: 查询列表 + data_set_list: 数据集列表 + + Returns: + 检索并重排序后的文档列表 + """ + all_documents = [] + # 记录开始时间 + time_start = time.time() + + # 创建异步任务列表 + tasks = [] + for query in query_list: + for dataset in data_set_list: + if dataset not in self._datasets_list: + logging.error(f"dataset {dataset} not in datasets_list") + continue + + # 创建异步任务 + task = self.retrieve_by_dataset_async(query, dataset) + tasks.append(task) + + # 并发执行所有异步任务 + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理结果 + for result in results: + if isinstance(result, Exception): + logging.error(f"异步检索过程中发生错误: {str(result)}", exc_info=True) + else: + all_documents.extend(result) + + time_end = time.time() + logging.info(f"异步检索耗时: {time_end - time_start:.2f}秒") + + # 根据segment_id对文档进行去重 + unique_documents = {} + for document in all_documents: + segment_id = document['segment']['id'] + if segment_id not in unique_documents: + unique_documents[segment_id] = document + + # 将去重后的文档转换为列表 + deduplicated_documents = list(unique_documents.values()) + + # 对所有检索出来的文档进行重排序 + time_start = time.time() + processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents) + time_end = time.time() + logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}秒") + + return processed_documents + def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: reranker_model = XinferenceReRankerModel() documents = [document['segment']['content'] for document in all_documents] @@ -138,6 +240,52 @@ class DifyQueryRetrieval: new_all_documents.append(to_dify_document_format(cur_doc_info)) return new_all_documents + async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 异步版本的data_post_processor方法 + + Args: + query: 查询字符串 + all_documents: 待处理的文档列表 + + Returns: + 处理后的文档列表 + """ + reranker_model = XinferenceReRankerModel() + documents = [document['segment']['content'] for document in all_documents] + # 使用异步重排序方法 + reranked_documents = await reranker_model.rerank_async(query, documents, top_k=5) + new_all_documents = [] + + def to_dify_document_format(document: dict)->dict: + return { + "metadata": { + "_source": "knowledge", + "dataset_id": document["dataset_id"], + "dataset_name": document["dataset_name"], + "document_id": document['segment']['document_id'], + "document_name": document["segment"]["document"]["name"], + "data_source_type": document["segment"]["document"]["data_source_type"], + "segment_id": document["segment"]["id"], + "retriever_from": "api", + "score": document.get("score", 0), + "segment_hit_count": document.get("segment", {}).get("hit_count", 0), + "segment_word_count": document.get("segment", {}).get("word_count", 0), + "segment_position": document.get("segment", {}).get("position", 0), + "segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""), + "doc_metadata": document.get("segment", {}).get("document", {}).get("doc_metadata", None), + "position": document["segment"].get("position", 0) + }, + "title": document["segment"]["document"]["name"], + "content": document["segment"]["content"] + } + + for reranked_document in reranked_documents: + cur_doc_info = all_documents[reranked_document["index"]] + cur_doc_info["score"] = reranked_document["score"] + new_all_documents.append(to_dify_document_format(cur_doc_info)) + return new_all_documents + def get_datasets_by_classification(self, classification: Classification, software_name: str) -> List[str]: if classification.vertical_classification=="软件问题" or classification.vertical_classification=="业务问题": software_name_list = self.software_to_dataset_map.keys() @@ -165,7 +313,21 @@ class DifyQueryRetrieval: if __name__ == "__main__": dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") # datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3") - datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?", - ["电力建设计价通软件如何批量修改设备价格?"], - ["电力建设计价通(2018)软件知识(new)"]) - print(json.dumps(datasets, ensure_ascii=False, indent=2)) \ No newline at end of file + # datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?", + # ["电力建设计价通软件如何批量修改设备价格?"], + # ["电力建设计价通(2018)软件知识(new)"]) + # print(json.dumps(datasets, ensure_ascii=False, indent=2)) + + # 测试异步API + async def test_async_api(): + datasets = await dify_query_retrieval.retrieve_api_async( + "电力建设计价通软件如何批量修改设备价格?", + ["电力建设计价通软件如何批量修改设备价格?"], + ["电力建设计价通(2018)软件知识(new)"] + ) + print("异步API测试结果:") + print(json.dumps(datasets, ensure_ascii=False, indent=2)) + + # 如果需要测试异步API,取消下面的注释 + import asyncio + asyncio.run(test_async_api()) \ No newline at end of file diff --git a/rag2_0/dify/DifyQueryRetrieval_api.py b/rag2_0/dify/DifyQueryRetrieval_api.py index a5f47fa..cda3127 100644 --- a/rag2_0/dify/DifyQueryRetrieval_api.py +++ b/rag2_0/dify/DifyQueryRetrieval_api.py @@ -1,19 +1,25 @@ # from gevent import monkey # monkey.patch_all() -from flask import Flask, request, Response import os +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from typing import Dict, List, Any, Optional +import asyncio + from dotenv import load_dotenv import json import time -from gevent.lock import RLock - - import datetime import logging # 加载环境变量 load_dotenv() -from DifyQueryRetrieval import DifyQueryRetrieval + +import sys +sys.path.append(os.getcwd()) +from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval logging.basicConfig( @@ -28,27 +34,87 @@ logging.getLogger('openai').setLevel(logging.WARNING) logger = logging.getLogger(__name__) -app = Flask(__name__) +# 定义请求模型 +class RetrieveRequest(BaseModel): + original_query: str + query_list: str + data_set_list: str -dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") -@app.route('/retrieve', methods=['POST']) -def retrieve(): - data = request.get_json(force=True) - original_query_str = data.get('original_query') - query_list_str = data.get('query_list') - data_set_list_str = data.get('data_set_list') +# 创建FastAPI应用 +app = FastAPI( + title="Dify查询检索服务", + description="基于Dify的异步查询检索服务", + version="1.0" +) + +# 添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 全局变量存储DifyQueryRetrieval实例 +dify_query_retrieval = None + +# 应用启动事件 +@app.on_event("startup") +async def startup_event(): + global dify_query_retrieval + # 初始化DifyQueryRetrieval实例 + dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") + logger.info("DifyQueryRetrieval初始化完成") + +# 添加健康检查端点 +@app.get("/health", summary="健康检查") +async def health_check(): + return {"status": "ok"} + +@app.post("/retrieve", summary="异步检索API") +async def retrieve(request: RetrieveRequest): + """ + 异步检索API + + Args: + request: 包含原始查询、查询列表和数据集列表的请求对象 + + Returns: + 检索结果 + """ try: - query_list = query_list_str.split("") - data_set_list = data_set_list_str.split("") - results = dify_query_retrieval.retrieve_api(original_query_str, query_list, data_set_list) - return Response(json.dumps(results, ensure_ascii=False), content_type='application/json; charset=utf-8') + # 解析查询列表和数据集列表 + query_list = request.query_list.split("") + data_set_list = request.data_set_list.split("") + + # 调用异步检索方法 + start_time = time.time() + results = await dify_query_retrieval.retrieve_api_async( + request.original_query, + query_list, + data_set_list + ) + end_time = time.time() + + logger.info(f"异步检索总耗时: {end_time - start_time:.2f}秒") + return results except Exception as e: - logger.error(f"检索出错: {str(e)}", exc_info=True) - return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500) - + logger.error(f"异步检索出错: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": - # 开发环境使用Flask内置服务器 - # 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app - # uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app - app.run(host="0.0.0.0", port=8002, threaded=True) \ No newline at end of file + # 使用Uvicorn运行FastAPI应用 + import uvicorn + uvicorn.run("rag2_0.dify.DifyQueryRetrieval_api:app", host="0.0.0.0", port=8002, reload=False, workers=10, log_level="info") + # # 使用uvicorn启动服务 + # import uvicorn + # uvicorn.run( + # "rag2_0.dify.intent_recognition_api:app", + # host="0.0.0.0", + # port=8001, + # reload=False, # 开发环境启用热重载 + # workers=1 # 生产环境可以增加worker数量 + # ) + # 生产环境可以使用以下命令启动: + # uvicorn rag2_0.dify.DifyQueryRetrieval_api:app --host 0.0.0.0 --port 8002 --workers 10 \ No newline at end of file