提交全部代码

This commit is contained in:
2025-06-24 08:34:57 +08:00
commit 5a1c74b356
61 changed files with 9123341 additions and 0 deletions
+65
View File
@@ -0,0 +1,65 @@
import logging
from langchain_core.output_parsers import StrOutputParser
from langchain_experimental.utilities import PythonREPL
from langchain_core.tools import Tool
from langchain_experimental.tools import PythonREPLTool
from project_implementation import ProjectBuilder
logger = logging.getLogger("BoweiAgent.CodeExecutor")
class CodeExecutor:
def __init__(self, prompts, llm_client, max_retries=3):
self.llm_client = llm_client
self.prompts = prompts
self.max_retries = max_retries
self.output_parser = StrOutputParser()
def generate_code(self, user_request: str, context: str = '', bowei_api_docs: str = '') -> str:
logger.info(f"开始生成代码,访问请求:{user_request}")
prompt = self.prompts.code_gen_prompt.format_prompt(user_request=user_request, context=context, bowei_api_docs=bowei_api_docs)
response = self.llm_client.invoke(prompt.to_messages())
code = self.output_parser.parse(response)
logger.debug(f"生成的代码内容:\n{code}")
return code
def fix_code(self, code: str, error: str) -> str:
logger.warning(f"代码执行出错,开始修复。错误信息:{error}")
prompt = self.prompts.code_fix_prompt.format_prompt(code=code, error=error)
response = self.llm_client.invoke(prompt.to_messages())
fixed_code = self.output_parser.parse(response)
logger.debug(f"修复后的代码内容:\n{fixed_code}")
return fixed_code
def generate_and_run_code(self, user_request: str, context: str = '', bowei_api_docs: str = '') -> str:
code = self.generate_code(user_request, context, bowei_api_docs)
logger.info("开始执行生成的代码")
for attempt in range(self.max_retries):
try:
# 自定义命名空间
namespace = {
"ProjectBuilder": ProjectBuilder,
"project_implementation": __import__("project_implementation"),
"project": __import__("project"),
}
# 初始化 REPL 并传入命名空间
python_repl = PythonREPL(globals=namespace)
repl_tool = Tool(
name="python_repl",
description="...",
func=python_repl.run,
)
result = repl_tool.func(code.content) # 代码将在自定义命名空间中执行
logger.info(f"代码执行成功,返回结果长度:{len(result)}")
return result
except Exception as e:
error_msg = str(e)
logger.error(f"代码执行异常,尝试第 {attempt+1} 次修复。异常信息:{error_msg}")
code = self.fix_code(code, error_msg)
logger.error(f"代码执行失败,超过最大重试次数 {self.max_retries}")
return f"代码执行失败,超过最大重试次数 {self.max_retries}\n最后一次错误信息:\n{error_msg}"
+29
View File
@@ -0,0 +1,29 @@
import yaml
class Config:
def __init__(self, path="config.yaml"):
with open(path, "r", encoding="utf-8") as f:
self._config = yaml.safe_load(f)
def get(self, key, default=None):
return self._config.get(key, default)
@property
def openai(self):
return self._config.get("openai", {})
@property
def bowei_api_docs_path(self):
return self._config.get("bowei_api_docs_path", "./data/bowei_api_docs.md")
@property
def business_object_structure_path(self):
return self._config.get("business_object_structure_path", "./data/business_object_structure.md")
@property
def neo4j_conf(self):
return self._config.get("neo4j", {})
@property
def embedding(self):
return self._config.get("embedding", {})
+152
View File
@@ -0,0 +1,152 @@
# src/dialog_manager.py
import logging
from langchain.schema import SystemMessage, HumanMessage
import asyncio
logger = logging.getLogger("BoweiAgent.DialogManager")
class DialogManager:
def __init__(
self,
llm_client,
business_structure: str,
bowei_api_docs: str,
code_executor,
knowledge_retriever,
prompt_manager,
):
self.llm_client = llm_client
self.business_structure = business_structure
self.bowei_api_docs = bowei_api_docs
self.code_executor = code_executor
self.knowledge_retriever = knowledge_retriever
self.prompts = prompt_manager.prompts
def retrieve_relevant_docs(self, user_input: str):
logger.debug(f"开始检索知识库,用户输入:{user_input}")
docs = self.knowledge_retriever.get_relevant_documents(user_input)
logger.debug(f"检索到 {len(docs)} 条相关文档")
return [doc.page_content for doc in docs]
async def convert_question_to_cypher(self, user_input: str) -> str:
prompt = self.prompts.cypher_conversion_prompt.format_prompt(
business_structure=self.business_structure,
user_input=user_input
)
messages = prompt.to_messages()
cypher_query = ""
async for chunk in self.llm_client.stream(messages):
print(chunk.content, end="", flush=True)
cypher_query += chunk.content
print()
logger.debug(f"生成的Cypher查询语句:{cypher_query.strip()}")
return cypher_query.strip()
async def rewrite_question_with_query_result(self, user_input: str, query_result: str) -> str:
prompt = self.prompts.rewrite_prompt_template.format_prompt(
business_structure=self.business_structure,
context=query_result
)
messages = prompt.to_messages()
messages.append(HumanMessage(content=user_input))
result = ""
async for chunk in self.llm_client.stream(messages):
print(chunk.content, end="", flush=True)
result += chunk.content
print()
logger.debug(f"重写后的用户问题:{result.strip()}")
return result.strip()
async def understand_user_question_stream(self, user_input: str):
logger.info(f"理解用户问题(流式):{user_input}")
# 1. 转换用户问题为Cypher查询
cypher_query = await self.convert_question_to_cypher(user_input)
# 2. 执行Cypher查询,最多返回5条
docs = self.knowledge_retriever.get_relevant_documents(cypher_query)
docs = docs[:5] # 限制最多5条
if not docs:
logger.info("查询无结果,直接用业务结构改写")
prompt = self.prompts.understand_prompt.format_prompt(
business_structure=self.business_structure,
user_input=user_input
)
messages = prompt.to_messages()
result = ""
async for chunk in self.llm_client.stream(messages):
print(chunk.content, end="", flush=True)
result += chunk.content
print()
return [(result.strip(), "")]
rewritten_list = []
for idx, doc in enumerate(docs, start=1):
print(f"\n{idx}条相关文档改写结果(流式):")
rewritten = await self.rewrite_question_with_query_result(user_input, doc.page_content)
rewritten_list.append((rewritten, doc.page_content))
return rewritten_list
async def run_async(self, pre_input: str = None):
logger.info("启动对话管理器,等待用户输入")
print("欢迎使用博微造价工程数据访问系统,输入 exit 退出。")
if pre_input:
user_questions = [pre_input]
else:
user_questions = []
while True:
if user_questions:
user_question = user_questions.pop(0)
print(f"预输入问题:{user_question}")
else:
user_question = input("请输入您的问题:")
if user_question.strip().lower() == "exit":
logger.info("用户退出程序")
print("退出程序。")
break
rewritten_results = await self.understand_user_question_stream(user_question)
print("\n系统为您理解并改写了以下访问请求,请选择:")
for idx, (rewritten, knowledge) in enumerate(rewritten_results, start=1):
print(f"{idx}: {rewritten}")
print("0: 重新输入问题")
while True:
choice = input("请输入编号选择(0重新输入,默认1):").strip()
if choice == "":
choice = "1"
if choice == "0":
logger.info("用户选择重新输入问题")
print("请重新输入您的问题。\n" + "-"*50)
user_questions.clear()
break
if choice.isdigit() and 1 <= int(choice) <= len(rewritten_results):
selected_rewritten, selected_knowledge = rewritten_results[int(choice) - 1]
logger.info(f"用户选择访问请求编号 {choice},内容:{selected_rewritten}")
print(f"\n您选择的访问请求是:\n{selected_rewritten}\n")
print(f"相关知识内容:\n{selected_knowledge}\n")
confirm = input("请确认是否继续执行该请求(输入n取消,其他继续):").strip().lower()
if confirm == "n":
logger.info("用户取消执行访问请求")
print("取消执行,您可以重新输入问题。\n" + "-"*50)
else:
logger.info("用户确认执行访问请求")
result = self.code_executor.generate_and_run_code(
selected_rewritten,
context=selected_knowledge,
bowei_api_docs=self.bowei_api_docs
)
logger.info("代码执行完成,返回结果")
print("\n访问结果:\n", result)
print("-" * 50)
break
else:
logger.warning(f"用户输入无效选择:{choice}")
print("输入无效,请输入有效编号。")
+7
View File
@@ -0,0 +1,7 @@
def load_file(path: str) -> str:
try:
with open(path, "r", encoding="utf-8") as f:
return f.read()
except Exception as e:
print(f"加载文件失败 {path}: {e}")
return ""
+25
View File
@@ -0,0 +1,25 @@
import os
from langchain_openai import OpenAIEmbeddings
class EmbeddingClient:
def __init__(self, embedding_config: dict):
api_key = embedding_config.get("api_key")
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
api_base = embedding_config.get("api_base")
if api_base:
os.environ["OPENAI_API_BASE"] = api_base
api_type = embedding_config.get("api_type", "openai")
os.environ["OPENAI_API_TYPE"] = api_type
model_name = embedding_config.get("model_name")
if model_name:
self.client = OpenAIEmbeddings(model=model_name)
else:
self.client = OpenAIEmbeddings()
def embed_documents(self, texts):
return self.client.embed_documents(texts)
def embed_query(self, text):
return self.client.embed_query(text)
+34
View File
@@ -0,0 +1,34 @@
# src/llm_client.py
import os
import getpass
from langchain_openai import ChatOpenAI
class LLMClient:
def __init__(self, openai_config: dict):
api_key = openai_config.get("api_key")
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
else:
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = getpass.getpass("请输入您的OpenAI API Key")
api_base = openai_config.get("api_base")
if api_base:
os.environ["OPENAI_API_BASE"] = api_base
api_type = openai_config.get("api_type", "openai")
os.environ["OPENAI_API_TYPE"] = api_type
model_name = openai_config.get("model_name", "gpt-4o-mini")
# 开启流式
self.llm = ChatOpenAI(model_name=model_name, temperature=0, streaming=True)
def invoke(self, messages):
# 同步调用,返回完整响应
return self.llm.invoke(messages)
def stream(self, messages):
# 异步流式调用,返回异步生成器
return self.llm.astream(messages)
+34
View File
@@ -0,0 +1,34 @@
# src/neo4j_raw_retriever.py
from typing import List
from langchain.schema import Document
from neo4j import GraphDatabase
import logging
logger = logging.getLogger("BoweiAgent.Neo4jRawRetriever")
class Neo4jRawRetriever:
def __init__(self, neo4j_conf: dict):
self.uri = neo4j_conf.get("uri")
self.username = neo4j_conf.get("username")
self.password = neo4j_conf.get("password")
self.driver = GraphDatabase.driver(self.uri, auth=(self.username, self.password))
def close(self):
self.driver.close()
def get_relevant_documents(self, cypher_query: str) -> list[Document]:
with self.driver.session() as session:
result = session.run(cypher_query)
documents = []
for record in result:
node = record.get("n") or record.values()[0] # 根据返回字段调整
content = ""
metadata = {}
if hasattr(node, "items"):
metadata = dict(node.items())
content = str(node)
else:
content = str(node)
documents.append(Document(page_content=content, metadata=metadata))
return documents
+23
View File
@@ -0,0 +1,23 @@
from langchain_neo4j import Neo4jVector
from langchain_openai import OpenAIEmbeddings
class Neo4jKnowledgeRetriever:
def __init__(self, neo4j_conf: dict, embedding_client):
neo4j_uri = neo4j_conf.get("uri")
neo4j_username = neo4j_conf.get("username")
neo4j_password = neo4j_conf.get("password")
index_name = neo4j_conf.get("index_name", "vector") # 默认向量索引名
keyword_index_name = neo4j_conf.get("keyword_index_name", "keyword") # 默认关键词索引名
self.vectorstore = Neo4jVector.from_existing_index(
embedding_client.client,
url=neo4j_uri,
username=neo4j_username,
password=neo4j_password,
index_name=index_name,
#keyword_index_name=keyword_index_name,
#search_type="hybrid",
)
def get_relevant_documents(self, query: str, k: int = 5):
return self.vectorstore.similarity_search(query, k=k)
+111
View File
@@ -0,0 +1,111 @@
# src/prompt_manager.py
from dataclasses import dataclass
from langchain.prompts import ChatPromptTemplate
from langchain.schema import SystemMessage, HumanMessage
@dataclass
class CodeExecutorPrompts:
understand_prompt: ChatPromptTemplate
code_gen_prompt: ChatPromptTemplate
code_fix_prompt: ChatPromptTemplate
rewrite_prompt_template: ChatPromptTemplate
cypher_conversion_prompt: ChatPromptTemplate # 新增Cypher转换提示模板
class PromptManager:
def __init__(self):
self.prompts = self._init_prompts()
def _init_prompts(self) -> CodeExecutorPrompts:
understand_prompt = ChatPromptTemplate.from_template(
"""
你是一名电力造价业务专家,请基于以下工程文件业务结构,将用户自然语言问题改写成专业查询语句:
**工程文件业务结构**
{business_structure}
**改写规则**
1. **定位目标对象**:从业务结构中识别核心对象(如 `ProjectDivisionTree`→项目划分树、`FeeScheduleItem`→费用表)。
2. **提取条件**:从用户输入中解析关键条件(如名称、量、类型),用【】标注变量。
3. **构建专业语句**:格式为:`在[目标对象]中查找【条件】的项。`
- 使用业务术语(如“项目划分项”而非“项目”)。
- 条件需明确属性(如【名称】、【量】、【类型】)。
4. **精确映射结构**:若用户查询层级(如“叶节点”),需在条件中体现。
**用户输入**{user_input}
**改写输出**:(仅输出改写后的语句)
""")
code_gen_prompt = ChatPromptTemplate.from_template("""
你是一个专业的Python工程师。我会给你一个用户问题,你需要将其转换为对应的Python代码
用户问题:
{user_request}
上下文信息:
{context}
工程数据访问库:
{bowei_api_docs}
# 工作流程
1. 从用户问题中提取关键信息(节点路径、节点类型、节点名称等)
2. 根据"用户问题""上下文信息"选择最匹配的"工程数据访问库"中的方法
3. 生成可直接执行的Python代码
# 代码模板(必须严格遵循)
def neo4j_find_function():
project = ProjectBuilder.build()
status, data, error, helper_info = project.[SELECTED_METHOD]([PARAMETERS])
return status, data, error, helper_info
# 执行规则
- 参数必须从用户问题或上下文信息中提取
- 必须确保生成的代码可以直接执行
- 禁止修改代码模板结构
- 禁止添加任何注释或解释
# 输出格式
def neo4j_find_function():
project = ProjectBuilder.build()
status, data, error, helper_info = project.[SELECTED_METHOD]([PARAMETERS])
return status, data, error, helper_info
""")
code_fix_prompt = ChatPromptTemplate.from_messages([
SystemMessage(content="你是一个专业的Python开发者。下面是之前生成的代码和执行时出现的错误,请修复代码中的错误,确保代码语法正确且能完成访问博微造价工程数据的任务。只输出修复后的Python代码,不要添加解释。"),
HumanMessage(content="原始代码:\n{code}\n\n错误信息:\n{error}\n")
])
rewrite_prompt_template = ChatPromptTemplate.from_template(
"""你是一个专业的工程业务助理,结合以下工程业务结构信息和相关知识:
{business_structure}
相关知识内容:
{context}
请根据用户的问题,结合上述信息,理解并改写成一个针对工程数据的访问请求(简洁明了的描述)。
请只输出改写后的访问请求文本,不要多余解释。"""
)
cypher_conversion_prompt = ChatPromptTemplate.from_template(
"""
你是一个Neo4j专家。请将用户的自然语言问题转换成一个有效的Cypher查询语句,查询知识图谱中相关信息。
只返回Cypher语句,不要任何解释,最多返回5条。
业务结构信息:
{business_structure}
用户问题:
{user_input}
Cypher查询语句:
"""
)
return CodeExecutorPrompts(
understand_prompt=understand_prompt,
code_gen_prompt=code_gen_prompt,
code_fix_prompt=code_fix_prompt,
rewrite_prompt_template=rewrite_prompt_template,
cypher_conversion_prompt=cypher_conversion_prompt,
)