diff --git a/rag2_0/dify/DifyQueryRetrieval.py b/rag2_0/dify/DifyQueryRetrieval.py index d0f912b..20bf6c7 100644 --- a/rag2_0/dify/DifyQueryRetrieval.py +++ b/rag2_0/dify/DifyQueryRetrieval.py @@ -13,17 +13,17 @@ from rag2_0.dify.dify_client.client import DifyClient, KnowledgeBaseClient from rag2_0.tool.ModelTool import XinferenceReRankerModel class DifyQueryRetrieval: - software_to_dataset_map = {"配网工程计价通D3":["下载安装注册","配网造价知识","配网造价软件知识"], - "新型储能电站建设计价通C1":["下载安装注册","储能C1计价通软件知识","新能源造价知识"], - "西藏电力工程计价通Z1":["下载安装注册","西藏造价知识","西藏造价软件知识"], - "技改检修工程计价通T1":["下载安装注册","技改检修工程计价通T1软件知识","技改造价知识"], - "技改检修清单计价通T1":["下载安装注册","技改检修清单计价通T1软件知识","技改造价知识"], - "电力建设计价通":["下载安装注册","主网造价知识","电力建设计价通(2018)软件知识"], - "其他":["下载安装注册","技改检修清单计价通T1软件知识", - "主网造价知识","西藏造价知识","技改检修工程计价通T1软件知识", - "电力建设计价通(2018)软件知识","储能C1计价通软件知识", - "西藏造价软件知识","新能源造价知识","配网造价知识","技改造价知识", - "配网造价软件知识"]} + software_to_dataset_map = {"配网工程计价通D3":["下载安装注册(new)","配网造价知识(new)","配网造价软件知识(new)"], + "新型储能电站建设计价通C1":["下载安装注册(new)","储能C1计价通软件知识(new)","新能源造价知识(new)"], + "西藏电力工程计价通Z1":["下载安装注册(new)","西藏造价知识(new)","西藏造价软件知识(new)"], + "技改检修工程计价通T1":["下载安装注册(new)","技改检修工程计价通T1软件知识(new)","技改造价知识(new)"], + "技改检修清单计价通T1":["下载安装注册(new)","技改检修清单计价通T1软件知识(new)","技改造价知识(new)"], + "电力建设计价通":["下载安装注册(new)","主网造价知识(new)","电力建设计价通(2018)软件知识(new)"], + "其他":["下载安装注册(new)","技改检修清单计价通T1软件知识(new)", + "主网造价知识(new)","西藏造价知识(new)","技改检修工程计价通T1软件知识(new)", + "电力建设计价通(2018)软件知识(new)","储能C1计价通软件知识(new)", + "西藏造价软件知识(new)","新能源造价知识(new)","配网造价知识(new)","技改造价知识(new)", + "配网造价软件知识(new)"]} def __init__(self, api_key: str, base_url: str): self._api_key = api_key @@ -57,66 +57,25 @@ class DifyQueryRetrieval: if len(datasets) == 0: return None - all_documents=[] - # 使用线程池替代无限制创建线程 - # 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制 - time_start = time.time() - max_workers = min(os.cpu_count() * 2, len(query_list) * len(datasets)) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for query in query_list: - for dataset in datasets: - if dataset not in self._datasets_list: - raise ValueError(f"dataset {dataset} not in datasets_list") - - futures.append(executor.submit(self.retrieve_by_dataset, query, dataset)) - - # 等待所有任务完成 - for future in as_completed(futures): - # 处理可能的异常 - try: - retrieved_documents = future.result() - all_documents.extend(retrieved_documents) - except Exception as e: - logging.error(f"检索过程中发生错误: {str(e)}", exc_info=True) - time_end = time.time() - - logging.info(f"检索耗时: {time_end - time_start:.2f}秒") - # 根据segment_id对文档进行去重 - unique_documents = {} - for document in all_documents: - segment_id = document['segment']['id'] - if segment_id not in unique_documents: - unique_documents[segment_id] = document - - # 将去重后的文档转换为列表 - deduplicated_documents = list(unique_documents.values()) - - # 对所有检索出来的文档进行重排序 - time_start = time.time() - processed_documents = self.data_post_processor(original_query, deduplicated_documents) - time_end = time.time() - logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒") - - return processed_documents + return self.retrieve_api(original_query, query_list, datasets) def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]: all_documents=[] # 使用线程池替代无限制创建线程 - # 设置合理的最大线程数,这里使用min(32, len(query_list) * len(data_set_list))来限制 + # 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制 time_start = time.time() max_workers = min(os.cpu_count() * 2, len(query_list) * len(data_set_list)) with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] + futures = {} for query in query_list: for dataset in data_set_list: if dataset not in self._datasets_list: raise ValueError(f"dataset {dataset} not in datasets_list") - futures.append(executor.submit(self.retrieve_by_dataset, query, dataset)) + futures[executor.submit(self.retrieve_by_dataset, query, dataset)] = query # 等待所有任务完成 - for future in as_completed(futures): + for future in as_completed(futures.keys()): # 处理可能的异常 try: retrieved_documents = future.result() @@ -194,7 +153,7 @@ class DifyQueryRetrieval: if classification.vertical_classification == "安装下载注册": if classification.sub_classification in ["后缀名咨询", "软件锁类"]: - return ["下载安装注册"] + return ["下载安装注册(new)"] elif classification.sub_classification == "安装下载类": return [] elif classification.sub_classification == "问题排查": @@ -206,5 +165,7 @@ class DifyQueryRetrieval: if __name__ == "__main__": dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") # datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3") - datasets = dify_query_retrieval.retrieve_api("配网工程计价通D3软件如何新建工程?", ["配网工程计价通D3软件如何新建工程?"], ["流式输出缺失"]) - print(datasets) \ No newline at end of file + datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?", + ["电力建设计价通软件如何批量修改设备价格?"], + ["电力建设计价通(2018)软件知识(new)"]) + print(json.dumps(datasets, ensure_ascii=False, indent=2)) \ No newline at end of file