feat: 添加清单定额查询API并优化意图识别模块

新增清单定额查询API服务,支持通过名称和编码查询定额及清单信息
在意图识别模块中添加定额清单信息提取功能,并记录各步骤耗时
将SiliconFlowEmbeddings替换为XinferenceEmbeddings并添加sqlite-vss依赖
优化shell脚本的screen会话检测逻辑
This commit is contained in:
2025-08-20 19:08:29 +08:00
parent db84105abf
commit 1a3fa44522
8 changed files with 1244 additions and 53 deletions
+67 -46
View File
@@ -188,6 +188,8 @@ class AsyncIntentRecognizer:
Returns:
分类结果
"""
start_time = time.time() # 记录开始时间
classification_parser = PydanticOutputParser(pydantic_object=Classification)
formatted_prompt = classification_prompt.format(user_input=query,
classification_info=classification_info,
@@ -203,6 +205,11 @@ class AsyncIntentRecognizer:
response.content = response.content.strip()
clean_output = re.sub(r'<think>.*?</think>', '', response.content, flags=re.DOTALL)
parsed_output = classification_parser.parse(clean_output)
# 计算并打印耗时
end_time = time.time()
logging.info(f"意图分类耗时: {end_time - start_time:.2f}")
return parsed_output
except Exception as e:
raise RuntimeError(f"解析分类结果时出错: {e}") from e
@@ -268,43 +275,6 @@ class AsyncIntentRecognizer:
parsed_output = terms_list_parser.parse(clean_output)
return parsed_output.terms
async def _rerank_matched_terms_async(self, query_key: str, matched_terms: set, top_k: int = 2, rerank_score:float = 0.6) -> List[Term]:
"""
异步对召回的专业术语进行重排序,按与用户查询的相关性排序
Args:
query: 用户查询
matched_terms: 匹配到的专业术语集合
query_keys: 用户查询中提取的关键词列表
Returns:
重排序后的专业术语列表
"""
if not matched_terms:
return []
if len(matched_terms) <= top_k:
return list(matched_terms)
try:
# 将每个术语转换为可用于重排序的文本表示
term_texts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) for term in matched_terms]
# 使用异步重排序模型
rerank_results = await XinferenceReRankerModel.rerank_async(query_key, term_texts, top_k=top_k)
# 将matched_terms转换为列表以便按索引访问
matched_terms_list = list(matched_terms)
# 根据重排序结果获取排序后的术语列表
reranked_terms = [matched_terms_list[result["index"]] for result in rerank_results if result["score"] >= rerank_score]
return reranked_terms
except Exception as e:
raise RuntimeError(f"异步_rerank_matched_terms重排失败:{e}") from e
async def _match_keywords_async(self, query: str, use_jieba: bool = False) -> Tuple[TermList, List[str]]:
"""
异步从用户问题中匹配关键词,结合LLM提取和向量检索
@@ -345,10 +315,56 @@ class AsyncIntentRecognizer:
total_time = end_time - start_time
# 输出整合的时间日志
logging.info(f"异步关键词匹配耗时统计 - 总耗时: {total_time:.2f}")
# logging.info(f"异步关键词匹配耗时统计 - 总耗时: {total_time:.2f}秒")
return term_list, query_keys
async def _get_dinge_qingdan_info(self, query: str, chat_history: List[Dict[str, str]] = None) -> dict:
"""
获取问题中定额、清单相关信息
Args:
query: 用户查询
Returns:
指令详情字典,包含定额、清单相关信息
"""
start_time = time.time() # 记录开始时间
prompt=f"""
当前提问内容:
<query>{query}</query>
对话上下文:
<chat_history>
{json.dumps(chat_history, ensure_ascii=False)}
</chat_history>
1、请从当前提问内容中提取电力造价行中定额编码、定额名称、清单编码、清单名称
2、请勿随机编造,如果没有提取到,返回空内容
3、返回结果为json格式
{{
"dinge_info_list":{{"dinge_code_list":["xxxx","xxxx"], "dinge_name_list":["xxxx","xxxx"]}},
"qingdan_info":{{"qingdan_code_list":["xxxx","xxxx"], "qingdan_name_list":["xxxx","xxxx"]}}
}}
"""
try:
response = await self._llm.invoke_async(prompt, False, response_format={"type": "json_object"})
response.content = response.content.strip()
clean_output = re.sub(r'<think>.*?</think>', '', response.content, flags=re.DOTALL)
parsed_output = JsonOutputParser().parse(clean_output)
# 计算并打印耗时
end_time = time.time()
logging.info(f"获取定额清单信息耗时: {end_time - start_time:.2f}")
return parsed_output
except Exception as e:
# 发生异常时也记录耗时
logging.error(f"获取问题定额清单详情失败: {e}", exc_info=True)
parsed_output = {"dinge_info_list": [], "qingdan_info": []}
return parsed_output
async def _rewrite_query_async(self, query: str, keywords: TermList, query_keys:List[str], chat_history: List[Dict[str, str]] = None, context: str = "") -> QueryRewrite:
"""
异步对用户问题进行改写
@@ -361,7 +377,7 @@ class AsyncIntentRecognizer:
Returns:
改写结果
"""
start_time = time.time()
# 准备问题改写提示
terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
@@ -378,6 +394,9 @@ class AsyncIntentRecognizer:
response.content = response.content.strip()
clean_output = re.sub(r'<think>.*?</think>', '', response.content, flags=re.DOTALL)
parsed_output = query_rewrite_parser.parse(clean_output)
end_time = time.time()
process_time=end_time-start_time
logging.info(f"异步问题改写耗时 - 耗时: {process_time:.2f}")
return parsed_output
except Exception as e:
raise RuntimeError(f"解析问题改写结果时出错: {e}") from e
@@ -447,12 +466,12 @@ class AsyncIntentRecognizer:
context=conversation_context
)
classification_task = self._classify_intent_async(query, conversation_context, chat_history, previous_slots)
# 定额清单信息
dinge_qingdan_info_task = self._get_dinge_qingdan_info(query, chat_history)
# 并行等待问题改写和意图分类完成
start_time = time.time()
rewrite, classification = await asyncio.gather(rewrite_task, classification_task)
end_time = time.time()
logging.info(f"意图分类耗时统计 - 总耗时: {end_time - start_time:.2f}")
rewrite, classification, dinge_qingdan_info = await asyncio.gather(rewrite_task, classification_task, dinge_qingdan_info_task)
# 特殊处理 锁相关咨询
if classification.vertical_classification == "安装下载注册" and classification.sub_classification == "软件锁类":
@@ -470,7 +489,8 @@ class AsyncIntentRecognizer:
"keywords": keywords_terms.model_dump(),
"rewrite": rewrite.model_dump(),
"query_keys": query_keys,
"slot_filling": slot_filling_result
"slot_filling": slot_filling_result,
"dinge_qingdan_info": dinge_qingdan_info
}
# 等待所有query_expand_tasks完成
@@ -505,7 +525,8 @@ class AsyncIntentRecognizer:
"rewrite": rewrite.model_dump(),
"query_keys": query_keys,
"slot_filling": slot_filling_result,
"query_expand": query_expand
"query_expand": query_expand,
"dinge_qingdan_info": dinge_qingdan_info
}
async def _fill_slots_async(self, query: str, classification: Classification, conversation_context: str = "",
@@ -14,7 +14,7 @@ import asyncio
from typing import List, Dict, Any, Tuple, Optional
from langchain.embeddings.base import Embeddings
from langchain_community.vectorstores import FAISS
from rag2_0.tool.ModelTool import SiliconFlowEmbeddings
from rag2_0.tool.ModelTool import XinferenceEmbeddings
import logging
import httpx
@@ -28,7 +28,7 @@ def get_embedding_model(api_key: str = None) -> Embeddings:
Returns:
嵌入模型实例
"""
return SiliconFlowEmbeddings(api_key=api_key)
return XinferenceEmbeddings(api_key=api_key)
class ProfessionalNounVectorizer: