提交全部代码
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_experimental.utilities import PythonREPL
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_experimental.tools import PythonREPLTool
|
||||
from project_implementation import ProjectBuilder
|
||||
|
||||
logger = logging.getLogger("BoweiAgent.CodeExecutor")
|
||||
|
||||
class CodeExecutor:
|
||||
def __init__(self, prompts, llm_client, max_retries=3):
|
||||
self.llm_client = llm_client
|
||||
self.prompts = prompts
|
||||
self.max_retries = max_retries
|
||||
self.output_parser = StrOutputParser()
|
||||
|
||||
def generate_code(self, user_request: str, context: str = '', bowei_api_docs: str = '') -> str:
|
||||
logger.info(f"开始生成代码,访问请求:{user_request}")
|
||||
prompt = self.prompts.code_gen_prompt.format_prompt(user_request=user_request, context=context, bowei_api_docs=bowei_api_docs)
|
||||
|
||||
response = self.llm_client.invoke(prompt.to_messages())
|
||||
code = self.output_parser.parse(response)
|
||||
logger.debug(f"生成的代码内容:\n{code}")
|
||||
return code
|
||||
|
||||
def fix_code(self, code: str, error: str) -> str:
|
||||
logger.warning(f"代码执行出错,开始修复。错误信息:{error}")
|
||||
prompt = self.prompts.code_fix_prompt.format_prompt(code=code, error=error)
|
||||
response = self.llm_client.invoke(prompt.to_messages())
|
||||
fixed_code = self.output_parser.parse(response)
|
||||
logger.debug(f"修复后的代码内容:\n{fixed_code}")
|
||||
return fixed_code
|
||||
|
||||
def generate_and_run_code(self, user_request: str, context: str = '', bowei_api_docs: str = '') -> str:
|
||||
code = self.generate_code(user_request, context, bowei_api_docs)
|
||||
logger.info("开始执行生成的代码")
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# 自定义命名空间
|
||||
namespace = {
|
||||
"ProjectBuilder": ProjectBuilder,
|
||||
"project_implementation": __import__("project_implementation"),
|
||||
"project": __import__("project"),
|
||||
}
|
||||
|
||||
# 初始化 REPL 并传入命名空间
|
||||
python_repl = PythonREPL(globals=namespace)
|
||||
repl_tool = Tool(
|
||||
name="python_repl",
|
||||
description="...",
|
||||
func=python_repl.run,
|
||||
)
|
||||
|
||||
result = repl_tool.func(code.content) # 代码将在自定义命名空间中执行
|
||||
logger.info(f"代码执行成功,返回结果长度:{len(result)}")
|
||||
return result
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"代码执行异常,尝试第 {attempt+1} 次修复。异常信息:{error_msg}")
|
||||
|
||||
code = self.fix_code(code, error_msg)
|
||||
|
||||
logger.error(f"代码执行失败,超过最大重试次数 {self.max_retries}")
|
||||
return f"代码执行失败,超过最大重试次数 {self.max_retries}。\n最后一次错误信息:\n{error_msg}"
|
||||
@@ -0,0 +1,29 @@
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
def __init__(self, path="config.yaml"):
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
self._config = yaml.safe_load(f)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._config.get(key, default)
|
||||
|
||||
@property
|
||||
def openai(self):
|
||||
return self._config.get("openai", {})
|
||||
|
||||
@property
|
||||
def bowei_api_docs_path(self):
|
||||
return self._config.get("bowei_api_docs_path", "./data/bowei_api_docs.md")
|
||||
|
||||
@property
|
||||
def business_object_structure_path(self):
|
||||
return self._config.get("business_object_structure_path", "./data/business_object_structure.md")
|
||||
|
||||
@property
|
||||
def neo4j_conf(self):
|
||||
return self._config.get("neo4j", {})
|
||||
|
||||
@property
|
||||
def embedding(self):
|
||||
return self._config.get("embedding", {})
|
||||
@@ -0,0 +1,152 @@
|
||||
# src/dialog_manager.py
|
||||
|
||||
import logging
|
||||
from langchain.schema import SystemMessage, HumanMessage
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger("BoweiAgent.DialogManager")
|
||||
|
||||
class DialogManager:
|
||||
def __init__(
|
||||
self,
|
||||
llm_client,
|
||||
business_structure: str,
|
||||
bowei_api_docs: str,
|
||||
code_executor,
|
||||
knowledge_retriever,
|
||||
prompt_manager,
|
||||
):
|
||||
self.llm_client = llm_client
|
||||
self.business_structure = business_structure
|
||||
self.bowei_api_docs = bowei_api_docs
|
||||
self.code_executor = code_executor
|
||||
self.knowledge_retriever = knowledge_retriever
|
||||
self.prompts = prompt_manager.prompts
|
||||
|
||||
def retrieve_relevant_docs(self, user_input: str):
|
||||
logger.debug(f"开始检索知识库,用户输入:{user_input}")
|
||||
docs = self.knowledge_retriever.get_relevant_documents(user_input)
|
||||
logger.debug(f"检索到 {len(docs)} 条相关文档")
|
||||
return [doc.page_content for doc in docs]
|
||||
|
||||
async def convert_question_to_cypher(self, user_input: str) -> str:
|
||||
prompt = self.prompts.cypher_conversion_prompt.format_prompt(
|
||||
business_structure=self.business_structure,
|
||||
user_input=user_input
|
||||
)
|
||||
messages = prompt.to_messages()
|
||||
cypher_query = ""
|
||||
async for chunk in self.llm_client.stream(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
cypher_query += chunk.content
|
||||
print()
|
||||
logger.debug(f"生成的Cypher查询语句:{cypher_query.strip()}")
|
||||
return cypher_query.strip()
|
||||
|
||||
async def rewrite_question_with_query_result(self, user_input: str, query_result: str) -> str:
|
||||
prompt = self.prompts.rewrite_prompt_template.format_prompt(
|
||||
business_structure=self.business_structure,
|
||||
context=query_result
|
||||
)
|
||||
messages = prompt.to_messages()
|
||||
messages.append(HumanMessage(content=user_input))
|
||||
|
||||
result = ""
|
||||
async for chunk in self.llm_client.stream(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
result += chunk.content
|
||||
print()
|
||||
logger.debug(f"重写后的用户问题:{result.strip()}")
|
||||
return result.strip()
|
||||
|
||||
async def understand_user_question_stream(self, user_input: str):
|
||||
logger.info(f"理解用户问题(流式):{user_input}")
|
||||
|
||||
# 1. 转换用户问题为Cypher查询
|
||||
cypher_query = await self.convert_question_to_cypher(user_input)
|
||||
|
||||
# 2. 执行Cypher查询,最多返回5条
|
||||
docs = self.knowledge_retriever.get_relevant_documents(cypher_query)
|
||||
docs = docs[:5] # 限制最多5条
|
||||
|
||||
if not docs:
|
||||
logger.info("查询无结果,直接用业务结构改写")
|
||||
prompt = self.prompts.understand_prompt.format_prompt(
|
||||
business_structure=self.business_structure,
|
||||
user_input=user_input
|
||||
)
|
||||
messages = prompt.to_messages()
|
||||
result = ""
|
||||
async for chunk in self.llm_client.stream(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
result += chunk.content
|
||||
print()
|
||||
return [(result.strip(), "")]
|
||||
|
||||
rewritten_list = []
|
||||
for idx, doc in enumerate(docs, start=1):
|
||||
print(f"\n第{idx}条相关文档改写结果(流式):")
|
||||
rewritten = await self.rewrite_question_with_query_result(user_input, doc.page_content)
|
||||
rewritten_list.append((rewritten, doc.page_content))
|
||||
return rewritten_list
|
||||
|
||||
async def run_async(self, pre_input: str = None):
|
||||
logger.info("启动对话管理器,等待用户输入")
|
||||
print("欢迎使用博微造价工程数据访问系统,输入 exit 退出。")
|
||||
|
||||
if pre_input:
|
||||
user_questions = [pre_input]
|
||||
else:
|
||||
user_questions = []
|
||||
|
||||
while True:
|
||||
if user_questions:
|
||||
user_question = user_questions.pop(0)
|
||||
print(f"预输入问题:{user_question}")
|
||||
else:
|
||||
user_question = input("请输入您的问题:")
|
||||
|
||||
if user_question.strip().lower() == "exit":
|
||||
logger.info("用户退出程序")
|
||||
print("退出程序。")
|
||||
break
|
||||
|
||||
rewritten_results = await self.understand_user_question_stream(user_question)
|
||||
|
||||
print("\n系统为您理解并改写了以下访问请求,请选择:")
|
||||
for idx, (rewritten, knowledge) in enumerate(rewritten_results, start=1):
|
||||
print(f"{idx}: {rewritten}")
|
||||
print("0: 重新输入问题")
|
||||
|
||||
while True:
|
||||
choice = input("请输入编号选择(0重新输入,默认1):").strip()
|
||||
if choice == "":
|
||||
choice = "1"
|
||||
if choice == "0":
|
||||
logger.info("用户选择重新输入问题")
|
||||
print("请重新输入您的问题。\n" + "-"*50)
|
||||
user_questions.clear()
|
||||
break
|
||||
if choice.isdigit() and 1 <= int(choice) <= len(rewritten_results):
|
||||
selected_rewritten, selected_knowledge = rewritten_results[int(choice) - 1]
|
||||
logger.info(f"用户选择访问请求编号 {choice},内容:{selected_rewritten}")
|
||||
print(f"\n您选择的访问请求是:\n{selected_rewritten}\n")
|
||||
print(f"相关知识内容:\n{selected_knowledge}\n")
|
||||
confirm = input("请确认是否继续执行该请求(输入n取消,其他继续):").strip().lower()
|
||||
if confirm == "n":
|
||||
logger.info("用户取消执行访问请求")
|
||||
print("取消执行,您可以重新输入问题。\n" + "-"*50)
|
||||
else:
|
||||
logger.info("用户确认执行访问请求")
|
||||
result = self.code_executor.generate_and_run_code(
|
||||
selected_rewritten,
|
||||
context=selected_knowledge,
|
||||
bowei_api_docs=self.bowei_api_docs
|
||||
)
|
||||
logger.info("代码执行完成,返回结果")
|
||||
print("\n访问结果:\n", result)
|
||||
print("-" * 50)
|
||||
break
|
||||
else:
|
||||
logger.warning(f"用户输入无效选择:{choice}")
|
||||
print("输入无效,请输入有效编号。")
|
||||
@@ -0,0 +1,7 @@
|
||||
def load_file(path: str) -> str:
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
print(f"加载文件失败 {path}: {e}")
|
||||
return ""
|
||||
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
class EmbeddingClient:
|
||||
def __init__(self, embedding_config: dict):
|
||||
api_key = embedding_config.get("api_key")
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
api_base = embedding_config.get("api_base")
|
||||
if api_base:
|
||||
os.environ["OPENAI_API_BASE"] = api_base
|
||||
api_type = embedding_config.get("api_type", "openai")
|
||||
os.environ["OPENAI_API_TYPE"] = api_type
|
||||
|
||||
model_name = embedding_config.get("model_name")
|
||||
if model_name:
|
||||
self.client = OpenAIEmbeddings(model=model_name)
|
||||
else:
|
||||
self.client = OpenAIEmbeddings()
|
||||
|
||||
def embed_documents(self, texts):
|
||||
return self.client.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text):
|
||||
return self.client.embed_query(text)
|
||||
@@ -0,0 +1,34 @@
|
||||
# src/llm_client.py
|
||||
|
||||
import os
|
||||
import getpass
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self, openai_config: dict):
|
||||
api_key = openai_config.get("api_key")
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
else:
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
os.environ["OPENAI_API_KEY"] = getpass.getpass("请输入您的OpenAI API Key:")
|
||||
|
||||
api_base = openai_config.get("api_base")
|
||||
if api_base:
|
||||
os.environ["OPENAI_API_BASE"] = api_base
|
||||
|
||||
api_type = openai_config.get("api_type", "openai")
|
||||
os.environ["OPENAI_API_TYPE"] = api_type
|
||||
|
||||
model_name = openai_config.get("model_name", "gpt-4o-mini")
|
||||
|
||||
# 开启流式
|
||||
self.llm = ChatOpenAI(model_name=model_name, temperature=0, streaming=True)
|
||||
|
||||
def invoke(self, messages):
|
||||
# 同步调用,返回完整响应
|
||||
return self.llm.invoke(messages)
|
||||
|
||||
def stream(self, messages):
|
||||
# 异步流式调用,返回异步生成器
|
||||
return self.llm.astream(messages)
|
||||
@@ -0,0 +1,34 @@
|
||||
# src/neo4j_raw_retriever.py
|
||||
|
||||
from typing import List
|
||||
from langchain.schema import Document
|
||||
from neo4j import GraphDatabase
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("BoweiAgent.Neo4jRawRetriever")
|
||||
|
||||
class Neo4jRawRetriever:
|
||||
def __init__(self, neo4j_conf: dict):
|
||||
self.uri = neo4j_conf.get("uri")
|
||||
self.username = neo4j_conf.get("username")
|
||||
self.password = neo4j_conf.get("password")
|
||||
self.driver = GraphDatabase.driver(self.uri, auth=(self.username, self.password))
|
||||
|
||||
def close(self):
|
||||
self.driver.close()
|
||||
|
||||
def get_relevant_documents(self, cypher_query: str) -> list[Document]:
|
||||
with self.driver.session() as session:
|
||||
result = session.run(cypher_query)
|
||||
documents = []
|
||||
for record in result:
|
||||
node = record.get("n") or record.values()[0] # 根据返回字段调整
|
||||
content = ""
|
||||
metadata = {}
|
||||
if hasattr(node, "items"):
|
||||
metadata = dict(node.items())
|
||||
content = str(node)
|
||||
else:
|
||||
content = str(node)
|
||||
documents.append(Document(page_content=content, metadata=metadata))
|
||||
return documents
|
||||
@@ -0,0 +1,23 @@
|
||||
from langchain_neo4j import Neo4jVector
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
class Neo4jKnowledgeRetriever:
|
||||
def __init__(self, neo4j_conf: dict, embedding_client):
|
||||
neo4j_uri = neo4j_conf.get("uri")
|
||||
neo4j_username = neo4j_conf.get("username")
|
||||
neo4j_password = neo4j_conf.get("password")
|
||||
index_name = neo4j_conf.get("index_name", "vector") # 默认向量索引名
|
||||
keyword_index_name = neo4j_conf.get("keyword_index_name", "keyword") # 默认关键词索引名
|
||||
|
||||
self.vectorstore = Neo4jVector.from_existing_index(
|
||||
embedding_client.client,
|
||||
url=neo4j_uri,
|
||||
username=neo4j_username,
|
||||
password=neo4j_password,
|
||||
index_name=index_name,
|
||||
#keyword_index_name=keyword_index_name,
|
||||
#search_type="hybrid",
|
||||
)
|
||||
|
||||
def get_relevant_documents(self, query: str, k: int = 5):
|
||||
return self.vectorstore.similarity_search(query, k=k)
|
||||
@@ -0,0 +1,111 @@
|
||||
# src/prompt_manager.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema import SystemMessage, HumanMessage
|
||||
|
||||
@dataclass
|
||||
class CodeExecutorPrompts:
|
||||
understand_prompt: ChatPromptTemplate
|
||||
code_gen_prompt: ChatPromptTemplate
|
||||
code_fix_prompt: ChatPromptTemplate
|
||||
rewrite_prompt_template: ChatPromptTemplate
|
||||
cypher_conversion_prompt: ChatPromptTemplate # 新增Cypher转换提示模板
|
||||
|
||||
class PromptManager:
|
||||
def __init__(self):
|
||||
self.prompts = self._init_prompts()
|
||||
|
||||
def _init_prompts(self) -> CodeExecutorPrompts:
|
||||
understand_prompt = ChatPromptTemplate.from_template(
|
||||
"""
|
||||
你是一名电力造价业务专家,请基于以下工程文件业务结构,将用户自然语言问题改写成专业查询语句:
|
||||
|
||||
**工程文件业务结构**:
|
||||
{business_structure}
|
||||
|
||||
**改写规则**:
|
||||
1. **定位目标对象**:从业务结构中识别核心对象(如 `ProjectDivisionTree`→项目划分树、`FeeScheduleItem`→费用表)。
|
||||
2. **提取条件**:从用户输入中解析关键条件(如名称、量、类型),用【】标注变量。
|
||||
3. **构建专业语句**:格式为:`在[目标对象]中查找【条件】的项。`
|
||||
- 使用业务术语(如“项目划分项”而非“项目”)。
|
||||
- 条件需明确属性(如【名称】、【量】、【类型】)。
|
||||
4. **精确映射结构**:若用户查询层级(如“叶节点”),需在条件中体现。
|
||||
|
||||
**用户输入**:{user_input}
|
||||
**改写输出**:(仅输出改写后的语句)
|
||||
""")
|
||||
|
||||
code_gen_prompt = ChatPromptTemplate.from_template("""
|
||||
你是一个专业的Python工程师。我会给你一个用户问题,你需要将其转换为对应的Python代码
|
||||
|
||||
用户问题:
|
||||
{user_request}
|
||||
|
||||
上下文信息:
|
||||
{context}
|
||||
|
||||
工程数据访问库:
|
||||
{bowei_api_docs}
|
||||
|
||||
# 工作流程
|
||||
1. 从用户问题中提取关键信息(节点路径、节点类型、节点名称等)
|
||||
2. 根据"用户问题"和"上下文信息"选择最匹配的"工程数据访问库"中的方法
|
||||
3. 生成可直接执行的Python代码
|
||||
|
||||
# 代码模板(必须严格遵循)
|
||||
def neo4j_find_function():
|
||||
project = ProjectBuilder.build()
|
||||
status, data, error, helper_info = project.[SELECTED_METHOD]([PARAMETERS])
|
||||
return status, data, error, helper_info
|
||||
|
||||
# 执行规则
|
||||
- 参数必须从用户问题或上下文信息中提取
|
||||
- 必须确保生成的代码可以直接执行
|
||||
- 禁止修改代码模板结构
|
||||
- 禁止添加任何注释或解释
|
||||
|
||||
# 输出格式
|
||||
def neo4j_find_function():
|
||||
project = ProjectBuilder.build()
|
||||
status, data, error, helper_info = project.[SELECTED_METHOD]([PARAMETERS])
|
||||
return status, data, error, helper_info
|
||||
""")
|
||||
|
||||
code_fix_prompt = ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content="你是一个专业的Python开发者。下面是之前生成的代码和执行时出现的错误,请修复代码中的错误,确保代码语法正确且能完成访问博微造价工程数据的任务。只输出修复后的Python代码,不要添加解释。"),
|
||||
HumanMessage(content="原始代码:\n{code}\n\n错误信息:\n{error}\n")
|
||||
])
|
||||
|
||||
rewrite_prompt_template = ChatPromptTemplate.from_template(
|
||||
"""你是一个专业的工程业务助理,结合以下工程业务结构信息和相关知识:
|
||||
{business_structure}
|
||||
|
||||
相关知识内容:
|
||||
{context}
|
||||
|
||||
请根据用户的问题,结合上述信息,理解并改写成一个针对工程数据的访问请求(简洁明了的描述)。
|
||||
请只输出改写后的访问请求文本,不要多余解释。"""
|
||||
)
|
||||
|
||||
cypher_conversion_prompt = ChatPromptTemplate.from_template(
|
||||
"""
|
||||
你是一个Neo4j专家。请将用户的自然语言问题转换成一个有效的Cypher查询语句,查询知识图谱中相关信息。
|
||||
只返回Cypher语句,不要任何解释,最多返回5条。
|
||||
|
||||
业务结构信息:
|
||||
{business_structure}
|
||||
|
||||
用户问题:
|
||||
{user_input}
|
||||
Cypher查询语句:
|
||||
"""
|
||||
)
|
||||
|
||||
return CodeExecutorPrompts(
|
||||
understand_prompt=understand_prompt,
|
||||
code_gen_prompt=code_gen_prompt,
|
||||
code_fix_prompt=code_fix_prompt,
|
||||
rewrite_prompt_template=rewrite_prompt_template,
|
||||
cypher_conversion_prompt=cypher_conversion_prompt,
|
||||
)
|
||||
Reference in New Issue
Block a user