From 714beff8adb5a9ea2a787bfc4d9f39b6c93f0d27 Mon Sep 17 00:00:00 2001 From: chentianrui Date: Tue, 24 Jun 2025 14:46:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=96=E6=B6=88=E4=B8=A4=E6=AC=A1=E5=88=9B?= =?UTF-8?q?=E5=BB=BA=E6=95=B0=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/code_executor.py | 126 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 src/code_executor.py diff --git a/src/code_executor.py b/src/code_executor.py new file mode 100644 index 0000000..d363dd5 --- /dev/null +++ b/src/code_executor.py @@ -0,0 +1,126 @@ +import logging +from langchain_core.output_parsers import StrOutputParser +from langchain_experimental.utilities import PythonREPL +from langchain_core.tools import Tool +from langchain_experimental.tools import PythonREPLTool +from project_implementation import ProjectBuilder +import sys +import io +import traceback + +logger = logging.getLogger("BoweiAgent.CodeExecutor") + + +class CodeExecutor: + def __init__(self, prompts, llm_client, max_retries=3, neo4j_driver=None): + self.llm_client = llm_client + self.prompts = prompts + self.max_retries = max_retries + self.output_parser = StrOutputParser() + self.neo4j_driver = neo4j_driver # 存储驱动实例 + + def generate_code(self, user_request: str, context: str = "", bowei_api_docs: str = "") -> str: + logger.info(f"开始生成代码,访问请求:{user_request}") + prompt = self.prompts.code_gen_prompt.format_prompt( + user_request=user_request, context=context, bowei_api_docs=bowei_api_docs + ) + + response = self.llm_client.invoke(prompt.to_messages()) + code = self.output_parser.parse(response) + logger.debug(f"生成的代码内容:\n{code}") + return code + + def fix_code(self, code: str, error: str) -> str: + logger.warning(f"代码执行出错,开始修复。错误信息:{error}") + prompt = self.prompts.code_fix_prompt.format_prompt(code=code, error=error) + response = self.llm_client.invoke(prompt.to_messages()) + fixed_code = self.output_parser.parse(response) + logger.debug(f"修复后的代码内容:\n{fixed_code}") + return fixed_code + + def execute_code(self, code_str): + """封装代码执行逻辑""" + logger.debug(f"开始执行代码: {code_str}") + try: + namespace = { + "ProjectBuilder": ProjectBuilder, + "project_implementation": __import__("project_implementation"), + "project": __import__("project"), + } + + # 如果有驱动实例,先创建项目实例并传入驱动 + if self.neo4j_driver: + namespace["project_instance"] = ProjectBuilder.build(self.neo4j_driver) + + old_stdout = sys.stdout + redirected_output = io.StringIO() + sys.stdout = redirected_output + + exec(code_str, namespace) + + # 确保neo4j_find_function存在 + if "project_get_calculate_function" not in namespace: + raise ValueError("代码中未定义project_get_calculate_function函数") + + result_tuple = namespace["project_get_calculate_function"]() + + sys.stdout = old_stdout + output = redirected_output.getvalue().strip() + + if not isinstance(result_tuple, tuple) or len(result_tuple) != 4: + raise ValueError("函数应返回包含4个元素的元组(status, data, error, helper_info)") + + status, data, error, helper_info = result_tuple + + logger.info(f"执行结果: status={status}, data={data}, error={error}") + + return { + "status": status, + "data": data, + "error": error, + "helper_info": helper_info, + "output": output, + } + + except Exception as e: + # 确保恢复stdout + sys.stdout = old_stdout + logger.error(f"执行代码时出错: {e}") + logger.error(traceback.format_exc()) + + return { + "status": "error", + "error": str(e), + "helper_info": [], + "traceback": traceback.format_exc(), + } + + def generate_and_run_code(self, user_request: str, context: str = "", bowei_api_docs: str = "") -> str: + code = self.generate_code(user_request, context, bowei_api_docs) + logger.info("开始执行生成的代码") + + pre_code = "" + error_msg = "" + prev_happend_error = False + for attempt in range(self.max_retries): + try: + if prev_happend_error: + logger.error(f"代码执行失败,尝试第 {attempt+1} 次修复。错误信息:{error_msg}") + code = self.fix_code(pre_code, error_msg) + + import re + + pre_code = re.sub(r"^```python\s*|\s*```$", "", code.content, flags=re.MULTILINE) + result = self.execute_code(pre_code) + if result["status"] == "success": + logger.info(f"代码执行成功,结果: {result['data']}") + return result["data"] + else: + error_msg = result.get("error", "未知错误") + prev_happend_error = True + except Exception as e: + error_msg = str(e) + prev_happend_error = True + + logger.error(f"代码执行失败,超过最大重试次数 {self.max_retries}") + return f"代码执行失败,超过最大重试次数 {self.max_retries}。\n最后一次错误信息:\n{error_msg}"