import sys import os import json from typing import List, Dict, Any import logging import time import asyncio import httpx sys.path.append(os.getcwd()) from rag2_0.dify.dify_client.client import DifyClient, KnowledgeBaseClient from rag2_0.tool.ModelTool import XinferenceReRankerModel class DifyQueryRetrieval: def __init__(self, dify_dataset_key: str, dify_base_url: str): self._dify_dataset_key = dify_dataset_key self._dify_base_url = dify_base_url self._datasets_list = self.get_datasets_list() def get_datasets_list(self) -> Dict[str, str]: client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url) datasets = client.list_datasets(page_size=50) datasets_json = datasets.json() return {dataset["name"]:dataset for dataset in datasets_json["data"]} def retrieve_by_dataset(self, query: str, dataset_name: str, metadata_filtering_conditions:dict = {}) -> Dict[str, Any]: try: dataset_id = self._datasets_list[dataset_name]["id"] retrieval_model = self._datasets_list[dataset_name]["retrieval_model_dict"] knowledge_base_client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url, dataset_id=dataset_id) if len(metadata_filtering_conditions) !=0: retrieval_model["metadata_filtering_conditions"]=metadata_filtering_conditions documents = knowledge_base_client.retrieve(query, retrieval_model=retrieval_model, timeout=300) retrieved_documents = documents.json().get("records", []) # 添加数据集信息 for retrieved_document in retrieved_documents: retrieved_document["dataset_id"] = dataset_id retrieved_document["dataset_name"] = dataset_name # 返回包含查询和文档的字典 return { "query": query, "dataset_name": dataset_name, "documents": retrieved_documents } except Exception as e: logging.error(f"检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True) return { "query": query, "dataset_name": dataset_name, "documents": [] } async def retrieve_by_dataset_async(self, query: str, dataset_name: str, metadata_filtering_conditions:dict = {}) -> Dict[str, Any]: """ 异步版本的retrieve_by_dataset方法 Args: query: 查询字符串 dataset_name: 数据集名称 Returns: 包含查询、数据集名称和检索到的文档的字典 """ try: # 使用asyncio.to_thread包装同步方法 return await asyncio.to_thread( self.retrieve_by_dataset, query, dataset_name, metadata_filtering_conditions ) except Exception as e: logging.error(f"异步检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True) return { "query": query, "dataset_name": dataset_name, "documents": [] } async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], query_expand_dict: dict, top_k: int = 5, metadata_filtering_conditions:dict = {})->Dict[str, Any]: """ 异步版本的retrieve_api方法,使用asyncio代替线程池 Args: original_query: 原始查询 query_list: 查询列表 data_set_list: 数据集列表 query_expand_dict: 查询扩展字典,包含不同类型的查询 top_k: 返回的文档数量 Returns: 包含检索结果和查询命中统计的字典 """ all_documents = [] query_document_mapping = {} # 用于存储查询和文档的映射关系,键为查询,值为文档列表 # 记录开始时间 time_start = time.time() # 创建异步任务列表 tasks = [] for query in query_list: for dataset in data_set_list: if dataset not in list(self._datasets_list.keys()): logging.error(f"dataset {dataset} not in datasets_list") continue # 创建异步任务 task = self.retrieve_by_dataset_async(query, dataset, metadata_filtering_conditions) tasks.append(task) # 并发执行所有异步任务 results = await asyncio.gather(*tasks, return_exceptions=True) # 处理结果 for result in results: if isinstance(result, Exception): logging.error(f"异步检索过程中发生错误: {str(result)}", exc_info=True) else: # 添加查询和文档的映射关系 query = result["query"] if query not in query_document_mapping: query_document_mapping[query] = [] # 将文档添加到对应查询的文档列表中 query_document_mapping[query].extend([item['segment']['id'] for item in result["documents"]]) # 将文档添加到总文档列表 all_documents.extend(result["documents"]) time_end = time.time() logging.info(f"异步检索耗时: {time_end - time_start:.2f}秒") # 根据segment_id对文档进行去重 unique_documents = {} for document in all_documents: segment_id = document['segment']['id'] if segment_id not in unique_documents: unique_documents[segment_id] = document # 将去重后的文档转换为列表 deduplicated_documents = list(unique_documents.values()) if len(deduplicated_documents) == 0: return { "documents": [], "query_hit_stats": {} } # 对所有检索出来的文档进行重排序 time_start = time.time() processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents, top_k) time_end = time.time() logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}秒") # 统计不同类型查询命中的文档 query_hit_stats = {} # 获取重排序后的文档ID列表 reranked_doc_ids = [doc["metadata"]["segment_id"] for doc in processed_documents] reranked_doc_titles = [doc["title"].split("/")[-1] for doc in processed_documents] # 解析query_expand_dict(如果是字符串) if isinstance(query_expand_dict, str): try: query_expand_dict = json.loads(query_expand_dict) except Exception as e: logging.error(f"解析query_expand_dict失败: {str(e)}") query_expand_dict = {} # 统计各类型查询命中的文档 for query_type, queries in query_expand_dict.items(): if not isinstance(queries, list): queries = [queries] # 初始化该查询类型命中的文档列表 query_hit_stats[query_type] = [] # 合并所有该类型查询命中的文档 hit_doc_ids_set = set() for query in queries: if query in query_document_mapping: hit_doc_ids_set.update(set(query_document_mapping[query])) # 找出在重排序结果中的文档 hit_doc_titles_set = set() # 用于去重 for i, doc_id in enumerate(reranked_doc_ids): if doc_id in hit_doc_ids_set: doc_title = reranked_doc_titles[i] if doc_title not in hit_doc_titles_set: # 确保不添加重复的标题 hit_doc_titles_set.add(doc_title) query_hit_stats[query_type].append(doc_title) return { "documents": processed_documents, "query_hit_stats": query_hit_stats } async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]: """ 异步版本的data_post_processor方法 Args: query: 查询字符串 all_documents: 待处理的文档列表 Returns: 处理后的文档列表 """ reranker_model = XinferenceReRankerModel() documents = [document['segment']['content'] for document in all_documents] # 使用异步重排序方法 reranked_documents = await reranker_model.rerank_async(query, documents, top_k=top_k) new_all_documents = [] def to_dify_document_format(document: dict)->dict: return { "metadata": { "_source": "knowledge", "dataset_id": document["dataset_id"], "dataset_name": document["dataset_name"], "document_id": document['segment']['document_id'], "document_name": document["segment"]["document"]["name"], "data_source_type": document["segment"]["document"]["data_source_type"], "segment_id": document["segment"]["id"], "retriever_from": "api", "score": document.get("score", 0), "segment_hit_count": document.get("segment", {}).get("hit_count", 0), "segment_word_count": document.get("segment", {}).get("word_count", 0), "segment_position": document.get("segment", {}).get("position", 0), "segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""), "doc_metadata": document.get("segment", {}).get("document", {}).get("doc_metadata", None), "position": document["segment"].get("position", 0) }, "title": document["segment"]["document"]["name"], "content": document["segment"]["content"] } for reranked_document in reranked_documents: cur_doc_info = all_documents[reranked_document["index"]] cur_doc_info["score"] = reranked_document["score"] new_all_documents.append(to_dify_document_format(cur_doc_info)) return new_all_documents if __name__ == "__main__": dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") # 测试异步API async def test_async_api(): datasets = await dify_query_retrieval.retrieve_api_async( "电力建设计价通软件如何批量修改设备价格?", ["电力建设计价通软件如何批量修改设备价格?"], ["电力建设计价通(2018)软件知识(new)"] ) print("异步API测试结果:") print(json.dumps(datasets, ensure_ascii=False, indent=2)) # 如果需要测试异步API,取消下面的注释 import asyncio asyncio.run(test_async_api())