实现完整功能

This commit is contained in:
2025-07-07 08:23:02 +08:00
parent 35d50305c8
commit d1c129c691
20 changed files with 504 additions and 469 deletions
+67 -37
View File
@@ -12,23 +12,11 @@ from src.project import ProjectBuilder, ProjectToolkit
import sys
import io
import traceback
import importlib
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
class CodeExecutor:
def __init__(self, prompts, llm_client, max_retries):
self.llm_client = llm_client
@@ -36,45 +24,87 @@ class CodeExecutor:
self.max_retries = max_retries if max_retries >= 1 else 1
self.output_parser = StrOutputParser()
def generate_code(self, user_request: str, context: str = "", bowei_api_docs: str = "") -> str:
def generate_code(self, user_request: str, context: str = "", bowei_api_docs: str = "") -> dict:
logger.info(f"开始生成代码,访问请求:{user_request}")
prompt = self.prompts.code_gen_prompt.format_prompt(
user_request=user_request, context=context, bowei_api_docs=bowei_api_docs
)
response = self.llm_client.invoke(prompt.to_messages())
code = self.output_parser.parse(response)
logger.debug(f"生成的代码内容:\n{code}")
try:
response = self.llm_client.invoke(prompt.to_messages())
parsed_response = self.output_parser.parse(response)
# 处理 AIMessage 类型的返回值
if hasattr(parsed_response, 'content'):
code = parsed_response.content
else:
code = str(parsed_response)
logger.debug(f"生成的代码内容:\n{code}")
return {
"code": 20000,
"message": 'ok',
"status": True,
"data": code
}
except Exception as e:
logger.error(f"大模型调用失败: {str(e)}", exc_info=True)
return {
"code": 50000,
"message": f'大模型调用失败: {str(e)}',
"status": False,
"data": None
}
return {
"code": 20000,
"message": 'ok',
"status": True,
"data": code.content
}
def fix_code(self, code: str, error: str) -> str:
def fix_code(self, code: str, error: str) -> dict:
logger.warning(f"代码执行出错,开始修复。错误信息:{error}")
prompt = self.prompts.code_fix_prompt.format_prompt(code=code, error=error)
response = self.llm_client.invoke(prompt.to_messages())
fixed_code = self.output_parser.parse(response)
logger.debug(f"修复后的代码内容:\n{fixed_code}")
return {
"code": 20000,
"message": 'ok',
"status": True,
"data": fixed_code.content
}
try:
response = self.llm_client.invoke(prompt.to_messages())
parsed_response = self.output_parser.parse(response)
# 处理 AIMessage 类型的返回值
if hasattr(parsed_response, 'content'):
fixed_code = parsed_response.content
else:
fixed_code = str(parsed_response)
logger.debug(f"修复后的代码内容:\n{fixed_code}")
return {
"code": 20000,
"message": 'ok',
"status": True,
"data": fixed_code
}
except Exception as e:
logger.error(f"代码修复时大模型调用失败: {str(e)}", exc_info=True)
return {
"code": 50001,
"message": f'代码修复失败: {str(e)}',
"status": False,
"data": None
}
def execute_code(self, code_str) -> dict:
"""封装代码执行逻辑"""
logger.debug(f"开始执行代码:\n {code_str}")
try:
import re
pattern = r'```python(.*?)```'
match = re.search(pattern, code_str, re.DOTALL)
if match:
code_str = match.group(1).strip()
except Exception as e:
logger.warning(f"解析生成代码格式时发生异常: {str(e)}")
old_stdout = None
try:
namespace = {
"project": __import__("src.project"),
"project": importlib.import_module("src.project"),
"Material": getattr(importlib.import_module("src.project"), "Material", None),
"Ration": getattr(importlib.import_module("src.project"), "Ration", None),
"Equipment": getattr(importlib.import_module("src.project"), "Equipment", None),
"MaterialOrEquipment": getattr(importlib.import_module("src.project"), "MaterialOrEquipment", None),
"ProjectBuilder": ProjectBuilder,
}
-15
View File
@@ -5,21 +5,6 @@ import os
import logging
from datetime import datetime
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
class Config:
def __init__(self, path="config.yaml"):
with open(path, "r", encoding="utf-8") as f:
-12
View File
@@ -8,18 +8,6 @@ from langchain.schema import SystemMessage, HumanMessage
import asyncio
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
class QuestionProcessor:
-16
View File
@@ -1,23 +1,7 @@
# src/document_loader.py
import os
import logging
from datetime import datetime
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
def load_file(path: str) -> str:
try:
+1 -12
View File
@@ -6,20 +6,9 @@ from datetime import datetime
from langchain_openai import OpenAIEmbeddings
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
class EmbeddingClient:
def __init__(self, embedding_config: dict):
api_key = embedding_config.get("api_key")
+2 -12
View File
@@ -8,21 +8,11 @@ import getpass
from langchain_openai import ChatOpenAI
from langchain_core.rate_limiters import InMemoryRateLimiter
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
class LLMClient:
def __init__(self, openai_config: dict):
api_key = openai_config.get("api_key")
+2 -12
View File
@@ -9,19 +9,8 @@ import itertools
from langchain_openai import ChatOpenAI
from langchain_core.rate_limiters import InMemoryRateLimiter
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
class MultiAPIKeyChatOpenAI:
@@ -63,6 +52,7 @@ class MultiAPIKeyChatOpenAI:
# 轮询器,用于循环调用不同的 llm 实例
self._llm_cycle = itertools.cycle(self.llms)
self.llm = next(self._llm_cycle)
def invoke(self, messages):
llm = next(self._llm_cycle)
+2 -12
View File
@@ -7,21 +7,11 @@ from typing import List
from langchain.schema import Document
from neo4j import GraphDatabase
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
class Neo4jRawRetriever:
def __init__(self, neo4j_conf: dict):
self.uri = neo4j_conf.get("uri")
+2 -12
View File
@@ -6,21 +6,11 @@ from datetime import datetime
from langchain_neo4j import Neo4jVector
from langchain_openai import OpenAIEmbeddings
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
class Neo4jKnowledgeRetriever:
def __init__(self, neo4j_conf: dict, embedding_client):
neo4j_uri = neo4j_conf.get("uri")
+22 -22
View File
@@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
import json
from typing import Any, Type
class ProjectToolkit(ABC):
@@ -13,13 +14,13 @@ class ProjectToolkit(ABC):
描述: 代表整个项目结构的顶层容器
"""
def __init__(self):
self.project_division_set = ProjectDivisionItem() # 项目划分集对象
def __init__(self, config: Any):
pass
# 项目划分查询方法
@abstractmethod
def get_division_by_name(self, name):
def get_division_by_name(self, name_part):
"""
通过名称获取项目划分对象
@@ -89,7 +90,7 @@ class ProjectToolkit(ABC):
pass
@abstractmethod
def get_quantities_node_by_parent_and_code(self, parent_path, quantity_type=None, code=None):
def get_quantities_node_by_parent_and_code(self, parent_path, quantity_type, code):
"""
通过父节点路径和编码获取工程量对象(定额、主材或设备),包括子节点
@@ -108,14 +109,15 @@ class ProjectToolkit(ABC):
pass
@abstractmethod
def get_quantities_node_by_parent_and_name(self, parent_path, partial_name, quantity_type=None):
def get_quantities_node_by_parent_and_name(self, parent_path, quantity_type, partial_name):
"""
通过父节点路径、模糊节点名称和类型获取工程量对象(主材或者设备),包括子节点
通过父节点路径、类型和模糊节点名称获取工程量对象(主材或者设备),包括子节点
Args:
parent_path (str): 父节点的路径,以'/'分隔的多级节点路径
partial_name (str): 目标节点的模糊或不完整名称
quantity_type (str): 工程量类型('定额''主材''设备')
partial_name (str): 目标节点的模糊或不完整名称
Returns:
dict: 返回字典,字段包括:
@@ -202,7 +204,7 @@ class ProjectToolkit(ABC):
# 费用表查询方法
@abstractmethod
def get_fee_schedule_on_auxiliary_expense_table(self, table_name, fee_name, fee: str):
def get_fee_schedule_on_auxiliary_expense_table(self, table_name, fee_name, fee_attribute: str):
"""
在辅助费用表中查找费用
@@ -221,7 +223,7 @@ class ProjectToolkit(ABC):
pass
@abstractmethod
def get_fee_schedule_on_other_expense_table(self, table_name, fee_name, fee):
def get_fee_schedule_on_other_expense_table(self, table_name, fee_name, fee_attribute):
"""
在其它费用表中查找费用
@@ -240,7 +242,7 @@ class ProjectToolkit(ABC):
pass
@abstractmethod
def get_fee_schedule_on_land_acquisition_fee_table_table(self, table_name, fee_name, fee):
def get_fee_schedule_on_land_acquisition_fee_table_table(self, table_name, fee_name, fee_attribute):
"""
在其中:场地征用费用表中查找费用
@@ -259,7 +261,7 @@ class ProjectToolkit(ABC):
pass
@abstractmethod
def get_fee_schedule_on_installation_price_difference_table(self, table_name, fee_name, fee):
def get_fee_schedule_on_installation_price_difference_table(self, table_name, fee_name, fee_attribute):
"""
在安装价差费用表中查找费用
@@ -278,7 +280,7 @@ class ProjectToolkit(ABC):
pass
@abstractmethod
def get_fee_schedule_on_Engineering_Cost_table(self, table_name, fee_name, fee):
def get_fee_schedule_on_Engineering_Cost_table(self, table_name, fee_name, fee_attribute):
"""
在工程费用表中查找费用
@@ -687,18 +689,19 @@ class Fee:
self.施工费 = None # xsd:string (可选)
self.单位投资 = None # xsd:string (可选)
class ProjectBuilder:
# 存储注册的工具类
_registry = None
_config = {}
class ProjectBuilder:
"""项目工具工厂类"""
# 存储注册的工具类
_registry: Type[ProjectToolkit] | None = None
_config: Any = None
def __init__(self):
pass
@classmethod
def register(cls, toolkit_class: type, config: dict):
def register(cls, toolkit_class: Type[ProjectToolkit], config: Any):
"""
注册工具类到工厂
@@ -716,13 +719,10 @@ class ProjectBuilder:
"""
创建工具实例
参数:
返回:
实例化的工具对象
"""
if cls._registry is None:
raise KeyError(f"未注册的类,请先注册类")
raise KeyError("未注册的类,请先注册类")
return cls._registry(cls._config)
return cls._registry(cls._config)
+173 -83
View File
@@ -25,7 +25,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
if neo4j_driver is None:
raise ValueError("必须提供Neo4j驱动实例")
super().__init__()
super().__init__(neo4j_driver)
# 保存驱动实例
self.driver = neo4j_driver
@@ -164,7 +164,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
if not node_data:
code = 201
status_flag = False
error = f"找不到路径: {path} 上的ProjectDivisionItem节点"
error = f"错误信息:找不到路径: {path} 上的ProjectDivisionItem节点"
# 提取父路径
path_parts = path.split("/")
@@ -190,7 +190,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
if record["name"]:
helper_info.append(record["name"])
except Exception as e:
helper_info = [f"查询父节点下的子节点时出错: {str(e)}"]
helper_info = [f"错误信息:查询父节点下的子节点时出错: {str(e)}"]
# 拼接 message
message = f"错误信息:{error}; 辅助信息: {helper_info}"
@@ -246,7 +246,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
if not partial_name:
code = 201
status_flag = False
message = "节点名称不能为空"
message = "错误信息:partial_name参数错误,参数不能为空"
else:
try:
# 第一步:找到父节点
@@ -385,7 +385,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
if not name_part or name_part.strip() == "":
code = 201
status_flag = False
message = "输入的名称部分不能为空"
message = "错误信息:name_part参数错误,参数不能为空"
else:
try:
# 直接查询所有类型为ProjectDivisionItem且name包含输入名称的节点
@@ -482,7 +482,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
if not paths_str:
code = 201
status_flag = False
message = "路径不能为空"
message = "错误信息:paths_str参数错误,参数不能为空"
else:
try:
# 使用通用方法获取节点,考虑所有可能的工程量类型
@@ -547,7 +547,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
# 最终统一返回格式包装成列表
return {"code": code, "message": message, "status": status_flag, "data": data}
def get_quantities_node_by_parent_and_code(self, parent_path, quantity_type=None, code=None) -> dict:
def get_quantities_node_by_parent_and_code(self, parent_path, quantity_type, code) -> dict:
"""
通过父节点路径和编码获取工程量对象(定额、主材或设备),包括子节点
@@ -579,13 +579,13 @@ class ProjectToolkitNeo4j(ProjectToolkit):
if not code or code.strip() == "":
code_status = 201
status_flag = False
message = "编码不能为空"
message = "错误信息:code参数错误,参数不能为空"
valid_types = ["定额", "主材", "设备", None]
if quantity_type not in valid_types:
code_status = 201
status_flag = False
message = f"无效的工程量类型: '{quantity_type}';有效类型为 {valid_types}"
message = f"错误信息:quantity_type参数错误,有效类型 {valid_types},当前值是'{quantity_type}'"
elif code_status == 200:
try:
@@ -701,7 +701,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
def get_quantities_node_by_parent_and_name(self, parent_path, quantity_type, partial_name) -> dict:
"""
通过父节点路径、模糊节点名称和类型获取工程量对象(主材或者设备),包括子节点
通过父节点路径、类型和模糊节点名称获取工程量对象(主材或者设备),包括子节点
执行三步查询:
1. 找到对应路径的父节点
@@ -714,7 +714,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
quantity_type (str): 工程量类型('定额''主材''设备')
Returns:
dict: 返回字典,字段包括:
dict: 包含一个字典,字段包括:
- code (int): 状态码,固定为 200(成功)或 201(失败)
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
- status (bool): 成功为 True,失败为 False
@@ -825,30 +825,59 @@ class ProjectToolkitNeo4j(ProjectToolkit):
matching_nodes = []
available_names = []
# 统一处理 partial_name 为列表形式
if isinstance(partial_name, str):
if partial_name.strip() == "":
code_status = 201
status_flag = False
message = "partial_name 不能为空"
else:
# 支持中文逗号、英文逗号、空格分隔
import re
keywords = re.split(r"[,,、\s]+", partial_name.strip())
else:
keywords = partial_name or []
if not keywords:
code_status = 201
status_flag = False
message = "未提供有效的关键词"
if code_status != 200:
return {"code": code_status, "message": message, "status": status_flag, "data": data}
# 尝试用 Python 端过滤
for node in child_nodes:
node_name = node.get("name", "")
if node_name:
available_names.append(node_name)
available_names.append(node_name)
if partial_name in str(node_name):
matching_nodes.append(node)
for keyword in keywords:
if keyword in str(node_name):
matching_nodes.append(node)
break
# 如果 Python 端没找到,尝试数据库端模糊搜索
if not matching_nodes:
# 构建动态查询语句
where_clauses = " OR ".join([f"q.name CONTAINS '{keyword}'" for keyword in keywords])
direct_query = f"""
MATCH (q)
WHERE q.name CONTAINS $partial_name AND {type_condition}
WHERE ({where_clauses}) AND {type_condition}
RETURN q LIMIT 20
"""
direct_params = {"partial_name": partial_name}
direct_result = self.session.run(direct_query, **direct_params)
matching_nodes = [record["q"] for record in direct_result]
try:
direct_result = self.session.run(direct_query)
matching_nodes = [record["q"] for record in direct_result]
except Exception as e:
code_status = 201
status_flag = False
message = f"数据库模糊查询失败: {str(e)}"
if not matching_nodes:
code_status = 201
status_flag = False
message = f"错误信息:在父节点路径'{parent_path}' 下找不到名称包含 '{partial_name}' 的节点;辅助信息:{available_names}"
message = f"错误信息:在父节点路径'{parent_path}' 下找不到包含关键词 {keywords} 的节点;辅助信息:{available_names}"
else:
result_data = []
for node in matching_nodes:
@@ -1008,7 +1037,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
if not parent_path or not code:
code_status = 201
status_flag = False
message = "父节点路径或要查找的编码不能为空"
message = "错误信息:parent_path或code参数错误,参数不能为空"
else:
try:
@@ -1323,7 +1352,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
fee_attribute (str): 费用值属性名
Returns:
dict: 返回字典,字段包括:
dict: 包含一个字典,字段包括:
- code (int): 状态码,固定为 200(成功)或 201(失败)
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
- status (bool): 成功为 True,失败为 False
@@ -1344,6 +1373,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
else:
# 第一步:查找父节点(费用表节点)
table_name = table_name.replace("", "")
parent_path = f"工程/工程费用/{table_name}"
parent_node_data = self.get_node_by_path(parent_path)
@@ -1417,17 +1447,28 @@ class ProjectToolkitNeo4j(ProjectToolkit):
else:
# 第三步:获取费用节点的属性值
try:
if fee_node and hasattr(fee_node, "get"):
fee_value = fee_node.get(fee_attribute)
elif fee_node and isinstance(fee_node, dict):
fee_value = fee_node.get(fee_attribute)
if hasattr(fee_node, "keys"):
all_attrs = list(fee_node.keys())
elif isinstance(fee_node, dict):
all_attrs = list(fee_node.keys())
else:
# 如果fee_node是Neo4j Node对象,尝试直接访问属性
fee_value = (
getattr(fee_node, fee_attribute, None)
if hasattr(fee_node, fee_attribute)
else fee_node.get(fee_attribute) if hasattr(fee_node, "get") else None
)
all_attrs = list(fee_node.keys()) if hasattr(fee_node, "keys") else []
# 过滤掉私有属性
all_attrs = [attr for attr in all_attrs if not attr.startswith("_")]
# 使用“包含”逻辑模糊匹配属性名
matched_attrs = []
for attr in all_attrs:
if fee_attribute.lower() in attr.lower() or attr.lower() in fee_attribute.lower():
matched_attrs.append(attr)
if matched_attrs:
# 优先返回第一个匹配项的值
best_match = matched_attrs[0]
fee_value = fee_node.get(best_match)
else:
fee_value = None
if fee_value is None:
code_status = 201
@@ -1487,7 +1528,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
fee_attribute (str): 费用值属性名
Returns:
dict: 返回字典,字段包括:
dict: 包含一个字典,字段包括:
- code (int): 状态码,固定为 200(成功)或 201(失败)
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
- status (bool): 成功为 True,失败为 False
@@ -1508,6 +1549,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
else:
# 第一步:查找父节点(费用表节点)
table_name = table_name.replace("", "")
parent_path = f"工程/工程费用/{table_name}"
parent_node_data = self.get_node_by_path(parent_path)
@@ -1541,7 +1583,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
# 递归查找费用节点,最多查找3级深度
query = """
MATCH (p)-[*1..3]->(f)
MATCH (p)-[*1..]->(f)
WHERE p.name = $table_name AND f.name = $fee_name
RETURN f LIMIT 1
"""
@@ -1581,17 +1623,28 @@ class ProjectToolkitNeo4j(ProjectToolkit):
else:
# 第三步:获取费用节点的属性值
try:
if fee_node and hasattr(fee_node, "get"):
fee_value = fee_node.get(fee_attribute)
elif fee_node and isinstance(fee_node, dict):
fee_value = fee_node.get(fee_attribute)
if hasattr(fee_node, "keys"):
all_attrs = list(fee_node.keys())
elif isinstance(fee_node, dict):
all_attrs = list(fee_node.keys())
else:
# 如果fee_node是Neo4j Node对象,尝试直接访问属性
fee_value = (
getattr(fee_node, fee_attribute, None)
if hasattr(fee_node, fee_attribute)
else fee_node.get(fee_attribute) if hasattr(fee_node, "get") else None
)
all_attrs = list(fee_node.keys()) if hasattr(fee_node, "keys") else []
# 过滤掉私有属性
all_attrs = [attr for attr in all_attrs if not attr.startswith("_")]
# 使用“包含”逻辑模糊匹配属性名
matched_attrs = []
for attr in all_attrs:
if fee_attribute.lower() in attr.lower() or attr.lower() in fee_attribute.lower():
matched_attrs.append(attr)
if matched_attrs:
# 优先返回第一个匹配项的值
best_match = matched_attrs[0]
fee_value = fee_node.get(best_match)
else:
fee_value = None
if fee_value is None:
code_status = 201
@@ -1653,7 +1706,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
fee_attribute (str): 费用值属性名
Returns:
dict: 返回字典,字段包括:
dict: 包含一个字典,字段包括:
- code (int): 状态码,固定为 200(成功)或 201(失败)
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
- status (bool): 成功为 True,失败为 False
@@ -1674,6 +1727,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
else:
# 第一步:查找父节点(费用表节点)
table_name = table_name.replace("", "")
parent_path = f"工程/工程费用/{table_name}"
parent_node_data = self.get_node_by_path(parent_path)
@@ -1747,17 +1801,28 @@ class ProjectToolkitNeo4j(ProjectToolkit):
else:
# 第三步:获取费用节点的属性值
try:
if fee_node and hasattr(fee_node, "get"):
fee_value = fee_node.get(fee_attribute)
elif fee_node and isinstance(fee_node, dict):
fee_value = fee_node.get(fee_attribute)
if hasattr(fee_node, "keys"):
all_attrs = list(fee_node.keys())
elif isinstance(fee_node, dict):
all_attrs = list(fee_node.keys())
else:
# 如果fee_node是Neo4j Node对象,尝试直接访问属性
fee_value = (
getattr(fee_node, fee_attribute, None)
if hasattr(fee_node, fee_attribute)
else fee_node.get(fee_attribute) if hasattr(fee_node, "get") else None
)
all_attrs = list(fee_node.keys()) if hasattr(fee_node, "keys") else []
# 过滤掉私有属性
all_attrs = [attr for attr in all_attrs if not attr.startswith("_")]
# 使用“包含”逻辑模糊匹配属性名
matched_attrs = []
for attr in all_attrs:
if fee_attribute.lower() in attr.lower() or attr.lower() in fee_attribute.lower():
matched_attrs.append(attr)
if matched_attrs:
# 优先返回第一个匹配项的值
best_match = matched_attrs[0]
fee_value = fee_node.get(best_match)
else:
fee_value = None
if fee_value is None:
code_status = 201
@@ -1819,7 +1884,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
fee_attribute (str): 费用值属性名
Returns:
dict: 返回字典,字段包括:
dict: 包含一个字典,字段包括:
- code (int): 状态码,固定为 200(成功)或 201(失败)
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
- status (bool): 成功为 True,失败为 False
@@ -1840,6 +1905,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
else:
# 第一步:查找父节点(费用表节点)
table_name = table_name.replace("", "")
parent_path = f"工程/工程费用/{table_name}"
parent_node_data = self.get_node_by_path(parent_path)
@@ -1913,17 +1979,28 @@ class ProjectToolkitNeo4j(ProjectToolkit):
else:
# 第三步:获取费用节点的属性值
try:
if fee_node and hasattr(fee_node, "get"):
fee_value = fee_node.get(fee_attribute)
elif fee_node and isinstance(fee_node, dict):
fee_value = fee_node.get(fee_attribute)
if hasattr(fee_node, "keys"):
all_attrs = list(fee_node.keys())
elif isinstance(fee_node, dict):
all_attrs = list(fee_node.keys())
else:
# 如果fee_node是Neo4j Node对象,尝试直接访问属性
fee_value = (
getattr(fee_node, fee_attribute, None)
if hasattr(fee_node, fee_attribute)
else fee_node.get(fee_attribute) if hasattr(fee_node, "get") else None
)
all_attrs = list(fee_node.keys()) if hasattr(fee_node, "keys") else []
# 过滤掉私有属性
all_attrs = [attr for attr in all_attrs if not attr.startswith("_")]
# 使用“包含”逻辑模糊匹配属性名
matched_attrs = []
for attr in all_attrs:
if fee_attribute.lower() in attr.lower() or attr.lower() in fee_attribute.lower():
matched_attrs.append(attr)
if matched_attrs:
# 优先返回第一个匹配项的值
best_match = matched_attrs[0]
fee_value = fee_node.get(best_match)
else:
fee_value = None
if fee_value is None:
code_status = 201
@@ -1983,7 +2060,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
fee_attribute (str): 费用值属性名
Returns:
dict: 返回字典,字段包括:
dict: 包含一个字典,字段包括:
- code (int): 状态码,固定为 200(成功)或 201(失败)
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
- status (bool): 成功为 True,失败为 False
@@ -2004,6 +2081,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
else:
# 第一步:查找父节点(费用表节点)
table_name = table_name.replace("", "")
parent_path = f"工程/工程费用/{table_name}"
parent_node_data = self.get_node_by_path(parent_path)
@@ -2077,17 +2155,28 @@ class ProjectToolkitNeo4j(ProjectToolkit):
else:
# 第三步:获取费用节点的属性值
try:
if fee_node and hasattr(fee_node, "get"):
fee_value = fee_node.get(fee_attribute)
elif fee_node and isinstance(fee_node, dict):
fee_value = fee_node.get(fee_attribute)
if hasattr(fee_node, "keys"):
all_attrs = list(fee_node.keys())
elif isinstance(fee_node, dict):
all_attrs = list(fee_node.keys())
else:
# 如果fee_node是Neo4j Node对象,尝试直接访问属性
fee_value = (
getattr(fee_node, fee_attribute, None)
if hasattr(fee_node, fee_attribute)
else fee_node.get(fee_attribute) if hasattr(fee_node, "get") else None
)
all_attrs = list(fee_node.keys()) if hasattr(fee_node, "keys") else []
# 过滤掉私有属性
all_attrs = [attr for attr in all_attrs if not attr.startswith("_")]
# 使用“包含”逻辑模糊匹配属性名
matched_attrs = []
for attr in all_attrs:
if fee_attribute.lower() in attr.lower() or attr.lower() in fee_attribute.lower():
matched_attrs.append(attr)
if matched_attrs:
# 优先返回第一个匹配项的值
best_match = matched_attrs[0]
fee_value = fee_node.get(best_match)
else:
fee_value = None
if fee_value is None:
code_status = 201
@@ -2223,7 +2312,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
fee_name (str): 取费名称
Returns:
dict: 返回字典,字段包括:
dict: 包含一个字典,字段包括:
- code (int): 状态码,固定为 200(成功)或 201(失败)
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
- status (bool): 成功为 True,失败为 False
@@ -2308,10 +2397,10 @@ class ProjectToolkitNeo4j(ProjectToolkit):
else:
cost_set_node = record["c"]
# 第三步:在CostSet节点的子节点中查找名称为fee_name的CostItem节点
# 第三步:在CostSet节点的子节点中查找名称为fee_name的CostItem节点(模糊匹配)
query = """
MATCH (c:CostSet)-[*1..1]->(i:CostItem)
WHERE id(c) = $cost_set_id AND i.name = $fee_name
WHERE id(c) = $cost_set_id AND (i.name CONTAINS $fee_name OR $fee_name CONTAINS i.name)
RETURN i
LIMIT 1
"""
@@ -2324,7 +2413,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
if not record:
code_status = 201
status_flag = False
error = f"在CostSet节点下找不到名称为 {fee_name} 的CostItem节点"
error = f"在CostSet节点下找不到 {fee_name} 模糊匹配的CostItem节点"
# 查询该CostSet下所有CostItem节点名称作为辅助信息
helper_query = """
@@ -3033,3 +3122,4 @@ class ProjectToolkitNeo4j(ProjectToolkit):
# 统一返回格式包装成列表
return {"code": code_status, "message": message, "status": status_flag, "data": data}
+2 -17
View File
@@ -75,9 +75,9 @@ def project_get_calculate_function():
# 执行规则
- 参数必须从用户问题或上下文信息中提取
- 禁止在代码函数范围外添加任何注释或解释或非代码内容
- 输出代码中必须以def project_get_calculate_function() -> dict函数作为入口函数
- 必须确保生成的代码可以直接执行,代码要注意进行各类错误检查,出错采用抛出异常方式,说明详细信息
- 禁止添加任何注释或解释
- 必须确保生成的代码可以直接执行,如果函数功能求取数值,project的函数返回结果为空或出错则算成功,data为0,并在message说明错误原因,代码要注意进行各类容错检查
- ProjectToolkit 类中涉及项目划分的函数已考虑在其及其子孙项目划分下查找,所以无需生成递归子项目划分的代码
- 如果文本中包含范围编码格式则需要进行编码展开,如'YX2-1~7'展开为‘YX2-1/YX2-2/YX2-3/YX2-4/YX2-5/YX2-6/YX2-7
"""
@@ -172,21 +172,6 @@ Cypher查询语句:MATCH (item:ProjectDivisionItem)\nWHERE item.name CONTAINS
cypher_conversion_prompt=cypher_conversion_prompt,
)
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
@dataclass
class CodeExecutorPrompts:
understand_prompt: ChatPromptTemplate
-12
View File
@@ -12,18 +12,6 @@ from typing import List, Dict, Any
import asyncio
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
class BusinessObject(BaseModel):