优化Dify工具逻辑,调整知识提取和重排序流程,增强API调用的重试机制,更新意图识别API以支持更好的错误处理和日志记录,改进多线程检索功能
This commit is contained in:
@@ -219,12 +219,23 @@ reason: 简明扼要的理由(中文)
|
|||||||
|
|
||||||
prompt = self.create_correctness_prompt(standard_answer, answer)
|
prompt = self.create_correctness_prompt(standard_answer, answer)
|
||||||
llm = self.get_llm(response_format={"type": "json_object"})
|
llm = self.get_llm(response_format={"type": "json_object"})
|
||||||
try:
|
|
||||||
response = llm.invoke(user_prompt=prompt, need_retry=True)
|
max_retries = 3
|
||||||
response_json = json.loads(response.content)
|
retry_count = 0
|
||||||
return response_json["result"]
|
|
||||||
except Exception as e:
|
while retry_count < max_retries:
|
||||||
return None
|
try:
|
||||||
|
response = llm.invoke(user_prompt=prompt, need_retry=True)
|
||||||
|
response_json = json.loads(response.content)
|
||||||
|
return response_json["result"]
|
||||||
|
except Exception as e:
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
logging.error(f"判断答案失败,已重试{max_retries}次: {str(e)}")
|
||||||
|
return False
|
||||||
|
# 指数退避策略,每次重试等待时间增加
|
||||||
|
import time
|
||||||
|
time.sleep(1 * (2 ** (retry_count - 1))) # 1秒, 2秒, 4秒...
|
||||||
|
|
||||||
def judge_by_standard_answer(self, standard_answer: str, old_answer: str, new_answer: str) -> str | None:
|
def judge_by_standard_answer(self, standard_answer: str, old_answer: str, new_answer: str) -> str | None:
|
||||||
"""
|
"""
|
||||||
@@ -689,7 +700,7 @@ content: "{content}"
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 创建命令行参数解析器
|
# 创建命令行参数解析器
|
||||||
os.environ["DIFY_BASEURL"] = "http://10.1.16.39/v1"
|
os.environ["DIFY_BASEURL"] = "http://10.1.16.39/v1"
|
||||||
os.environ["DIFY_NEW_API_KEY"] = "app-qxsSybCs7ABiKlC1JabTYVn6"
|
os.environ["DIFY_NEW_API_KEY"] = "app-rv6ie73Ufoa3nRYCMiJx3a8K"
|
||||||
os.environ["DIFY_OLD_API_KEY"] = "app-wUdkWJx5zeOvmvBUZizMoSw3"
|
os.environ["DIFY_OLD_API_KEY"] = "app-wUdkWJx5zeOvmvBUZizMoSw3"
|
||||||
|
|
||||||
os.environ["DIFY_PG_HOST"] = "10.1.16.39"
|
os.environ["DIFY_PG_HOST"] = "10.1.16.39"
|
||||||
|
|||||||
@@ -100,6 +100,50 @@ class DifyQueryRetrieval:
|
|||||||
|
|
||||||
return processed_documents
|
return processed_documents
|
||||||
|
|
||||||
|
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]:
|
||||||
|
all_documents=[]
|
||||||
|
# 使用线程池替代无限制创建线程
|
||||||
|
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(data_set_list))来限制
|
||||||
|
time_start = time.time()
|
||||||
|
max_workers = min(os.cpu_count() * 2, len(query_list) * len(data_set_list))
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
futures = []
|
||||||
|
for query in query_list:
|
||||||
|
for dataset in data_set_list:
|
||||||
|
if dataset not in self._datasets_list:
|
||||||
|
raise ValueError(f"dataset {dataset} not in datasets_list")
|
||||||
|
|
||||||
|
futures.append(executor.submit(self.retrieve_by_dataset, query, dataset))
|
||||||
|
|
||||||
|
# 等待所有任务完成
|
||||||
|
for future in as_completed(futures):
|
||||||
|
# 处理可能的异常
|
||||||
|
try:
|
||||||
|
retrieved_documents = future.result()
|
||||||
|
all_documents.extend(retrieved_documents)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"检索过程中发生错误: {str(e)}", exc_info=True)
|
||||||
|
time_end = time.time()
|
||||||
|
|
||||||
|
logging.info(f"检索耗时: {time_end - time_start:.2f}秒")
|
||||||
|
# 根据segment_id对文档进行去重
|
||||||
|
unique_documents = {}
|
||||||
|
for document in all_documents:
|
||||||
|
segment_id = document['segment']['id']
|
||||||
|
if segment_id not in unique_documents:
|
||||||
|
unique_documents[segment_id] = document
|
||||||
|
|
||||||
|
# 将去重后的文档转换为列表
|
||||||
|
deduplicated_documents = list(unique_documents.values())
|
||||||
|
|
||||||
|
# 对所有检索出来的文档进行重排序
|
||||||
|
time_start = time.time()
|
||||||
|
processed_documents = self.data_post_processor(original_query, deduplicated_documents)
|
||||||
|
time_end = time.time()
|
||||||
|
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒")
|
||||||
|
|
||||||
|
return processed_documents
|
||||||
|
|
||||||
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
reranker_model = XinferenceReRankerModel()
|
reranker_model = XinferenceReRankerModel()
|
||||||
documents = [document['segment']['content'] for document in all_documents]
|
documents = [document['segment']['content'] for document in all_documents]
|
||||||
@@ -114,7 +158,7 @@ class DifyQueryRetrieval:
|
|||||||
"dataset_name": document["dataset_name"],
|
"dataset_name": document["dataset_name"],
|
||||||
"document_id": document['segment']['document_id'],
|
"document_id": document['segment']['document_id'],
|
||||||
"document_name": document["segment"]["document"]["name"],
|
"document_name": document["segment"]["document"]["name"],
|
||||||
"document_data_source_type": document["segment"]["document"]["data_source_type"],
|
"data_source_type": document["segment"]["document"]["data_source_type"],
|
||||||
"segment_id": document["segment"]["id"],
|
"segment_id": document["segment"]["id"],
|
||||||
"retriever_from": "api",
|
"retriever_from": "api",
|
||||||
"score": document.get("score", 0),
|
"score": document.get("score", 0),
|
||||||
@@ -122,6 +166,7 @@ class DifyQueryRetrieval:
|
|||||||
"segment_word_count": document.get("segment", {}).get("word_count", 0),
|
"segment_word_count": document.get("segment", {}).get("word_count", 0),
|
||||||
"segment_position": document.get("segment", {}).get("position", 0),
|
"segment_position": document.get("segment", {}).get("position", 0),
|
||||||
"segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""),
|
"segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""),
|
||||||
|
"doc_metadata": document.get("segment", {}).get("document", {}).get("doc_metadata", None),
|
||||||
"position": document["segment"].get("position", 0)
|
"position": document["segment"].get("position", 0)
|
||||||
},
|
},
|
||||||
"title": document["segment"]["document"]["name"],
|
"title": document["segment"]["document"]["name"],
|
||||||
@@ -159,6 +204,7 @@ class DifyQueryRetrieval:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="https://172.20.0.145/v1")
|
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
||||||
datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3")
|
# datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3")
|
||||||
|
datasets = dify_query_retrieval.retrieve_api("配网工程计价通D3软件如何新建工程?", ["配网工程计价通D3软件如何新建工程?"], ["流式输出缺失"])
|
||||||
print(datasets)
|
print(datasets)
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
# from gevent import monkey
|
||||||
|
# monkey.patch_all()
|
||||||
|
|
||||||
|
from flask import Flask, request, Response
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from gevent.lock import RLock
|
||||||
|
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
from DifyQueryRetrieval import DifyQueryRetrieval
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||||
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
||||||
|
@app.route('/retrieve', methods=['POST'])
|
||||||
|
def retrieve():
|
||||||
|
data = request.get_json(force=True)
|
||||||
|
original_query_str = data.get('original_query')
|
||||||
|
query_list_str = data.get('query_list')
|
||||||
|
data_set_list_str = data.get('data_set_list')
|
||||||
|
try:
|
||||||
|
query_list = query_list_str.split("<sub_query>")
|
||||||
|
data_set_list = data_set_list_str.split("<dataset>")
|
||||||
|
results = dify_query_retrieval.retrieve_api(original_query_str, query_list, data_set_list)
|
||||||
|
return Response(json.dumps(results, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检索出错: {str(e)}", exc_info=True)
|
||||||
|
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 开发环境使用Flask内置服务器
|
||||||
|
# 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app
|
||||||
|
# uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.DifyQueryRetrieval_api:app
|
||||||
|
app.run(host="0.0.0.0", port=8002, threaded=True)
|
||||||
@@ -453,16 +453,13 @@ class NewWorkflowChat(BaseWorkflowChat):
|
|||||||
# 先取出重排得分
|
# 先取出重排得分
|
||||||
message_info = self.dify_tool.get_message_debug_info_by_id(message_id=message_id)
|
message_info = self.dify_tool.get_message_debug_info_by_id(message_id=message_id)
|
||||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||||
if workflow_node["title"] == "软件知识检索聚合":
|
if workflow_node["title"] == "提取处理后的知识":
|
||||||
retrieve_outputs = json.loads(workflow_node["inputs"])["result"]
|
retrieve_outputs = json.loads(workflow_node["outputs"])["source_kno"]
|
||||||
reranker_sorce = [{"score":result["metadata"]["score"], "segment_id":result["metadata"]["segment_id"]} for result in retrieve_outputs]
|
reranker_sorce = [{"score":result["metadata"]["score"], "segment_id":result["metadata"]["segment_id"]} for result in retrieve_outputs]
|
||||||
|
break
|
||||||
|
|
||||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||||
if workflow_node["title"] == "软件知识检索聚合":
|
if workflow_node["title"] == "提取处理后的知识":
|
||||||
retrieve_outputs = json.loads(workflow_node["inputs"])["result"]
|
|
||||||
reranker_sorce = [{"score":result["metadata"]["score"], "segment_id":result["metadata"]["segment_id"]} for result in retrieve_outputs]
|
|
||||||
elif workflow_node["title"] == "提取处理后的知识":
|
|
||||||
outputs = json.loads(workflow_node["outputs"])["knowledge_list"]
|
outputs = json.loads(workflow_node["outputs"])["knowledge_list"]
|
||||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs, reranker_sorce_info=reranker_sorce)
|
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs, reranker_sorce_info=reranker_sorce)
|
||||||
elif workflow_node["title"] == "意图识别结果解析":
|
elif workflow_node["title"] == "意图识别结果解析":
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
|
from gevent import monkey
|
||||||
|
monkey.patch_all()
|
||||||
|
|
||||||
from flask import Flask, request, Response
|
from flask import Flask, request, Response
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import threading
|
from gevent.lock import RLock
|
||||||
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
@@ -29,12 +34,12 @@ logger = logging.getLogger(__name__)
|
|||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
# 创建线程锁,用于保护共享资源
|
# 创建线程锁,用于保护共享资源
|
||||||
recognizer_lock = threading.Lock()
|
recognizer_lock = RLock()
|
||||||
|
|
||||||
# 使用单例模式创建意图识别器
|
# 使用单例模式创建意图识别器
|
||||||
class RecognizerSingleton:
|
class RecognizerSingleton:
|
||||||
_instance = None
|
_instance = None
|
||||||
_lock = threading.Lock()
|
_lock = RLock()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls):
|
def get_instance(cls):
|
||||||
@@ -105,14 +110,16 @@ def intent_recognize():
|
|||||||
"is_complete": slot_filling.get("is_complete", False),
|
"is_complete": slot_filling.get("is_complete", False),
|
||||||
"missing_slots": slot_filling.get("missing_slots", {}),
|
"missing_slots": slot_filling.get("missing_slots", {}),
|
||||||
"filled_data": slot_filling.get("filled_data", {})
|
"filled_data": slot_filling.get("filled_data", {})
|
||||||
}
|
},
|
||||||
|
"query_expand": result["query_expand"]
|
||||||
}
|
}
|
||||||
return Response(json.dumps(response_result, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
return Response(json.dumps(response_result, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"意图识别出错: {str(e)}")
|
logger.error(f"意图识别出错: {str(e)}",exc_info=True)
|
||||||
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 开发环境使用Flask内置服务器
|
# 开发环境使用Flask内置服务器
|
||||||
# 生产环境使用gunicorn支持高并发 poetry run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app
|
# 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app
|
||||||
|
# uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app
|
||||||
app.run(host="0.0.0.0", port=8001, threaded=True)
|
app.run(host="0.0.0.0", port=8001, threaded=True)
|
||||||
@@ -9,7 +9,6 @@ Description: 意图分类、改写核心逻辑
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
|
||||||
from langchain.output_parsers import PydanticOutputParser
|
from langchain.output_parsers import PydanticOutputParser
|
||||||
import json
|
import json
|
||||||
from typing import List, Tuple, Dict, Any, Optional
|
from typing import List, Tuple, Dict, Any, Optional
|
||||||
@@ -449,10 +448,10 @@ class IntentRecognizer:
|
|||||||
"slot_filling": slot_filling_result
|
"slot_filling": slot_filling_result
|
||||||
}
|
}
|
||||||
|
|
||||||
# 等待所有线程完成
|
# 等待所有greenlet完成
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for thread, _ in threads_and_results:
|
for greenlet, _ in threads_and_results:
|
||||||
thread.join()
|
greenlet.join()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logging.info(f"问题扩展环节耗时统计 - 总耗时: {end_time - start_time:.2f}秒")
|
logging.info(f"问题扩展环节耗时统计 - 总耗时: {end_time - start_time:.2f}秒")
|
||||||
|
|
||||||
@@ -751,7 +750,7 @@ class IntentRecognizer:
|
|||||||
|
|
||||||
def _run_in_thread(self, func, args=(), kwargs={}):
|
def _run_in_thread(self, func, args=(), kwargs={}):
|
||||||
"""
|
"""
|
||||||
在线程中执行函数并返回结果
|
在greenlet中执行函数并返回结果
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func: 要执行的函数
|
func: 要执行的函数
|
||||||
@@ -759,21 +758,22 @@ class IntentRecognizer:
|
|||||||
kwargs: 函数的关键字参数
|
kwargs: 函数的关键字参数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(thread, result_container): 线程对象和存放结果的容器
|
(greenlet, result_container): greenlet对象和存放结果的容器
|
||||||
"""
|
"""
|
||||||
|
from gevent import Greenlet
|
||||||
result_container = []
|
result_container = []
|
||||||
|
|
||||||
def thread_target():
|
def greenlet_target():
|
||||||
try:
|
try:
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
result_container.append(result)
|
result_container.append(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"线程执行函数 {func.__name__} 时出错: {e}", exc_info=True)
|
logging.error(f"greenlet执行函数 {func.__name__} 时出错: {e}", exc_info=True)
|
||||||
result_container.append(None)
|
result_container.append(None)
|
||||||
|
|
||||||
thread = threading.Thread(target=thread_target)
|
greenlet = Greenlet(greenlet_target)
|
||||||
thread.start()
|
greenlet.start()
|
||||||
return thread, result_container
|
return greenlet, result_container
|
||||||
|
|
||||||
|
|
||||||
def _process_intent_and_slot(self, user_input: str, conversation_context: str = "",
|
def _process_intent_and_slot(self, user_input: str, conversation_context: str = "",
|
||||||
|
|||||||
Reference in New Issue
Block a user