更新DifyQueryRetrieval类的初始化参数,改为使用环境变量获取API密钥和基础URL;优化意图识别示例中的参数传递;调整问题和回答的格式描述;增加请求超时设置。

This commit is contained in:
2025-07-17 09:05:05 +08:00
parent 8a58fef1a7
commit bd0f86ff61
6 changed files with 33 additions and 25 deletions
+11 -8
View File
@@ -62,8 +62,8 @@ class QueryRewriteProcessor:
api_key: str = None, api_key: str = None,
base_url: str = None, base_url: str = None,
model_name: str = None, model_name: str = None,
dify_api_key: str = "dataset-skLjmPVonjHo119OWNf3kAmY", dify_dataset_key: str = None,
dify_base_url: str = "http://172.20.0.145/v1"): dify_base_url: str = None):
""" """
初始化查询改写处理器 初始化查询改写处理器
@@ -71,13 +71,17 @@ class QueryRewriteProcessor:
api_key: API密钥,默认使用环境变量 api_key: API密钥,默认使用环境变量
base_url: API基础URL,默认使用环境变量 base_url: API基础URL,默认使用环境变量
model_name: 模型名称,默认使用环境变量或默认模型 model_name: 模型名称,默认使用环境变量或默认模型
dify_api_key: Dify API密钥 dify_dataset_key: Dify API密钥
dify_base_url: Dify API基础URL dify_base_url: Dify API基础URL
""" """
# 初始化意图识别器 # 初始化意图识别器
# 使用asyncio.run()运行异步create方法 # 使用asyncio.run()运行异步create方法
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create()) self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create())
self.dify_query_retrieval = DifyQueryRetrieval(api_key=dify_api_key, base_url=dify_base_url) if not dify_dataset_key:
dify_dataset_key = os.getenv("DIFY_DATASET_KEY")
if not dify_base_url:
dify_base_url = os.getenv("DIFY_BSAE_URL")
self.dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=dify_dataset_key, dify_base_url=dify_base_url)
def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]: def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]:
""" """
@@ -205,14 +209,13 @@ class QueryRewriteProcessor:
classification = result["classification"] classification = result["classification"]
original_query = result["rewrite"]["rewrite"] original_query = result["rewrite"]["rewrite"]
query_list = result["query_expand"]["all"] query_list = result["query_expand"]["all"]
soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","")
# 将字典转换为Classification对象 # 将字典转换为Classification对象
classification_obj = Classification(**classification) classification_obj = Classification(**classification)
# 根据enable_retrieval参数决定是否进行文档检索 # 根据enable_retrieval参数决定是否进行文档检索
retrieved_doc = None retrieved_doc = None
if enable_retrieval: if enable_retrieval:
retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name) retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, current_softname)
# 判断检索文档是否相关 # 判断检索文档是否相关
relevance_result = {} relevance_result = {}
@@ -439,9 +442,9 @@ def main():
for idx, query in enumerate(examples): for idx, query in enumerate(examples):
if query.strip() == "": if query.strip() == "":
continue continue
query="怎么把一个批次拆分成多个批次工程" query="怎么调整报表顺序"
conversation_context={ conversation_context={
"current_softname": "配网计价通D3软件" "current_softname": "储能计价通C1软件"
} }
# 在调试模式下使用完整的参数 # 在调试模式下使用完整的参数
print(json.dumps(processor.process_query( print(json.dumps(processor.process_query(
+17 -12
View File
@@ -27,26 +27,29 @@ class DifyQueryRetrieval:
"西藏造价软件知识(new)","新能源造价知识(new)","配网造价知识(new)","技改造价知识(new)", "西藏造价软件知识(new)","新能源造价知识(new)","配网造价知识(new)","技改造价知识(new)",
"配网造价软件知识(new)"]} "配网造价软件知识(new)"]}
def __init__(self, api_key: str, base_url: str): def __init__(self, dify_dataset_key: str, dify_base_url: str):
self._api_key = api_key self._dify_dataset_key = dify_dataset_key
self._base_url = base_url self._dify_base_url = dify_base_url
self._datasets_list = self.get_datasets_list() self._datasets_list = self.get_datasets_list()
def get_datasets_list(self) -> Dict[str, str]: def get_datasets_list(self) -> Dict[str, str]:
client = KnowledgeBaseClient(api_key=self._api_key, base_url=self._base_url) client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url)
datasets = client.list_datasets(page_size=50) datasets = client.list_datasets(page_size=50)
datasets_json = datasets.json() datasets_json = datasets.json()
return {dataset["name"]:dataset["id"] for dataset in datasets_json["data"]} return {dataset["name"]:dataset for dataset in datasets_json["data"]}
def retrieve_by_dataset(self, query: str, dataset_name: str) -> List[Dict[str, Any]]: def retrieve_by_dataset(self, query: str, dataset_name: str) -> List[Dict[str, Any]]:
try: try:
knowledge_base_client = KnowledgeBaseClient(api_key=self._api_key, base_url=self._base_url, dataset_id=self._datasets_list[dataset_name]) dataset_id = self._datasets_list[dataset_name]["id"]
documents = knowledge_base_client.retrieve(query, timeout=300) retrieval_model = self._datasets_list[dataset_name]["retrieval_model_dict"]
knowledge_base_client = KnowledgeBaseClient(api_key=self._dify_dataset_key, base_url=self._dify_base_url, dataset_id=dataset_id)
documents = knowledge_base_client.retrieve(query, retrieval_model=retrieval_model, timeout=300)
retrieved_documents = documents.json().get("records", []) retrieved_documents = documents.json().get("records", [])
# 添加数据集信息 # 添加数据集信息
for retrieved_document in retrieved_documents: for retrieved_document in retrieved_documents:
retrieved_document["dataset_id"] = self._datasets_list[dataset_name] retrieved_document["dataset_id"] = dataset_id
retrieved_document["dataset_name"] = dataset_name retrieved_document["dataset_name"] = dataset_name
return retrieved_documents return retrieved_documents
@@ -78,6 +81,7 @@ class DifyQueryRetrieval:
def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]: def retrieve(self, original_query: str, query_list: List[str], classification: Classification, software_name: str) -> Optional[List[Dict[str, Any]]]:
datasets = self.get_datasets_by_classification(classification, software_name) datasets = self.get_datasets_by_classification(classification, software_name)
datasets=["电力建设计价通(2018)软件知识(new)", "主网造价知识(new)", "下载安装注册(new)"]
if len(datasets) == 0: if len(datasets) == 0:
return None return None
@@ -103,6 +107,7 @@ class DifyQueryRetrieval:
return await self.retrieve_api_async(original_query, query_list, datasets) return await self.retrieve_api_async(original_query, query_list, datasets)
def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]: def retrieve_api(self, original_query: str, query_list: List[str],data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
ssss = self.retrieve_by_dataset("怎么调整报表顺序", "电力建设计价通(2018)软件知识(new)")
all_documents=[] all_documents=[]
# 使用线程池替代无限制创建线程 # 使用线程池替代无限制创建线程
# 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制 # 设置合理的最大线程数,这里使用min(32, len(query_list) * len(datasets))来限制
@@ -112,7 +117,7 @@ class DifyQueryRetrieval:
futures = {} futures = {}
for query in query_list: for query in query_list:
for dataset in data_set_list: for dataset in data_set_list:
if dataset not in self._datasets_list: if dataset not in list(self._datasets_list.keys()):
raise ValueError(f"dataset {dataset} not in datasets_list") raise ValueError(f"dataset {dataset} not in datasets_list")
futures[executor.submit(self.retrieve_by_dataset, query, dataset)] = query futures[executor.submit(self.retrieve_by_dataset, query, dataset)] = query
@@ -140,10 +145,10 @@ class DifyQueryRetrieval:
# 对所有检索出来的文档进行重排序 # 对所有检索出来的文档进行重排序
time_start = time.time() time_start = time.time()
processed_documents = self.data_post_processor(original_query, deduplicated_documents, top_k) processed_documents = self.data_post_processor("怎么调整报表顺序", deduplicated_documents, top_k)
time_end = time.time() time_end = time.time()
logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}") logging.info(f"检索后重排序耗时: {time_end - time_start:.2f}")
return processed_documents return processed_documents
async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]: async def retrieve_api_async(self, original_query: str, query_list: List[str], data_set_list: List[str], top_k: int = 5)->List[Dict[str, Any]]:
@@ -166,7 +171,7 @@ class DifyQueryRetrieval:
tasks = [] tasks = []
for query in query_list: for query in query_list:
for dataset in data_set_list: for dataset in data_set_list:
if dataset not in self._datasets_list: if dataset not in list(self._datasets_list.keys()):
logging.error(f"dataset {dataset} not in datasets_list") logging.error(f"dataset {dataset} not in datasets_list")
continue continue
+1 -1
View File
@@ -64,7 +64,7 @@ dify_query_retrieval = None
async def startup_event(): async def startup_event():
global dify_query_retrieval global dify_query_retrieval
# 初始化DifyQueryRetrieval实例 # 初始化DifyQueryRetrieval实例
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://10.1.16.39/v1") dify_query_retrieval = DifyQueryRetrieval(dify_dataset_key=os.getenv("DIFY_DATASET_KEY"), dify_base_url=os.getenv("DIFY_BSAE_URL"))
logger.info("DifyQueryRetrieval初始化完成") logger.info("DifyQueryRetrieval初始化完成")
# 添加健康检查端点 # 添加健康检查端点
+1 -1
View File
@@ -19,7 +19,7 @@ for index, row in pd_data.iterrows():
if "存在抱怨" in answer: if "存在抱怨" in answer:
answer = answer.split("存在抱怨")[0] answer = answer.split("存在抱怨")[0]
content = f"问题:{query}\n解决方案{answer}" content = f"问题:{query}\n回答{answer}"
segments_list.append({ segments_list.append({
"content": str(content), "content": str(content),
"answer": "", "answer": "",
+2 -2
View File
@@ -8,7 +8,7 @@ class DifyClient:
self.api_key = api_key self.api_key = api_key
self.base_url = base_url self.base_url = base_url
def _send_request(self, method, endpoint, json=None, params=None, stream=False): def _send_request(self, method, endpoint, json=None, params=None, stream=False, timeout=300):
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
@@ -16,7 +16,7 @@ class DifyClient:
url = f"{self.base_url}{endpoint}" url = f"{self.base_url}{endpoint}"
response = requests.request( response = requests.request(
method, url, json=json, params=params, headers=headers, stream=stream, verify=False method, url, json=json, params=params, headers=headers, stream=stream, verify=False, timeout=timeout
) )
return response return response
+1 -1
View File
@@ -309,7 +309,7 @@ class IntentAndSlotResult(BaseModel):
class StepBackPrompt(BaseModel): class StepBackPrompt(BaseModel):
"""后退提示数据模型""" """后退提示数据模型"""
original_query: str = Field(description="原始查询") original_query: str = Field(description="原始查询")
can_use_back_prompt: bool = Field(description="原始查询是否可以进行后退提示(True/False),如果原始查询没有限定词或其他限定词语,则不能进行后退提示") can_use_back_prompt: bool = Field(description="原始查询是否可以进行后退提示(true/false),如果原始查询没有限定词或其他限定词语,则不能进行后退提示")
step_back_query: List[str] = Field(description="后退提示生成的抽象查询(多个)") step_back_query: List[str] = Field(description="后退提示生成的抽象查询(多个)")
class FollowUpQuestions(BaseModel): class FollowUpQuestions(BaseModel):