From f76f44640a2e0b012203e1d21bfae4d12ce8579f Mon Sep 17 00:00:00 2001 From: ouyangyouzhang Date: Tue, 1 Jul 2025 18:56:10 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96Dify=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E8=B0=83=E6=95=B4=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E6=8F=90=E5=8F=96=E5=92=8C=E9=87=8D=E6=8E=92=E5=BA=8F=E6=B5=81?= =?UTF-8?q?=E7=A8=8B=EF=BC=8C=E5=A2=9E=E5=BC=BAAPI=E8=B0=83=E7=94=A8?= =?UTF-8?q?=E7=9A=84=E9=87=8D=E8=AF=95=E6=9C=BA=E5=88=B6=EF=BC=8C=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E6=84=8F=E5=9B=BE=E8=AF=86=E5=88=ABAPI=E4=BB=A5?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=9B=B4=E5=A5=BD=E7=9A=84=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E5=A4=84=E7=90=86=E5=92=8C=E6=97=A5=E5=BF=97=E8=AE=B0=E5=BD=95?= =?UTF-8?q?=EF=BC=8C=E6=94=B9=E8=BF=9B=E5=A4=9A=E7=BA=BF=E7=A8=8B=E6=A3=80?= =?UTF-8?q?=E7=B4=A2=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rag2_0/dify/DifyCompareTest.py | 25 ++++++--- rag2_0/dify/DifyQueryRetrieval.py | 52 ++++++++++++++++-- rag2_0/dify/DifyQueryRetrieval_api.py | 54 +++++++++++++++++++ rag2_0/dify/dify_tool.py | 11 ++-- rag2_0/dify/intent_recognition_api.py | 19 ++++--- .../intent_recognition/IntentRecognition.py | 22 ++++---- 6 files changed, 149 insertions(+), 34 deletions(-) create mode 100644 rag2_0/dify/DifyQueryRetrieval_api.py diff --git a/rag2_0/dify/DifyCompareTest.py b/rag2_0/dify/DifyCompareTest.py index 519b516..b0c1070 100755 --- a/rag2_0/dify/DifyCompareTest.py +++ b/rag2_0/dify/DifyCompareTest.py @@ -219,12 +219,23 @@ reason: 简明扼要的理由(中文) prompt = self.create_correctness_prompt(standard_answer, answer) llm = self.get_llm(response_format={"type": "json_object"}) - try: - response = llm.invoke(user_prompt=prompt, need_retry=True) - response_json = json.loads(response.content) - return response_json["result"] - except Exception as e: - return None + + max_retries = 3 + retry_count = 0 + + while retry_count < max_retries: + try: + response = llm.invoke(user_prompt=prompt, need_retry=True) + response_json = json.loads(response.content) + return response_json["result"] + except Exception as e: + retry_count += 1 + if retry_count >= max_retries: + logging.error(f"判断答案失败,已重试{max_retries}次: {str(e)}") + return False + # 指数退避策略,每次重试等待时间增加 + import time + time.sleep(1 * (2 ** (retry_count - 1))) # 1秒, 2秒, 4秒... def judge_by_standard_answer(self, standard_answer: str, old_answer: str, new_answer: str) -> str | None: """ @@ -689,7 +700,7 @@ content: "{content}" if __name__ == "__main__": # 创建命令行参数解析器 os.environ["DIFY_BASEURL"] = "http://10.1.16.39/v1" - os.environ["DIFY_NEW_API_KEY"] = "app-qxsSybCs7ABiKlC1JabTYVn6" + os.environ["DIFY_NEW_API_KEY"] = "app-rv6ie73Ufoa3nRYCMiJx3a8K" os.environ["DIFY_OLD_API_KEY"] = "app-wUdkWJx5zeOvmvBUZizMoSw3" os.environ["DIFY_PG_HOST"] = "10.1.16.39" diff --git a/rag2_0/dify/DifyQueryRetrieval.py b/rag2_0/dify/DifyQueryRetrieval.py index 6046917..d0f912b 100644 --- a/rag2_0/dify/DifyQueryRetrieval.py +++ b/rag2_0/dify/DifyQueryRetrieval.py @@ -100,6 +100,50 @@ class DifyQueryRetrieval: return processed_documents + def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]: + all_documents=[] + # 使用线程池替代无限制创建线程 + # 设置合理的最大线程数,这里使用min(32, len(query_list) * len(data_set_list))来限制 + time_start = time.time() + max_workers = min(os.cpu_count() * 2, len(query_list) * len(data_set_list)) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for query in query_list: + for dataset in data_set_list: + if dataset not in self._datasets_list: + raise ValueError(f"dataset {dataset} not in datasets_list") + + futures.append(executor.submit(self.retrieve_by_dataset, query, dataset)) + + # 等待所有任务完成 + for future in as_completed(futures): + # 处理可能的异常 + try: + retrieved_documents = future.result() + all_documents.extend(retrieved_documents) + except Exception as e: + logging.error(f"检索过程中发生错误: {str(e)}", exc_info=True) + time_end = time.time() + + logging.info(f"检索耗时: {time_end - time_start:.2f}秒") + # 根据segment_id对文档进行去重 + unique_documents = {} + for document in all_documents: + segment_id = document['segment']['id'] + if segment_id not in unique_documents: + unique_documents[segment_id] = document + + # 将去重后的文档转换为列表 + deduplicated_documents = list(unique_documents.values()) + + # 对所有检索出来的文档进行重排序 + time_start = time.time() + processed_documents = self.data_post_processor(original_query, deduplicated_documents) + time_end = time.time() + logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒") + + return processed_documents + def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: reranker_model = XinferenceReRankerModel() documents = [document['segment']['content'] for document in all_documents] @@ -114,7 +158,7 @@ class DifyQueryRetrieval: "dataset_name": document["dataset_name"], "document_id": document['segment']['document_id'], "document_name": document["segment"]["document"]["name"], - "document_data_source_type": document["segment"]["document"]["data_source_type"], + "data_source_type": document["segment"]["document"]["data_source_type"], "segment_id": document["segment"]["id"], "retriever_from": "api", "score": document.get("score", 0), @@ -122,6 +166,7 @@ class DifyQueryRetrieval: "segment_word_count": document.get("segment", {}).get("word_count", 0), "segment_position": document.get("segment", {}).get("position", 0), "segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""), + "doc_metadata": document.get("segment", {}).get("document", {}).get("doc_metadata", None), "position": document["segment"].get("position", 0) }, "title": document["segment"]["document"]["name"], @@ -159,6 +204,7 @@ class DifyQueryRetrieval: if __name__ == "__main__": - dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="https://172.20.0.145/v1") - datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3") + dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") + # datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3") + datasets = dify_query_retrieval.retrieve_api("配网工程计价通D3软件如何新建工程?", ["配网工程计价通D3软件如何新建工程?"], ["流式输出缺失"]) print(datasets) \ No newline at end of file diff --git a/rag2_0/dify/DifyQueryRetrieval_api.py b/rag2_0/dify/DifyQueryRetrieval_api.py new file mode 100644 index 0000000..a5f47fa --- /dev/null +++ b/rag2_0/dify/DifyQueryRetrieval_api.py @@ -0,0 +1,54 @@ +# from gevent import monkey +# monkey.patch_all() + +from flask import Flask, request, Response +import os +from dotenv import load_dotenv +import json +import time +from gevent.lock import RLock + + +import datetime +import logging +# 加载环境变量 +load_dotenv() +from DifyQueryRetrieval import DifyQueryRetrieval + + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] +) +logging.getLogger('httpx').setLevel(logging.WARNING) +logging.getLogger('openai').setLevel(logging.WARNING) + +logger = logging.getLogger(__name__) + +app = Flask(__name__) + +dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") +@app.route('/retrieve', methods=['POST']) +def retrieve(): + data = request.get_json(force=True) + original_query_str = data.get('original_query') + query_list_str = data.get('query_list') + data_set_list_str = data.get('data_set_list') + try: + query_list = query_list_str.split("") + data_set_list = data_set_list_str.split("") + results = dify_query_retrieval.retrieve_api(original_query_str, query_list, data_set_list) + return Response(json.dumps(results, ensure_ascii=False), content_type='application/json; charset=utf-8') + except Exception as e: + logger.error(f"检索出错: {str(e)}", exc_info=True) + return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500) + + +if __name__ == "__main__": + # 开发环境使用Flask内置服务器 + # 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app + # uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app + app.run(host="0.0.0.0", port=8002, threaded=True) \ No newline at end of file diff --git a/rag2_0/dify/dify_tool.py b/rag2_0/dify/dify_tool.py index b734f0c..2d1cbaa 100755 --- a/rag2_0/dify/dify_tool.py +++ b/rag2_0/dify/dify_tool.py @@ -453,16 +453,13 @@ class NewWorkflowChat(BaseWorkflowChat): # 先取出重排得分 message_info = self.dify_tool.get_message_debug_info_by_id(message_id=message_id) for workflow_node in message_info["workflow_node_executions_info"]: - if workflow_node["title"] == "软件知识检索聚合": - retrieve_outputs = json.loads(workflow_node["inputs"])["result"] + if workflow_node["title"] == "提取处理后的知识": + retrieve_outputs = json.loads(workflow_node["outputs"])["source_kno"] reranker_sorce = [{"score":result["metadata"]["score"], "segment_id":result["metadata"]["segment_id"]} for result in retrieve_outputs] - + break for workflow_node in message_info["workflow_node_executions_info"]: - if workflow_node["title"] == "软件知识检索聚合": - retrieve_outputs = json.loads(workflow_node["inputs"])["result"] - reranker_sorce = [{"score":result["metadata"]["score"], "segment_id":result["metadata"]["segment_id"]} for result in retrieve_outputs] - elif workflow_node["title"] == "提取处理后的知识": + if workflow_node["title"] == "提取处理后的知识": outputs = json.loads(workflow_node["outputs"])["knowledge_list"] retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs, reranker_sorce_info=reranker_sorce) elif workflow_node["title"] == "意图识别结果解析": diff --git a/rag2_0/dify/intent_recognition_api.py b/rag2_0/dify/intent_recognition_api.py index 282a734..d2e5772 100755 --- a/rag2_0/dify/intent_recognition_api.py +++ b/rag2_0/dify/intent_recognition_api.py @@ -1,9 +1,14 @@ +from gevent import monkey +monkey.patch_all() + from flask import Flask, request, Response import os from dotenv import load_dotenv import json import time -import threading +from gevent.lock import RLock + + import datetime import logging # 加载环境变量 @@ -29,12 +34,12 @@ logger = logging.getLogger(__name__) app = Flask(__name__) # 创建线程锁,用于保护共享资源 -recognizer_lock = threading.Lock() +recognizer_lock = RLock() # 使用单例模式创建意图识别器 class RecognizerSingleton: _instance = None - _lock = threading.Lock() + _lock = RLock() @classmethod def get_instance(cls): @@ -105,14 +110,16 @@ def intent_recognize(): "is_complete": slot_filling.get("is_complete", False), "missing_slots": slot_filling.get("missing_slots", {}), "filled_data": slot_filling.get("filled_data", {}) - } + }, + "query_expand": result["query_expand"] } return Response(json.dumps(response_result, ensure_ascii=False), content_type='application/json; charset=utf-8') except Exception as e: - print(f"意图识别出错: {str(e)}") + logger.error(f"意图识别出错: {str(e)}",exc_info=True) return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500) if __name__ == "__main__": # 开发环境使用Flask内置服务器 - # 生产环境使用gunicorn支持高并发 poetry run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app + # 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app + # uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app app.run(host="0.0.0.0", port=8001, threaded=True) \ No newline at end of file diff --git a/rag2_0/intent_recognition/IntentRecognition.py b/rag2_0/intent_recognition/IntentRecognition.py index 860297b..da55435 100755 --- a/rag2_0/intent_recognition/IntentRecognition.py +++ b/rag2_0/intent_recognition/IntentRecognition.py @@ -9,7 +9,6 @@ Description: 意图分类、改写核心逻辑 import logging import os -import threading from langchain.output_parsers import PydanticOutputParser import json from typing import List, Tuple, Dict, Any, Optional @@ -449,10 +448,10 @@ class IntentRecognizer: "slot_filling": slot_filling_result } - # 等待所有线程完成 + # 等待所有greenlet完成 start_time = time.time() - for thread, _ in threads_and_results: - thread.join() + for greenlet, _ in threads_and_results: + greenlet.join() end_time = time.time() logging.info(f"问题扩展环节耗时统计 - 总耗时: {end_time - start_time:.2f}秒") @@ -751,7 +750,7 @@ class IntentRecognizer: def _run_in_thread(self, func, args=(), kwargs={}): """ - 在线程中执行函数并返回结果 + 在greenlet中执行函数并返回结果 Args: func: 要执行的函数 @@ -759,21 +758,22 @@ class IntentRecognizer: kwargs: 函数的关键字参数 Returns: - (thread, result_container): 线程对象和存放结果的容器 + (greenlet, result_container): greenlet对象和存放结果的容器 """ + from gevent import Greenlet result_container = [] - def thread_target(): + def greenlet_target(): try: result = func(*args, **kwargs) result_container.append(result) except Exception as e: - logging.error(f"线程执行函数 {func.__name__} 时出错: {e}", exc_info=True) + logging.error(f"greenlet执行函数 {func.__name__} 时出错: {e}", exc_info=True) result_container.append(None) - thread = threading.Thread(target=thread_target) - thread.start() - return thread, result_container + greenlet = Greenlet(greenlet_target) + greenlet.start() + return greenlet, result_container def _process_intent_and_slot(self, user_input: str, conversation_context: str = "",