import sys import os import json from threading import Thread from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Dict, Any, Optional import logging import time sys.path.append(os.getcwd()) from rag2_0.intent_recognition.DataModels import Classification from rag2_0.dify.dify_client.client import DifyClient, KnowledgeBaseClient from rag2_0.tool.ModelTool import XinferenceReRankerModel class DifyQueryRetrieval: software_to_dataset_map = {"配网工程计价通D3":["下载安装注册(new)","配网造价知识(new)","配网造价软件知识(new)"], "新型储能电站建设计价通C1":["下载安装注册(new)","储能C1计价通软件知识(new)","新能源造价知识(new)"], "西藏电力工程计价通Z1":["下载安装注册(new)","西藏造价知识(new)","西藏造价软件知识(new)"], "技改检修工程计价通T1":["下载安装注册(new)","技改检修工程计价通T1软件知识(new)","技改造价知识(new)"], "技改检修清单计价通T1":["下载安装注册(new)","技改检修清单计价通T1软件知识(new)","技改造价知识(new)"], "电力建设计价通":["下载安装注册(new)","主网造价知识(new)","电力建设计价通(2018)软件知识(new)"], "其他":["下载安装注册(new)","技改检修清单计价通T1软件知识(new)", "主网造价知识(new)","西藏造价知识(new)","技改检修工程计价通T1软件知识(new)", "电力建设计价通(2018)软件知识(new)","储能C1计价通软件知识(new)", "西藏造价软件知识(new)","新能源造价知识(new)","配网造价知识(new)","技改造价知识(new)", "配网造价软件知识(new)"]} def __init__(self, api_key: str, base_url: str): self._api_key = api_key self._base_url = base_url self._datasets_list = self.get_datasets_list() def get_datasets_list(self) -> Dict[str, str]: client = KnowledgeBaseClient(api_key=self._api_key, base_url=self._base_url) datasets = client.list_datasets(page_size=50) datasets_json = datasets.json() return {dataset["name"]:dataset["id"] for dataset in datasets_json["data"]} def retrieve_by_dataset(self, query: str, dataset_name: str) -> List[Dict[str, Any]]: try: knowledge_base_client = KnowledgeBaseClient(api_key=self._api_key, base_url=self._base_url, dataset_id=self._datasets_list[dataset_name]) documents = knowledge_base_client.retrieve(query) retrieved_documents = documents.json().get("records", []) # 添加数据集信息 for retrieved_document in retrieved_documents: retrieved_document["dataset_id"] = self._datasets_list[dataset_name] retrieved_document["dataset_name"] = dataset_name return retrieved_documents except Exception as e: logging.error(f"检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True) return [] def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]: datasets = self.get_datasets_by_classification(classification, software_name) if len(datasets) == 0: return None return self.retrieve_api(original_query, query_list, datasets) def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]: all_documents=[] # 使用线程池替代无限制创建线程 # 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制 time_start = time.time() max_workers = min(os.cpu_count() * 2, len(query_list) * len(data_set_list)) with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = {} for query in query_list: for dataset in data_set_list: if dataset not in self._datasets_list: raise ValueError(f"dataset {dataset} not in datasets_list") futures[executor.submit(self.retrieve_by_dataset, query, dataset)] = query # 等待所有任务完成 for future in as_completed(futures.keys()): # 处理可能的异常 try: retrieved_documents = future.result() all_documents.extend(retrieved_documents) except Exception as e: logging.error(f"检索过程中发生错误: {str(e)}", exc_info=True) time_end = time.time() logging.info(f"检索耗时: {time_end - time_start:.2f}秒") # 根据segment_id对文档进行去重 unique_documents = {} for document in all_documents: segment_id = document['segment']['id'] if segment_id not in unique_documents: unique_documents[segment_id] = document # 将去重后的文档转换为列表 deduplicated_documents = list(unique_documents.values()) # 对所有检索出来的文档进行重排序 time_start = time.time() processed_documents = self.data_post_processor(original_query, deduplicated_documents) time_end = time.time() logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒") return processed_documents def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: reranker_model = XinferenceReRankerModel() documents = [document['segment']['content'] for document in all_documents] reranked_documents = reranker_model.rerank(query, documents, top_k=5) new_all_documents = [] def to_dify_document_format(document: dict)->dict: return { "metadata": { "_source": "knowledge", "dataset_id": document["dataset_id"], "dataset_name": document["dataset_name"], "document_id": document['segment']['document_id'], "document_name": document["segment"]["document"]["name"], "data_source_type": document["segment"]["document"]["data_source_type"], "segment_id": document["segment"]["id"], "retriever_from": "api", "score": document.get("score", 0), "segment_hit_count": document.get("segment", {}).get("hit_count", 0), "segment_word_count": document.get("segment", {}).get("word_count", 0), "segment_position": document.get("segment", {}).get("position", 0), "segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""), "doc_metadata": document.get("segment", {}).get("document", {}).get("doc_metadata", None), "position": document["segment"].get("position", 0) }, "title": document["segment"]["document"]["name"], "content": document["segment"]["content"] } for reranked_document in reranked_documents: cur_doc_info = all_documents[reranked_document["index"]] cur_doc_info["score"] = reranked_document["score"] new_all_documents.append(to_dify_document_format(cur_doc_info)) return new_all_documents def get_datasets_by_classification(self, classification: Classification, software_name: str) -> List[str]: if classification.vertical_classification=="软件问题" or classification.vertical_classification=="业务问题": software_name_list = self.software_to_dataset_map.keys() cur_software_name = "" for software_name_info in software_name_list: if software_name_info in software_name: cur_software_name = software_name_info break if cur_software_name == "": return self.software_to_dataset_map["其他"] else: return self.software_to_dataset_map[cur_software_name] if classification.vertical_classification == "安装下载注册": if classification.sub_classification in ["后缀名咨询", "软件锁类"]: return ["下载安装注册(new)"] elif classification.sub_classification == "安装下载类": return [] elif classification.sub_classification == "问题排查": return self.software_to_dataset_map["其他"] return self.software_to_dataset_map["其他"] if __name__ == "__main__": dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") # datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3") datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?", ["电力建设计价通软件如何批量修改设备价格?"], ["电力建设计价通(2018)软件知识(new)"]) print(json.dumps(datasets, ensure_ascii=False, indent=2))