优化DifyQueryRetrieval类,新增top_k参数以支持检索结果的数量控制,同时重构相关方法以提升性能和可维护性。更新DifyTool类,新增获取应用会话和消息信息的方法。修复DifyExporter类中的代码格式问题,调整日期参数的默认值。

This commit is contained in:
2025-07-10 08:36:12 +08:00
parent ec3db656a5
commit 23f522dde5
4 changed files with 28 additions and 15 deletions
+8 -8
View File
@@ -102,7 +102,7 @@ class DifyQueryRetrieval:
return await self.retrieve_api_async(original_query, query_list, datasets)
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]:
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
all_documents=[]
# 使用线程池替代无限制创建线程
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制
@@ -140,13 +140,13 @@ class DifyQueryRetrieval:
# 对所有检索出来的文档进行重排序
time_start = time.time()
processed_documents = self.data_post_processor(original_query, deduplicated_documents)
processed_documents = self.data_post_processor(original_query, deduplicated_documents, top_k)
time_end = time.time()
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}")
return processed_documents
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str])->List[Dict[str, Any]]:
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
"""
异步版本的retrieve_api方法,使用asyncio代替线程池
@@ -199,16 +199,16 @@ class DifyQueryRetrieval:
# 对所有检索出来的文档进行重排序
time_start = time.time()
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents)
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents, top_k)
time_end = time.time()
logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}")
return processed_documents
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
reranker_model = XinferenceReRankerModel()
documents = [document['segment']['content'] for document in all_documents]
reranked_documents = reranker_model.rerank(query, documents, top_k=5)
reranked_documents = reranker_model.rerank(query, documents, top_k=top_k)
new_all_documents = []
def to_dify_document_format(document: dict)->dict:
@@ -240,7 +240,7 @@ class DifyQueryRetrieval:
new_all_documents.append(to_dify_document_format(cur_doc_info))
return new_all_documents
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
"""
异步版本的data_post_processor方法
@@ -254,7 +254,7 @@ class DifyQueryRetrieval:
reranker_model = XinferenceReRankerModel()
documents = [document['segment']['content'] for document in all_documents]
# 使用异步重排序方法
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=5)
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=top_k)
new_all_documents = []
def to_dify_document_format(document: dict)->dict: