优化DifyQueryRetrieval类,新增top_k参数以支持检索结果的数量控制,同时重构相关方法以提升性能和可维护性。更新DifyTool类,新增获取应用会话和消息信息的方法。修复DifyExporter类中的代码格式问题,调整日期参数的默认值。
This commit is contained in:
@@ -102,7 +102,7 @@ class DifyQueryRetrieval:
|
||||
|
||||
return await self.retrieve_api_async(original_query, query_list, datasets)
|
||||
|
||||
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]:
|
||||
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
|
||||
all_documents=[]
|
||||
# 使用线程池替代无限制创建线程
|
||||
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制
|
||||
@@ -140,13 +140,13 @@ class DifyQueryRetrieval:
|
||||
|
||||
# 对所有检索出来的文档进行重排序
|
||||
time_start = time.time()
|
||||
processed_documents = self.data_post_processor(original_query, deduplicated_documents)
|
||||
processed_documents = self.data_post_processor(original_query, deduplicated_documents, top_k)
|
||||
time_end = time.time()
|
||||
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒")
|
||||
|
||||
return processed_documents
|
||||
|
||||
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str])->List[Dict[str, Any]]:
|
||||
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
|
||||
"""
|
||||
异步版本的retrieve_api方法,使用asyncio代替线程池
|
||||
|
||||
@@ -199,16 +199,16 @@ class DifyQueryRetrieval:
|
||||
|
||||
# 对所有检索出来的文档进行重排序
|
||||
time_start = time.time()
|
||||
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents)
|
||||
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents, top_k)
|
||||
time_end = time.time()
|
||||
logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}秒")
|
||||
|
||||
return processed_documents
|
||||
|
||||
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
reranker_model = XinferenceReRankerModel()
|
||||
documents = [document['segment']['content'] for document in all_documents]
|
||||
reranked_documents = reranker_model.rerank(query, documents, top_k=5)
|
||||
reranked_documents = reranker_model.rerank(query, documents, top_k=top_k)
|
||||
new_all_documents = []
|
||||
|
||||
def to_dify_document_format(document: dict)->dict:
|
||||
@@ -240,7 +240,7 @@ class DifyQueryRetrieval:
|
||||
new_all_documents.append(to_dify_document_format(cur_doc_info))
|
||||
return new_all_documents
|
||||
|
||||
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
异步版本的data_post_processor方法
|
||||
|
||||
@@ -254,7 +254,7 @@ class DifyQueryRetrieval:
|
||||
reranker_model = XinferenceReRankerModel()
|
||||
documents = [document['segment']['content'] for document in all_documents]
|
||||
# 使用异步重排序方法
|
||||
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=5)
|
||||
reranked_documents = await reranker_model.rerank_async(query, documents, top_k=top_k)
|
||||
new_all_documents = []
|
||||
|
||||
def to_dify_document_format(document: dict)->dict:
|
||||
|
||||
@@ -93,7 +93,8 @@ async def retrieve(request: RetrieveRequest):
|
||||
results = await dify_query_retrieval.retrieve_api_async(
|
||||
request.original_query,
|
||||
query_list,
|
||||
data_set_list
|
||||
data_set_list,
|
||||
top_k=3
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
@@ -68,7 +68,6 @@ class PgSql:
|
||||
self.connection.close()
|
||||
self.connection = None
|
||||
|
||||
|
||||
def get_appinfo(self, appid:str)->dict | None:
|
||||
"""
|
||||
根据应用 ID 从 'apps' 表中获取应用信息。
|
||||
@@ -97,7 +96,6 @@ class PgSql:
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting tenant_id by appid: {error}")
|
||||
|
||||
|
||||
def get_messages_info(self, appid:str, query:str)->dict | None:
|
||||
"""
|
||||
根据应用 ID 和查询内容从 'messages' 表中获取消息信息。
|
||||
@@ -246,6 +244,7 @@ class PgSql:
|
||||
raise Exception(f"Error while getting conversation_messages: {error}")
|
||||
return rating
|
||||
|
||||
|
||||
class DifyTool:
|
||||
"""
|
||||
提供用于获取 Dify 应用调试信息的工具类。
|
||||
@@ -322,6 +321,21 @@ class DifyTool:
|
||||
}
|
||||
except Exception as e:
|
||||
raise Exception(f"Error in get_message_debug_info_by_query: {e}")
|
||||
|
||||
def get_app_conversations(self, appid:str)->list[str] | None:
|
||||
"""
|
||||
根据应用 ID 从 'conversations' 表中获取应用会话信息。
|
||||
"""
|
||||
return self.dify_pgsql.get_app_conversations(appid)
|
||||
|
||||
def get_conversation_messages(self, conversation_id:str):
|
||||
"""
|
||||
根据会话 ID 从 'messages' 表中获取会话消息信息。
|
||||
"""
|
||||
return self.dify_pgsql.get_app_conversations(conversation_id)
|
||||
|
||||
def get_message_rating(self, msg_id):
|
||||
return self.dify_pgsql.get_message_rating(msg_id)
|
||||
|
||||
class BaseWorkflowChat:
|
||||
"""
|
||||
|
||||
@@ -174,8 +174,6 @@ class DifyExporter:
|
||||
"""
|
||||
conversations = self.dify_pgsql.get_app_conversations(appid=self.app_id)
|
||||
for conversation in conversations:
|
||||
if conversation['conversation_id'] == '10d04219-0359-42f7-b9da-2ba039bf87a2':
|
||||
breakpoint()
|
||||
messages = self.dify_pgsql.get_conversation_messages(conversation_id=conversation['conversation_id'])
|
||||
message_chain_new = self.process_message_chain(messages)
|
||||
if len(message_chain_new) != len(messages):
|
||||
@@ -198,7 +196,7 @@ class DifyExporter:
|
||||
message_info = self.extract_message_info(message)
|
||||
if message_info:
|
||||
self.message_info_list.append(message_info)
|
||||
|
||||
|
||||
return self.message_info_list
|
||||
|
||||
def save_to_excel(self, message_info_list, output_file):
|
||||
@@ -323,7 +321,7 @@ if __name__ == "__main__":
|
||||
help='Dify应用ID')
|
||||
parser.add_argument('--query_log_file', '-q', type=str, default="data/query_logs/answer_type_logs.json",
|
||||
help='查询日志文件路径')
|
||||
parser.add_argument('--start_date', '-s', type=str, default=None,
|
||||
parser.add_argument('--start_date', '-s', type=str, default="2025-07-09 13",
|
||||
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
|
||||
parser.add_argument('--end_date', '-e', type=str, default=None,
|
||||
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')
|
||||
|
||||
Reference in New Issue
Block a user