实现完整功能

This commit is contained in:
2025-07-07 08:23:02 +08:00
parent 35d50305c8
commit d1c129c691
20 changed files with 504 additions and 469 deletions
+88 -74
View File
@@ -5,6 +5,7 @@ import json
import logging
import os
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 获取当前时间,格式化为字符串
@@ -21,14 +22,12 @@ log_filename = f"{current_file}_{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
handlers=[logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"), logging.StreamHandler()],
)
logger = logging.getLogger(current_file)
def setup_logger(logger_name):
"""
设置指定名称的logger,将其级别设置为WARNING并禁用传播
@@ -55,34 +54,45 @@ from src.project import ProjectBuilder, ProjectToolkit
from src.project_implementation import ProjectToolkitNeo4j
from src.code_executor import CodeExecutor
config = Config()
business_structure = load_file(config.business_object_structure_path)
bowei_api_docs = load_file(config.bowei_api_docs_path)
# 初始化资源和客户端的函数,使用缓存避免重复加载
@st.cache_resource
def initialize_resources():
config = Config()
business_structure = load_file(config.business_object_structure_path)
bowei_api_docs = load_file(config.bowei_api_docs_path)
llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder)
prompt_manager = PromptManager()
#llm_client = MultiAPIKeyChatOpenAI(config.openai)
# 创建Neo4j检索器
neo4j_conf = config.neo4j_conf
embedding_conf = config.embedding
embedding_client = EmbeddingClient(embedding_conf)
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder)
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
prompt_manager = PromptManager()
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, config.max_retries)
neo4j_conf = config.neo4j_conf
embedding_conf = config.embedding
return {
"config": config,
"business_structure": business_structure,
"bowei_api_docs": bowei_api_docs,
"llm_client_coder": llm_client_coder,
"prompt_manager": prompt_manager,
"knowledge_retriever": knowledge_retriever,
"embedding_client": embedding_client,
"code_executor": code_executor,
}
embedding_client = EmbeddingClient(embedding_conf)
# 创建Neo4j检索器
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, config.max_retries)
# 使用缓存加载JSONL文件
@st.cache_data
def load_jsonl(file_path):
"""加载JSONL文件并返回JSON记录列表"""
"""加载JSONL文件并返回JSON指标列表"""
records = []
try:
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
if line.strip():
records.append(json.loads(line))
@@ -91,82 +101,86 @@ def load_jsonl(file_path):
st.error(f"加载文件失败: {str(e)}")
return []
def run_code(data):
global code_executor
"""运行JSON记录中的代码并返回结果"""
if not data or 'code' not in data:
return {
"code": 40000,
"message": "没有可执行的代码",
"status": False,
"data": None
}
return code_executor.execute_code(data['code'])
def run_code(data, code_executor):
"""运行JSON指标中的代码并返回结果"""
if not data or "code" not in data:
return {"code": 40000, "message": "没有可执行的代码", "status": False, "data": None}
return code_executor.execute_code(data["code"])
def main():
st.set_page_config(layout="wide", page_title="JSONL查看器")
st.title("JSONL文件查看器")
st.set_page_config(layout="wide", page_title="工程指标查看器")
st.title("工程指标查看器")
# 初始化资源
resources = initialize_resources()
code_executor = resources["code_executor"]
# 设置默认文件路径
default_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tests", "code.jsonl")
# 文件路径输入
file_path = st.text_input("JSONL文件路径", value=default_file_path)
if not os.path.exists(file_path):
st.warning(f"文件不存在: {file_path}")
return
# 加载JSONL文件
records = load_jsonl(file_path)
if not records:
st.warning("没有找到有效的记录")
st.warning("没有找到有效的指标")
return
# 初始化会话状态
if "selected_index" not in st.session_state:
st.session_state.selected_index = 0
if "execution_results" not in st.session_state:
st.session_state.execution_results = {}
# 创建两列布局
col1, col2 = st.columns([1, 3])
# 左侧列表
with col1:
st.subheader("记录列表")
selected_index = None
st.subheader("指标列表")
for i, record in enumerate(records):
if st.button(record.get('name', f"记录 {i+1}"), key=f"btn_{i}"):
selected_index = i
if st.button(record.get("name", f"指标 {i+1}"), key=f"btn_{i}"):
st.session_state.selected_index = i
# 右侧详细信息
with col2:
if 'selected_index' not in st.session_state:
st.session_state.selected_index = 0
if selected_index is not None:
st.session_state.selected_index = selected_index
if st.session_state.selected_index < len(records):
selected_record = records[st.session_state.selected_index]
st.subheader(f"查询问题: {selected_record.get('name', '无名称')}")
record_id = selected_record.get("id", str(st.session_state.selected_index))
st.subheader(f"指标名称: {selected_record.get('name', '无名称指标')}")
st.info(selected_record.get('query', '无查询信息'))
st.subheader("代码")
# 将指标代码标题、运行代码按钮、执行结果放在同一行
run_btn_col, title_col, result_col = st.columns([1, 1, 4])
with title_col:
st.subheader("执行结果:")
with result_col:
title_code_result = st.subheader("")
with run_btn_col:
if st.button("执行指标代码"):
with st.spinner('正在执行代码...'):
result = run_code(selected_record, code_executor)
if result.get('status'):
#st.success("执行成功")
title_code_result.info(result.get('data'))
else:
title_code_result.error(f"执行失败: {result.get('message', '未知错误')}")
st.code(selected_record.get('code', '无代码'), language='python')
# 运行代码按钮
if st.button("运行代码"):
with st.spinner('正在执行代码...'):
result = run_code(selected_record)
st.subheader("运行结果")
if result.get('status'):
st.success("执行成功")
st.info(result.get('data'))
else:
st.error(f"执行失败: {result.get('message', '未知错误')}")
if __name__ == "__main__":
main()
main()
+22 -2
View File
@@ -3,6 +3,26 @@ import xml.etree.ElementTree as ET
import json
import re
def transform_string(input_str: str) -> str:
# 提取所有【】中的内容
matches = re.findall(r'【(.*?)】', input_str)
if len(matches) < 2:
return input_str # 不符合格式,直接返回
first_block = matches[0] # 第一个块,如“工程费用”
second_block = matches[1] # 第二个块,可能是“基本预备费.合计费”或“基本预备费”
suffix = input_str.split('')[-1] # 获取末尾部分,如“的属性”
# 判断第二个块是否包含点号
if '.' in second_block:
nested_parts = second_block.split('.')
nested_str = ''.join(f'{part}】的' for part in nested_parts).rstrip('')
return f'从【{first_block}】中获取{nested_str}{suffix}'
else:
return input_str
def clean_bracketed_strings(input_str: str) -> str:
# 替换【'xxx'】为【xxx】
result = re.sub(r"'([^']+)'", r"\1】", input_str)
@@ -125,7 +145,7 @@ def xml_to_json(xml_content, output_path):
result.append(base_item)
elif data_sources in project_division:
mapping_desc = f"从【{index_extraction_scope}】项目划分中获取名称属于【{indicator_name}】的费用"
mapping_desc = f"从【{index_extraction_scope}】项目划分中获取取费费用名称属于【{keyword}】的费用"
base_item["指标描述"] = {
"指标映射": mapping_desc,
"映射规则": parsed["映射规则"]
@@ -191,6 +211,7 @@ def xml_to_json(xml_content, output_path):
if isinstance(indicator_map, str):
new_mapping = clean_bracketed_strings(indicator_map)
new_mapping = transform_string(new_mapping)
item["指标描述"]["指标映射"] = new_mapping
# 保存为 JSON 文件
@@ -199,7 +220,6 @@ def xml_to_json(xml_content, output_path):
return "结果已保存"
xml_content = read_xml_as_string('dataset/主网架空线路造价分析指标.xml')
json_output = xml_to_json(xml_content, output_path= "./tests/zhibiao.json")
print("转换完毕!")