更新DifyQueryRetrieval类,启用父子级检索
This commit is contained in:
@@ -13,17 +13,17 @@ from rag2_0.dify.dify_client.client import DifyClient, KnowledgeBaseClient
|
||||
from rag2_0.tool.ModelTool import XinferenceReRankerModel
|
||||
class DifyQueryRetrieval:
|
||||
|
||||
software_to_dataset_map = {"配网工程计价通D3":["下载安装注册","配网造价知识","配网造价软件知识"],
|
||||
"新型储能电站建设计价通C1":["下载安装注册","储能C1计价通软件知识","新能源造价知识"],
|
||||
"西藏电力工程计价通Z1":["下载安装注册","西藏造价知识","西藏造价软件知识"],
|
||||
"技改检修工程计价通T1":["下载安装注册","技改检修工程计价通T1软件知识","技改造价知识"],
|
||||
"技改检修清单计价通T1":["下载安装注册","技改检修清单计价通T1软件知识","技改造价知识"],
|
||||
"电力建设计价通":["下载安装注册","主网造价知识","电力建设计价通(2018)软件知识"],
|
||||
"其他":["下载安装注册","技改检修清单计价通T1软件知识",
|
||||
"主网造价知识","西藏造价知识","技改检修工程计价通T1软件知识",
|
||||
"电力建设计价通(2018)软件知识","储能C1计价通软件知识",
|
||||
"西藏造价软件知识","新能源造价知识","配网造价知识","技改造价知识",
|
||||
"配网造价软件知识"]}
|
||||
software_to_dataset_map = {"配网工程计价通D3":["下载安装注册(new)","配网造价知识(new)","配网造价软件知识(new)"],
|
||||
"新型储能电站建设计价通C1":["下载安装注册(new)","储能C1计价通软件知识(new)","新能源造价知识(new)"],
|
||||
"西藏电力工程计价通Z1":["下载安装注册(new)","西藏造价知识(new)","西藏造价软件知识(new)"],
|
||||
"技改检修工程计价通T1":["下载安装注册(new)","技改检修工程计价通T1软件知识(new)","技改造价知识(new)"],
|
||||
"技改检修清单计价通T1":["下载安装注册(new)","技改检修清单计价通T1软件知识(new)","技改造价知识(new)"],
|
||||
"电力建设计价通":["下载安装注册(new)","主网造价知识(new)","电力建设计价通(2018)软件知识(new)"],
|
||||
"其他":["下载安装注册(new)","技改检修清单计价通T1软件知识(new)",
|
||||
"主网造价知识(new)","西藏造价知识(new)","技改检修工程计价通T1软件知识(new)",
|
||||
"电力建设计价通(2018)软件知识(new)","储能C1计价通软件知识(new)",
|
||||
"西藏造价软件知识(new)","新能源造价知识(new)","配网造价知识(new)","技改造价知识(new)",
|
||||
"配网造价软件知识(new)"]}
|
||||
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
self._api_key = api_key
|
||||
@@ -57,66 +57,25 @@ class DifyQueryRetrieval:
|
||||
if len(datasets) == 0:
|
||||
return None
|
||||
|
||||
all_documents=[]
|
||||
# 使用线程池替代无限制创建线程
|
||||
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制
|
||||
time_start = time.time()
|
||||
max_workers = min(os.cpu_count() * 2, len(query_list) * len(datasets))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
for query in query_list:
|
||||
for dataset in datasets:
|
||||
if dataset not in self._datasets_list:
|
||||
raise ValueError(f"dataset {dataset} not in datasets_list")
|
||||
|
||||
futures.append(executor.submit(self.retrieve_by_dataset, query, dataset))
|
||||
|
||||
# 等待所有任务完成
|
||||
for future in as_completed(futures):
|
||||
# 处理可能的异常
|
||||
try:
|
||||
retrieved_documents = future.result()
|
||||
all_documents.extend(retrieved_documents)
|
||||
except Exception as e:
|
||||
logging.error(f"检索过程中发生错误: {str(e)}", exc_info=True)
|
||||
time_end = time.time()
|
||||
|
||||
logging.info(f"检索耗时: {time_end - time_start:.2f}秒")
|
||||
# 根据segment_id对文档进行去重
|
||||
unique_documents = {}
|
||||
for document in all_documents:
|
||||
segment_id = document['segment']['id']
|
||||
if segment_id not in unique_documents:
|
||||
unique_documents[segment_id] = document
|
||||
|
||||
# 将去重后的文档转换为列表
|
||||
deduplicated_documents = list(unique_documents.values())
|
||||
|
||||
# 对所有检索出来的文档进行重排序
|
||||
time_start = time.time()
|
||||
processed_documents = self.data_post_processor(original_query, deduplicated_documents)
|
||||
time_end = time.time()
|
||||
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒")
|
||||
|
||||
return processed_documents
|
||||
return self.retrieve_api(original_query, query_list, datasets)
|
||||
|
||||
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]:
|
||||
all_documents=[]
|
||||
# 使用线程池替代无限制创建线程
|
||||
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(data_set_list))来限制
|
||||
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制
|
||||
time_start = time.time()
|
||||
max_workers = min(os.cpu_count() * 2, len(query_list) * len(data_set_list))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
futures = {}
|
||||
for query in query_list:
|
||||
for dataset in data_set_list:
|
||||
if dataset not in self._datasets_list:
|
||||
raise ValueError(f"dataset {dataset} not in datasets_list")
|
||||
|
||||
futures.append(executor.submit(self.retrieve_by_dataset, query, dataset))
|
||||
futures[executor.submit(self.retrieve_by_dataset, query, dataset)] = query
|
||||
|
||||
# 等待所有任务完成
|
||||
for future in as_completed(futures):
|
||||
for future in as_completed(futures.keys()):
|
||||
# 处理可能的异常
|
||||
try:
|
||||
retrieved_documents = future.result()
|
||||
@@ -194,7 +153,7 @@ class DifyQueryRetrieval:
|
||||
|
||||
if classification.vertical_classification == "安装下载注册":
|
||||
if classification.sub_classification in ["后缀名咨询", "软件锁类"]:
|
||||
return ["下载安装注册"]
|
||||
return ["下载安装注册(new)"]
|
||||
elif classification.sub_classification == "安装下载类":
|
||||
return []
|
||||
elif classification.sub_classification == "问题排查":
|
||||
@@ -206,5 +165,7 @@ class DifyQueryRetrieval:
|
||||
if __name__ == "__main__":
|
||||
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
||||
# datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3")
|
||||
datasets = dify_query_retrieval.retrieve_api("配网工程计价通D3软件如何新建工程?", ["配网工程计价通D3软件如何新建工程?"], ["流式输出缺失"])
|
||||
print(datasets)
|
||||
datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?",
|
||||
["电力建设计价通软件如何批量修改设备价格?"],
|
||||
["电力建设计价通(2018)软件知识(new)"])
|
||||
print(json.dumps(datasets, ensure_ascii=False, indent=2))
|
||||
Reference in New Issue
Block a user