From 23f522dde5943058242e6d3de004c3be3c9cf8ea Mon Sep 17 00:00:00 2001 From: ouyangyouzhang Date: Thu, 10 Jul 2025 08:36:12 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96DifyQueryRetrieval=E7=B1=BB?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9Etop=5Fk=E5=8F=82=E6=95=B0=E4=BB=A5?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=A3=80=E7=B4=A2=E7=BB=93=E6=9E=9C=E7=9A=84?= =?UTF-8?q?=E6=95=B0=E9=87=8F=E6=8E=A7=E5=88=B6=EF=BC=8C=E5=90=8C=E6=97=B6?= =?UTF-8?q?=E9=87=8D=E6=9E=84=E7=9B=B8=E5=85=B3=E6=96=B9=E6=B3=95=E4=BB=A5?= =?UTF-8?q?=E6=8F=90=E5=8D=87=E6=80=A7=E8=83=BD=E5=92=8C=E5=8F=AF=E7=BB=B4?= =?UTF-8?q?=E6=8A=A4=E6=80=A7=E3=80=82=E6=9B=B4=E6=96=B0DifyTool=E7=B1=BB?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9E=E8=8E=B7=E5=8F=96=E5=BA=94=E7=94=A8?= =?UTF-8?q?=E4=BC=9A=E8=AF=9D=E5=92=8C=E6=B6=88=E6=81=AF=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E7=9A=84=E6=96=B9=E6=B3=95=E3=80=82=E4=BF=AE=E5=A4=8DDifyExpor?= =?UTF-8?q?ter=E7=B1=BB=E4=B8=AD=E7=9A=84=E4=BB=A3=E7=A0=81=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E9=97=AE=E9=A2=98=EF=BC=8C=E8=B0=83=E6=95=B4=E6=97=A5?= =?UTF-8?q?=E6=9C=9F=E5=8F=82=E6=95=B0=E7=9A=84=E9=BB=98=E8=AE=A4=E5=80=BC?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rag2_0/dify/DifyQueryRetrieval.py | 16 ++++++++-------- rag2_0/dify/DifyQueryRetrieval_api.py | 3 ++- rag2_0/dify/dify_tool.py | 18 ++++++++++++++++-- rag2_0/dify/export_new_dify.py | 6 ++---- 4 files changed, 28 insertions(+), 15 deletions(-) diff --git a/rag2_0/dify/DifyQueryRetrieval.py b/rag2_0/dify/DifyQueryRetrieval.py index e2593d7..53d5aea 100644 --- a/rag2_0/dify/DifyQueryRetrieval.py +++ b/rag2_0/dify/DifyQueryRetrieval.py @@ -102,7 +102,7 @@ class DifyQueryRetrieval: return await self.retrieve_api_async(original_query, query_list, datasets) - def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]: + def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]: all_documents=[] # 使用线程池替代无限制创建线程 # 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制 @@ -140,13 +140,13 @@ class DifyQueryRetrieval: # 对所有检索出来的文档进行重排序 time_start = time.time() - processed_documents = self.data_post_processor(original_query, deduplicated_documents) + processed_documents = self.data_post_processor(original_query, deduplicated_documents, top_k) time_end = time.time() logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒") return processed_documents - async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str])->List[Dict[str, Any]]: + async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]: """ 异步版本的retrieve_api方法,使用asyncio代替线程池 @@ -199,16 +199,16 @@ class DifyQueryRetrieval: # 对所有检索出来的文档进行重排序 time_start = time.time() - processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents) + processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents, top_k) time_end = time.time() logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}秒") return processed_documents - def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]: reranker_model = XinferenceReRankerModel() documents = [document['segment']['content'] for document in all_documents] - reranked_documents = reranker_model.rerank(query, documents, top_k=5) + reranked_documents = reranker_model.rerank(query, documents, top_k=top_k) new_all_documents = [] def to_dify_document_format(document: dict)->dict: @@ -240,7 +240,7 @@ class DifyQueryRetrieval: new_all_documents.append(to_dify_document_format(cur_doc_info)) return new_all_documents - async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]: """ 异步版本的data_post_processor方法 @@ -254,7 +254,7 @@ class DifyQueryRetrieval: reranker_model = XinferenceReRankerModel() documents = [document['segment']['content'] for document in all_documents] # 使用异步重排序方法 - reranked_documents = await reranker_model.rerank_async(query, documents, top_k=5) + reranked_documents = await reranker_model.rerank_async(query, documents, top_k=top_k) new_all_documents = [] def to_dify_document_format(document: dict)->dict: diff --git a/rag2_0/dify/DifyQueryRetrieval_api.py b/rag2_0/dify/DifyQueryRetrieval_api.py index cda3127..4e6c508 100644 --- a/rag2_0/dify/DifyQueryRetrieval_api.py +++ b/rag2_0/dify/DifyQueryRetrieval_api.py @@ -93,7 +93,8 @@ async def retrieve(request: RetrieveRequest): results = await dify_query_retrieval.retrieve_api_async( request.original_query, query_list, - data_set_list + data_set_list, + top_k=3 ) end_time = time.time() diff --git a/rag2_0/dify/dify_tool.py b/rag2_0/dify/dify_tool.py index fbb3400..f5f8541 100755 --- a/rag2_0/dify/dify_tool.py +++ b/rag2_0/dify/dify_tool.py @@ -68,7 +68,6 @@ class PgSql: self.connection.close() self.connection = None - def get_appinfo(self, appid:str)->dict | None: """ 根据应用 ID 从 'apps' 表中获取应用信息。 @@ -97,7 +96,6 @@ class PgSql: except (Exception, psycopg2.Error) as error: raise Exception(f"Error while getting tenant_id by appid: {error}") - def get_messages_info(self, appid:str, query:str)->dict | None: """ 根据应用 ID 和查询内容从 'messages' 表中获取消息信息。 @@ -246,6 +244,7 @@ class PgSql: raise Exception(f"Error while getting conversation_messages: {error}") return rating + class DifyTool: """ 提供用于获取 Dify 应用调试信息的工具类。 @@ -322,6 +321,21 @@ class DifyTool: } except Exception as e: raise Exception(f"Error in get_message_debug_info_by_query: {e}") + + def get_app_conversations(self, appid:str)->list[str] | None: + """ + 根据应用 ID 从 'conversations' 表中获取应用会话信息。 + """ + return self.dify_pgsql.get_app_conversations(appid) + + def get_conversation_messages(self, conversation_id:str): + """ + 根据会话 ID 从 'messages' 表中获取会话消息信息。 + """ + return self.dify_pgsql.get_app_conversations(conversation_id) + + def get_message_rating(self, msg_id): + return self.dify_pgsql.get_message_rating(msg_id) class BaseWorkflowChat: """ diff --git a/rag2_0/dify/export_new_dify.py b/rag2_0/dify/export_new_dify.py index cf9e999..c95f856 100644 --- a/rag2_0/dify/export_new_dify.py +++ b/rag2_0/dify/export_new_dify.py @@ -174,8 +174,6 @@ class DifyExporter: """ conversations = self.dify_pgsql.get_app_conversations(appid=self.app_id) for conversation in conversations: - if conversation['conversation_id'] == '10d04219-0359-42f7-b9da-2ba039bf87a2': - breakpoint() messages = self.dify_pgsql.get_conversation_messages(conversation_id=conversation['conversation_id']) message_chain_new = self.process_message_chain(messages) if len(message_chain_new) != len(messages): @@ -198,7 +196,7 @@ class DifyExporter: message_info = self.extract_message_info(message) if message_info: self.message_info_list.append(message_info) - + return self.message_info_list def save_to_excel(self, message_info_list, output_file): @@ -323,7 +321,7 @@ if __name__ == "__main__": help='Dify应用ID') parser.add_argument('--query_log_file', '-q', type=str, default="data/query_logs/answer_type_logs.json", help='查询日志文件路径') - parser.add_argument('--start_date', '-s', type=str, default=None, + parser.add_argument('--start_date', '-s', type=str, default="2025-07-09 13", help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)') parser.add_argument('--end_date', '-e', type=str, default=None, help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')