210 lines
11 KiB
Python
210 lines
11 KiB
Python
import sys
|
|
import os
|
|
import json
|
|
from threading import Thread
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import List, Dict, Any, Optional
|
|
import logging
|
|
import time
|
|
sys.path.append(os.getcwd())
|
|
|
|
from rag2_0.intent_recognition.DataModels import Classification
|
|
from rag2_0.dify.dify_client.client import DifyClient, KnowledgeBaseClient
|
|
from rag2_0.tool.ModelTool import XinferenceReRankerModel
|
|
class DifyQueryRetrieval:
|
|
|
|
software_to_dataset_map = {"配网工程计价通D3":["下载安装注册","配网造价知识","配网造价软件知识"],
|
|
"新型储能电站建设计价通C1":["下载安装注册","储能C1计价通软件知识","新能源造价知识"],
|
|
"西藏电力工程计价通Z1":["下载安装注册","西藏造价知识","西藏造价软件知识"],
|
|
"技改检修工程计价通T1":["下载安装注册","技改检修工程计价通T1软件知识","技改造价知识"],
|
|
"技改检修清单计价通T1":["下载安装注册","技改检修清单计价通T1软件知识","技改造价知识"],
|
|
"电力建设计价通":["下载安装注册","主网造价知识","电力建设计价通(2018)软件知识"],
|
|
"其他":["下载安装注册","技改检修清单计价通T1软件知识",
|
|
"主网造价知识","西藏造价知识","技改检修工程计价通T1软件知识",
|
|
"电力建设计价通(2018)软件知识","储能C1计价通软件知识",
|
|
"西藏造价软件知识","新能源造价知识","配网造价知识","技改造价知识",
|
|
"配网造价软件知识"]}
|
|
|
|
def __init__(self, api_key: str, base_url: str):
|
|
self._api_key = api_key
|
|
self._base_url = base_url
|
|
self._datasets_list = self.get_datasets_list()
|
|
|
|
def get_datasets_list(self) -> Dict[str, str]:
|
|
client = KnowledgeBaseClient(api_key=self._api_key, base_url=self._base_url)
|
|
datasets = client.list_datasets(page_size=50)
|
|
datasets_json = datasets.json()
|
|
return {dataset["name"]:dataset["id"] for dataset in datasets_json["data"]}
|
|
|
|
def retrieve_by_dataset(self, query: str, dataset_name: str) -> List[Dict[str, Any]]:
|
|
try:
|
|
knowledge_base_client = KnowledgeBaseClient(api_key=self._api_key, base_url=self._base_url, dataset_id=self._datasets_list[dataset_name])
|
|
documents = knowledge_base_client.retrieve(query)
|
|
retrieved_documents = documents.json().get("records", [])
|
|
|
|
# 添加数据集信息
|
|
for retrieved_document in retrieved_documents:
|
|
retrieved_document["dataset_id"] = self._datasets_list[dataset_name]
|
|
retrieved_document["dataset_name"] = dataset_name
|
|
|
|
return retrieved_documents
|
|
except Exception as e:
|
|
logging.error(f"检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
|
|
return []
|
|
|
|
def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
|
|
datasets = self.get_datasets_by_classification(classification, software_name)
|
|
if len(datasets) == 0:
|
|
return None
|
|
|
|
all_documents=[]
|
|
# 使用线程池替代无限制创建线程
|
|
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制
|
|
time_start = time.time()
|
|
max_workers = min(os.cpu_count() * 2, len(query_list) * len(datasets))
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
futures = []
|
|
for query in query_list:
|
|
for dataset in datasets:
|
|
if dataset not in self._datasets_list:
|
|
raise ValueError(f"dataset {dataset} not in datasets_list")
|
|
|
|
futures.append(executor.submit(self.retrieve_by_dataset, query, dataset))
|
|
|
|
# 等待所有任务完成
|
|
for future in as_completed(futures):
|
|
# 处理可能的异常
|
|
try:
|
|
retrieved_documents = future.result()
|
|
all_documents.extend(retrieved_documents)
|
|
except Exception as e:
|
|
logging.error(f"检索过程中发生错误: {str(e)}", exc_info=True)
|
|
time_end = time.time()
|
|
|
|
logging.info(f"检索耗时: {time_end - time_start:.2f}秒")
|
|
# 根据segment_id对文档进行去重
|
|
unique_documents = {}
|
|
for document in all_documents:
|
|
segment_id = document['segment']['id']
|
|
if segment_id not in unique_documents:
|
|
unique_documents[segment_id] = document
|
|
|
|
# 将去重后的文档转换为列表
|
|
deduplicated_documents = list(unique_documents.values())
|
|
|
|
# 对所有检索出来的文档进行重排序
|
|
time_start = time.time()
|
|
processed_documents = self.data_post_processor(original_query, deduplicated_documents)
|
|
time_end = time.time()
|
|
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒")
|
|
|
|
return processed_documents
|
|
|
|
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str])->List[Dict[str, Any]]:
|
|
all_documents=[]
|
|
# 使用线程池替代无限制创建线程
|
|
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(data_set_list))来限制
|
|
time_start = time.time()
|
|
max_workers = min(os.cpu_count() * 2, len(query_list) * len(data_set_list))
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
futures = []
|
|
for query in query_list:
|
|
for dataset in data_set_list:
|
|
if dataset not in self._datasets_list:
|
|
raise ValueError(f"dataset {dataset} not in datasets_list")
|
|
|
|
futures.append(executor.submit(self.retrieve_by_dataset, query, dataset))
|
|
|
|
# 等待所有任务完成
|
|
for future in as_completed(futures):
|
|
# 处理可能的异常
|
|
try:
|
|
retrieved_documents = future.result()
|
|
all_documents.extend(retrieved_documents)
|
|
except Exception as e:
|
|
logging.error(f"检索过程中发生错误: {str(e)}", exc_info=True)
|
|
time_end = time.time()
|
|
|
|
logging.info(f"检索耗时: {time_end - time_start:.2f}秒")
|
|
# 根据segment_id对文档进行去重
|
|
unique_documents = {}
|
|
for document in all_documents:
|
|
segment_id = document['segment']['id']
|
|
if segment_id not in unique_documents:
|
|
unique_documents[segment_id] = document
|
|
|
|
# 将去重后的文档转换为列表
|
|
deduplicated_documents = list(unique_documents.values())
|
|
|
|
# 对所有检索出来的文档进行重排序
|
|
time_start = time.time()
|
|
processed_documents = self.data_post_processor(original_query, deduplicated_documents)
|
|
time_end = time.time()
|
|
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒")
|
|
|
|
return processed_documents
|
|
|
|
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
reranker_model = XinferenceReRankerModel()
|
|
documents = [document['segment']['content'] for document in all_documents]
|
|
reranked_documents = reranker_model.rerank(query, documents, top_k=5)
|
|
new_all_documents = []
|
|
|
|
def to_dify_document_format(document: dict)->dict:
|
|
return {
|
|
"metadata": {
|
|
"_source": "knowledge",
|
|
"dataset_id": document["dataset_id"],
|
|
"dataset_name": document["dataset_name"],
|
|
"document_id": document['segment']['document_id'],
|
|
"document_name": document["segment"]["document"]["name"],
|
|
"data_source_type": document["segment"]["document"]["data_source_type"],
|
|
"segment_id": document["segment"]["id"],
|
|
"retriever_from": "api",
|
|
"score": document.get("score", 0),
|
|
"segment_hit_count": document.get("segment", {}).get("hit_count", 0),
|
|
"segment_word_count": document.get("segment", {}).get("word_count", 0),
|
|
"segment_position": document.get("segment", {}).get("position", 0),
|
|
"segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""),
|
|
"doc_metadata": document.get("segment", {}).get("document", {}).get("doc_metadata", None),
|
|
"position": document["segment"].get("position", 0)
|
|
},
|
|
"title": document["segment"]["document"]["name"],
|
|
"content": document["segment"]["content"]
|
|
}
|
|
|
|
for reranked_document in reranked_documents:
|
|
cur_doc_info = all_documents[reranked_document["index"]]
|
|
cur_doc_info["score"] = reranked_document["score"]
|
|
new_all_documents.append(to_dify_document_format(cur_doc_info))
|
|
return new_all_documents
|
|
|
|
def get_datasets_by_classification(self, classification: Classification, software_name: str) -> List[str]:
|
|
if classification.vertical_classification=="软件问题" or classification.vertical_classification=="业务问题":
|
|
software_name_list = self.software_to_dataset_map.keys()
|
|
cur_software_name = ""
|
|
for software_name_info in software_name_list:
|
|
if software_name_info in software_name:
|
|
cur_software_name = software_name_info
|
|
break
|
|
if cur_software_name == "":
|
|
return self.software_to_dataset_map["其他"]
|
|
else:
|
|
return self.software_to_dataset_map[cur_software_name]
|
|
|
|
if classification.vertical_classification == "安装下载注册":
|
|
if classification.sub_classification in ["后缀名咨询", "软件锁类"]:
|
|
return ["下载安装注册"]
|
|
elif classification.sub_classification == "安装下载类":
|
|
return []
|
|
elif classification.sub_classification == "问题排查":
|
|
return self.software_to_dataset_map["其他"]
|
|
|
|
return self.software_to_dataset_map["其他"]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
|
# datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3")
|
|
datasets = dify_query_retrieval.retrieve_api("配网工程计价通D3软件如何新建工程?", ["配网工程计价通D3软件如何新建工程?"], ["流式输出缺失"])
|
|
print(datasets) |