增加图QA问答

This commit is contained in:
2025-07-14 15:36:15 +08:00
parent 46552a536f
commit 32136be5db
6 changed files with 283 additions and 215 deletions
+4
View File
@@ -21,6 +21,10 @@ class Config:
def openai_coder(self):
return self._config.get("openai_coder", {})
@property
def openai_qa(self):
return self._config.get("openai_qa", {})
@property
def bowei_api_docs_path(self):
return self._config.get("bowei_api_docs_path", "./data/bowei_api_docs.md")
+169
View File
@@ -0,0 +1,169 @@
from langchain_neo4j import GraphCypherQAChain, Neo4jGraph
from langchain.chains import RetrievalQA, LLMChain
from langchain_core.runnables import RunnableSequence
from langchain.agents import initialize_agent, Tool
from langchain.agents import AgentType
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from typing import Any
class EngineeringQAAgent:
def __init__(self, llm_client, embedding_client, graph: Neo4jGraph, business_retriever, callbacks=None):
"""
:param llm_client: 聊天大模型实例,如 ChatOpenAI
:param embedding_client: 嵌入模型实例,如 OpenAIEmbeddings(本类未直接使用,但保留)
:param graph: Neo4jGraph实例,连接工程知识图谱
:param business_retriever: 通用业务知识库的Retriever对象
"""
self.llm = llm_client.llm
self.embedding = embedding_client
self.business_retriever = business_retriever
self.callbacks = callbacks
self.kwargs = {}
if self.callbacks:
self.kwargs = {"callbacks": self.callbacks }
CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.
Instructions:
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.
Schema:
{schema}
Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.
The question is:
{question}"""
CYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
)
# 工程知识图谱查询链
self.engineering_qa = GraphCypherQAChain.from_llm(
llm=self.llm,
graph=graph,
verbose=True,
allow_dangerous_requests=True,
cypher_prompt=CYPHER_GENERATION_PROMPT,
**self.kwargs
)
# Chain of Thought提示模板和链
cot_prompt = PromptTemplate(
input_variables=["question"],
template=(
"你是一个专业的电力造价工程助理。请先将用户的问题拆分成多个子问题,"
"分别检索相关的通用业务知识和工程数据知识库,最后综合回答用户。\n\n"
"用户问题:{question}\n\n"
"请给出拆分的子问题列表和每个子问题的检索计划,"
"然后给出最终综合回答。\n\n"
"思考过程:"
)
)
self.cot_chain = RunnableSequence(
cot_prompt | self.llm
).with_config(
verbose=True,
**self.kwargs
)
# 定义Agent工具
self.tools = [
Tool(
name="EngineeringData",
func=self.engineering_qa.invoke,
description="用于查询具体工程数据相关问题"
),
Tool(
name="BusinessKnowledge",
func=self._business_qa,
description="用于查询电力造价行业通用业务知识"
),
]
# 多轮对话记忆
message_history = ChatMessageHistory()
self.memory = ConversationBufferMemory(
chat_memory=message_history, memory_key="chat_history", return_messages=True)
# 初始化Agent,支持多轮对话和工具调用
self.agent = initialize_agent(
self.tools,
self.llm,
agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
memory=self.memory,
verbose=True,
**self.kwargs
)
def _business_qa(self, query: str) -> str:
if not self.business_retriever:
return "通用业务知识库未配置。"
business_qa = RetrievalQA(
llm=self.llm,
retriever=self.business_retriever,
return_source_documents=False,
**self.kwargs
)
return business_qa.invoke(query)
def ask(self, question: str) -> str:
# 1. 使用CoT链拆分问题和规划检索
cot_output = self.cot_chain.invoke(question)
# 2. 简单示例:先调用业务知识库,再调用工程数据知识库
business_answer = self._business_qa(question)
engineering_answer = self.engineering_qa.run(question)
# 3. 综合回答
final_answer = (
"根据通用业务知识库,得到的信息是:\n"
f"{business_answer}\n\n"
"根据工程数据知识库,得到的信息是:\n"
f"{engineering_answer}\n\n"
"综合以上信息,回答用户问题如下:\n"
f"{cot_output}"
)
return final_answer
# ------------------ 使用示例 ------------------
if __name__ == "__main__":
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.config import Config
from src.multi_llm_client import MultiAPIKeyChatOpenAI
from src.embedding_client import EmbeddingClient
config = Config()
# 初始化图数据库连接
graph = Neo4jGraph(
url=config.neo4j_conf.get("uri"),
username=config.neo4j_conf.get("username"),
password=config.neo4j_conf.get("password")
)
# 初始化LLM和Embedding
llm = MultiAPIKeyChatOpenAI(config.openai_qa)
embedding = EmbeddingClient(config.embedding)
# 初始化通用业务知识向量库Retriever(示例中未配置)
business_vectorstore = None # 例如 Chroma(collection_name="business_knowledge")
business_retriever = None # business_vectorstore.as_retriever() if business_vectorstore else None
# 创建Agent实例
agent = EngineeringQAAgent(llm, embedding, graph, business_retriever)
# 交互示例
question = "工程里有几个项目划分,每个项目划分的名字分别是什么"
answer = agent.ask(question)
print("回答:", answer)
+3 -178
View File
@@ -71,23 +71,13 @@ def project_get_calculate_function():
quantity = result_dict['数量']
if status:
return {
"code": 200,
"message": 'ok',
"status": True,
"data": quantity
}
return result_dict
else:
return {
"code": 500,
"message": message,
"status": False,
"data": quantity
}
return result_dict
# 执行规则
- 参数必须从用户问题或上下文信息中提取。
- 在代码函数内部生成功能说明,在函数外禁止生成任何注释或解释或非代码内容。
- 函数内部代码生成流程注释,并使用logger进行日志输出,在函数外禁止生成任何注释或解释或非代码内容。
- 输出代码中必须以def project_get_calculate_function() -> dict函数作为入口函数,该函数返回字典包含:'code''status''message''data'四个字段。
- 必须确保生成的代码可以直接执行,代码要注意进行各类容错检查。
- 'data'字段通常要求是浮点或整型值,除非用户要求返回其他类型,同时函数执行过程中发生错误,'data'字段也必须为0,并在message说明错误原因。
@@ -184,168 +174,3 @@ Cypher查询语句:MATCH (item:ProjectDivisionItem)\nWHERE item.name CONTAINS
rewrite_prompt_template=rewrite_prompt_template,
cypher_conversion_prompt=cypher_conversion_prompt,
)
@dataclass
class CodeExecutorPrompts:
understand_prompt: ChatPromptTemplate
code_gen_prompt: ChatPromptTemplate
code_fix_prompt: ChatPromptTemplate
rewrite_prompt_template: ChatPromptTemplate
cypher_conversion_prompt: ChatPromptTemplate # 新增Cypher转换提示模板
class PromptManager:
def __init__(self):
self.prompts = self._init_prompts()
def _init_prompts(self) -> CodeExecutorPrompts:
understand_prompt = ChatPromptTemplate.from_template(
"""
你是一名电力造价业务专家,请基于以下示意工程文件业务结构,将用户自然语言问题改写成专业查询语句:
**示意工程文件业务结构**
{business_structure}
**改写规则**
1. **定位目标对象**:仅从示意工程文件业务结构中识别核心对象(如 `ProjectDivisionTree`→项目划分树、`FeeScheduleItem`→费用表)。
2. **提取条件**:从用户输入中解析关键条件(如名称、量、类型),用【】标注变量。
3. **构建专业语句**:格式为:`在[目标对象]中查找【条件】的项。`
- 使用业务术语(如“项目划分项”而非“项目”)。
- 条件需明确属性(如【名称】、【量】、【类型】)。
4. **精确映射结构**:若用户查询层级(如“叶节点”),需在条件中体现。
**用户输入**{user_input}
**改写输出**:(仅输出改写后的语句)
"""
)
code_gen_prompt = ChatPromptTemplate.from_template(
"""
你是一个专业的Python工程师。我会给你一个用户问题,你需要将其转换为对应的Python代码
用户问题:
{user_request}
上下文信息:
{context}
工程数据访问库:
{bowei_api_docs}
# 工作流程
1. 从用户问题中提取关键信息(节点路径、节点类型、节点名称等)
2. 根据"用户问题""上下文信息"选择最匹配的"工程数据访问库"中的函数和对象属性
3. 生成可直接执行的完全满足用户输入问题要求功能效果的Python函数代码
# 输出格式(必须严格遵循)
def project_get_calculate_function():
project = ProjectBuilder.build()
result_dict = project.[SELECTED_METHOD]([PARAMETERS])
status = result_dict.get('status', False)
message = result_dict.get('message', '')
code = result_dict.get('data', '')
data = result_dict.get('data', [])
logger.info(f"status {{status}} message: {{message}}")
if status:
return result_dict
else:
return result_dict
# 执行规则
- 参数必须从用户问题或上下文信息中提取
- 输出代码中必须以def project_get_calculate_function() -> dict函数作为入口函数
- 必须确保生成的代码可以直接执行,代码要注意进行各类错误检查,出错采用抛出异常方式,说明详细信息
- 为函数内部代码生成流程注释,并使用logger进行日志输出
- ProjectToolkit 类中涉及项目划分的函数已考虑在其及其子孙项目划分下查找,所以无需生成递归子项目划分的代码
- 如果文本中包含范围编码格式则需要进行编码展开,如'YX2-1~7'展开为‘YX2-1/YX2-2/YX2-3/YX2-4/YX2-5/YX2-6/YX2-7
"""
)
code_fix_prompt = ChatPromptTemplate.from_template(
"""
你是一个专业的Python工程师。我会给你一段错误python代码和错误信息,你需要帮我修复这段出错的代码
已执行代码:
{code}
代码执行报错信息:
{error}
你的任务是:
1. 根据"已执行代码""代码执行报错信息"来对“已执行代码”和函数调用参数进行修改,修复执行错误
2. 如果错误信息中是代码的逻辑出现错误,那么就需要对代码本身整体结构进行修改
3. 如果是代码中参数出现问题了,那么就需要结合错误信息中的帮助信息(helper_info)来对代码总的参数进行修改
4. 修复后的代码应该完整,可以直接执行,并且能够返回查询结果
注意:
- 如果文本中包含范围编码格式则需要进行编码展开,如'YX2-1~7'展开为‘YX2-1/YX2-2/YX2-3/YX2-4/YX2-5/YX2-6/YX2-7
- 必须只输出最终的Python代码,不要添加任何解释、注释、推理过程或自然语言描述。
- 不要以“以下是修正后的代码”、“修改如下”等语句开头。
- 不要输出任何其他无关的内容。
- 输出格式必须完全符合指定的函数模板。
- 如果无法根据已有信息进行修改,请原样返回原始代码。
- 禁止在代码前加上```python字样
- 禁止在代码后加上```字样
请输出你修补后的代码:
""")
rewrite_prompt_template = ChatPromptTemplate.from_template(
"""
您是一个AI查询改写助手。基于给定的原始查询和上下文知识,生成一个精确的改写查询。步骤:
1. 从上下文知识的`labels`提取对象类型,翻译为中文。
2. 从`properties`选择对象标识:优先用`path`值,若无则用`name`值。
3. 智能映射原始查询的属性名称:
- 如果属性名称是上下文属性的缩写、省略或同义词,映射到实际属性名称(如“人工费”可能映射到“费率”或“合价含税”)。
- 如果无法映射,保留原始名称。
4. 保留原始查询的额外操作(如计算指令)。
5. 输出格式:“获取[对象标识][对象类型]的[属性]属性,[额外操作]”。
示例参考:
- 输入:原始问题="查找名称中包含“工程”的项目划分项,并返回其人工费乘以1000的值。", 上下文知识=...
- 输出="获取[安装/架空输电线路本体工程/基础工程/基础工程材料工地运输]项目划分项的人工费,并乘以1000的值"
现在,处理以下输入:
- 原始问题:{user_input}
- 上下文知识:{context}
""")
cypher_conversion_prompt = ChatPromptTemplate.from_template(
"""
你是一名电力造价业务专家,负责将用户自然语言问题中需要访问的对象识别出来,并生成针对该对象的NEO4J知识图谱的Cypher查询语句,获取该对象的全部信息。知识图谱基于从文件“获取[安装/架空输电线路本体工程]项目划分项的单位.”中读取的层级关联结构构建。该文件包含用反斜杠分割的多个字符串(如“安装/架空输电线路本体工程”),每个字符串表示一个完整的层级路径,路径部分用斜杠分隔,对应于知识图谱中ProjectDivisionItem节点的层级关系。路径映射规则:每个路径部分(如“安装”)是一个ProjectDivisionItem节点,父子关系通过关系类型`:CHILD_OF`连接,形成从根节点到叶节点的层级结构。
示例业务结构:
{business_structure}
用户问题:
{user_input}
改写规则:
识别目标层级: 从用户输入中解析层级路径,例如字符串“安装/架空输电线路本体工程”映射为:根节点(name: '安装') -[:CHILD_OF]-> 子节点(name: '架空输电线路本体工程')。
识别目标对象:仅从业务结构中识别如下类型:ProjectDivisionItem、ProjectAttributeSet、FeeSchedule、FeeItem的核心对象。对象类型必须精确映射到结构中的节点标签(例如,使用 ProjectDivisionItem 而非“项目划分项”)。
提取查询条件:从用户输入中解析关键条件(如名称、量、类型、值、层级等),条件应基于对象属性(如 name、quantity、type),注意区分查找子串和相等的区别。如果条件涉及层级路径(如路径中包含特定部分),需提取路径部分作为条件(例如,用户提到“架空输电线路”时,解析为路径匹配)。
如果用户指定层级(如“叶节点”),需在条件中体现(例如,添加 WHERE item.isLeaf = true)。
忽略任何计算、转换或后处理要求(如“乘以1000”),只关注获取原始数据对象或属性。
构建Cypher查询:生成一个Cypher查询语句,格式为:
MATCH 子句:匹配目标对象节点,并应用条件。如果涉及层级路径,使用路径匹配(如 MATCH path = (item1:ProjectDivisionItem)-[:CHILD_OF*]->(item2:ProjectDivisionItem) WHERE item1.name = '安装' AND item2.name = '架空输电线路本体工程')。变量仅用于目标节点和必要路径节点。
WHERE 子句:包含提取的条件(使用变量或具体值)。
RETURN 子句:必须返回对象(如 RETURN item),不能包含对象属性、函数(如乘法、SUM)。
LIMIT 子句: 最多返回5条。
使用业务术语在节点标签和属性中(例如,ProjectDivisionItem 而不是“项目”)。
查询应简洁,只获取数据,不执行计算。
输出格式:直接输出最终Cypher查询语句,不添加解释或额外文本。
**示例**:
用户问题:查找一下名称中包含工程的项目划分
Cypher查询语句:MATCH (item:ProjectDivisionItem)\nWHERE item.name CONTAINS '工程'\nRETURN item\nLIMIT 5
""")
return CodeExecutorPrompts(
understand_prompt=understand_prompt,
code_gen_prompt=code_gen_prompt,
code_fix_prompt=code_fix_prompt,
rewrite_prompt_template=rewrite_prompt_template,
cypher_conversion_prompt=cypher_conversion_prompt,
)