优化DifyQueryRetrieval类,新增top_k参数以支持检索结果的数量控制,同时重构相关方法以提升性能和可维护性。更新DifyTool类,新增获取应用会话和消息信息的方法。修复DifyExporter类中的代码格式问题,调整日期参数的默认值。

This commit is contained in:
2025-07-10 08:36:12 +08:00
parent ec3db656a5
commit 23f522dde5
4 changed files with 28 additions and 15 deletions
+8 -8
View File
@@ -102,7 +102,7 @@ class DifyQueryRetrieval:
return await self.retrieve_api_async(original_query, query_list, datasets)
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]:
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
all_documents=[]
# 使用线程池替代无限制创建线程
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制
@@ -140,13 +140,13 @@ class DifyQueryRetrieval:
# 对所有检索出来的文档进行重排序
time_start = time.time()
processed_documents = self.data_post_processor(original_query, deduplicated_documents)
processed_documents = self.data_post_processor(original_query, deduplicated_documents, top_k)
time_end = time.time()
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}")
return processed_documents
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str])->List[Dict[str, Any]]:
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
"""
异步版本的retrieve_api方法,使用asyncio代替线程池
@@ -199,16 +199,16 @@ class DifyQueryRetrieval:
# 对所有检索出来的文档进行重排序
time_start = time.time()
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents)
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents, top_k)
time_end = time.time()
logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}")
return processed_documents
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
reranker_model = XinferenceReRankerModel()
documents = [document['segment']['content'] for document in all_documents]
reranked_documents = reranker_model.rerank(query, documents, top_k=5)
reranked_documents = reranker_model.rerank(query, documents, top_k=top_k)
new_all_documents = []
def to_dify_document_format(document: dict)->dict:
@@ -240,7 +240,7 @@ class DifyQueryRetrieval:
new_all_documents.append(to_dify_document_format(cur_doc_info))
return new_all_documents
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
"""
异步版本的data_post_processor方法
@@ -254,7 +254,7 @@ class DifyQueryRetrieval:
reranker_model = XinferenceReRankerModel()
documents = [document['segment']['content'] for document in all_documents]
# 使用异步重排序方法
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=5)
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=top_k)
new_all_documents = []
def to_dify_document_format(document: dict)->dict:
+2 -1
View File
@@ -93,7 +93,8 @@ async def retrieve(request: RetrieveRequest):
results = await dify_query_retrieval.retrieve_api_async(
request.original_query,
query_list,
data_set_list
data_set_list,
top_k=3
)
end_time = time.time()
+16 -2
View File
@@ -68,7 +68,6 @@ class PgSql:
self.connection.close()
self.connection = None
def get_appinfo(self, appid:str)->dict | None:
"""
根据应用 ID 从 'apps' 表中获取应用信息。
@@ -97,7 +96,6 @@ class PgSql:
except (Exception, psycopg2.Error) as error:
raise Exception(f"Error while getting tenant_id by appid: {error}")
def get_messages_info(self, appid:str, query:str)->dict | None:
"""
根据应用 ID 和查询内容从 'messages' 表中获取消息信息。
@@ -246,6 +244,7 @@ class PgSql:
raise Exception(f"Error while getting conversation_messages: {error}")
return rating
class DifyTool:
"""
提供用于获取 Dify 应用调试信息的工具类。
@@ -323,6 +322,21 @@ class DifyTool:
except Exception as e:
raise Exception(f"Error in get_message_debug_info_by_query: {e}")
def get_app_conversations(self, appid:str)->list[str] | None:
"""
根据应用 ID 从 'conversations' 表中获取应用会话信息。
"""
return self.dify_pgsql.get_app_conversations(appid)
def get_conversation_messages(self, conversation_id:str):
"""
根据会话 ID 从 'messages' 表中获取会话消息信息。
"""
return self.dify_pgsql.get_app_conversations(conversation_id)
def get_message_rating(self, msg_id):
return self.dify_pgsql.get_message_rating(msg_id)
class BaseWorkflowChat:
"""
工作流对话基类,封装了与Dify API交互的基本功能
+1 -3
View File
@@ -174,8 +174,6 @@ class DifyExporter:
"""
conversations = self.dify_pgsql.get_app_conversations(appid=self.app_id)
for conversation in conversations:
if conversation['conversation_id'] == '10d04219-0359-42f7-b9da-2ba039bf87a2':
breakpoint()
messages = self.dify_pgsql.get_conversation_messages(conversation_id=conversation['conversation_id'])
message_chain_new = self.process_message_chain(messages)
if len(message_chain_new) != len(messages):
@@ -323,7 +321,7 @@ if __name__ == "__main__":
help='Dify应用ID')
parser.add_argument('--query_log_file', '-q', type=str, default="data/query_logs/answer_type_logs.json",
help='查询日志文件路径')
parser.add_argument('--start_date', '-s', type=str, default=None,
parser.add_argument('--start_date', '-s', type=str, default="2025-07-09 13",
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
parser.add_argument('--end_date', '-e', type=str, default=None,
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')