Files
GraphRAG/chains_lab/unified_chain.py
T
2025-03-31 17:28:23 +08:00

227 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from chains_lab import qwen_llm
from chains_lab import StrOutputParser
from chains_lab import ChatPromptTemplate
from chains_lab import PromptTemplate
from chains_lab import JsonOutputParser
from langchain.chains.base import Chain
from typing import Dict, Any, List, Optional
from manager.manager_intention import ParserIntentJson
from utils.utils import define_suffix, extract_values, check_and_return, extract_names_from_json, output_suffix
from utils.utils import str_to_pydantic, parse_pydantic_fields
from pydantic import BaseModel, Field
class UnifiedNLUChain(Chain):
"""
统一的NLU处理链,整合意图识别、槽位提取和结果改写功能
"""
def __init__(self, suffix_fields: List[str], parser: ParserIntentJson):
super().__init__()
self._suffix_fields = suffix_fields
self._parser = parser
self._lv1_intent = parser.lv1_intent_names_str
self._lv2_intent = parser.intent_lv2_str
@property
def input_keys(self) -> List[str]:
return ["query"]
@property
def output_keys(self) -> List[str]:
return ["intention_result", "slot_result", "rewrite_result", "full_response"]
def _intention_chain(self):
"""实现意图识别链"""
PromptTemplate1 = """
你是博微公司的电力造价员专家,需要将后续用户输入的对博微公司多款软件产品使用和业务方面的咨询问题进行意图分类。
不要假设上下文,更不要尝试回答问题。
# 用户输入
{query}
# 一级意图
{lv1_intent}
# 二级意图
{lv2_intent}
# 注意:
1. 请按json格式返回
2. 对于二级意图,如果一级意图是'操作指南',则二级意图必选['下载安装注册', '软件使用操作', '数据管理']之一;如果是一级意图是'规范解读',则二级意图必选['国家规范', '行业标准', '地方政策', '行业知识查询']之一, 以此类推
3. json的keys,一定含有'一级意图'、'二级意图',且无论用户输入上下文多少,输出json只有一个
4. 如果你认为用户输入内容不属于一级意图里面任意一个,则输出的json里面内容为:"一级意图""其他""二级意图""未知"
5. 当用户输入内容含有费用等咨询的信息,则输出的json里面内容为:"一级意图""其他""二级意图""未知"
"""
preset_variables = {"lv1_intent": self._lv1_intent, "lv2_intent": self._lv2_intent}
Prompt = ChatPromptTemplate.from_template(PromptTemplate1).partial(**preset_variables)
Chain = Prompt | qwen_llm | StrOutputParser()
return Chain
def _suffix_result(self, input_str):
"""实现后缀处理链"""
kesword_suffix = output_suffix(input_str, self._suffix_fields)
return f"[{kesword_suffix}]需要用什么软件?"
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
query = inputs["query"]
# 初始化返回结果
output = {"intention_result": None, "slot_result": None, "rewrite_result": None, "full_response": None}
# 处理后缀名问题
if define_suffix(query, self._suffix_fields) == ["后缀名问题"]:
intention_result = ["操作指南", "下载安装注册"]
rewrite_result = self._suffix_result(query)
def _result_chain(self):
"""实现普通结果改写链"""
PromptTemplate1 = """
输入:
"{finall}"
要求:
根据输入信息重新改写问题。
注意事项:
不假设上下文信息,仅根据提供的内容进行调整。
对于未知的自动忽略。
对于礼貌用词、语气助词等无任何意义词汇自动忽略
不要写思考过程,不要对词汇进行解释,不要进行追问,不要回答问题,仅进行问题改写。
"""
Prompt = ChatPromptTemplate.from_template(PromptTemplate1)
Chain = Prompt | qwen_llm | StrOutputParser()
return Chain
def _result_softwared_chain(self):
"""实现软件相关结果改写链"""
PromptTemplate1 = """
# 输入:
"{finall}"
# 要求:
根据输入信息重新改写问题。
# 输出格式要求:
"[software_name][interface(非必需)][How]functionality"
输出格式里的'interface'的含义是相关操作界面,如果没有则自动忽略
输出格式里的'How'的含义有很多,如怎么设置、怎么调整、怎么选择、等等,需要你自己根据整体语义进行替换
# 注意事项:
不假设上下文信息,仅根据提供的内容进行调整。
对于未知的自动忽略。
不要写思考过程,不要对词汇进行解释,不要进行追问,不要回答问题,仅进行问题改写。
"""
Prompt = ChatPromptTemplate.from_template(PromptTemplate1)
Chain = Prompt | qwen_llm | StrOutputParser()
return Chain
def _slot_chain(self, slot_structs):
"""实现槽位提取链"""
if slot_structs is not None:
if slot_structs[1] is None:
return lambda query: {}
pydantic_lv2 = str_to_pydantic(slot_structs[1])
pydantic_lv2_str = parse_pydantic_fields(slot_structs[1])
software_extraction_prompt = """
输出应格式化为符合以下 JSON 结构的 JSON 实例。
# 用户输入
{query}
# 注意
1.你要在电力造价的领域,结合日常咨询知识来入将用户的问题理解意图后采用以下指定槽位结构填充
2.如果问题中没有给出对应槽位的值则为未知
3.对于'software_name'的值,如果有相关软件的说明,请补充软件名字的全称,全称如下:
['西藏造价软件Z1', '技改检修计价通T1', '储能电站建设计价通C1', '配网计价通D3', '主网电力建设计价通']
4.如果输出结构里面没有keys为'software_name',则忽略第3条
5.不要回答问题,仅生成 JSON 实例
# 输出
```
{parse_pydantic}
```
"""
preset_variables = {"parse_pydantic": pydantic_lv2_str}
Prompt = ChatPromptTemplate.from_template(software_extraction_prompt).partial(**preset_variables)
Chain = Prompt | qwen_llm | JsonOutputParser(pydantic_object=pydantic_lv2)
return Chain
else:
return lambda query: {}
def _suffix_result(self, input_str):
"""实现后缀处理链"""
kesword_suffix = output_suffix(input_str, self._suffix_fields)
return f"[{kesword_suffix}]需要用什么软件?"
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
query = inputs["query"]
# 初始化返回结果
output = {"intention_result": None, "slot_result": None, "rewrite_result": None, "full_response": None}
# 处理后缀名问题
if define_suffix(query, self._suffix_fields) == ["后缀名问题"]:
intention_result = ["操作指南", "下载安装注册"]
rewrite_result = self._suffix_result(query)
output["intention_result"] = intention_result
output["slot_result"] = {}
output["rewrite_result"] = rewrite_result
full_response = {
"用户输入": query,
"意图": {
"一级意图": {"name": intention_result[0], "slot_lv1": {}},
"二级意图": {"name": intention_result[1], "slot_lv2": {}},
},
"改写结果": rewrite_result,
}
else:
# 1. 意图识别
intention_chain = self._intention_chain()
intention_result_temp = intention_chain.invoke(query)
intention_result = extract_values(intention_result_temp)
output["intention_result"] = intention_result
# 2. 槽位提取
slot_structs = check_and_return(intention_result, self._parser)
slot_chain = self._slot_chain(slot_structs)
# 检查slot_chain是否为函数(处理二级意图为None的情况)
if callable(slot_chain) and not isinstance(slot_chain, Chain):
slot_result = slot_chain(query)
else:
slot_result = slot_chain.invoke(query)
output["slot_result"] = slot_result
# 3. 结果改写
if "software_name" in slot_result and slot_result["software_name"]:
rewrite_chain = self._result_softwared_chain()
rewrite_result = rewrite_chain.invoke({"finall": slot_result})
else:
rewrite_chain = self._result_chain()
rewrite_result = rewrite_chain.invoke({"finall": slot_result})
output["rewrite_result"] = rewrite_result
# 构建完整响应
full_response = {
"用户输入": query,
"意图": {
"一级意图": {"name": intention_result[0] if intention_result else None, "slot_lv1": {}},
"二级意图": {
"name": intention_result[1] if len(intention_result) > 1 else None,
"slot_lv2": slot_result,
},
},
"改写结果": rewrite_result,
}
output["full_response"] = full_response
return output