实现完整功能
This commit is contained in:
+88
-74
@@ -5,6 +5,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# 获取当前时间,格式化为字符串
|
||||
@@ -21,14 +22,12 @@ log_filename = f"{current_file}_{now_str}.log"
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
handlers=[logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"), logging.StreamHandler()],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
|
||||
def setup_logger(logger_name):
|
||||
"""
|
||||
设置指定名称的logger,将其级别设置为WARNING并禁用传播
|
||||
@@ -55,34 +54,45 @@ from src.project import ProjectBuilder, ProjectToolkit
|
||||
from src.project_implementation import ProjectToolkitNeo4j
|
||||
from src.code_executor import CodeExecutor
|
||||
|
||||
config = Config()
|
||||
|
||||
business_structure = load_file(config.business_object_structure_path)
|
||||
bowei_api_docs = load_file(config.bowei_api_docs_path)
|
||||
# 初始化资源和客户端的函数,使用缓存避免重复加载
|
||||
@st.cache_resource
|
||||
def initialize_resources():
|
||||
config = Config()
|
||||
business_structure = load_file(config.business_object_structure_path)
|
||||
bowei_api_docs = load_file(config.bowei_api_docs_path)
|
||||
llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder)
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
#llm_client = MultiAPIKeyChatOpenAI(config.openai)
|
||||
# 创建Neo4j检索器
|
||||
neo4j_conf = config.neo4j_conf
|
||||
embedding_conf = config.embedding
|
||||
embedding_client = EmbeddingClient(embedding_conf)
|
||||
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
|
||||
|
||||
llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder)
|
||||
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, config.max_retries)
|
||||
|
||||
neo4j_conf = config.neo4j_conf
|
||||
embedding_conf = config.embedding
|
||||
return {
|
||||
"config": config,
|
||||
"business_structure": business_structure,
|
||||
"bowei_api_docs": bowei_api_docs,
|
||||
"llm_client_coder": llm_client_coder,
|
||||
"prompt_manager": prompt_manager,
|
||||
"knowledge_retriever": knowledge_retriever,
|
||||
"embedding_client": embedding_client,
|
||||
"code_executor": code_executor,
|
||||
}
|
||||
|
||||
embedding_client = EmbeddingClient(embedding_conf)
|
||||
|
||||
# 创建Neo4j检索器
|
||||
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
|
||||
|
||||
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
|
||||
|
||||
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, config.max_retries)
|
||||
|
||||
# 使用缓存加载JSONL文件
|
||||
@st.cache_data
|
||||
def load_jsonl(file_path):
|
||||
"""加载JSONL文件并返回JSON记录列表"""
|
||||
"""加载JSONL文件并返回JSON指标列表"""
|
||||
records = []
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
if line.strip():
|
||||
records.append(json.loads(line))
|
||||
@@ -91,82 +101,86 @@ def load_jsonl(file_path):
|
||||
st.error(f"加载文件失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def run_code(data):
|
||||
global code_executor
|
||||
|
||||
"""运行JSON记录中的代码并返回结果"""
|
||||
if not data or 'code' not in data:
|
||||
return {
|
||||
"code": 40000,
|
||||
"message": "没有可执行的代码",
|
||||
"status": False,
|
||||
"data": None
|
||||
}
|
||||
|
||||
return code_executor.execute_code(data['code'])
|
||||
def run_code(data, code_executor):
|
||||
"""运行JSON指标中的代码并返回结果"""
|
||||
if not data or "code" not in data:
|
||||
return {"code": 40000, "message": "没有可执行的代码", "status": False, "data": None}
|
||||
|
||||
return code_executor.execute_code(data["code"])
|
||||
|
||||
|
||||
def main():
|
||||
st.set_page_config(layout="wide", page_title="JSONL查看器")
|
||||
|
||||
st.title("JSONL文件查看器")
|
||||
|
||||
st.set_page_config(layout="wide", page_title="工程指标查看器")
|
||||
|
||||
st.title("工程指标查看器")
|
||||
|
||||
# 初始化资源
|
||||
resources = initialize_resources()
|
||||
code_executor = resources["code_executor"]
|
||||
|
||||
# 设置默认文件路径
|
||||
default_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tests", "code.jsonl")
|
||||
|
||||
|
||||
# 文件路径输入
|
||||
file_path = st.text_input("JSONL文件路径", value=default_file_path)
|
||||
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
st.warning(f"文件不存在: {file_path}")
|
||||
return
|
||||
|
||||
|
||||
# 加载JSONL文件
|
||||
records = load_jsonl(file_path)
|
||||
|
||||
|
||||
if not records:
|
||||
st.warning("没有找到有效的记录")
|
||||
st.warning("没有找到有效的指标")
|
||||
return
|
||||
|
||||
|
||||
# 初始化会话状态
|
||||
if "selected_index" not in st.session_state:
|
||||
st.session_state.selected_index = 0
|
||||
|
||||
if "execution_results" not in st.session_state:
|
||||
st.session_state.execution_results = {}
|
||||
|
||||
# 创建两列布局
|
||||
col1, col2 = st.columns([1, 3])
|
||||
|
||||
|
||||
# 左侧列表
|
||||
with col1:
|
||||
st.subheader("记录列表")
|
||||
selected_index = None
|
||||
|
||||
st.subheader("指标列表")
|
||||
|
||||
for i, record in enumerate(records):
|
||||
if st.button(record.get('name', f"记录 {i+1}"), key=f"btn_{i}"):
|
||||
selected_index = i
|
||||
|
||||
if st.button(record.get("name", f"指标 {i+1}"), key=f"btn_{i}"):
|
||||
st.session_state.selected_index = i
|
||||
|
||||
# 右侧详细信息
|
||||
with col2:
|
||||
if 'selected_index' not in st.session_state:
|
||||
st.session_state.selected_index = 0
|
||||
|
||||
if selected_index is not None:
|
||||
st.session_state.selected_index = selected_index
|
||||
|
||||
if st.session_state.selected_index < len(records):
|
||||
selected_record = records[st.session_state.selected_index]
|
||||
|
||||
st.subheader(f"查询问题: {selected_record.get('name', '无名称')}")
|
||||
record_id = selected_record.get("id", str(st.session_state.selected_index))
|
||||
|
||||
st.subheader(f"指标名称: {selected_record.get('name', '无名称指标')}")
|
||||
st.info(selected_record.get('query', '无查询信息'))
|
||||
|
||||
st.subheader("代码")
|
||||
|
||||
# 将指标代码标题、运行代码按钮、执行结果放在同一行
|
||||
run_btn_col, title_col, result_col = st.columns([1, 1, 4])
|
||||
with title_col:
|
||||
st.subheader("执行结果:")
|
||||
with result_col:
|
||||
title_code_result = st.subheader("")
|
||||
with run_btn_col:
|
||||
if st.button("执行指标代码"):
|
||||
with st.spinner('正在执行代码...'):
|
||||
result = run_code(selected_record, code_executor)
|
||||
if result.get('status'):
|
||||
#st.success("执行成功")
|
||||
title_code_result.info(result.get('data'))
|
||||
else:
|
||||
title_code_result.error(f"执行失败: {result.get('message', '未知错误')}")
|
||||
|
||||
st.code(selected_record.get('code', '无代码'), language='python')
|
||||
|
||||
# 运行代码按钮
|
||||
if st.button("运行代码"):
|
||||
with st.spinner('正在执行代码...'):
|
||||
result = run_code(selected_record)
|
||||
|
||||
st.subheader("运行结果")
|
||||
if result.get('status'):
|
||||
st.success("执行成功")
|
||||
st.info(result.get('data'))
|
||||
else:
|
||||
st.error(f"执行失败: {result.get('message', '未知错误')}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
|
||||
+22
-2
@@ -3,6 +3,26 @@ import xml.etree.ElementTree as ET
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
def transform_string(input_str: str) -> str:
|
||||
|
||||
# 提取所有【】中的内容
|
||||
matches = re.findall(r'【(.*?)】', input_str)
|
||||
if len(matches) < 2:
|
||||
return input_str # 不符合格式,直接返回
|
||||
|
||||
first_block = matches[0] # 第一个块,如“工程费用”
|
||||
second_block = matches[1] # 第二个块,可能是“基本预备费.合计费”或“基本预备费”
|
||||
suffix = input_str.split('】')[-1] # 获取末尾部分,如“的属性”
|
||||
|
||||
# 判断第二个块是否包含点号
|
||||
if '.' in second_block:
|
||||
nested_parts = second_block.split('.')
|
||||
nested_str = ''.join(f'【{part}】的' for part in nested_parts).rstrip('的')
|
||||
return f'从【{first_block}】中获取{nested_str}{suffix}'
|
||||
else:
|
||||
return input_str
|
||||
|
||||
def clean_bracketed_strings(input_str: str) -> str:
|
||||
# 替换【'xxx'】为【xxx】
|
||||
result = re.sub(r"【'([^']+)'】", r"【\1】", input_str)
|
||||
@@ -125,7 +145,7 @@ def xml_to_json(xml_content, output_path):
|
||||
result.append(base_item)
|
||||
|
||||
elif data_sources in project_division:
|
||||
mapping_desc = f"从【{index_extraction_scope}】项目划分中获取名称属于【{indicator_name}】的费用"
|
||||
mapping_desc = f"从【{index_extraction_scope}】项目划分中获取取费费用名称属于【{keyword}】的费用"
|
||||
base_item["指标描述"] = {
|
||||
"指标映射": mapping_desc,
|
||||
"映射规则": parsed["映射规则"]
|
||||
@@ -191,6 +211,7 @@ def xml_to_json(xml_content, output_path):
|
||||
|
||||
if isinstance(indicator_map, str):
|
||||
new_mapping = clean_bracketed_strings(indicator_map)
|
||||
new_mapping = transform_string(new_mapping)
|
||||
item["指标描述"]["指标映射"] = new_mapping
|
||||
|
||||
# 保存为 JSON 文件
|
||||
@@ -199,7 +220,6 @@ def xml_to_json(xml_content, output_path):
|
||||
|
||||
return "结果已保存"
|
||||
|
||||
|
||||
xml_content = read_xml_as_string('dataset/主网架空线路造价分析指标.xml')
|
||||
json_output = xml_to_json(xml_content, output_path= "./tests/zhibiao.json")
|
||||
print("转换完毕!")
|
||||
|
||||
Reference in New Issue
Block a user