import streamlit as st import json import os import uuid import time from typing import Dict, List, Optional, Any import logging import sys from datetime import datetime from textcomplete import textcomplete, StrategyProps from streamlit.components.v1 import html # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 确保必要的目录存在 CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache") DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(DATA_DIR, exist_ok=True) # 配置日志 current_file = os.path.splitext(os.path.basename(__file__))[0] now_str = datetime.now().strftime("%Y%m%d%H%M%S") log_filename = f"{current_file}_{now_str}.log" # 确保日志目录存在并设置在logs目录下 log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs") os.makedirs(log_dir, exist_ok=True) logging.basicConfig( level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ logging.FileHandler(os.path.join(log_dir, log_filename), encoding="utf-8"), logging.StreamHandler() ], ) logger = logging.getLogger(current_file) def setup_logger(logger_name): """ 设置指定名称的logger,将其级别设置为WARNING并禁用传播 :param logger_name: logger的名称 """ logger = logging.getLogger(logger_name) logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别 logger.propagate = False # 可选:禁用传播(防止被根logger处理) return logger logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"] for name in logger_names: setup_logger(name) # 延迟导入依赖,只有在需要时才加载 @st.cache_resource def load_dependencies(): """ 延迟加载依赖项,并缓存资源以提高性能 """ from src.config import Config from src.document_loader import load_file from src.multi_llm_client import MultiAPIKeyChatOpenAI from src.user_interaction import UserInteraction from src.dialog_manager import DialogManager from src.code_executor import CodeExecutor from src.neo4j_raw_retriever import Neo4jRawRetriever from src.prompt_manager import PromptManager from src.embedding_client import EmbeddingClient from src.project import ProjectBuilder, ProjectToolkit from src.project_implementation import ProjectToolkitNeo4j return { "Config": Config, "load_file": load_file, "MultiAPIKeyChatOpenAI": MultiAPIKeyChatOpenAI, "UserInteraction": UserInteraction, "DialogManager": DialogManager, "CodeExecutor": CodeExecutor, "Neo4jRawRetriever": Neo4jRawRetriever, "PromptManager": PromptManager, "EmbeddingClient": EmbeddingClient, "ProjectBuilder": ProjectBuilder, "ProjectToolkitNeo4j": ProjectToolkitNeo4j } # 后台数据结构 class Indicator: def __init__(self, name: str, query: str = "", code: str = "", id: Optional[str] = None): self.id = id if id else str(uuid.uuid4()) self.name = name self.query = query self.code = code self.created_at = time.time() self.updated_at = time.time() def to_dict(self) -> Dict[str, Any]: return { "id": self.id, "name": self.name, "query": self.query, "code": self.code, "created_at": self.created_at, "updated_at": self.updated_at } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'Indicator': indicator = cls( name=data["name"], query=data["query"], code=data.get("code", ""), # 使用 get 方法处理可能不存在的 code 字段 id=data["id"] ) indicator.created_at = data.get("created_at", time.time()) indicator.updated_at = data.get("updated_at", time.time()) return indicator class IndicatorManager: def __init__(self, save_path: Optional[str] = None): if save_path is None: save_path = os.path.join(CACHE_DIR, "indicator_library.json") self.indicators: Dict[str, Indicator] = {} self.save_path = save_path self.load_indicators() def load_indicators(self) -> None: """从文件加载指标库""" try: if os.path.exists(self.save_path): with open(self.save_path, "r", encoding="utf-8") as f: data = json.load(f) for indicator_data in data: indicator = Indicator.from_dict(indicator_data) self.indicators[indicator.id] = indicator except Exception as e: st.error(f"加载指标库失败: {str(e)}") def save_indicators(self) -> None: """保存指标库到文件""" try: with open(self.save_path, "w", encoding="utf-8") as f: json.dump([ind.to_dict() for ind in self.indicators.values()], f, ensure_ascii=False, indent=2) except Exception as e: st.error(f"保存指标库失败: {str(e)}") def add_indicator(self, name: str, query: str = "") -> Indicator: """添加新指标""" if not name.strip(): raise ValueError("指标名称不能为空") indicator = Indicator(name=name, query=query) self.indicators[indicator.id] = indicator self.save_indicators() return indicator def update_indicator(self, id: str, name: Optional[str] = None, query: Optional[str] = None, code: Optional[str] = None) -> Optional[Indicator]: """更新指标""" if id not in self.indicators: return None if name is not None and not name.strip(): raise ValueError("指标名称不能为空") indicator = self.indicators[id] if name is not None: indicator.name = name if query is not None: indicator.query = query if code is not None: indicator.code = code indicator.updated_at = time.time() self.save_indicators() return indicator def delete_indicator(self, id: str) -> bool: """删除指标""" if id in self.indicators: del self.indicators[id] self.save_indicators() return True return False def clear_all_indicators(self) -> None: """清除所有指标""" self.indicators.clear() self.save_indicators() def get_indicator(self, id: str) -> Optional[Indicator]: """获取指标""" return self.indicators.get(id) def get_all_indicators(self) -> List[Indicator]: """获取所有指标""" return list(self.indicators.values()) # 保存会话状态 @st.cache_data(ttl=3600) def save_session_state(): """保存会话状态到文件""" try: state = { "current_indicator_id": st.session_state.get("current_indicator_id", None) } state_file = os.path.join(CACHE_DIR, "indicator_creator_state.json") with open(state_file, "w", encoding="utf-8") as f: json.dump(state, f, ensure_ascii=False, indent=2) except Exception as e: st.error(f"保存会话状态失败: {str(e)}") # 加载会话状态 @st.cache_data(ttl=3600) def load_session_state(): """从文件加载会话状态""" try: state_file = os.path.join(CACHE_DIR, "indicator_creator_state.json") if os.path.exists(state_file): with open(state_file, "r", encoding="utf-8") as f: state = json.load(f) for key, value in state.items(): if key not in st.session_state: st.session_state[key] = value except Exception as e: st.error(f"加载会话状态失败: {str(e)}") # 测试指标查询 @st.cache_data(ttl=60) # 缓存结果1分钟 def generate_indicator_code(query: str) -> str: """生成指标代码""" if not query.strip(): return "查询语句为空,无法生成代码" result = "" try: # 步骤1:通过调用user_interaction.understand(query)实现 result = "正在理解查询语句..." if "user_interaction" not in st.session_state: st.session_state.user_interaction = st.session_state.dialog_manager.user_interaction understanding_result = st.session_state.user_interaction.understand(query) if not understanding_result: return "❌ 无法理解查询语句" # 步骤2:验证步骤一返回值 result = "正在验证理解结果..." selected_knowledge = [] entity_info = [] for item in understanding_result: entity_info.append(f"识别到实体: {item.get('name')}, 约束条件: {item.get('constraints')}") selected_knowledge.append(item) if not selected_knowledge: return "❌ 没有找到相关知识" # 显示识别到的实体信息 result = "\n".join(entity_info) # 步骤3:dialog_manager.generated_code(query, selected_knowledge) 实现 result = "正在生成指标代码..." if "dialog_manager" not in st.session_state: return "❌ 对话管理器未初始化" generated_result = st.session_state.dialog_manager.generated_code(query, selected_knowledge) # 处理返回结构 if not generated_result: return "❌ 代码生成失败" # 检查返回的是否是字典结构 if isinstance(generated_result, dict): if not generated_result.get('status', False): return f"❌ 代码生成失败: {generated_result.get('message', '未知错误')}" # 提取代码内容 code = generated_result.get('data', '') if not code: return "❌ 生成的代码为空" # 将生成的代码保存到session_state中 st.session_state.generated_code = code else: # 如果直接返回的是代码字符串 st.session_state.generated_code = generated_result if "current_indicator_id" in st.session_state: st.session_state.current_query_id = st.session_state.current_indicator_id return "✅ 指标代码生成成功,可以在下方查看并执行" except Exception as e: import traceback error_details = traceback.format_exc() return f"❌ 指标代码生成失败: {str(e)}" return result @st.cache_resource def init_app_data(): """ 初始化应用数据,使用缓存资源以避免重复加载 """ # 加载依赖 deps = load_dependencies() config = deps["Config"]() business_structure = deps["load_file"](config.business_object_structure_path) bowei_api_docs = deps["load_file"](config.bowei_api_docs_path) llm_client = deps["MultiAPIKeyChatOpenAI"](config.openai) user_interaction = deps["UserInteraction"](llm_client.llm, business_structure) llm_client_coder = deps["MultiAPIKeyChatOpenAI"](config.openai_coder) prompt_manager = deps["PromptManager"]() neo4j_conf = config.neo4j_conf embedding_conf = config.embedding embedding_client = deps["EmbeddingClient"](embedding_conf) # 创建Neo4j检索器 knowledge_retriever = deps["Neo4jRawRetriever"](neo4j_conf) deps["ProjectBuilder"].register(deps["ProjectToolkitNeo4j"], knowledge_retriever.driver) code_executor = deps["CodeExecutor"](prompt_manager.prompts, llm_client_coder, config.max_retries) dialog_manager = deps["DialogManager"]( llm_client, business_structure, bowei_api_docs, code_executor, knowledge_retriever, prompt_manager, ) # 将必要的对象存储在session_state中 return { "dialog_manager": dialog_manager, "user_interaction": user_interaction, "code_executor": code_executor, "knowledge_retriever": knowledge_retriever } # 自动补全策略相关函数 def create_project_division_strategy(): """创建项目划分自动补全策略""" return StrategyProps( id="projectDivision", match="从项目划分【(.*)】", search="""async (term, callback) => { const divisions = [ "架空输电线路本体工程/基础工程", "架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立", "架空输电线路本体工程/架线工程", "架空输电线路本体工程/附件安装工程" ]; const matches = divisions.filter(div => div.toLowerCase().includes(term.toLowerCase())); callback(matches); }""", replace="(division) => `从项目划分【${division}】`", template="(division) => `📁 ${division}`", ) def create_search_type_strategy(): """创建搜索类型自动补全策略""" return StrategyProps( id="searchType", match="中查找(名称|编码)(包含|等于)【(.*)】", search="""async (term, callback) => { const types = ["名称包含", "编码包含", "名称等于", "编码等于"]; const matches = types.filter(type => type.toLowerCase().includes(term.toLowerCase())); callback(matches); }""", replace="(type) => `中查找${type}【`", template="(type) => `🔍 ${type}`", ) def create_item_type_strategy(): """创建项目类型自动补全策略""" return StrategyProps( id="itemType", match="的所有【(.*)】", search="""async (term, callback) => { const types = ["定额", "主材", "取费名称"]; const matches = types.filter(type => type.toLowerCase().includes(term.toLowerCase())); callback(matches); }""", replace="(type) => `的所有【${type}】`", template="(type) => `📋 ${type}`", ) def create_attribute_strategy(): """创建属性自动补全策略""" return StrategyProps( id="attribute", match="的【(.*)】", search="""async (term, callback) => { const attributes = ["数量", "单价", "合计费", "属性"]; const matches = attributes.filter(attr => attr.toLowerCase().includes(term.toLowerCase())); callback(matches); }""", replace="(attr) => `的【${attr}】`", template="(attr) => `📊 ${attr}`", ) def create_fee_type_strategy(): """创建费用类型自动补全策略""" return StrategyProps( id="feeType", match="从【(.*)】中", search="""async (term, callback) => { const types = ["工程费用", "其他费用"]; const matches = types.filter(type => type.toLowerCase().includes(term.toLowerCase())); callback(matches); }""", replace="(type) => `从【${type}】中`", template="(type) => `💰 ${type}`", ) def create_project_name_strategy(): """创建项目名称自动补全策略""" return StrategyProps( id="projectName", match="获取【(.*)】的", search="""async (term, callback) => { const names = [ "架空输电线路本体工程", "建设场地征用及清理费" ]; const matches = names.filter(name => name.toLowerCase().includes(term.toLowerCase())); callback(matches); }""", replace="(name) => `获取【${name}】的`", template="(name) => `🏗️ ${name}`", ) def create_code_value_strategy(): """创建编码值自动补全策略""" return StrategyProps( id="codeValue", match="编码包含【(.*)】", search="""async (term, callback) => { const codes = ["YX2-1~7", "YX5-9"]; const matches = codes.filter(code => code.toLowerCase().includes(term.toLowerCase())); callback(matches); }""", replace="(code) => `编码包含【${code}】`", template="(code) => `🔢 ${code}`", ) def create_name_value_strategy(): """创建名称值自动补全策略""" return StrategyProps( id="nameValue", match="名称包含【(.*)】", search="""async (term, callback) => { const names = ["角钢", "钢管杆"]; const matches = names.filter(name => name.toLowerCase().includes(term.toLowerCase())); callback(matches); }""", replace="(name) => `名称包含【${name}】`", template="(name) => `📝 ${name}`", ) def create_name_equal_strategy(): """创建名称等于自动补全策略""" return StrategyProps( id="nameEqual", match="取费名称等于【(.*)】", search="""async (term, callback) => { const names = ["合计"]; const matches = names.filter(name => name.toLowerCase().includes(term.toLowerCase())); callback(matches); }""", replace="(name) => `取费名称等于【${name}】`", template="(name) => `📌 ${name}`", ) def create_query_template_strategy(): """创建查询模板自动补全策略""" return StrategyProps( id="queryTemplate", match="^$", search="""async (term, callback) => { const templates = [ "从项目划分【架空输电线路本体工程/基础工程】中查找编码包含【YX2-1~7】的所有【定额】的【数量】之和", "从项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】中查找名称包含【角钢】的所有【主材】的【数量】之和", "从项目划分【架空输电线路本体工程/杆塔工程/杆塔组立/铁塔、钢管杆组立】中查找名称包含【钢管杆】的所有【主材】的【单价】之和", "从项目划分【架空输电线路本体工程/架线工程】中查找编码中包含【YX5-9】的所有【定额】的【数量】之和", "从【工程费用】中获取【架空输电线路本体工程】的【合计费】属性", "从【其他费用】中获取【建设场地征用及清理费】的属性", "从项目划分【架空输电线路本体工程/附件安装工程】中获取取费名称等于【合计】的费用" ]; const matches = templates.filter(template => template.toLowerCase().includes(term.toLowerCase())); callback(matches); }""", replace="(template) => template", template="(template) => `📋 ${template}`", ) def apply_autocomplete_to_query_input(area_label="输入查询语句", max_count=5): """应用自动补全到查询输入框""" # 初始化textcomplete组件 textcomplete( area_label=area_label, strategies=[ create_query_template_strategy(), create_project_division_strategy(), create_search_type_strategy(), create_item_type_strategy(), create_attribute_strategy(), create_fee_type_strategy(), create_project_name_strategy(), create_code_value_strategy(), create_name_value_strategy(), create_name_equal_strategy() ], max_count=max_count ) # Streamlit 应用界面 def main(): st.set_page_config(page_title="造价工程指标管理器", layout="wide") # 使用进度条显示加载过程 with st.spinner("正在初始化应用..."): # 初始化应用数据 app_data = init_app_data() # 将必要的对象存储在session_state中 for key, value in app_data.items(): st.session_state[key] = value # 初始化会话状态 initialize_session_state() # 侧边栏 render_sidebar() # 主区域 render_main_area() def initialize_session_state(): """初始化会话状态""" if "indicator_manager" not in st.session_state: st.session_state.indicator_manager = IndicatorManager() if "current_indicator_id" not in st.session_state: st.session_state.current_indicator_id = None # 加载保存的会话状态 load_session_state() def render_sidebar(): """渲染侧边栏内容""" with st.sidebar: # 第一个expander:指标管理 render_indicator_management_expander() # 第二个expander:指标库浏览 render_indicator_library_expander() def render_indicator_management_expander(): """渲染指标管理扩展面板""" with st.expander("新建指标工具", expanded=True): # 显示清除指标链接 col1, col2, col3 = st.columns([1, 1, 1]) with col1: st.subheader("指标列表") with col2: render_new_indicator_button() with col3: render_clear_indicators_button() # 显示指标列表 render_indicator_list() def render_new_indicator_button(): """渲染新建指标按钮""" if st.button("新建"): try: # 生成新指标名称 indicator_count = len(st.session_state.indicator_manager.get_all_indicators()) new_name = f"指标 {indicator_count + 1}" # 创建新指标 new_indicator = st.session_state.indicator_manager.add_indicator(name=new_name) st.session_state.current_indicator_id = new_indicator.id # 清除已生成的代码 if "generated_code" in st.session_state: del st.session_state.generated_code # 清除第二个expander的状态 st.session_state.view_indicator_detail = False save_session_state() st.rerun() except ValueError as e: st.error(str(e)) def render_clear_indicators_button(): """渲染清空指标按钮""" if st.button("清空", help="删除所有指标", type="secondary"): st.session_state.indicator_manager.clear_all_indicators() st.session_state.current_indicator_id = None # 清除已生成的代码 if "generated_code" in st.session_state: del st.session_state.generated_code # 清除第二个expander的状态 st.session_state.view_indicator_detail = False save_session_state() st.rerun() def render_indicator_list(): """渲染指标列表""" indicators = st.session_state.indicator_manager.get_all_indicators() if not indicators: st.info("暂无指标,请点击\"新建指标\"按钮创建") else: for indicator in sorted(indicators, key=lambda x: x.created_at): if st.button( f"📄 {indicator.name}", key=f"btn_{indicator.id}", use_container_width=True, type="primary" if st.session_state.current_indicator_id == indicator.id else "secondary" ): st.session_state.current_indicator_id = indicator.id # 更新显示的代码 if indicator.code: st.session_state.generated_code = indicator.code elif "generated_code" in st.session_state: # 如果当前指标没有代码,清除session中的代码 del st.session_state.generated_code # 清除第二个expander的状态 st.session_state.view_indicator_detail = False save_session_state() st.rerun() def render_indicator_library_expander(): """渲染指标库浏览扩展面板""" with st.expander("查看指标库", expanded=False): # 设置默认文件路径 file_path = os.path.join(DATA_DIR, "code.jsonl") # 检查文件是否存在 if not os.path.exists(file_path): st.warning(f"文件不存在: {file_path}") else: # 加载JSONL文件 records = load_jsonl(file_path) if not records: st.warning("没有找到有效的指标") else: # 初始化会话状态 if "selected_index" not in st.session_state: st.session_state.selected_index = 0 # 显示指标列表 render_library_indicator_list(records) @st.cache_data def load_jsonl(file_path): """加载JSONL文件并返回JSON指标列表""" records = [] try: with open(file_path, "r", encoding="utf-8") as file: for line in file: if line.strip(): records.append(json.loads(line)) return records except Exception as e: st.error(f"加载文件失败: {str(e)}") return [] def render_library_indicator_list(records): """渲染指标库列表""" st.subheader("指标列表") for i, record in enumerate(records): if st.button(f"📚 {record.get('name', f'指标 {i+1}')}", key=f"lib_btn_{i}", use_container_width=True): st.session_state.selected_index = i # 清除第一个expander的选中状态 st.session_state.current_indicator_id = None # 设置查看指标详情的逻辑 st.session_state.view_indicator_detail = True st.session_state.current_library_indicator = record save_session_state() st.rerun() def render_main_area(): """渲染主区域内容""" if "view_indicator_detail" in st.session_state and st.session_state.view_indicator_detail and "current_library_indicator" in st.session_state: render_library_indicator_detail() elif st.session_state.current_indicator_id is not None: render_current_indicator_detail() else: st.info("请在左侧创建或选择一个指标") # HTML/JS 实现可拖动面板 panel_js = """ """ # 渲染面板 html(panel_js, height=200) # 动态更新面板内容 if st.button("生成日志"): st.write('
新的日志消息...
', unsafe_allow_html=True) def render_library_indicator_detail(): """渲染指标库详情""" library_indicator = st.session_state.current_library_indicator # 指标名称和操作按钮 col1, col2, col3 = st.columns([1, 4, 2]) with col1: st.markdown('
指标名称:
', unsafe_allow_html=True) with col2: st.markdown(f'
{library_indicator.get("name", "未命名指标")}
', unsafe_allow_html=True) # 查询语句部分 title_col, run_btn_col, result_col = st.columns([1, 1, 4]) with title_col: st.subheader("查询语句") with result_col: title_query_result = st.subheader("") # 显示查询语句 st.text_area( "查询语句", value=library_indicator.get('query', '无查询语句'), height=70, key="lib_query_display", disabled=True, label_visibility="collapsed" ) # 指标代码部分 render_library_indicator_code(library_indicator) def render_library_indicator_code(library_indicator): """渲染指标库代码部分""" title_col, run_btn_col, result_col = st.columns([1, 1, 4]) with title_col: st.subheader("指标代码") with result_col: title_code_result = st.subheader("") with run_btn_col: if st.button("▶️ 执行代码", key="execute_lib_indicator", use_container_width=True): execute_indicator_code(library_indicator.get('code', ''), title_code_result) # 显示代码 st.code(library_indicator.get('code', '无代码'), language='python') def render_current_indicator_detail(): """渲染当前指标详情""" current_indicator = st.session_state.indicator_manager.get_indicator(st.session_state.current_indicator_id) if current_indicator: # 指标名称和保存按钮 render_indicator_header(current_indicator) # 查询语句 render_query_input_section(current_indicator) # 代码执行部分 render_code_execution_section(current_indicator) else: st.warning("找不到当前指标,可能已被删除") def render_indicator_header(current_indicator): """渲染指标头部(名称和保存按钮)""" col1, col2, col3 = st.columns([1, 4, 2]) with col1: st.markdown('
指标名称:
', unsafe_allow_html=True) with col2: new_name = st.text_input("指标名称", value=current_indicator.name, key="indicator_name", label_visibility="collapsed") if new_name != current_indicator.name: if new_name is None or not new_name.strip(): st.error("指标名称不能为空") else: try: st.session_state.indicator_manager.update_indicator( current_indicator.id, name=new_name ) except ValueError as e: st.error(str(e)) with col3: render_save_indicator_button(current_indicator, new_name) def render_save_indicator_button(current_indicator, new_name): """渲染保存指标按钮""" if st.button("💾 保存指标", use_container_width=True): if new_name is None or not new_name.strip(): st.error("指标名称不能为空") elif not current_indicator.query.strip(): st.error("请先输入查询语句") elif not current_indicator.code: st.error("请先生成指标代码") else: try: save_indicator_to_library(current_indicator) st.toast("指标已保存") st.rerun() # 刷新页面以更新指标库列表 except Exception as e: st.error(f"保存失败: {str(e)}") def save_indicator_to_library(current_indicator): """保存指标到指标库""" # 保存到指标管理器 st.session_state.indicator_manager.save_indicators() # 同时保存到指标库 library_path = os.path.join(DATA_DIR, "code.jsonl") new_record = { "name": current_indicator.name, "query": current_indicator.query, "code": current_indicator.code, "created_at": time.time() } # 追加到文件 with open(library_path, "a", encoding="utf-8") as f: f.write(json.dumps(new_record, ensure_ascii=False) + "\n") # 清除指标库的缓存,以便重新加载 load_jsonl.clear() def render_query_input_section(current_indicator): """渲染查询输入部分""" title_col, test_btn_col, result_col = st.columns([1, 1, 4]) with title_col: st.subheader("查询语句") with result_col: title_query_result = st.subheader("") # 查询输入框 query = st.text_area( "输入查询语句", value=current_indicator.query, height=70, key="query_input", placeholder="在此输入查询语句...", label_visibility="collapsed" ) # 应用自动补全 apply_autocomplete_to_query_input() if query != current_indicator.query: st.session_state.indicator_manager.update_indicator( current_indicator.id, query=query ) # 生成代码按钮 with test_btn_col: render_generate_code_button(query, title_query_result, current_indicator) def render_generate_code_button(query, title_query_result, current_indicator): """渲染生成代码按钮""" if st.button("🔄 生成代码", key="generate_code_btn", use_container_width=True): with st.spinner("正在生成指标代码..."): # 使用缓存的测试函数 result = generate_indicator_code(query) # 根据结果显示不同的状态 if result.startswith("✅"): title_query_result.success(result) # 生成成功后,立即保存代码到指标对象中 if "generated_code" in st.session_state: st.session_state.indicator_manager.update_indicator( current_indicator.id, code=st.session_state.generated_code ) elif result.startswith("❌"): title_query_result.error(result) else: # 显示实体识别结果或其他信息 title_query_result.info(result) def render_code_execution_section(current_indicator): """渲染代码执行部分""" title_col, run_btn_col, result_col = st.columns([1, 1, 4]) with title_col: st.subheader("指标代码") with result_col: title_code_result = st.subheader("") with run_btn_col: if st.button("▶️ 执行代码", key="execute_code_btn", use_container_width=True): with st.spinner("正在执行代码..."): if "code_executor" in st.session_state and "generated_code" in st.session_state: execute_and_update_code(current_indicator, title_code_result) else: title_code_result.error("代码执行器未初始化或没有生成的代码") # 显示代码 st.code(st.session_state.get("generated_code", "# 暂无代码,请输入有效的查询语句并点击测试按钮"), language="python") def execute_and_update_code(current_indicator, title_code_result): """执行并更新代码""" result = st.session_state.code_executor.execute_code(st.session_state.generated_code) if result and hasattr(result, 'get') and result.get('status'): title_code_result.success("执行成功") data = result.get('data', {}) if isinstance(data, (str, int, float)): title_code_result.write(data) else: title_code_result.json(data) # 更新指标的代码 st.session_state.indicator_manager.update_indicator( current_indicator.id, code=st.session_state.generated_code ) else: title_code_result.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}") def execute_indicator_code(code, result_container): """执行指标代码""" with st.spinner('正在执行代码...'): # 确保code_executor已初始化 if "code_executor" in st.session_state: result = st.session_state.code_executor.execute_code(code) if result and hasattr(result, 'get') and result.get('status'): result_container.success("执行成功") data = result.get('data', {}) if isinstance(data, (str, int, float)): result_container.write(data) else: result_container.json(data) else: result_container.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}") else: result_container.error("代码执行器未初始化") if __name__ == "__main__": main()