调试后整体可以使用的版本。

This commit is contained in:
2025-06-26 12:57:21 +08:00
parent b792f9acfa
commit 9b3d5b7ef1
10 changed files with 1388 additions and 63 deletions
+7 -6
View File
@@ -587,7 +587,7 @@ class ProjectBuilder:
pass
@classmethod
def register(toolkit_class: type, config: dict):
def register(cls, toolkit_class: type, config: dict):
"""
注册工具类到工厂
@@ -597,10 +597,11 @@ class ProjectBuilder:
if not issubclass(toolkit_class, ProjectToolkit):
raise TypeError(f"{toolkit_class.__name__} 必须继承自 ProjectToolkit")
_config = config
_registry = toolkit_class
cls._config = config
cls._registry = toolkit_class
def build(self) -> ProjectToolkit:
@classmethod
def build(cls) -> ProjectToolkit:
"""
创建工具实例
@@ -609,7 +610,7 @@ class ProjectBuilder:
返回:
实例化的工具对象
"""
if _registry is None:
if cls._registry is None:
raise KeyError(f"未注册的类,请先注册类")
return _registry(_config)
return cls._registry(cls._config)
+5 -3
View File
@@ -56,11 +56,11 @@ from src.code_executor import CodeExecutor
from src.dialog_manager import DialogManager
from src.neo4j_raw_retriever import Neo4jRawRetriever
from src.embedding_client import EmbeddingClient
from project import ProjectBuilder
from project import ProjectBuilder, ProjectToolkit
from project_implementation import ProjectToolkitNeo4j
def main():
config = Config()
# 根据配置设置环境变量
@@ -87,7 +87,9 @@ def main():
# 创建Neo4j检索器
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
ProjectBuilder.register(ProjectToolkitNeo4j, neo4j_conf)
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder)
+7 -6
View File
@@ -587,7 +587,7 @@ class ProjectBuilder:
pass
@classmethod
def register(toolkit_class: type, config: dict):
def register(cls, toolkit_class: type, config: dict):
"""
注册工具类到工厂
@@ -597,10 +597,11 @@ class ProjectBuilder:
if not issubclass(toolkit_class, ProjectToolkit):
raise TypeError(f"{toolkit_class.__name__} 必须继承自 ProjectToolkit")
_config = config
_registry = toolkit_class
cls._config = config
cls._registry = toolkit_class
def build(self) -> ProjectToolkit:
@classmethod
def build(cls) -> ProjectToolkit:
"""
创建工具实例
@@ -609,7 +610,7 @@ class ProjectBuilder:
返回:
实例化的工具对象
"""
if _registry is None:
if cls._registry is None:
raise KeyError(f"未注册的类,请先注册类")
return _registry(_config)
return cls._registry(cls._config)
+1
View File
@@ -2215,3 +2215,4 @@ class ProjectToolkitNeo4j(ProjectToolkit):
error = f"查询失败: {str(e)}"
return status, data, error, helper_info
+2 -1
View File
@@ -3,7 +3,7 @@ from langchain_core.output_parsers import StrOutputParser
from langchain_experimental.utilities import PythonREPL
from langchain_core.tools import Tool
from langchain_experimental.tools import PythonREPLTool
from project import ProjectBuilder
from project import ProjectBuilder, ProjectToolkit
import sys
import io
import traceback
@@ -43,6 +43,7 @@ class CodeExecutor:
try:
namespace = {
"project": __import__("project"),
"ProjectBuilder": ProjectBuilder,
}
old_stdout = sys.stdout
+39 -47
View File
@@ -112,12 +112,9 @@ class DialogManager:
rewritten_list.append((rewritten, doc.page_content))
return rewritten_list
async def run_async(self, pre_input: str = None, automated: bool = False):
async def run_async(self, pre_input: str = None):
logger.info("启动对话管理器,等待用户输入")
if automated:
print("自动化模式已启动。")
else:
print("欢迎使用博微造价工程数据访问系统,输入 exit 退出。")
print("欢迎使用博微造价工程数据访问系统,输入 exit 退出。")
if pre_input:
user_questions = [pre_input]
@@ -127,16 +124,11 @@ class DialogManager:
while True:
if user_questions:
user_question = user_questions.pop(0)
if not automated:
print(f"预输入问题:{user_question}")
elif automated:
if not user_questions:
logger.info("自动化模式下没有更多问题,退出程序。")
break
print(f"预输入问题:{user_question}")
else:
user_question = input("请输入您的问题:")
if user_question.strip().lower() == "exit" and not automated:
if user_question.strip().lower() == "exit":
logger.info("用户退出程序")
print("退出程序。")
break
@@ -147,38 +139,38 @@ class DialogManager:
user_questions.clear()
continue
if automated:
# 自动化模式下选择第一个结果
selected_rewritten, selected_knowledge = rewritten_results[0]
logger.info(f"自动化模式选择第一个访问请求,内容:{selected_rewritten}")
result = self.code_executor.generate_and_run_code(
selected_rewritten,
context=selected_knowledge,
bowei_api_docs=self.bowei_api_docs
)
logger.info("代码执行完成,返回结果")
print("\n访问结果:\n", result)
print("-" * 50)
else:
InteractionHandler.display_rewritten_requests(rewritten_results)
choice_index = InteractionHandler.get_user_choice(rewritten_results)
if choice_index is not None:
selected_rewritten, selected_knowledge = rewritten_results[choice_index]
logger.info(f"用户选择访问请求编号 {choice_index + 1},内容:{selected_rewritten}")
print(f"\n您选择的访问请求是:\n{selected_rewritten}\n")
print(f"相关知识内容:\n{selected_knowledge}\n")
#confirm = input("请确认是否继续执行该请求(输入n取消,其他继续):").strip().lower()
confirm = ""
if confirm == "n":
#logger.info("用户取消执行访问请求")
print("取消执行,您可以重新输入问题。\n" + "-"*50)
else:
#logger.info("用户确认执行访问请求")
result = self.code_executor.generate_and_run_code(
selected_rewritten,
context=selected_knowledge,
bowei_api_docs=self.bowei_api_docs
)
logger.info("代码执行完成,返回结果")
print("\n访问结果:\n", result)
print("-" * 50)
InteractionHandler.display_rewritten_requests(rewritten_results)
choice_index = InteractionHandler.get_user_choice(rewritten_results)
if choice_index is not None:
selected_rewritten, selected_knowledge = rewritten_results[choice_index]
logger.info(f"用户选择访问请求编号 {choice_index + 1},内容:{selected_rewritten}")
print(f"\n您选择的访问请求是:\n{selected_rewritten}\n")
print(f"相关知识内容:\n{selected_knowledge}\n")
#confirm = input("请确认是否继续执行该请求(输入n取消,其他继续):").strip().lower()
confirm = ""
if confirm == "n":
#logger.info("用户取消执行访问请求")
print("取消执行,您可以重新输入问题。\n" + "-"*50)
else:
#logger.info("用户确认执行访问请求")
result = self.execute_generated_code(
selected_rewritten,
selected_knowledge
)
print("\n访问结果:\n", result)
print("-" * 50)
def execute_generated_code(self, selected_rewritten, selected_knowledge):
"""
执行生成的代码并返回结果
:param selected_rewritten: 选中的重写后的请求
:param selected_knowledge: 选中的知识
:return: 代码执行结果
"""
result = self.code_executor.generate_and_run_code(
selected_rewritten,
context=selected_knowledge,
bowei_api_docs=self.bowei_api_docs
)
return result
+3
View File
@@ -17,6 +17,9 @@ class Neo4jRawRetriever:
def close(self):
self.driver.close()
def driver(self):
return self.driver
def get_relevant_documents(self, cypher_query: str) -> list[Document]:
with self.driver.session() as session:
result = session.run(cypher_query)
+102
View File
@@ -0,0 +1,102 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
from src.dialog_manager import DialogManager
from src.llm_client import LLMClient
from src.code_executor import CodeExecutor
from src.neo4j_raw_retriever import Neo4jRawRetriever
from src.prompt_manager import PromptManager
import yaml
from src.config import Config
from src.document_loader import load_file
from src.embedding_client import EmbeddingClient
from project import ProjectBuilder, ProjectToolkit
from project_implementation import ProjectToolkitNeo4j
def main():
config = Config()
business_structure = load_file(config.business_object_structure_path)
bowei_api_docs = load_file(config.bowei_api_docs_path)
llm_client = LLMClient(config.openai)
llm_client_coder = LLMClient(config.openai_coder)
prompt_manager = PromptManager()
neo4j_conf = config.neo4j_conf
embedding_conf = config.embedding
embedding_client = EmbeddingClient(embedding_conf)
# 创建Neo4j检索器
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder)
dialog_manager = DialogManager(
llm_client,
business_structure,
bowei_api_docs,
code_executor,
knowledge_retriever,
prompt_manager,
)
# 加载 zhibiao.json
with open('./tests/zhibiao.json', 'r', encoding='utf-8') as f:
zhibiao_data = json.load(f)
# 提取指标映射关系并批量执行
for item in zhibiao_data:
query = item['指标描述']['指标映射']
rewritten_results = dialog_manager.understand_user_question_stream(query)
if rewritten_results is None or rewritten_results == []:
print('问题: {} 没有找到符合要求的数据'.format(query))
continue
selected_rewritten, selected_knowledge = rewritten_results[0]
questions = []
success_count = 0
fail_count = 0
error_list = []
# 检查 understand_user_question_stream 方法调用结果,假设存在该方法调用
# 示例调用,实际使用时请替换为真实调用
# stream_result = dialog_manager.understand_user_question_stream(query)
# if stream_result is None or stream_result == []:
# questions.append(query)
# fail_count += 1
# error_list.append(f"问题 {query} 调用 understand_user_question_stream 返回空结果")
result = dialog_manager.execute_generated_code(selected_rewritten, selected_knowledge)
if result is None or result == []:
questions.append(selected_rewritten)
fail_count += 1
error_list.append(f"问题 {selected_rewritten} 调用 execute_generated_code 返回空结果")
else:
success_count += 1
print(result)
if __name__ == "__main__":
total = success_count + fail_count
if total > 0:
success_rate = success_count / total
fail_rate = fail_count / total
else:
success_rate = fail_rate = 0
print(f"问题总数: {total}")
print(f"成功比例: {success_rate * 100:.2f}%")
print(f"失败比例: {fail_rate * 100:.2f}%")
print("错误列表:")
for error in error_list:
print(error)
main()
+76
View File
@@ -0,0 +1,76 @@
import sys
import os
import io
import logging
import traceback
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from project import ProjectBuilder, ProjectToolkit
from project_implementation import ProjectToolkitNeo4j
from neo4j import GraphDatabase
from src.config import Config
logging.basicConfig(
level=logging.DEBUG, # 生产环境可改为 INFO 或 WARNING
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler("test_runcode.log", encoding="utf-8"), logging.StreamHandler()],
)
logger = logging.getLogger("test_runcode")
def main():
config = Config()
neo4j_conf = config.neo4j_conf
code_str = '''
def project_get_calculate_function():
project = ProjectBuilder.build()
status, data, error, helper_info = project.get_division_item_by_path("安装/架空输电线路本体工程")
if status == 'success':
return status, data.get('单位', ''), error, helper_info
return status, None, error, helper_info
'''
neo4j_driver = GraphDatabase.driver(neo4j_conf.get("uri"), auth=(neo4j_conf.get("username"), neo4j_conf.get("password")))
old_stdout = sys.stdout
redirected_output = io.StringIO()
ProjectBuilder.register(ProjectToolkitNeo4j, neo4j_driver)
try:
namespace = {
"project": __import__("project"),
"ProjectBuilder": ProjectBuilder,
}
sys.stdout = redirected_output
exec(code_str, namespace)
# 确保neo4j_find_function存在
if "project_get_calculate_function" not in namespace:
raise ValueError("代码中未定义project_get_calculate_function函数")
result_tuple = namespace["project_get_calculate_function"]()
sys.stdout = old_stdout
output = redirected_output.getvalue().strip()
if not isinstance(result_tuple, tuple) or len(result_tuple) != 4:
raise ValueError("函数应返回包含4个元素的元组(status, data, error, helper_info)")
status, data, error, helper_info = result_tuple
logger.info(f"执行结果: status={status}, data={data}, error={error}")
except Exception as e:
# 确保恢复stdout
sys.stdout = old_stdout
logger.error(f"执行代码时出错: {e}")
logger.error(traceback.format_exc())
if __name__ == "__main__":
main()
+1146
View File
File diff suppressed because it is too large Load Diff