实现完整功能
This commit is contained in:
+67
-37
@@ -12,23 +12,11 @@ from src.project import ProjectBuilder, ProjectToolkit
|
||||
import sys
|
||||
import io
|
||||
import traceback
|
||||
import importlib
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
def __init__(self, prompts, llm_client, max_retries):
|
||||
self.llm_client = llm_client
|
||||
@@ -36,45 +24,87 @@ class CodeExecutor:
|
||||
self.max_retries = max_retries if max_retries >= 1 else 1
|
||||
self.output_parser = StrOutputParser()
|
||||
|
||||
def generate_code(self, user_request: str, context: str = "", bowei_api_docs: str = "") -> str:
|
||||
def generate_code(self, user_request: str, context: str = "", bowei_api_docs: str = "") -> dict:
|
||||
logger.info(f"开始生成代码,访问请求:{user_request}")
|
||||
prompt = self.prompts.code_gen_prompt.format_prompt(
|
||||
user_request=user_request, context=context, bowei_api_docs=bowei_api_docs
|
||||
)
|
||||
|
||||
response = self.llm_client.invoke(prompt.to_messages())
|
||||
code = self.output_parser.parse(response)
|
||||
logger.debug(f"生成的代码内容:\n{code}")
|
||||
try:
|
||||
response = self.llm_client.invoke(prompt.to_messages())
|
||||
parsed_response = self.output_parser.parse(response)
|
||||
|
||||
# 处理 AIMessage 类型的返回值
|
||||
if hasattr(parsed_response, 'content'):
|
||||
code = parsed_response.content
|
||||
else:
|
||||
code = str(parsed_response)
|
||||
|
||||
logger.debug(f"生成的代码内容:\n{code}")
|
||||
return {
|
||||
"code": 20000,
|
||||
"message": 'ok',
|
||||
"status": True,
|
||||
"data": code
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"大模型调用失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"code": 50000,
|
||||
"message": f'大模型调用失败: {str(e)}',
|
||||
"status": False,
|
||||
"data": None
|
||||
}
|
||||
|
||||
return {
|
||||
"code": 20000,
|
||||
"message": 'ok',
|
||||
"status": True,
|
||||
"data": code.content
|
||||
}
|
||||
|
||||
def fix_code(self, code: str, error: str) -> str:
|
||||
def fix_code(self, code: str, error: str) -> dict:
|
||||
logger.warning(f"代码执行出错,开始修复。错误信息:{error}")
|
||||
prompt = self.prompts.code_fix_prompt.format_prompt(code=code, error=error)
|
||||
response = self.llm_client.invoke(prompt.to_messages())
|
||||
fixed_code = self.output_parser.parse(response)
|
||||
logger.debug(f"修复后的代码内容:\n{fixed_code}")
|
||||
|
||||
return {
|
||||
"code": 20000,
|
||||
"message": 'ok',
|
||||
"status": True,
|
||||
"data": fixed_code.content
|
||||
}
|
||||
try:
|
||||
response = self.llm_client.invoke(prompt.to_messages())
|
||||
parsed_response = self.output_parser.parse(response)
|
||||
|
||||
# 处理 AIMessage 类型的返回值
|
||||
if hasattr(parsed_response, 'content'):
|
||||
fixed_code = parsed_response.content
|
||||
else:
|
||||
fixed_code = str(parsed_response)
|
||||
|
||||
logger.debug(f"修复后的代码内容:\n{fixed_code}")
|
||||
return {
|
||||
"code": 20000,
|
||||
"message": 'ok',
|
||||
"status": True,
|
||||
"data": fixed_code
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"代码修复时大模型调用失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"code": 50001,
|
||||
"message": f'代码修复失败: {str(e)}',
|
||||
"status": False,
|
||||
"data": None
|
||||
}
|
||||
|
||||
def execute_code(self, code_str) -> dict:
|
||||
"""封装代码执行逻辑"""
|
||||
logger.debug(f"开始执行代码:\n {code_str}")
|
||||
|
||||
try:
|
||||
import re
|
||||
pattern = r'```python(.*?)```'
|
||||
match = re.search(pattern, code_str, re.DOTALL)
|
||||
if match:
|
||||
code_str = match.group(1).strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"解析生成代码格式时发生异常: {str(e)}")
|
||||
|
||||
old_stdout = None
|
||||
try:
|
||||
namespace = {
|
||||
"project": __import__("src.project"),
|
||||
"project": importlib.import_module("src.project"),
|
||||
"Material": getattr(importlib.import_module("src.project"), "Material", None),
|
||||
"Ration": getattr(importlib.import_module("src.project"), "Ration", None),
|
||||
"Equipment": getattr(importlib.import_module("src.project"), "Equipment", None),
|
||||
"MaterialOrEquipment": getattr(importlib.import_module("src.project"), "MaterialOrEquipment", None),
|
||||
"ProjectBuilder": ProjectBuilder,
|
||||
}
|
||||
|
||||
|
||||
@@ -5,21 +5,6 @@ import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
class Config:
|
||||
def __init__(self, path="config.yaml"):
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
|
||||
@@ -8,18 +8,6 @@ from langchain.schema import SystemMessage, HumanMessage
|
||||
import asyncio
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
class QuestionProcessor:
|
||||
|
||||
@@ -1,23 +1,7 @@
|
||||
# src/document_loader.py
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
def load_file(path: str) -> str:
|
||||
try:
|
||||
|
||||
+1
-12
@@ -6,20 +6,9 @@ from datetime import datetime
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
|
||||
class EmbeddingClient:
|
||||
def __init__(self, embedding_config: dict):
|
||||
api_key = embedding_config.get("api_key")
|
||||
|
||||
+2
-12
@@ -8,21 +8,11 @@ import getpass
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self, openai_config: dict):
|
||||
api_key = openai_config.get("api_key")
|
||||
|
||||
+2
-12
@@ -9,19 +9,8 @@ import itertools
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
class MultiAPIKeyChatOpenAI:
|
||||
@@ -63,6 +52,7 @@ class MultiAPIKeyChatOpenAI:
|
||||
|
||||
# 轮询器,用于循环调用不同的 llm 实例
|
||||
self._llm_cycle = itertools.cycle(self.llms)
|
||||
self.llm = next(self._llm_cycle)
|
||||
|
||||
def invoke(self, messages):
|
||||
llm = next(self._llm_cycle)
|
||||
|
||||
@@ -7,21 +7,11 @@ from typing import List
|
||||
from langchain.schema import Document
|
||||
from neo4j import GraphDatabase
|
||||
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
|
||||
class Neo4jRawRetriever:
|
||||
def __init__(self, neo4j_conf: dict):
|
||||
self.uri = neo4j_conf.get("uri")
|
||||
|
||||
+2
-12
@@ -6,21 +6,11 @@ from datetime import datetime
|
||||
from langchain_neo4j import Neo4jVector
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
|
||||
class Neo4jKnowledgeRetriever:
|
||||
def __init__(self, neo4j_conf: dict, embedding_client):
|
||||
neo4j_uri = neo4j_conf.get("uri")
|
||||
|
||||
+22
-22
@@ -5,6 +5,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
from typing import Any, Type
|
||||
|
||||
|
||||
class ProjectToolkit(ABC):
|
||||
@@ -13,13 +14,13 @@ class ProjectToolkit(ABC):
|
||||
描述: 代表整个项目结构的顶层容器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.project_division_set = ProjectDivisionItem() # 项目划分集对象
|
||||
def __init__(self, config: Any):
|
||||
pass
|
||||
|
||||
# 项目划分查询方法
|
||||
|
||||
@abstractmethod
|
||||
def get_division_by_name(self, name):
|
||||
def get_division_by_name(self, name_part):
|
||||
"""
|
||||
通过名称获取项目划分对象
|
||||
|
||||
@@ -89,7 +90,7 @@ class ProjectToolkit(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_quantities_node_by_parent_and_code(self, parent_path, quantity_type=None, code=None):
|
||||
def get_quantities_node_by_parent_and_code(self, parent_path, quantity_type, code):
|
||||
"""
|
||||
通过父节点路径和编码获取工程量对象(定额、主材或设备),包括子节点
|
||||
|
||||
@@ -108,14 +109,15 @@ class ProjectToolkit(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_quantities_node_by_parent_and_name(self, parent_path, partial_name, quantity_type=None):
|
||||
def get_quantities_node_by_parent_and_name(self, parent_path, quantity_type, partial_name):
|
||||
"""
|
||||
通过父节点路径、模糊节点名称和类型获取工程量对象(主材或者设备),包括子节点
|
||||
通过父节点路径、类型和模糊节点名称获取工程量对象(主材或者设备),包括子节点
|
||||
|
||||
Args:
|
||||
parent_path (str): 父节点的路径,以'/'分隔的多级节点路径
|
||||
partial_name (str): 目标节点的模糊或不完整名称
|
||||
quantity_type (str): 工程量类型('定额'、'主材'、'设备')
|
||||
partial_name (str): 目标节点的模糊或不完整名称
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 返回字典,字段包括:
|
||||
@@ -202,7 +204,7 @@ class ProjectToolkit(ABC):
|
||||
|
||||
# 费用表查询方法
|
||||
@abstractmethod
|
||||
def get_fee_schedule_on_auxiliary_expense_table(self, table_name, fee_name, fee: str):
|
||||
def get_fee_schedule_on_auxiliary_expense_table(self, table_name, fee_name, fee_attribute: str):
|
||||
"""
|
||||
在辅助费用表中查找费用
|
||||
|
||||
@@ -221,7 +223,7 @@ class ProjectToolkit(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_fee_schedule_on_other_expense_table(self, table_name, fee_name, fee):
|
||||
def get_fee_schedule_on_other_expense_table(self, table_name, fee_name, fee_attribute):
|
||||
"""
|
||||
在其它费用表中查找费用
|
||||
|
||||
@@ -240,7 +242,7 @@ class ProjectToolkit(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_fee_schedule_on_land_acquisition_fee_table_table(self, table_name, fee_name, fee):
|
||||
def get_fee_schedule_on_land_acquisition_fee_table_table(self, table_name, fee_name, fee_attribute):
|
||||
"""
|
||||
在其中:场地征用费用表中查找费用
|
||||
|
||||
@@ -259,7 +261,7 @@ class ProjectToolkit(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_fee_schedule_on_installation_price_difference_table(self, table_name, fee_name, fee):
|
||||
def get_fee_schedule_on_installation_price_difference_table(self, table_name, fee_name, fee_attribute):
|
||||
"""
|
||||
在安装价差费用表中查找费用
|
||||
|
||||
@@ -278,7 +280,7 @@ class ProjectToolkit(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_fee_schedule_on_Engineering_Cost_table(self, table_name, fee_name, fee):
|
||||
def get_fee_schedule_on_Engineering_Cost_table(self, table_name, fee_name, fee_attribute):
|
||||
"""
|
||||
在工程费用表中查找费用
|
||||
|
||||
@@ -687,18 +689,19 @@ class Fee:
|
||||
self.施工费 = None # xsd:string (可选)
|
||||
self.单位投资 = None # xsd:string (可选)
|
||||
|
||||
class ProjectBuilder:
|
||||
# 存储注册的工具类
|
||||
_registry = None
|
||||
_config = {}
|
||||
|
||||
class ProjectBuilder:
|
||||
"""项目工具工厂类"""
|
||||
|
||||
# 存储注册的工具类
|
||||
_registry: Type[ProjectToolkit] | None = None
|
||||
_config: Any = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def register(cls, toolkit_class: type, config: dict):
|
||||
def register(cls, toolkit_class: Type[ProjectToolkit], config: Any):
|
||||
"""
|
||||
注册工具类到工厂
|
||||
|
||||
@@ -716,13 +719,10 @@ class ProjectBuilder:
|
||||
"""
|
||||
创建工具实例
|
||||
|
||||
参数:
|
||||
|
||||
返回:
|
||||
实例化的工具对象
|
||||
"""
|
||||
if cls._registry is None:
|
||||
raise KeyError(f"未注册的类,请先注册类")
|
||||
raise KeyError("未注册的类,请先注册类")
|
||||
|
||||
|
||||
return cls._registry(cls._config)
|
||||
return cls._registry(cls._config)
|
||||
+173
-83
@@ -25,7 +25,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
if neo4j_driver is None:
|
||||
raise ValueError("必须提供Neo4j驱动实例")
|
||||
|
||||
super().__init__()
|
||||
super().__init__(neo4j_driver)
|
||||
|
||||
# 保存驱动实例
|
||||
self.driver = neo4j_driver
|
||||
@@ -164,7 +164,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
if not node_data:
|
||||
code = 201
|
||||
status_flag = False
|
||||
error = f"找不到路径: {path} 上的ProjectDivisionItem节点"
|
||||
error = f"错误信息:找不到路径: {path} 上的ProjectDivisionItem节点"
|
||||
|
||||
# 提取父路径
|
||||
path_parts = path.split("/")
|
||||
@@ -190,7 +190,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
if record["name"]:
|
||||
helper_info.append(record["name"])
|
||||
except Exception as e:
|
||||
helper_info = [f"查询父节点下的子节点时出错: {str(e)}"]
|
||||
helper_info = [f"错误信息:查询父节点下的子节点时出错: {str(e)}"]
|
||||
|
||||
# 拼接 message
|
||||
message = f"错误信息:{error}; 辅助信息: {helper_info}"
|
||||
@@ -246,7 +246,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
if not partial_name:
|
||||
code = 201
|
||||
status_flag = False
|
||||
message = "节点名称不能为空"
|
||||
message = "错误信息:partial_name参数错误,参数不能为空"
|
||||
else:
|
||||
try:
|
||||
# 第一步:找到父节点
|
||||
@@ -385,7 +385,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
if not name_part or name_part.strip() == "":
|
||||
code = 201
|
||||
status_flag = False
|
||||
message = "输入的名称部分不能为空"
|
||||
message = "错误信息:name_part参数错误,参数不能为空"
|
||||
else:
|
||||
try:
|
||||
# 直接查询所有类型为ProjectDivisionItem且name包含输入名称的节点
|
||||
@@ -482,7 +482,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
if not paths_str:
|
||||
code = 201
|
||||
status_flag = False
|
||||
message = "路径不能为空"
|
||||
message = "错误信息:paths_str参数错误,参数不能为空"
|
||||
else:
|
||||
try:
|
||||
# 使用通用方法获取节点,考虑所有可能的工程量类型
|
||||
@@ -547,7 +547,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
# 最终统一返回格式包装成列表
|
||||
return {"code": code, "message": message, "status": status_flag, "data": data}
|
||||
|
||||
def get_quantities_node_by_parent_and_code(self, parent_path, quantity_type=None, code=None) -> dict:
|
||||
def get_quantities_node_by_parent_and_code(self, parent_path, quantity_type, code) -> dict:
|
||||
"""
|
||||
通过父节点路径和编码获取工程量对象(定额、主材或设备),包括子节点
|
||||
|
||||
@@ -579,13 +579,13 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
if not code or code.strip() == "":
|
||||
code_status = 201
|
||||
status_flag = False
|
||||
message = "编码不能为空"
|
||||
message = "错误信息:code参数错误,参数不能为空"
|
||||
|
||||
valid_types = ["定额", "主材", "设备", None]
|
||||
if quantity_type not in valid_types:
|
||||
code_status = 201
|
||||
status_flag = False
|
||||
message = f"无效的工程量类型: '{quantity_type}';有效类型为 {valid_types}"
|
||||
message = f"错误信息:quantity_type参数错误,有效类型为 {valid_types},当前值是'{quantity_type}'"
|
||||
|
||||
elif code_status == 200:
|
||||
try:
|
||||
@@ -701,7 +701,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
|
||||
def get_quantities_node_by_parent_and_name(self, parent_path, quantity_type, partial_name) -> dict:
|
||||
"""
|
||||
通过父节点路径、模糊节点名称和类型获取工程量对象(主材或者设备),包括子节点
|
||||
通过父节点路径、类型和模糊节点名称获取工程量对象(主材或者设备),包括子节点
|
||||
|
||||
执行三步查询:
|
||||
1. 找到对应路径的父节点
|
||||
@@ -714,7 +714,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
quantity_type (str): 工程量类型('定额'、'主材'、'设备')
|
||||
|
||||
Returns:
|
||||
dict: 返回字典,字段包括:
|
||||
dict: 包含一个字典,字段包括:
|
||||
- code (int): 状态码,固定为 200(成功)或 201(失败)
|
||||
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
|
||||
- status (bool): 成功为 True,失败为 False
|
||||
@@ -825,30 +825,59 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
matching_nodes = []
|
||||
available_names = []
|
||||
|
||||
# 统一处理 partial_name 为列表形式
|
||||
if isinstance(partial_name, str):
|
||||
if partial_name.strip() == "":
|
||||
code_status = 201
|
||||
status_flag = False
|
||||
message = "partial_name 不能为空"
|
||||
else:
|
||||
# 支持中文逗号、英文逗号、空格分隔
|
||||
import re
|
||||
|
||||
keywords = re.split(r"[,,、\s]+", partial_name.strip())
|
||||
else:
|
||||
keywords = partial_name or []
|
||||
|
||||
if not keywords:
|
||||
code_status = 201
|
||||
status_flag = False
|
||||
message = "未提供有效的关键词"
|
||||
|
||||
if code_status != 200:
|
||||
return {"code": code_status, "message": message, "status": status_flag, "data": data}
|
||||
|
||||
# 尝试用 Python 端过滤
|
||||
for node in child_nodes:
|
||||
node_name = node.get("name", "")
|
||||
if node_name:
|
||||
available_names.append(node_name)
|
||||
available_names.append(node_name)
|
||||
|
||||
if partial_name in str(node_name):
|
||||
matching_nodes.append(node)
|
||||
for keyword in keywords:
|
||||
if keyword in str(node_name):
|
||||
matching_nodes.append(node)
|
||||
break
|
||||
|
||||
# 如果 Python 端没找到,尝试数据库端模糊搜索
|
||||
if not matching_nodes:
|
||||
# 构建动态查询语句
|
||||
where_clauses = " OR ".join([f"q.name CONTAINS '{keyword}'" for keyword in keywords])
|
||||
direct_query = f"""
|
||||
MATCH (q)
|
||||
WHERE q.name CONTAINS $partial_name AND {type_condition}
|
||||
WHERE ({where_clauses}) AND {type_condition}
|
||||
RETURN q LIMIT 20
|
||||
"""
|
||||
direct_params = {"partial_name": partial_name}
|
||||
|
||||
direct_result = self.session.run(direct_query, **direct_params)
|
||||
matching_nodes = [record["q"] for record in direct_result]
|
||||
try:
|
||||
direct_result = self.session.run(direct_query)
|
||||
matching_nodes = [record["q"] for record in direct_result]
|
||||
except Exception as e:
|
||||
code_status = 201
|
||||
status_flag = False
|
||||
message = f"数据库模糊查询失败: {str(e)}"
|
||||
|
||||
if not matching_nodes:
|
||||
code_status = 201
|
||||
status_flag = False
|
||||
message = f"错误信息:在父节点路径'{parent_path}' 下找不到名称包含 '{partial_name}' 的节点;辅助信息:{available_names}"
|
||||
|
||||
message = f"错误信息:在父节点路径'{parent_path}' 下找不到包含关键词 {keywords} 的节点;辅助信息:{available_names}"
|
||||
else:
|
||||
result_data = []
|
||||
for node in matching_nodes:
|
||||
@@ -1008,7 +1037,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
if not parent_path or not code:
|
||||
code_status = 201
|
||||
status_flag = False
|
||||
message = "父节点路径或要查找的编码不能为空"
|
||||
message = "错误信息:parent_path或code参数错误,参数不能为空"
|
||||
|
||||
else:
|
||||
try:
|
||||
@@ -1323,7 +1352,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
fee_attribute (str): 费用值属性名
|
||||
|
||||
Returns:
|
||||
dict: 返回字典,字段包括:
|
||||
dict: 包含一个字典,字段包括:
|
||||
- code (int): 状态码,固定为 200(成功)或 201(失败)
|
||||
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
|
||||
- status (bool): 成功为 True,失败为 False
|
||||
@@ -1344,6 +1373,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
|
||||
else:
|
||||
# 第一步:查找父节点(费用表节点)
|
||||
table_name = table_name.replace("表", "")
|
||||
parent_path = f"工程/工程费用/{table_name}"
|
||||
parent_node_data = self.get_node_by_path(parent_path)
|
||||
|
||||
@@ -1417,17 +1447,28 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
else:
|
||||
# 第三步:获取费用节点的属性值
|
||||
try:
|
||||
if fee_node and hasattr(fee_node, "get"):
|
||||
fee_value = fee_node.get(fee_attribute)
|
||||
elif fee_node and isinstance(fee_node, dict):
|
||||
fee_value = fee_node.get(fee_attribute)
|
||||
if hasattr(fee_node, "keys"):
|
||||
all_attrs = list(fee_node.keys())
|
||||
elif isinstance(fee_node, dict):
|
||||
all_attrs = list(fee_node.keys())
|
||||
else:
|
||||
# 如果fee_node是Neo4j Node对象,尝试直接访问属性
|
||||
fee_value = (
|
||||
getattr(fee_node, fee_attribute, None)
|
||||
if hasattr(fee_node, fee_attribute)
|
||||
else fee_node.get(fee_attribute) if hasattr(fee_node, "get") else None
|
||||
)
|
||||
all_attrs = list(fee_node.keys()) if hasattr(fee_node, "keys") else []
|
||||
|
||||
# 过滤掉私有属性
|
||||
all_attrs = [attr for attr in all_attrs if not attr.startswith("_")]
|
||||
|
||||
# 使用“包含”逻辑模糊匹配属性名
|
||||
matched_attrs = []
|
||||
for attr in all_attrs:
|
||||
if fee_attribute.lower() in attr.lower() or attr.lower() in fee_attribute.lower():
|
||||
matched_attrs.append(attr)
|
||||
|
||||
if matched_attrs:
|
||||
# 优先返回第一个匹配项的值
|
||||
best_match = matched_attrs[0]
|
||||
fee_value = fee_node.get(best_match)
|
||||
else:
|
||||
fee_value = None
|
||||
|
||||
if fee_value is None:
|
||||
code_status = 201
|
||||
@@ -1487,7 +1528,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
fee_attribute (str): 费用值属性名
|
||||
|
||||
Returns:
|
||||
dict: 返回字典,字段包括:
|
||||
dict: 包含一个字典,字段包括:
|
||||
- code (int): 状态码,固定为 200(成功)或 201(失败)
|
||||
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
|
||||
- status (bool): 成功为 True,失败为 False
|
||||
@@ -1508,6 +1549,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
|
||||
else:
|
||||
# 第一步:查找父节点(费用表节点)
|
||||
table_name = table_name.replace("表", "")
|
||||
parent_path = f"工程/工程费用/{table_name}"
|
||||
parent_node_data = self.get_node_by_path(parent_path)
|
||||
|
||||
@@ -1541,7 +1583,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
|
||||
# 递归查找费用节点,最多查找3级深度
|
||||
query = """
|
||||
MATCH (p)-[*1..3]->(f)
|
||||
MATCH (p)-[*1..]->(f)
|
||||
WHERE p.name = $table_name AND f.name = $fee_name
|
||||
RETURN f LIMIT 1
|
||||
"""
|
||||
@@ -1581,17 +1623,28 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
else:
|
||||
# 第三步:获取费用节点的属性值
|
||||
try:
|
||||
if fee_node and hasattr(fee_node, "get"):
|
||||
fee_value = fee_node.get(fee_attribute)
|
||||
elif fee_node and isinstance(fee_node, dict):
|
||||
fee_value = fee_node.get(fee_attribute)
|
||||
if hasattr(fee_node, "keys"):
|
||||
all_attrs = list(fee_node.keys())
|
||||
elif isinstance(fee_node, dict):
|
||||
all_attrs = list(fee_node.keys())
|
||||
else:
|
||||
# 如果fee_node是Neo4j Node对象,尝试直接访问属性
|
||||
fee_value = (
|
||||
getattr(fee_node, fee_attribute, None)
|
||||
if hasattr(fee_node, fee_attribute)
|
||||
else fee_node.get(fee_attribute) if hasattr(fee_node, "get") else None
|
||||
)
|
||||
all_attrs = list(fee_node.keys()) if hasattr(fee_node, "keys") else []
|
||||
|
||||
# 过滤掉私有属性
|
||||
all_attrs = [attr for attr in all_attrs if not attr.startswith("_")]
|
||||
|
||||
# 使用“包含”逻辑模糊匹配属性名
|
||||
matched_attrs = []
|
||||
for attr in all_attrs:
|
||||
if fee_attribute.lower() in attr.lower() or attr.lower() in fee_attribute.lower():
|
||||
matched_attrs.append(attr)
|
||||
|
||||
if matched_attrs:
|
||||
# 优先返回第一个匹配项的值
|
||||
best_match = matched_attrs[0]
|
||||
fee_value = fee_node.get(best_match)
|
||||
else:
|
||||
fee_value = None
|
||||
|
||||
if fee_value is None:
|
||||
code_status = 201
|
||||
@@ -1653,7 +1706,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
fee_attribute (str): 费用值属性名
|
||||
|
||||
Returns:
|
||||
dict: 返回字典,字段包括:
|
||||
dict: 包含一个字典,字段包括:
|
||||
- code (int): 状态码,固定为 200(成功)或 201(失败)
|
||||
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
|
||||
- status (bool): 成功为 True,失败为 False
|
||||
@@ -1674,6 +1727,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
|
||||
else:
|
||||
# 第一步:查找父节点(费用表节点)
|
||||
table_name = table_name.replace("表", "")
|
||||
parent_path = f"工程/工程费用/{table_name}"
|
||||
parent_node_data = self.get_node_by_path(parent_path)
|
||||
|
||||
@@ -1747,17 +1801,28 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
else:
|
||||
# 第三步:获取费用节点的属性值
|
||||
try:
|
||||
if fee_node and hasattr(fee_node, "get"):
|
||||
fee_value = fee_node.get(fee_attribute)
|
||||
elif fee_node and isinstance(fee_node, dict):
|
||||
fee_value = fee_node.get(fee_attribute)
|
||||
if hasattr(fee_node, "keys"):
|
||||
all_attrs = list(fee_node.keys())
|
||||
elif isinstance(fee_node, dict):
|
||||
all_attrs = list(fee_node.keys())
|
||||
else:
|
||||
# 如果fee_node是Neo4j Node对象,尝试直接访问属性
|
||||
fee_value = (
|
||||
getattr(fee_node, fee_attribute, None)
|
||||
if hasattr(fee_node, fee_attribute)
|
||||
else fee_node.get(fee_attribute) if hasattr(fee_node, "get") else None
|
||||
)
|
||||
all_attrs = list(fee_node.keys()) if hasattr(fee_node, "keys") else []
|
||||
|
||||
# 过滤掉私有属性
|
||||
all_attrs = [attr for attr in all_attrs if not attr.startswith("_")]
|
||||
|
||||
# 使用“包含”逻辑模糊匹配属性名
|
||||
matched_attrs = []
|
||||
for attr in all_attrs:
|
||||
if fee_attribute.lower() in attr.lower() or attr.lower() in fee_attribute.lower():
|
||||
matched_attrs.append(attr)
|
||||
|
||||
if matched_attrs:
|
||||
# 优先返回第一个匹配项的值
|
||||
best_match = matched_attrs[0]
|
||||
fee_value = fee_node.get(best_match)
|
||||
else:
|
||||
fee_value = None
|
||||
|
||||
if fee_value is None:
|
||||
code_status = 201
|
||||
@@ -1819,7 +1884,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
fee_attribute (str): 费用值属性名
|
||||
|
||||
Returns:
|
||||
dict: 返回字典,字段包括:
|
||||
dict: 包含一个字典,字段包括:
|
||||
- code (int): 状态码,固定为 200(成功)或 201(失败)
|
||||
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
|
||||
- status (bool): 成功为 True,失败为 False
|
||||
@@ -1840,6 +1905,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
|
||||
else:
|
||||
# 第一步:查找父节点(费用表节点)
|
||||
table_name = table_name.replace("表", "")
|
||||
parent_path = f"工程/工程费用/{table_name}"
|
||||
parent_node_data = self.get_node_by_path(parent_path)
|
||||
|
||||
@@ -1913,17 +1979,28 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
else:
|
||||
# 第三步:获取费用节点的属性值
|
||||
try:
|
||||
if fee_node and hasattr(fee_node, "get"):
|
||||
fee_value = fee_node.get(fee_attribute)
|
||||
elif fee_node and isinstance(fee_node, dict):
|
||||
fee_value = fee_node.get(fee_attribute)
|
||||
if hasattr(fee_node, "keys"):
|
||||
all_attrs = list(fee_node.keys())
|
||||
elif isinstance(fee_node, dict):
|
||||
all_attrs = list(fee_node.keys())
|
||||
else:
|
||||
# 如果fee_node是Neo4j Node对象,尝试直接访问属性
|
||||
fee_value = (
|
||||
getattr(fee_node, fee_attribute, None)
|
||||
if hasattr(fee_node, fee_attribute)
|
||||
else fee_node.get(fee_attribute) if hasattr(fee_node, "get") else None
|
||||
)
|
||||
all_attrs = list(fee_node.keys()) if hasattr(fee_node, "keys") else []
|
||||
|
||||
# 过滤掉私有属性
|
||||
all_attrs = [attr for attr in all_attrs if not attr.startswith("_")]
|
||||
|
||||
# 使用“包含”逻辑模糊匹配属性名
|
||||
matched_attrs = []
|
||||
for attr in all_attrs:
|
||||
if fee_attribute.lower() in attr.lower() or attr.lower() in fee_attribute.lower():
|
||||
matched_attrs.append(attr)
|
||||
|
||||
if matched_attrs:
|
||||
# 优先返回第一个匹配项的值
|
||||
best_match = matched_attrs[0]
|
||||
fee_value = fee_node.get(best_match)
|
||||
else:
|
||||
fee_value = None
|
||||
|
||||
if fee_value is None:
|
||||
code_status = 201
|
||||
@@ -1983,7 +2060,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
fee_attribute (str): 费用值属性名
|
||||
|
||||
Returns:
|
||||
dict: 返回字典,字段包括:
|
||||
dict: 包含一个字典,字段包括:
|
||||
- code (int): 状态码,固定为 200(成功)或 201(失败)
|
||||
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
|
||||
- status (bool): 成功为 True,失败为 False
|
||||
@@ -2004,6 +2081,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
|
||||
else:
|
||||
# 第一步:查找父节点(费用表节点)
|
||||
table_name = table_name.replace("表", "")
|
||||
parent_path = f"工程/工程费用/{table_name}"
|
||||
parent_node_data = self.get_node_by_path(parent_path)
|
||||
|
||||
@@ -2077,17 +2155,28 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
else:
|
||||
# 第三步:获取费用节点的属性值
|
||||
try:
|
||||
if fee_node and hasattr(fee_node, "get"):
|
||||
fee_value = fee_node.get(fee_attribute)
|
||||
elif fee_node and isinstance(fee_node, dict):
|
||||
fee_value = fee_node.get(fee_attribute)
|
||||
if hasattr(fee_node, "keys"):
|
||||
all_attrs = list(fee_node.keys())
|
||||
elif isinstance(fee_node, dict):
|
||||
all_attrs = list(fee_node.keys())
|
||||
else:
|
||||
# 如果fee_node是Neo4j Node对象,尝试直接访问属性
|
||||
fee_value = (
|
||||
getattr(fee_node, fee_attribute, None)
|
||||
if hasattr(fee_node, fee_attribute)
|
||||
else fee_node.get(fee_attribute) if hasattr(fee_node, "get") else None
|
||||
)
|
||||
all_attrs = list(fee_node.keys()) if hasattr(fee_node, "keys") else []
|
||||
|
||||
# 过滤掉私有属性
|
||||
all_attrs = [attr for attr in all_attrs if not attr.startswith("_")]
|
||||
|
||||
# 使用“包含”逻辑模糊匹配属性名
|
||||
matched_attrs = []
|
||||
for attr in all_attrs:
|
||||
if fee_attribute.lower() in attr.lower() or attr.lower() in fee_attribute.lower():
|
||||
matched_attrs.append(attr)
|
||||
|
||||
if matched_attrs:
|
||||
# 优先返回第一个匹配项的值
|
||||
best_match = matched_attrs[0]
|
||||
fee_value = fee_node.get(best_match)
|
||||
else:
|
||||
fee_value = None
|
||||
|
||||
if fee_value is None:
|
||||
code_status = 201
|
||||
@@ -2223,7 +2312,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
fee_name (str): 取费名称
|
||||
|
||||
Returns:
|
||||
dict: 返回字典,字段包括:
|
||||
dict: 包含一个字典,字段包括:
|
||||
- code (int): 状态码,固定为 200(成功)或 201(失败)
|
||||
- message (str): 成功时为 "Ok",失败时包含错误信息和辅助信息
|
||||
- status (bool): 成功为 True,失败为 False
|
||||
@@ -2308,10 +2397,10 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
else:
|
||||
cost_set_node = record["c"]
|
||||
|
||||
# 第三步:在CostSet节点的子节点中查找名称为fee_name的CostItem节点
|
||||
# 第三步:在CostSet节点的子节点中查找名称为fee_name的CostItem节点(模糊匹配)
|
||||
query = """
|
||||
MATCH (c:CostSet)-[*1..1]->(i:CostItem)
|
||||
WHERE id(c) = $cost_set_id AND i.name = $fee_name
|
||||
WHERE id(c) = $cost_set_id AND (i.name CONTAINS $fee_name OR $fee_name CONTAINS i.name)
|
||||
RETURN i
|
||||
LIMIT 1
|
||||
"""
|
||||
@@ -2324,7 +2413,7 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
if not record:
|
||||
code_status = 201
|
||||
status_flag = False
|
||||
error = f"在CostSet节点下找不到名称为 {fee_name} 的CostItem节点"
|
||||
error = f"在CostSet节点下找不到与 {fee_name} 模糊匹配的CostItem节点"
|
||||
|
||||
# 查询该CostSet下所有CostItem节点名称作为辅助信息
|
||||
helper_query = """
|
||||
@@ -3033,3 +3122,4 @@ class ProjectToolkitNeo4j(ProjectToolkit):
|
||||
|
||||
# 统一返回格式包装成列表
|
||||
return {"code": code_status, "message": message, "status": status_flag, "data": data}
|
||||
|
||||
|
||||
+2
-17
@@ -75,9 +75,9 @@ def project_get_calculate_function():
|
||||
|
||||
# 执行规则
|
||||
- 参数必须从用户问题或上下文信息中提取
|
||||
- 禁止在代码函数范围外添加任何注释或解释或非代码内容
|
||||
- 输出代码中必须以def project_get_calculate_function() -> dict函数作为入口函数
|
||||
- 必须确保生成的代码可以直接执行,代码要注意进行各类错误检查,出错采用抛出异常方式,说明详细信息
|
||||
- 禁止添加任何注释或解释
|
||||
- 必须确保生成的代码可以直接执行,如果函数功能求取数值,project的函数返回结果为空或出错则算成功,data为0,并在message说明错误原因,代码要注意进行各类容错检查
|
||||
- ProjectToolkit 类中涉及项目划分的函数已考虑在其及其子孙项目划分下查找,所以无需生成递归子项目划分的代码
|
||||
- 如果文本中包含范围编码格式则需要进行编码展开,如'YX2-1~7'展开为‘YX2-1/YX2-2/YX2-3/YX2-4/YX2-5/YX2-6/YX2-7’
|
||||
"""
|
||||
@@ -172,21 +172,6 @@ Cypher查询语句:MATCH (item:ProjectDivisionItem)\nWHERE item.name CONTAINS
|
||||
cypher_conversion_prompt=cypher_conversion_prompt,
|
||||
)
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
@dataclass
|
||||
class CodeExecutorPrompts:
|
||||
understand_prompt: ChatPromptTemplate
|
||||
|
||||
@@ -12,18 +12,6 @@ from typing import List, Dict, Any
|
||||
import asyncio
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
class BusinessObject(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user