提交最新代码。 代码动态执行存在一些问题,无法执行。
This commit is contained in:
@@ -2,3 +2,6 @@ BoweiAgent.log
|
||||
__pycache__/*
|
||||
.vscode/launch.json
|
||||
src/__pycache__/*
|
||||
.cursor/rules/use.mdc
|
||||
test_runcode.log
|
||||
BoweiAgent1.log
|
||||
|
||||
+2
-1
@@ -8,7 +8,8 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 连接到Neo4j数据库
|
||||
uri = "bolt://172.20.0.145:7687"
|
||||
uri = "bolt://localhost:7487"
|
||||
#uri = "bolt://172.20.0.145:7687"
|
||||
user = "neo4j"
|
||||
password = "password"
|
||||
|
||||
|
||||
+16
-3
@@ -10,21 +10,34 @@ business_object_structure_path: ./data/business_object_structure.md
|
||||
bowei_api_docs_path: ./data/bowei_api_docs.md
|
||||
|
||||
openai:
|
||||
api_key: sk-bbeamiumkouptsrueilgufqqyuumelcsivxwjbdugqwsqhwj
|
||||
api_key: sk-xlrnesfcuwrpevdwbuhthivpygwyzwbxxsyvhzzwrkpzjduk
|
||||
api_base: https://api.siliconflow.cn/v1
|
||||
#api_version: "" # 可选,某些API版本需要指定
|
||||
#organization: your_organization_id # 可选
|
||||
api_type: openai # 可选,默认为 openai;如果用 Azure 则为 azure
|
||||
model_name: Qwen/Qwen2.5-72B-Instruct
|
||||
#model_name: Qwen/Qwen3-8B
|
||||
#model_name: deepseek-ai/DeepSeek-V3
|
||||
|
||||
openai_coder:
|
||||
api_key: sk-xlrnesfcuwrpevdwbuhthivpygwyzwbxxsyvhzzwrkpzjduk
|
||||
api_base: https://api.siliconflow.cn/v1
|
||||
#api_version: "" # 可选,某些API版本需要指定
|
||||
#organization: your_organization_id # 可选
|
||||
api_type: openai # 可选,默认为 openai;如果用 Azure 则为 azure
|
||||
#model_name: Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
#model_name: Qwen/Qwen3-8B
|
||||
model_name: deepseek-ai/DeepSeek-V3
|
||||
|
||||
embedding:
|
||||
api_base: https://api.siliconflow.cn/v1
|
||||
model_name: BAAI/bge-m3
|
||||
api_key: sk-bbeamiumkouptsrueilgufqqyuumelcsivxwjbdugqwsqhwj
|
||||
api_key: sk-xlrnesfcuwrpevdwbuhthivpygwyzwbxxsyvhzzwrkpzjduk
|
||||
|
||||
log_level: INFO
|
||||
|
||||
langsmith:
|
||||
api_key: your_langsmith_api_key
|
||||
api_key: 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
|
||||
api_url: https://api.smith.langchain.com/api/v1
|
||||
project: your_project_name
|
||||
tracing_enabled: true
|
||||
|
||||
+37
-1
@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||
import json
|
||||
|
||||
|
||||
class ProjectTookiIt(ABC):
|
||||
class ProjectToolkit(ABC):
|
||||
"""
|
||||
项目类(抽象基类)
|
||||
描述: 代表整个项目结构的顶层容器
|
||||
@@ -577,3 +577,39 @@ class Fee:
|
||||
self.施工费 = None # xsd:string (可选)
|
||||
self.单位投资 = None # xsd:string (可选)
|
||||
|
||||
class ProjectBuilder:
|
||||
# 存储注册的工具类
|
||||
_registry = None
|
||||
_config = {}
|
||||
|
||||
"""项目工具工厂类"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def register(toolkit_class: type, config: dict):
|
||||
"""
|
||||
注册工具类到工厂
|
||||
|
||||
参数:
|
||||
toolkit_class: 继承自ProjectToolkit的具体工具类
|
||||
"""
|
||||
if not issubclass(toolkit_class, ProjectToolkit):
|
||||
raise TypeError(f"{toolkit_class.__name__} 必须继承自 ProjectToolkit")
|
||||
|
||||
_config = config
|
||||
_registry = toolkit_class
|
||||
|
||||
def build(self) -> ProjectToolkit:
|
||||
"""
|
||||
创建工具实例
|
||||
|
||||
参数:
|
||||
|
||||
返回:
|
||||
实例化的工具对象
|
||||
"""
|
||||
if _registry is None:
|
||||
raise KeyError(f"未注册的类,请先注册类")
|
||||
|
||||
return _registry(_config)
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, # 生产环境可改为 INFO 或 WARNING
|
||||
@@ -17,19 +18,30 @@ httpx_logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
|
||||
# 可选:禁用传播(防止被根logger处理)
|
||||
httpx_logger.propagate = False
|
||||
|
||||
|
||||
# 获取logger并设置级别
|
||||
openai_logger = logging.getLogger("openai")
|
||||
openai_logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
|
||||
# 可选:禁用传播(防止被根logger处理)
|
||||
openai_logger.propagate = False
|
||||
|
||||
# 获取logger并设置级别
|
||||
langsmith_logger = logging.getLogger("langsmith.client")
|
||||
langsmith_logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
|
||||
# 可选:禁用传播(防止被根logger处理)
|
||||
langsmith_logger.propagate = False
|
||||
|
||||
# 获取logger并设置级别
|
||||
neo4j_logger = logging.getLogger("neo4j")
|
||||
neo4j_logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
|
||||
# 可选:禁用传播(防止被根logger处理)
|
||||
neo4j_logger.propagate = False
|
||||
|
||||
# 获取logger并设置级别
|
||||
urllib3_logger = logging.getLogger("urllib3")
|
||||
urllib3_logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
|
||||
# 可选:禁用传播(防止被根logger处理)
|
||||
urllib3_logger.propagate = False
|
||||
|
||||
# 获取logger并设置级别
|
||||
httpcore_logger = logging.getLogger("httpcore")
|
||||
httpcore_logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
|
||||
@@ -45,17 +57,26 @@ from src.dialog_manager import DialogManager
|
||||
from src.neo4j_raw_retriever import Neo4jRawRetriever
|
||||
from src.embedding_client import EmbeddingClient
|
||||
|
||||
from project_implementation import ProjectBuilder
|
||||
from project import ProjectBuilder
|
||||
from project_implementation import ProjectToolkitNeo4j
|
||||
|
||||
|
||||
def main():
|
||||
config = Config()
|
||||
|
||||
# 根据配置设置环境变量
|
||||
tracing_enabled = config.langsmith.get("tracing_enabled", False)
|
||||
os.environ["LANGSMITH_PROJECT"] = config.langsmith.get("project")
|
||||
os.environ["LANGSMITH_TRACING"] = "true" if tracing_enabled else "false"
|
||||
os.environ["LANGSMITH_API_KEY"] = config.langsmith.get("api_key")
|
||||
#os.environ["LANGSMITH_API_KEY"] = config.langsmith.get("api_url")
|
||||
|
||||
business_structure = load_file(config.business_object_structure_path)
|
||||
bowei_api_docs = load_file(config.bowei_api_docs_path)
|
||||
|
||||
llm_client = LLMClient(config.openai)
|
||||
|
||||
llm_client_coder = LLMClient(config.openai_coder)
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
neo4j_conf = config.neo4j_conf
|
||||
@@ -66,9 +87,9 @@ def main():
|
||||
# 创建Neo4j检索器
|
||||
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
|
||||
|
||||
ProjectBuilder.init_driver(neo4j_conf)
|
||||
ProjectBuilder.register(ProjectToolkitNeo4j, neo4j_conf)
|
||||
|
||||
code_executor = CodeExecutor(prompt_manager.prompts, llm_client)
|
||||
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder)
|
||||
|
||||
dialog_manager = DialogManager(
|
||||
llm_client,
|
||||
@@ -80,6 +101,7 @@ def main():
|
||||
)
|
||||
|
||||
pre_input_question = "查找名称中包含“工程”的项目划分项,并返回其人工费乘以1000的值。"
|
||||
pre_input_question = "查找名称中包含“工程”的项目划分项,并返回单位。"
|
||||
|
||||
asyncio.run(dialog_manager.run_async(pre_input=pre_input_question))
|
||||
|
||||
|
||||
+38
-1
@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||
import json
|
||||
|
||||
|
||||
class ProjectTookiIt(ABC):
|
||||
class ProjectToolkit(ABC):
|
||||
"""
|
||||
项目类(抽象基类)
|
||||
描述: 代表整个项目结构的顶层容器
|
||||
@@ -576,3 +576,40 @@ class Fee:
|
||||
self.其他费 = None # xsd:string (可选)
|
||||
self.施工费 = None # xsd:string (可选)
|
||||
self.单位投资 = None # xsd:string (可选)
|
||||
|
||||
class ProjectBuilder:
|
||||
# 存储注册的工具类
|
||||
_registry = None
|
||||
_config = {}
|
||||
|
||||
"""项目工具工厂类"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def register(toolkit_class: type, config: dict):
|
||||
"""
|
||||
注册工具类到工厂
|
||||
|
||||
参数:
|
||||
toolkit_class: 继承自ProjectToolkit的具体工具类
|
||||
"""
|
||||
if not issubclass(toolkit_class, ProjectToolkit):
|
||||
raise TypeError(f"{toolkit_class.__name__} 必须继承自 ProjectToolkit")
|
||||
|
||||
_config = config
|
||||
_registry = toolkit_class
|
||||
|
||||
def build(self) -> ProjectToolkit:
|
||||
"""
|
||||
创建工具实例
|
||||
|
||||
参数:
|
||||
|
||||
返回:
|
||||
实例化的工具对象
|
||||
"""
|
||||
if _registry is None:
|
||||
raise KeyError(f"未注册的类,请先注册类")
|
||||
|
||||
return _registry(_config)
|
||||
|
||||
@@ -9,7 +9,7 @@ logger = logging.getLogger("project_implementation")
|
||||
config = Config()
|
||||
|
||||
|
||||
class ProjectTookiItNeo4j(ProjectTookiIt):
|
||||
class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
"""
|
||||
基于Neo4j数据库的项目类实现
|
||||
"""
|
||||
@@ -2215,70 +2215,3 @@ class ProjectTookiItNeo4j(ProjectTookiIt):
|
||||
error = f"查询失败: {str(e)}"
|
||||
return status, data, error, helper_info
|
||||
|
||||
|
||||
class ProjectBuilder:
|
||||
"""
|
||||
项目构建器
|
||||
描述: 用于构建项目对象的构建器
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_driver = None # 存储Neo4j驱动实例
|
||||
|
||||
@staticmethod
|
||||
def init_driver(neo4j_conf):
|
||||
"""
|
||||
初始化Neo4j驱动
|
||||
|
||||
Args:
|
||||
neo4j_conf (dict): Neo4j配置信息,包含uri、username和password
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if ProjectBuilder._driver is not None:
|
||||
# 如果已经有驱动实例,先关闭它
|
||||
ProjectBuilder._driver.close()
|
||||
|
||||
# 创建新的驱动实例
|
||||
uri = neo4j_conf.get("uri")
|
||||
username = neo4j_conf.get("username")
|
||||
password = neo4j_conf.get("password")
|
||||
ProjectBuilder._driver = GraphDatabase.driver(uri, auth=(username, password))
|
||||
|
||||
@staticmethod
|
||||
def build():
|
||||
"""
|
||||
构建并返回项目实例,使用已初始化的驱动
|
||||
|
||||
Returns:
|
||||
ProjectTookiItNeo4j: 创建的项目实例
|
||||
"""
|
||||
# 如果已经有实例,直接返回
|
||||
if ProjectBuilder._instance is not None:
|
||||
return ProjectBuilder._instance
|
||||
|
||||
# 检查驱动是否已初始化
|
||||
if ProjectBuilder._driver is None:
|
||||
raise ValueError("必须先调用ProjectBuilder.init_driver初始化Neo4j驱动")
|
||||
|
||||
# 创建新实例,使用已初始化的驱动
|
||||
ProjectBuilder._instance = ProjectTookiItNeo4j(ProjectBuilder._driver)
|
||||
return ProjectBuilder._instance
|
||||
|
||||
@staticmethod
|
||||
def close():
|
||||
"""
|
||||
关闭当前项目实例和驱动的连接
|
||||
"""
|
||||
if ProjectBuilder._instance is not None:
|
||||
ProjectBuilder._instance.close()
|
||||
ProjectBuilder._instance = None
|
||||
|
||||
if ProjectBuilder._driver is not None:
|
||||
ProjectBuilder._driver.close()
|
||||
ProjectBuilder._driver = None
|
||||
|
||||
|
||||
# 注册退出处理函数,确保程序退出时自动关闭连接
|
||||
atexit.register(ProjectBuilder.close)
|
||||
|
||||
@@ -3,7 +3,7 @@ from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_experimental.utilities import PythonREPL
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_experimental.tools import PythonREPLTool
|
||||
from project_implementation import ProjectBuilder
|
||||
from project import ProjectBuilder
|
||||
import sys
|
||||
import io
|
||||
import traceback
|
||||
@@ -42,8 +42,6 @@ class CodeExecutor:
|
||||
logger.debug(f"开始执行代码: {code_str}")
|
||||
try:
|
||||
namespace = {
|
||||
"ProjectBuilder": ProjectBuilder,
|
||||
"project_implementation": __import__("project_implementation"),
|
||||
"project": __import__("project"),
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,10 @@ class Config:
|
||||
def openai(self):
|
||||
return self._config.get("openai", {})
|
||||
|
||||
@property
|
||||
def openai_coder(self):
|
||||
return self._config.get("openai_coder", {})
|
||||
|
||||
@property
|
||||
def bowei_api_docs_path(self):
|
||||
return self._config.get("bowei_api_docs_path", "./data/bowei_api_docs.md")
|
||||
@@ -27,3 +31,7 @@ class Config:
|
||||
@property
|
||||
def embedding(self):
|
||||
return self._config.get("embedding", {})
|
||||
|
||||
@property
|
||||
def langsmith(self):
|
||||
return self._config.get("langsmith", {})
|
||||
|
||||
+105
-75
@@ -6,6 +6,68 @@ import asyncio
|
||||
|
||||
logger = logging.getLogger("BoweiAgent.DialogManager")
|
||||
|
||||
class QuestionProcessor:
|
||||
def __init__(self, llm_client, business_structure, prompts):
|
||||
self.llm_client = llm_client
|
||||
self.business_structure = business_structure
|
||||
self.prompts = prompts
|
||||
|
||||
async def convert_question_to_cypher(self, user_input: str) -> str:
|
||||
prompt = self.prompts.cypher_conversion_prompt.format_prompt(
|
||||
business_structure=self.business_structure,
|
||||
user_input=user_input
|
||||
)
|
||||
messages = prompt.to_messages()
|
||||
cypher_query = ""
|
||||
async for chunk in self.llm_client.stream(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
cypher_query += chunk.content
|
||||
print()
|
||||
logger.debug(f"生成的Cypher查询语句:{cypher_query.strip()}")
|
||||
return cypher_query.strip()
|
||||
|
||||
async def rewrite_question_with_query_result(self, user_input: str, context: str) -> str:
|
||||
prompt = self.prompts.rewrite_prompt_template.format_prompt(
|
||||
user_input=user_input,
|
||||
context=context
|
||||
)
|
||||
messages = prompt.to_messages()
|
||||
messages.append(HumanMessage(content=user_input))
|
||||
|
||||
logger.debug(f"重写提示词:{messages}")
|
||||
|
||||
result = ""
|
||||
async for chunk in self.llm_client.stream(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
result += chunk.content
|
||||
print()
|
||||
logger.debug(f"重写后的用户问题:{result.strip()}")
|
||||
return result.strip()
|
||||
|
||||
class InteractionHandler:
|
||||
@staticmethod
|
||||
def display_rewritten_requests(rewritten_results):
|
||||
print("\n系统为您理解并改写了以下访问请求,请选择:")
|
||||
for idx, (rewritten, knowledge) in enumerate(rewritten_results, start=1):
|
||||
print(f"{idx}: {rewritten}")
|
||||
print("0: 重新输入问题")
|
||||
|
||||
@staticmethod
|
||||
def get_user_choice(rewritten_results):
|
||||
while True:
|
||||
choice = input("请输入编号选择(0重新输入,默认1):").strip()
|
||||
if choice == "":
|
||||
choice = "1"
|
||||
if choice == "0":
|
||||
logger.info("用户选择重新输入问题")
|
||||
print("请重新输入您的问题。\n" + "-"*50)
|
||||
return None
|
||||
if choice.isdigit() and 1 <= int(choice) <= len(rewritten_results):
|
||||
return int(choice) - 1
|
||||
else:
|
||||
logger.warning(f"用户输入无效选择:{choice}")
|
||||
print("输入无效,请输入有效编号。")
|
||||
|
||||
class DialogManager:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -21,7 +83,7 @@ class DialogManager:
|
||||
self.bowei_api_docs = bowei_api_docs
|
||||
self.code_executor = code_executor
|
||||
self.knowledge_retriever = knowledge_retriever
|
||||
self.prompts = prompt_manager.prompts
|
||||
self.question_processor = QuestionProcessor(llm_client, business_structure, prompt_manager.prompts)
|
||||
|
||||
def retrieve_relevant_docs(self, user_input: str):
|
||||
logger.debug(f"开始检索知识库,用户输入:{user_input}")
|
||||
@@ -29,72 +91,33 @@ class DialogManager:
|
||||
logger.debug(f"检索到 {len(docs)} 条相关文档")
|
||||
return [doc.page_content for doc in docs]
|
||||
|
||||
async def convert_question_to_cypher(self, user_input: str) -> str:
|
||||
prompt = self.prompts.cypher_conversion_prompt.format_prompt(
|
||||
business_structure=self.business_structure,
|
||||
user_input=user_input
|
||||
)
|
||||
messages = prompt.to_messages()
|
||||
cypher_query = ""
|
||||
async for chunk in self.llm_client.stream(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
cypher_query += chunk.content
|
||||
print()
|
||||
logger.debug(f"生成的Cypher查询语句:{cypher_query.strip()}")
|
||||
return cypher_query.strip()
|
||||
|
||||
async def rewrite_question_with_query_result(self, user_input: str, query_result: str) -> str:
|
||||
prompt = self.prompts.rewrite_prompt_template.format_prompt(
|
||||
business_structure=self.business_structure,
|
||||
context=query_result
|
||||
)
|
||||
messages = prompt.to_messages()
|
||||
messages.append(HumanMessage(content=user_input))
|
||||
|
||||
logger.debug(f"重写提示词:{messages}")
|
||||
|
||||
result = ""
|
||||
async for chunk in self.llm_client.stream(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
result += chunk.content
|
||||
print()
|
||||
logger.debug(f"重写后的用户问题:{result.strip()}")
|
||||
return result.strip()
|
||||
|
||||
async def understand_user_question_stream(self, user_input: str):
|
||||
logger.info(f"理解用户问题(流式):{user_input}")
|
||||
|
||||
# 1. 转换用户问题为Cypher查询
|
||||
cypher_query = await self.convert_question_to_cypher(user_input)
|
||||
cypher_query = await self.question_processor.convert_question_to_cypher(user_input)
|
||||
|
||||
# 2. 执行Cypher查询,最多返回5条
|
||||
docs = self.knowledge_retriever.get_relevant_documents(cypher_query)
|
||||
docs = docs[:5] # 限制最多5条
|
||||
|
||||
if not docs:
|
||||
logger.info("查询无结果,直接用业务结构改写")
|
||||
prompt = self.prompts.understand_prompt.format_prompt(
|
||||
business_structure=self.business_structure,
|
||||
user_input=user_input
|
||||
)
|
||||
messages = prompt.to_messages()
|
||||
result = ""
|
||||
async for chunk in self.llm_client.stream(messages):
|
||||
#print(chunk.content, end="", flush=True)
|
||||
result += chunk.content
|
||||
#print()
|
||||
return [(result.strip(), "")]
|
||||
logger.debug("查询无结果")
|
||||
return []
|
||||
|
||||
rewritten_list = []
|
||||
for idx, doc in enumerate(docs, start=1):
|
||||
logger.debug(f"\n第{idx}条相关文档改写结果(流式):")
|
||||
rewritten = await self.rewrite_question_with_query_result(user_input, doc.page_content)
|
||||
rewritten = await self.question_processor.rewrite_question_with_query_result(user_input, doc.page_content)
|
||||
rewritten_list.append((rewritten, doc.page_content))
|
||||
return rewritten_list
|
||||
|
||||
async def run_async(self, pre_input: str = None):
|
||||
async def run_async(self, pre_input: str = None, automated: bool = False):
|
||||
logger.info("启动对话管理器,等待用户输入")
|
||||
print("欢迎使用博微造价工程数据访问系统,输入 exit 退出。")
|
||||
if automated:
|
||||
print("自动化模式已启动。")
|
||||
else:
|
||||
print("欢迎使用博微造价工程数据访问系统,输入 exit 退出。")
|
||||
|
||||
if pre_input:
|
||||
user_questions = [pre_input]
|
||||
@@ -104,42 +127,53 @@ class DialogManager:
|
||||
while True:
|
||||
if user_questions:
|
||||
user_question = user_questions.pop(0)
|
||||
print(f"预输入问题:{user_question}")
|
||||
if not automated:
|
||||
print(f"预输入问题:{user_question}")
|
||||
elif automated:
|
||||
if not user_questions:
|
||||
logger.info("自动化模式下没有更多问题,退出程序。")
|
||||
break
|
||||
else:
|
||||
user_question = input("请输入您的问题:")
|
||||
|
||||
if user_question.strip().lower() == "exit":
|
||||
if user_question.strip().lower() == "exit" and not automated:
|
||||
logger.info("用户退出程序")
|
||||
print("退出程序。")
|
||||
break
|
||||
|
||||
rewritten_results = await self.understand_user_question_stream(user_question)
|
||||
if rewritten_results is None or rewritten_results == []:
|
||||
print('没有找到符合要求的数据,请继续提问')
|
||||
user_questions.clear()
|
||||
continue
|
||||
|
||||
print("\n系统为您理解并改写了以下访问请求,请选择:")
|
||||
for idx, (rewritten, knowledge) in enumerate(rewritten_results, start=1):
|
||||
print(f"{idx}: {rewritten}")
|
||||
print("0: 重新输入问题")
|
||||
|
||||
while True:
|
||||
choice = input("请输入编号选择(0重新输入,默认1):").strip()
|
||||
if choice == "":
|
||||
choice = "1"
|
||||
if choice == "0":
|
||||
logger.info("用户选择重新输入问题")
|
||||
print("请重新输入您的问题。\n" + "-"*50)
|
||||
user_questions.clear()
|
||||
break
|
||||
if choice.isdigit() and 1 <= int(choice) <= len(rewritten_results):
|
||||
selected_rewritten, selected_knowledge = rewritten_results[int(choice) - 1]
|
||||
logger.info(f"用户选择访问请求编号 {choice},内容:{selected_rewritten}")
|
||||
if automated:
|
||||
# 自动化模式下选择第一个结果
|
||||
selected_rewritten, selected_knowledge = rewritten_results[0]
|
||||
logger.info(f"自动化模式选择第一个访问请求,内容:{selected_rewritten}")
|
||||
result = self.code_executor.generate_and_run_code(
|
||||
selected_rewritten,
|
||||
context=selected_knowledge,
|
||||
bowei_api_docs=self.bowei_api_docs
|
||||
)
|
||||
logger.info("代码执行完成,返回结果")
|
||||
print("\n访问结果:\n", result)
|
||||
print("-" * 50)
|
||||
else:
|
||||
InteractionHandler.display_rewritten_requests(rewritten_results)
|
||||
choice_index = InteractionHandler.get_user_choice(rewritten_results)
|
||||
if choice_index is not None:
|
||||
selected_rewritten, selected_knowledge = rewritten_results[choice_index]
|
||||
logger.info(f"用户选择访问请求编号 {choice_index + 1},内容:{selected_rewritten}")
|
||||
print(f"\n您选择的访问请求是:\n{selected_rewritten}\n")
|
||||
print(f"相关知识内容:\n{selected_knowledge}\n")
|
||||
confirm = input("请确认是否继续执行该请求(输入n取消,其他继续):").strip().lower()
|
||||
#confirm = input("请确认是否继续执行该请求(输入n取消,其他继续):").strip().lower()
|
||||
confirm = ""
|
||||
if confirm == "n":
|
||||
logger.info("用户取消执行访问请求")
|
||||
#logger.info("用户取消执行访问请求")
|
||||
print("取消执行,您可以重新输入问题。\n" + "-"*50)
|
||||
else:
|
||||
logger.info("用户确认执行访问请求")
|
||||
#logger.info("用户确认执行访问请求")
|
||||
result = self.code_executor.generate_and_run_code(
|
||||
selected_rewritten,
|
||||
context=selected_knowledge,
|
||||
@@ -148,7 +182,3 @@ class DialogManager:
|
||||
logger.info("代码执行完成,返回结果")
|
||||
print("\n访问结果:\n", result)
|
||||
print("-" * 50)
|
||||
break
|
||||
else:
|
||||
logger.warning(f"用户输入无效选择:{choice}")
|
||||
print("输入无效,请输入有效编号。")
|
||||
|
||||
+35
-42
@@ -21,13 +21,13 @@ class PromptManager:
|
||||
def _init_prompts(self) -> CodeExecutorPrompts:
|
||||
understand_prompt = ChatPromptTemplate.from_template(
|
||||
"""
|
||||
你是一名电力造价业务专家,请基于以下工程文件业务结构,将用户自然语言问题改写成专业查询语句:
|
||||
你是一名电力造价业务专家,请基于以下示意工程文件业务结构,将用户自然语言问题改写成专业查询语句:
|
||||
|
||||
**工程文件业务结构(示例,请勿当真实数据)**:
|
||||
**示意工程文件业务结构**:
|
||||
{business_structure}
|
||||
|
||||
**改写规则**:
|
||||
1. **定位目标对象**:从业务结构中识别核心对象(如 `ProjectDivisionTree`→项目划分树、`FeeScheduleItem`→费用表)。
|
||||
1. **定位目标对象**:仅从示意工程文件业务结构中识别核心对象(如 `ProjectDivisionTree`→项目划分树、`FeeScheduleItem`→费用表)。
|
||||
2. **提取条件**:从用户输入中解析关键条件(如名称、量、类型),用【】标注变量。
|
||||
3. **构建专业语句**:格式为:`在[目标对象]中查找【条件】的项。`
|
||||
- 使用业务术语(如“项目划分项”而非“项目”)。
|
||||
@@ -54,8 +54,8 @@ class PromptManager:
|
||||
|
||||
# 工作流程
|
||||
1. 从用户问题中提取关键信息(节点路径、节点类型、节点名称等)
|
||||
2. 根据"用户问题"和"上下文信息"选择最匹配的"工程数据访问库"中的方法
|
||||
3. 只能生成一个可直接执行的Python函数代码
|
||||
2. 根据"用户问题"和"上下文信息"选择最匹配的"工程数据访问库"中的函数和对象属性
|
||||
3. 生成可直接执行的完全满足用户输入问题要求功能效果的Python函数代码
|
||||
|
||||
# 输出格式(必须严格遵循)
|
||||
def project_get_calculate_function():
|
||||
@@ -65,7 +65,8 @@ def project_get_calculate_function():
|
||||
|
||||
# 执行规则
|
||||
- 参数必须从用户问题或上下文信息中提取
|
||||
- 必须确保生成的代码可以直接执行
|
||||
- 输出代码中必须以def project_get_calculate_function() -> Tuple[bool, Any, Optional[str], Dict[str, Any]]函数作为入口函数
|
||||
- 必须确保生成的代码可以直接执行,代码要注意进行各类错误检查,出错采用抛出异常方式,说明详细信息
|
||||
- 禁止添加任何注释或解释
|
||||
"""
|
||||
)
|
||||
@@ -97,55 +98,47 @@ def project_get_calculate_function():
|
||||
- 禁止在代码后加上```字样
|
||||
|
||||
请输出你修补后的代码:
|
||||
"""
|
||||
)
|
||||
|
||||
""")
|
||||
|
||||
rewrite_prompt_template = ChatPromptTemplate.from_template(
|
||||
"""
|
||||
你是一个AI助手,负责将模糊的用户问题改写成明确的查询。给定一个模糊的用户问题、一条从知识库获取的具体上下文知识以及相关业务背景信息,你需要执行以下任务:
|
||||
您是一个AI查询改写助手。基于给定的原始查询和上下文知识,生成一个精确的改写查询。步骤:
|
||||
1. 从上下文知识的`labels`提取对象类型,翻译为中文。
|
||||
2. 从`properties`选择对象标识:优先用`path`值,若无则用`name`值。
|
||||
3. 智能映射原始查询的属性名称:
|
||||
- 如果属性名称是上下文属性的缩写、省略或同义词,映射到实际属性名称(如“人工费”可能映射到“费率”或“合价含税”)。
|
||||
- 如果无法映射,保留原始名称。
|
||||
4. 保留原始查询的额外操作(如计算指令)。
|
||||
5. 输出格式:“获取[对象标识][对象类型]的[属性]属性,[额外操作]”。
|
||||
|
||||
1. **理解输入**:
|
||||
- 原始问题:用户输入的模糊性问题(字符串)
|
||||
- 上下文知识:从知识库查询获取的一条具体知识(字符串)
|
||||
- 业务背景:当前问题所属的业务场景信息(字符串,如行业、产品类型、业务规则等)
|
||||
示例参考:
|
||||
- 输入:原始问题="查找名称中包含“工程”的项目划分项,并返回其人工费乘以1000的值。", 上下文知识=...
|
||||
- 输出="获取[安装/架空输电线路本体工程/基础工程/基础工程材料工地运输]项目划分项的人工费,并乘以1000的值"
|
||||
|
||||
2. **改写要求**:
|
||||
- 基于**上下文知识**和**业务背景**,将原始问题改写成针对该知识的明确性问题
|
||||
- 严格保留原始问题的核心语义(意图和关键信息不变)
|
||||
- 输出应是一个完整的自然语言问题,需同时满足:
|
||||
✅ 如果是一个查找语句则改写成一个获取语句
|
||||
✅ 直接关联上下文知识的具体内容
|
||||
✅ 符合业务背景的专业表述(如使用行业术语)
|
||||
- 禁止添加原始问题未提及的额外假设
|
||||
|
||||
3. **输出格式**:
|
||||
- 仅输出改写后的明确性问题(单个字符串)
|
||||
|
||||
4. **示例**:
|
||||
用户输入:查找名称中包含“工程”的项目划分项,并返回其人工费乘以1000的值。
|
||||
上下文知识: <Node element_id='4:f0ca7d86-42db-43e9-87d5-8b942b8972a9:1080' labels=frozenset({{'ProjectDivisionItem'}}) properties={{'序号': '1.1', '专业类型': '线路', 'GUID': '{{3CA66E10-EBB1-41E4-8B9B-23EFEA1B0E98}}', '取费表': '线路取费表', '颜色标记': '标记:16777215;', '资源库名称': '预算 第四册 架空输电线路工程(2018年版)', 'type': '项目划分', '最小资源库编码': '10', 'path': '安装/架空输电线路本体工程/基础工程/基础工程材料工地运输', '合价含税': '250340.309025', '取费表id': '3_1', 'name': '基础工程材料工地运输', 'notCheck': '1', '费率': '0'}}>
|
||||
改写输出:获取[安装/架空输电线路本体工程/基础工程/基础工程材料工地运输]项目划分项的人工费,并乘以1000的值
|
||||
|
||||
现在请处理以下输入:
|
||||
- 上下文知识:"{context}"
|
||||
- 业务背景(示例,请勿当真实数据):"{business_structure}"
|
||||
现在,处理以下输入:
|
||||
- 原始问题:{user_input}
|
||||
- 上下文知识:{context}
|
||||
""")
|
||||
|
||||
cypher_conversion_prompt = ChatPromptTemplate.from_template(
|
||||
"""
|
||||
你是一名电力造价业务专家,负责将用户自然语言问题中需要访问的对象识别出来,并生成针对该对象的NEO4J知识图谱的Cypher查询语句,获取该对象的全部信息。知识图谱基于以下工程文件业务结构构建。
|
||||
你是一名电力造价业务专家,负责将用户自然语言问题中需要访问的对象识别出来,并生成针对该对象的NEO4J知识图谱的Cypher查询语句,获取该对象的全部信息。知识图谱基于从文件“获取[安装/架空输电线路本体工程]项目划分项的单位.”中读取的层级关联结构构建。该文件包含用反斜杠分割的多个字符串(如“安装/架空输电线路本体工程”),每个字符串表示一个完整的层级路径,路径部分用斜杠分隔,对应于知识图谱中ProjectDivisionItem节点的层级关系。路径映射规则:每个路径部分(如“安装”)是一个ProjectDivisionItem节点,父子关系通过关系类型`:CHILD_OF`连接,形成从根节点到叶节点的层级结构。
|
||||
|
||||
业务结构(示例,请勿当真实数据):
|
||||
示例业务结构:
|
||||
{business_structure}
|
||||
|
||||
用户问题:
|
||||
{user_input}
|
||||
|
||||
改写规则:
|
||||
识别目标对象:从业务结构中识别用户问题中需要获取的核心对象类型(如 ProjectDivisionItem、ProjectQuantity、Fee、FeeCollection 等)。对象类型必须精确映射到结构中的节点标签(例如,使用 ProjectDivisionItem 而非“项目划分项”)。
|
||||
提取查询条件:从用户输入中解析关键条件(如名称、量、类型、值、层级等),条件应基于对象属性(如 name、quantity、type)。
|
||||
识别目标层级: 从用户输入中解析层级路径,例如字符串“安装/架空输电线路本体工程”映射为:根节点(name: '安装') -[:CHILD_OF]-> 子节点(name: '架空输电线路本体工程')。知识图谱节点标签包括ProjectDivisionItem、ProjectQuantity、Fee、FeeCollection等。
|
||||
识别目标对象:仅从业务结构中识别用户问题中需要获取的核心对象类型(如 ProjectDivisionItem、ProjectQuantity、Fee、FeeCollection 等)。对象类型必须精确映射到结构中的节点标签(例如,使用 ProjectDivisionItem 而非“项目划分项”)。
|
||||
提取查询条件:从用户输入中解析关键条件(如名称、量、类型、值、层级等),条件应基于对象属性(如 name、quantity、type),注意区分查找子串和相等的区别。如果条件涉及层级路径(如路径中包含特定部分),需提取路径部分作为条件(例如,用户提到“架空输电线路”时,解析为路径匹配)。
|
||||
如果用户指定层级(如“叶节点”),需在条件中体现(例如,添加 WHERE item.isLeaf = true)。
|
||||
忽略任何计算、转换或后处理要求(如“乘以1000”),只关注获取原始数据对象或属性。
|
||||
构建Cypher查询:生成一个Cypher查询语句,格式为:
|
||||
|
||||
MATCH 子句:匹配目标对象节点,并应用条件,不带变量。
|
||||
MATCH 子句:匹配目标对象节点,并应用条件。如果涉及层级路径,使用路径匹配(如 MATCH path = (item1:ProjectDivisionItem)-[:CHILD_OF*]->(item2:ProjectDivisionItem) WHERE item1.name = '安装' AND item2.name = '架空输电线路本体工程')。变量仅用于目标节点和必要路径节点。
|
||||
WHERE 子句:包含提取的条件(使用变量或具体值)。
|
||||
RETURN 子句:必须返回对象(如 RETURN item),不能包含对象属性、函数(如乘法、SUM)。
|
||||
LIMIT 子句: 最多返回5条。
|
||||
@@ -153,9 +146,9 @@ LIMIT 子句: 最多返回5条。
|
||||
查询应简洁,只获取数据,不执行计算。
|
||||
输出格式:直接输出最终Cypher查询语句,不添加解释或额外文本。
|
||||
|
||||
用户问题:
|
||||
{user_input}
|
||||
Cypher查询语句:
|
||||
**示例**:
|
||||
用户问题:查找一下名称中包含工程的项目划分
|
||||
Cypher查询语句:MATCH (item:ProjectDivisionItem)\nWHERE item.name CONTAINS '工程'\nRETURN item\nLIMIT 5
|
||||
""")
|
||||
|
||||
return CodeExecutorPrompts(
|
||||
|
||||
Reference in New Issue
Block a user