更新.gitignore以忽略临时文件,修改api_key文件,重构合并名词的逻辑,删除不再使用的脚本,优化对话到工单的处理流程,添加会话结果保存为JSON的功能,调整API调用参数,修复部分代码中的错误。
This commit is contained in:
@@ -1,32 +1,17 @@
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from threading import Thread
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Any
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
import httpx
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from rag2_0.intent_recognition.DataModels import Classification
|
||||
from rag2_0.dify.dify_client.client import DifyClient, KnowledgeBaseClient
|
||||
from rag2_0.tool.ModelTool import XinferenceReRankerModel
|
||||
class DifyQueryRetrieval:
|
||||
|
||||
software_to_dataset_map = {"配网工程计价通D3":["下载安装注册(new)","配网造价知识(new)","配网造价软件知识(new)"],
|
||||
"新型储能电站建设计价通C1":["下载安装注册(new)","储能C1计价通软件知识(new)","新能源造价知识(new)"],
|
||||
"西藏电力工程计价通Z1":["下载安装注册(new)","西藏造价知识(new)","西藏造价软件知识(new)"],
|
||||
"技改检修工程计价通T1":["下载安装注册(new)","技改检修工程计价通T1软件知识(new)","技改造价知识(new)"],
|
||||
"技改检修清单计价通T1":["下载安装注册(new)","技改检修清单计价通T1软件知识(new)","技改造价知识(new)"],
|
||||
"电力建设计价通":["下载安装注册(new)","主网造价知识(new)","电力建设计价通(2018)软件知识(new)"],
|
||||
"其他":["下载安装注册(new)","技改检修清单计价通T1软件知识(new)",
|
||||
"主网造价知识(new)","西藏造价知识(new)","技改检修工程计价通T1软件知识(new)",
|
||||
"电力建设计价通(2018)软件知识(new)","储能C1计价通软件知识(new)",
|
||||
"西藏造价软件知识(new)","新能源造价知识(new)","配网造价知识(new)","技改造价知识(new)",
|
||||
"配网造价软件知识(new)"]}
|
||||
|
||||
def __init__(self, dify_dataset_key: str, dify_base_url: str):
|
||||
self._dify_dataset_key = dify_dataset_key
|
||||
self._dify_base_url = dify_base_url
|
||||
@@ -38,7 +23,7 @@ class DifyQueryRetrieval:
|
||||
datasets_json = datasets.json()
|
||||
return {dataset["name"]:dataset for dataset in datasets_json["data"]}
|
||||
|
||||
def retrieve_by_dataset(self, query: str, dataset_name: str) -> List[Dict[str, Any]]:
|
||||
def retrieve_by_dataset(self, query: str, dataset_name: str) -> Dict[str, Any]:
|
||||
try:
|
||||
dataset_id = self._datasets_list[dataset_name]["id"]
|
||||
retrieval_model = self._datasets_list[dataset_name]["retrieval_model_dict"]
|
||||
@@ -52,12 +37,21 @@ class DifyQueryRetrieval:
|
||||
retrieved_document["dataset_id"] = dataset_id
|
||||
retrieved_document["dataset_name"] = dataset_name
|
||||
|
||||
return retrieved_documents
|
||||
# 返回包含查询和文档的字典
|
||||
return {
|
||||
"query": query,
|
||||
"dataset_name": dataset_name,
|
||||
"documents": retrieved_documents
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
|
||||
return []
|
||||
return {
|
||||
"query": query,
|
||||
"dataset_name": dataset_name,
|
||||
"documents": []
|
||||
}
|
||||
|
||||
async def retrieve_by_dataset_async(self, query: str, dataset_name: str) -> List[Dict[str, Any]]:
|
||||
async def retrieve_by_dataset_async(self, query: str, dataset_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
异步版本的retrieve_by_dataset方法
|
||||
|
||||
@@ -66,7 +60,7 @@ class DifyQueryRetrieval:
|
||||
dataset_name: 数据集名称
|
||||
|
||||
Returns:
|
||||
检索到的文档列表
|
||||
包含查询、数据集名称和检索到的文档的字典
|
||||
"""
|
||||
try:
|
||||
# 使用asyncio.to_thread包装同步方法
|
||||
@@ -77,81 +71,13 @@ class DifyQueryRetrieval:
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"异步检索数据集 {dataset_name} 时出错: {str(e)}", exc_info=True)
|
||||
return []
|
||||
return {
|
||||
"query": query,
|
||||
"dataset_name": dataset_name,
|
||||
"documents": []
|
||||
}
|
||||
|
||||
def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
|
||||
datasets = self.get_datasets_by_classification(classification, software_name)
|
||||
datasets=["电力建设计价通(2018)软件知识(new)", "主网造价知识(new)", "下载安装注册(new)"]
|
||||
if len(datasets) == 0:
|
||||
return None
|
||||
|
||||
return self.retrieve_api(original_query, query_list, datasets)
|
||||
|
||||
async def retrieve_async(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
异步版本的retrieve方法
|
||||
|
||||
Args:
|
||||
original_query: 原始查询
|
||||
query_list: 查询列表
|
||||
classification: 分类信息
|
||||
software_name: 软件名称
|
||||
|
||||
Returns:
|
||||
检索到的文档列表
|
||||
"""
|
||||
datasets = self.get_datasets_by_classification(classification, software_name)
|
||||
if len(datasets) == 0:
|
||||
return None
|
||||
|
||||
return await self.retrieve_api_async(original_query, query_list, datasets)
|
||||
|
||||
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
|
||||
ssss = self.retrieve_by_dataset("怎么调整报表顺序", "电力建设计价通(2018)软件知识(new)")
|
||||
all_documents=[]
|
||||
# 使用线程池替代无限制创建线程
|
||||
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制
|
||||
time_start = time.time()
|
||||
max_workers = min(os.cpu_count() * 2, len(query_list) * len(data_set_list))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {}
|
||||
for query in query_list:
|
||||
for dataset in data_set_list:
|
||||
if dataset not in list(self._datasets_list.keys()):
|
||||
raise ValueError(f"dataset {dataset} not in datasets_list")
|
||||
|
||||
futures[executor.submit(self.retrieve_by_dataset, query, dataset)] = query
|
||||
|
||||
# 等待所有任务完成
|
||||
for future in as_completed(futures.keys()):
|
||||
# 处理可能的异常
|
||||
try:
|
||||
retrieved_documents = future.result()
|
||||
all_documents.extend(retrieved_documents)
|
||||
except Exception as e:
|
||||
logging.error(f"检索过程中发生错误: {str(e)}", exc_info=True)
|
||||
time_end = time.time()
|
||||
|
||||
logging.info(f"检索耗时: {time_end - time_start:.2f}秒")
|
||||
# 根据segment_id对文档进行去重
|
||||
unique_documents = {}
|
||||
for document in all_documents:
|
||||
segment_id = document['segment']['id']
|
||||
if segment_id not in unique_documents:
|
||||
unique_documents[segment_id] = document
|
||||
|
||||
# 将去重后的文档转换为列表
|
||||
deduplicated_documents = list(unique_documents.values())
|
||||
|
||||
# 对所有检索出来的文档进行重排序
|
||||
time_start = time.time()
|
||||
processed_documents = self.data_post_processor("怎么调整报表顺序", deduplicated_documents, top_k)
|
||||
time_end = time.time()
|
||||
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}秒")
|
||||
|
||||
return processed_documents
|
||||
|
||||
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
|
||||
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], query_expand_dict: dict, top_k: int = 5)->Dict[str, Any]:
|
||||
"""
|
||||
异步版本的retrieve_api方法,使用asyncio代替线程池
|
||||
|
||||
@@ -159,11 +85,14 @@ class DifyQueryRetrieval:
|
||||
original_query: 原始查询
|
||||
query_list: 查询列表
|
||||
data_set_list: 数据集列表
|
||||
query_expand_dict: 查询扩展字典,包含不同类型的查询
|
||||
top_k: 返回的文档数量
|
||||
|
||||
Returns:
|
||||
检索并重排序后的文档列表
|
||||
包含检索结果和查询命中统计的字典
|
||||
"""
|
||||
all_documents = []
|
||||
query_document_mapping = {} # 用于存储查询和文档的映射关系,键为查询,值为文档列表
|
||||
# 记录开始时间
|
||||
time_start = time.time()
|
||||
|
||||
@@ -187,7 +116,16 @@ class DifyQueryRetrieval:
|
||||
if isinstance(result, Exception):
|
||||
logging.error(f"异步检索过程中发生错误: {str(result)}", exc_info=True)
|
||||
else:
|
||||
all_documents.extend(result)
|
||||
# 添加查询和文档的映射关系
|
||||
query = result["query"]
|
||||
if query not in query_document_mapping:
|
||||
query_document_mapping[query] = []
|
||||
|
||||
# 将文档添加到对应查询的文档列表中
|
||||
query_document_mapping[query].extend([item['segment']['id'] for item in result["documents"]])
|
||||
|
||||
# 将文档添加到总文档列表
|
||||
all_documents.extend(result["documents"])
|
||||
|
||||
time_end = time.time()
|
||||
logging.info(f"异步检索耗时: {time_end - time_start:.2f}秒")
|
||||
@@ -203,49 +141,61 @@ class DifyQueryRetrieval:
|
||||
deduplicated_documents = list(unique_documents.values())
|
||||
|
||||
if len(deduplicated_documents) == 0:
|
||||
return []
|
||||
return {
|
||||
"documents": [],
|
||||
"query_hit_stats": {}
|
||||
}
|
||||
|
||||
# 对所有检索出来的文档进行重排序
|
||||
time_start = time.time()
|
||||
processed_documents = await self.data_post_processor_async(original_query, deduplicated_documents, top_k)
|
||||
time_end = time.time()
|
||||
logging.info(f"异步检索后重排序耗时: {time_end - time_start:.2f}秒")
|
||||
|
||||
return processed_documents
|
||||
|
||||
def data_post_processor(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
reranker_model = XinferenceReRankerModel()
|
||||
documents = [document['segment']['content'] for document in all_documents]
|
||||
reranked_documents = reranker_model.rerank(query, documents, top_k=top_k)
|
||||
new_all_documents = []
|
||||
|
||||
def to_dify_document_format(document: dict)->dict:
|
||||
return {
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_id": document["dataset_id"],
|
||||
"dataset_name": document["dataset_name"],
|
||||
"document_id": document['segment']['document_id'],
|
||||
"document_name": document["segment"]["document"]["name"],
|
||||
"data_source_type": document["segment"]["document"]["data_source_type"],
|
||||
"segment_id": document["segment"]["id"],
|
||||
"retriever_from": "api",
|
||||
"score": document.get("score", 0),
|
||||
"segment_hit_count": document.get("segment", {}).get("hit_count", 0),
|
||||
"segment_word_count": document.get("segment", {}).get("word_count", 0),
|
||||
"segment_position": document.get("segment", {}).get("position", 0),
|
||||
"segment_index_node_hash": document.get("segment", {}).get("index_node_hash", ""),
|
||||
"doc_metadata": document.get("segment", {}).get("document", {}).get("doc_metadata", None),
|
||||
"position": document["segment"].get("position", 0)
|
||||
},
|
||||
"title": document["segment"]["document"]["name"],
|
||||
"content": document["segment"]["content"]
|
||||
}
|
||||
|
||||
for reranked_document in reranked_documents:
|
||||
cur_doc_info = all_documents[reranked_document["index"]]
|
||||
cur_doc_info["score"] = reranked_document["score"]
|
||||
new_all_documents.append(to_dify_document_format(cur_doc_info))
|
||||
return new_all_documents
|
||||
|
||||
# 统计不同类型查询命中的文档
|
||||
query_hit_stats = {}
|
||||
|
||||
# 获取重排序后的文档ID列表
|
||||
reranked_doc_ids = [doc["metadata"]["segment_id"] for doc in processed_documents]
|
||||
reranked_doc_titles = [doc["title"].split("/")[-1] for doc in processed_documents]
|
||||
|
||||
# 解析query_expand_dict(如果是字符串)
|
||||
if isinstance(query_expand_dict, str):
|
||||
try:
|
||||
query_expand_dict = json.loads(query_expand_dict)
|
||||
except Exception as e:
|
||||
logging.error(f"解析query_expand_dict失败: {str(e)}")
|
||||
query_expand_dict = {}
|
||||
|
||||
# 统计各类型查询命中的文档
|
||||
for query_type, queries in query_expand_dict.items():
|
||||
if not isinstance(queries, list):
|
||||
queries = [queries]
|
||||
|
||||
# 初始化该查询类型命中的文档列表
|
||||
query_hit_stats[query_type] = []
|
||||
|
||||
# 合并所有该类型查询命中的文档
|
||||
hit_doc_ids_set = set()
|
||||
for query in queries:
|
||||
if query in query_document_mapping:
|
||||
hit_doc_ids_set.update(set(query_document_mapping[query]))
|
||||
|
||||
# 找出在重排序结果中的文档
|
||||
hit_doc_titles_set = set() # 用于去重
|
||||
for i, doc_id in enumerate(reranked_doc_ids):
|
||||
if doc_id in hit_doc_ids_set:
|
||||
doc_title = reranked_doc_titles[i]
|
||||
if doc_title not in hit_doc_titles_set: # 确保不添加重复的标题
|
||||
hit_doc_titles_set.add(doc_title)
|
||||
query_hit_stats[query_type].append(doc_title)
|
||||
|
||||
logging.info(f"查询命中统计: {json.dumps(query_hit_stats, ensure_ascii=False)}")
|
||||
|
||||
return {
|
||||
"documents": processed_documents,
|
||||
"query_hit_stats": query_hit_stats
|
||||
}
|
||||
|
||||
async def data_post_processor_async(self, query: str, all_documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -293,37 +243,9 @@ class DifyQueryRetrieval:
|
||||
new_all_documents.append(to_dify_document_format(cur_doc_info))
|
||||
return new_all_documents
|
||||
|
||||
def get_datasets_by_classification(self, classification: Classification, software_name: str) -> List[str]:
|
||||
if classification.vertical_classification=="软件问题" or classification.vertical_classification=="业务问题":
|
||||
software_name_list = self.software_to_dataset_map.keys()
|
||||
cur_software_name = ""
|
||||
for software_name_info in software_name_list:
|
||||
if software_name_info in software_name:
|
||||
cur_software_name = software_name_info
|
||||
break
|
||||
if cur_software_name == "":
|
||||
return self.software_to_dataset_map["其他"]
|
||||
else:
|
||||
return self.software_to_dataset_map[cur_software_name]
|
||||
|
||||
if classification.vertical_classification == "安装下载注册":
|
||||
if classification.sub_classification in ["后缀名咨询", "软件锁类"]:
|
||||
return ["下载安装注册(new)"]
|
||||
elif classification.sub_classification == "安装下载类":
|
||||
return []
|
||||
elif classification.sub_classification == "问题排查":
|
||||
return self.software_to_dataset_map["其他"]
|
||||
|
||||
return self.software_to_dataset_map["其他"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1")
|
||||
# datasets = dify_query_retrieval.retrieve("配网工程计价通D3软件如何新建工程?", Classification(vertical_classification="软件问题", sub_classification="软件功能"), "配网工程计价通D3")
|
||||
# datasets = dify_query_retrieval.retrieve_api("电力建设计价通软件如何批量修改设备价格?",
|
||||
# ["电力建设计价通软件如何批量修改设备价格?"],
|
||||
# ["电力建设计价通(2018)软件知识(new)"])
|
||||
# print(json.dumps(datasets, ensure_ascii=False, indent=2))
|
||||
|
||||
# 测试异步API
|
||||
async def test_async_api():
|
||||
|
||||
Reference in New Issue
Block a user