优化DifyQueryRetrieval类,新增top_k参数以支持检索结果的数量控制,同时重构相关方法以提升性能和可维护性。更新DifyTool类,新增获取应用会话和消息信息的方法。修复DifyExporter类中的代码格式问题,调整日期参数的默认值。
This commit is contained in:
@@ -102,7 +102,7 @@ class DifyQueryRetrieval:
|
|||||||
|
|
||||||
return await self.retrieve_api_async(original_query, query_list, datasets)
|
return await self.retrieve_api_async(original_query, query_list, datasets)
|
||||||
|
|
||||||
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]:
|
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
|
||||||
all_documents=[]
|
all_documents=[]
|
||||||
# 使用线程池替代无限制创建线程
|
# 使用线程池替代无限制创建线程
|
||||||
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制
|
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制
|
||||||
@@ -140,13 +140,13 @@ class DifyQueryRetrieval:
|
|||||||
|
|
||||||
# 对所有检索出来的文档进行重排序
|
# 对所有检索出来的文档进行重排序
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
processed_documents = self.data_post_processor(original_query, deduplicated_documents)
|
processed_documents = self.data_post_processor(original_query, deduplicated_documents, top_k)
|
||||||
time_end = time.time()
|
time_end = time.time()
|
||||||
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒")
|
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒")
|
||||||
|
|
||||||
return processed_documents
|
return processed_documents
|
||||||
|
|
||||||
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str])->List[Dict[str, Any]]:
|
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
异步版本的retrieve_api方法,使用asyncio代替线程池
|
异步版本的retrieve_api方法,使用asyncio代替线程池
|
||||||
|
|
||||||
@@ -199,16 +199,16 @@ class DifyQueryRetrieval:
|
|||||||
|
|
||||||
# 对所有检索出来的文档进行重排序
|
# 对所有检索出来的文档进行重排序
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents)
|
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents, top_k)
|
||||||
time_end = time.time()
|
time_end = time.time()
|
||||||
logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}秒")
|
logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}秒")
|
||||||
|
|
||||||
return processed_documents
|
return processed_documents
|
||||||
|
|
||||||
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
|
||||||
reranker_model = XinferenceReRankerModel()
|
reranker_model = XinferenceReRankerModel()
|
||||||
documents = [document['segment']['content'] for document in all_documents]
|
documents = [document['segment']['content'] for document in all_documents]
|
||||||
reranked_documents = reranker_model.rerank(query, documents, top_k=5)
|
reranked_documents = reranker_model.rerank(query, documents, top_k=top_k)
|
||||||
new_all_documents = []
|
new_all_documents = []
|
||||||
|
|
||||||
def to_dify_document_format(document: dict)->dict:
|
def to_dify_document_format(document: dict)->dict:
|
||||||
@@ -240,7 +240,7 @@ class DifyQueryRetrieval:
|
|||||||
new_all_documents.append(to_dify_document_format(cur_doc_info))
|
new_all_documents.append(to_dify_document_format(cur_doc_info))
|
||||||
return new_all_documents
|
return new_all_documents
|
||||||
|
|
||||||
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
异步版本的data_post_processor方法
|
异步版本的data_post_processor方法
|
||||||
|
|
||||||
@@ -254,7 +254,7 @@ class DifyQueryRetrieval:
|
|||||||
reranker_model = XinferenceReRankerModel()
|
reranker_model = XinferenceReRankerModel()
|
||||||
documents = [document['segment']['content'] for document in all_documents]
|
documents = [document['segment']['content'] for document in all_documents]
|
||||||
# 使用异步重排序方法
|
# 使用异步重排序方法
|
||||||
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=5)
|
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=top_k)
|
||||||
new_all_documents = []
|
new_all_documents = []
|
||||||
|
|
||||||
def to_dify_document_format(document: dict)->dict:
|
def to_dify_document_format(document: dict)->dict:
|
||||||
|
|||||||
@@ -93,7 +93,8 @@ async def retrieve(request: RetrieveRequest):
|
|||||||
results = await dify_query_retrieval.retrieve_api_async(
|
results = await dify_query_retrieval.retrieve_api_async(
|
||||||
request.original_query,
|
request.original_query,
|
||||||
query_list,
|
query_list,
|
||||||
data_set_list
|
data_set_list,
|
||||||
|
top_k=3
|
||||||
)
|
)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,6 @@ class PgSql:
|
|||||||
self.connection.close()
|
self.connection.close()
|
||||||
self.connection = None
|
self.connection = None
|
||||||
|
|
||||||
|
|
||||||
def get_appinfo(self, appid:str)->dict | None:
|
def get_appinfo(self, appid:str)->dict | None:
|
||||||
"""
|
"""
|
||||||
根据应用 ID 从 'apps' 表中获取应用信息。
|
根据应用 ID 从 'apps' 表中获取应用信息。
|
||||||
@@ -97,7 +96,6 @@ class PgSql:
|
|||||||
except (Exception, psycopg2.Error) as error:
|
except (Exception, psycopg2.Error) as error:
|
||||||
raise Exception(f"Error while getting tenant_id by appid: {error}")
|
raise Exception(f"Error while getting tenant_id by appid: {error}")
|
||||||
|
|
||||||
|
|
||||||
def get_messages_info(self, appid:str, query:str)->dict | None:
|
def get_messages_info(self, appid:str, query:str)->dict | None:
|
||||||
"""
|
"""
|
||||||
根据应用 ID 和查询内容从 'messages' 表中获取消息信息。
|
根据应用 ID 和查询内容从 'messages' 表中获取消息信息。
|
||||||
@@ -246,6 +244,7 @@ class PgSql:
|
|||||||
raise Exception(f"Error while getting conversation_messages: {error}")
|
raise Exception(f"Error while getting conversation_messages: {error}")
|
||||||
return rating
|
return rating
|
||||||
|
|
||||||
|
|
||||||
class DifyTool:
|
class DifyTool:
|
||||||
"""
|
"""
|
||||||
提供用于获取 Dify 应用调试信息的工具类。
|
提供用于获取 Dify 应用调试信息的工具类。
|
||||||
@@ -322,6 +321,21 @@ class DifyTool:
|
|||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Error in get_message_debug_info_by_query: {e}")
|
raise Exception(f"Error in get_message_debug_info_by_query: {e}")
|
||||||
|
|
||||||
|
def get_app_conversations(self, appid:str)->list[str] | None:
|
||||||
|
"""
|
||||||
|
根据应用 ID 从 'conversations' 表中获取应用会话信息。
|
||||||
|
"""
|
||||||
|
return self.dify_pgsql.get_app_conversations(appid)
|
||||||
|
|
||||||
|
def get_conversation_messages(self, conversation_id:str):
|
||||||
|
"""
|
||||||
|
根据会话 ID 从 'messages' 表中获取会话消息信息。
|
||||||
|
"""
|
||||||
|
return self.dify_pgsql.get_app_conversations(conversation_id)
|
||||||
|
|
||||||
|
def get_message_rating(self, msg_id):
|
||||||
|
return self.dify_pgsql.get_message_rating(msg_id)
|
||||||
|
|
||||||
class BaseWorkflowChat:
|
class BaseWorkflowChat:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -174,8 +174,6 @@ class DifyExporter:
|
|||||||
"""
|
"""
|
||||||
conversations = self.dify_pgsql.get_app_conversations(appid=self.app_id)
|
conversations = self.dify_pgsql.get_app_conversations(appid=self.app_id)
|
||||||
for conversation in conversations:
|
for conversation in conversations:
|
||||||
if conversation['conversation_id'] == '10d04219-0359-42f7-b9da-2ba039bf87a2':
|
|
||||||
breakpoint()
|
|
||||||
messages = self.dify_pgsql.get_conversation_messages(conversation_id=conversation['conversation_id'])
|
messages = self.dify_pgsql.get_conversation_messages(conversation_id=conversation['conversation_id'])
|
||||||
message_chain_new = self.process_message_chain(messages)
|
message_chain_new = self.process_message_chain(messages)
|
||||||
if len(message_chain_new) != len(messages):
|
if len(message_chain_new) != len(messages):
|
||||||
@@ -198,7 +196,7 @@ class DifyExporter:
|
|||||||
message_info = self.extract_message_info(message)
|
message_info = self.extract_message_info(message)
|
||||||
if message_info:
|
if message_info:
|
||||||
self.message_info_list.append(message_info)
|
self.message_info_list.append(message_info)
|
||||||
|
|
||||||
return self.message_info_list
|
return self.message_info_list
|
||||||
|
|
||||||
def save_to_excel(self, message_info_list, output_file):
|
def save_to_excel(self, message_info_list, output_file):
|
||||||
@@ -323,7 +321,7 @@ if __name__ == "__main__":
|
|||||||
help='Dify应用ID')
|
help='Dify应用ID')
|
||||||
parser.add_argument('--query_log_file', '-q', type=str, default="data/query_logs/answer_type_logs.json",
|
parser.add_argument('--query_log_file', '-q', type=str, default="data/query_logs/answer_type_logs.json",
|
||||||
help='查询日志文件路径')
|
help='查询日志文件路径')
|
||||||
parser.add_argument('--start_date', '-s', type=str, default=None,
|
parser.add_argument('--start_date', '-s', type=str, default="2025-07-09 13",
|
||||||
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
|
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
|
||||||
parser.add_argument('--end_date', '-e', type=str, default=None,
|
parser.add_argument('--end_date', '-e', type=str, default=None,
|
||||||
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')
|
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')
|
||||||
|
|||||||
Reference in New Issue
Block a user