import streamlit as st import json import os import uuid import time from typing import Dict, List, Optional, Any import logging import sys from datetime import datetime # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 确保必要的目录存在 CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache") DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(DATA_DIR, exist_ok=True) # 配置日志 current_file = os.path.splitext(os.path.basename(__file__))[0] now_str = datetime.now().strftime("%Y%m%d%H%M%S") log_filename = f"{current_file}_{now_str}.log" # 确保日志目录存在并设置在logs目录下 log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs") os.makedirs(log_dir, exist_ok=True) logging.basicConfig( level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ logging.FileHandler(os.path.join(log_dir, log_filename), encoding="utf-8"), logging.StreamHandler() ], ) logger = logging.getLogger(current_file) def setup_logger(logger_name): """ 设置指定名称的logger,将其级别设置为WARNING并禁用传播 :param logger_name: logger的名称 """ logger = logging.getLogger(logger_name) logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别 logger.propagate = False # 可选:禁用传播(防止被根logger处理) return logger logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"] for name in logger_names: setup_logger(name) # 延迟导入依赖,只有在需要时才加载 @st.cache_resource def load_dependencies(): """ 延迟加载依赖项,并缓存资源以提高性能 """ from src.config import Config from src.document_loader import load_file from src.multi_llm_client import MultiAPIKeyChatOpenAI from src.user_interaction import UserInteraction from src.dialog_manager import DialogManager from src.code_executor import CodeExecutor from src.neo4j_raw_retriever import Neo4jRawRetriever from src.prompt_manager import PromptManager from src.embedding_client import EmbeddingClient from src.project import ProjectBuilder, ProjectToolkit from src.project_implementation import ProjectToolkitNeo4j return { "Config": Config, "load_file": load_file, "MultiAPIKeyChatOpenAI": MultiAPIKeyChatOpenAI, "UserInteraction": UserInteraction, "DialogManager": DialogManager, "CodeExecutor": CodeExecutor, "Neo4jRawRetriever": Neo4jRawRetriever, "PromptManager": PromptManager, "EmbeddingClient": EmbeddingClient, "ProjectBuilder": ProjectBuilder, "ProjectToolkitNeo4j": ProjectToolkitNeo4j } # 后台数据结构 class Indicator: def __init__(self, name: str, query: str = "", code: str = "", id: Optional[str] = None): self.id = id if id else str(uuid.uuid4()) self.name = name self.query = query self.code = code self.created_at = time.time() self.updated_at = time.time() def to_dict(self) -> Dict[str, Any]: return { "id": self.id, "name": self.name, "query": self.query, "code": self.code, "created_at": self.created_at, "updated_at": self.updated_at } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'Indicator': indicator = cls( name=data["name"], query=data["query"], code=data.get("code", ""), # 使用 get 方法处理可能不存在的 code 字段 id=data["id"] ) indicator.created_at = data.get("created_at", time.time()) indicator.updated_at = data.get("updated_at", time.time()) return indicator class IndicatorManager: def __init__(self, save_path: Optional[str] = None): if save_path is None: save_path = os.path.join(CACHE_DIR, "indicator_library.json") self.indicators: Dict[str, Indicator] = {} self.save_path = save_path self.load_indicators() def load_indicators(self) -> None: """从文件加载指标库""" try: if os.path.exists(self.save_path): with open(self.save_path, "r", encoding="utf-8") as f: data = json.load(f) for indicator_data in data: indicator = Indicator.from_dict(indicator_data) self.indicators[indicator.id] = indicator except Exception as e: st.error(f"加载指标库失败: {str(e)}") def save_indicators(self) -> None: """保存指标库到文件""" try: with open(self.save_path, "w", encoding="utf-8") as f: json.dump([ind.to_dict() for ind in self.indicators.values()], f, ensure_ascii=False, indent=2) except Exception as e: st.error(f"保存指标库失败: {str(e)}") def add_indicator(self, name: str, query: str = "") -> Indicator: """添加新指标""" if not name.strip(): raise ValueError("指标名称不能为空") indicator = Indicator(name=name, query=query) self.indicators[indicator.id] = indicator self.save_indicators() return indicator def update_indicator(self, id: str, name: Optional[str] = None, query: Optional[str] = None, code: Optional[str] = None) -> Optional[Indicator]: """更新指标""" if id not in self.indicators: return None if name is not None and not name.strip(): raise ValueError("指标名称不能为空") indicator = self.indicators[id] if name is not None: indicator.name = name if query is not None: indicator.query = query if code is not None: indicator.code = code indicator.updated_at = time.time() self.save_indicators() return indicator def delete_indicator(self, id: str) -> bool: """删除指标""" if id in self.indicators: del self.indicators[id] self.save_indicators() return True return False def clear_all_indicators(self) -> None: """清除所有指标""" self.indicators.clear() self.save_indicators() def get_indicator(self, id: str) -> Optional[Indicator]: """获取指标""" return self.indicators.get(id) def get_all_indicators(self) -> List[Indicator]: """获取所有指标""" return list(self.indicators.values()) # 保存会话状态 @st.cache_data(ttl=3600) def save_session_state(): """保存会话状态到文件""" try: state = { "current_indicator_id": st.session_state.get("current_indicator_id", None) } state_file = os.path.join(CACHE_DIR, "indicator_creator_state.json") with open(state_file, "w", encoding="utf-8") as f: json.dump(state, f, ensure_ascii=False, indent=2) except Exception as e: st.error(f"保存会话状态失败: {str(e)}") # 加载会话状态 @st.cache_data(ttl=3600) def load_session_state(): """从文件加载会话状态""" try: state_file = os.path.join(CACHE_DIR, "indicator_creator_state.json") if os.path.exists(state_file): with open(state_file, "r", encoding="utf-8") as f: state = json.load(f) for key, value in state.items(): if key not in st.session_state: st.session_state[key] = value except Exception as e: st.error(f"加载会话状态失败: {str(e)}") # 测试指标查询 @st.cache_data(ttl=60) # 缓存结果1分钟 def generate_indicator_code(query: str) -> str: """生成指标代码""" if not query.strip(): return "查询语句为空,无法生成代码" result = "" try: # 步骤1:通过调用user_interaction.understand(query)实现 result = "正在理解查询语句..." if "user_interaction" not in st.session_state: st.session_state.user_interaction = st.session_state.dialog_manager.user_interaction understanding_result = st.session_state.user_interaction.understand(query) if not understanding_result: return "❌ 无法理解查询语句" # 步骤2:验证步骤一返回值 result = "正在验证理解结果..." selected_knowledge = [] entity_info = [] for item in understanding_result: entity_info.append(f"识别到实体: {item.get('name')}, 约束条件: {item.get('constraints')}") selected_knowledge.append(item) if not selected_knowledge: return "❌ 没有找到相关知识" # 显示识别到的实体信息 result = "\n".join(entity_info) # 步骤3:dialog_manager.generated_code(query, selected_knowledge) 实现 result = "正在生成指标代码..." if "dialog_manager" not in st.session_state: return "❌ 对话管理器未初始化" generated_result = st.session_state.dialog_manager.generated_code(query, selected_knowledge) # 处理返回结构 if not generated_result: return "❌ 代码生成失败" # 检查返回的是否是字典结构 if isinstance(generated_result, dict): if not generated_result.get('status', False): return f"❌ 代码生成失败: {generated_result.get('message', '未知错误')}" # 提取代码内容 code = generated_result.get('data', '') if not code: return "❌ 生成的代码为空" # 将生成的代码保存到session_state中 st.session_state.generated_code = code else: # 如果直接返回的是代码字符串 st.session_state.generated_code = generated_result if "current_indicator_id" in st.session_state: st.session_state.current_query_id = st.session_state.current_indicator_id return "✅ 指标代码生成成功,可以在下方查看并执行" except Exception as e: import traceback error_details = traceback.format_exc() return f"❌ 指标代码生成失败: {str(e)}" return result @st.cache_resource def init_app_data(): """ 初始化应用数据,使用缓存资源以避免重复加载 """ # 加载依赖 deps = load_dependencies() config = deps["Config"]() business_structure = deps["load_file"](config.business_object_structure_path) bowei_api_docs = deps["load_file"](config.bowei_api_docs_path) llm_client = deps["MultiAPIKeyChatOpenAI"](config.openai) user_interaction = deps["UserInteraction"](llm_client.llm, business_structure) llm_client_coder = deps["MultiAPIKeyChatOpenAI"](config.openai_coder) prompt_manager = deps["PromptManager"]() neo4j_conf = config.neo4j_conf embedding_conf = config.embedding embedding_client = deps["EmbeddingClient"](embedding_conf) # 创建Neo4j检索器 knowledge_retriever = deps["Neo4jRawRetriever"](neo4j_conf) deps["ProjectBuilder"].register(deps["ProjectToolkitNeo4j"], knowledge_retriever.driver) code_executor = deps["CodeExecutor"](prompt_manager.prompts, llm_client_coder, config.max_retries) dialog_manager = deps["DialogManager"]( llm_client, business_structure, bowei_api_docs, code_executor, knowledge_retriever, prompt_manager, ) # 将必要的对象存储在session_state中 return { "dialog_manager": dialog_manager, "user_interaction": user_interaction, "code_executor": code_executor, "knowledge_retriever": knowledge_retriever } # Streamlit 应用界面 def main(): st.set_page_config(page_title="造价工程指标管理器", layout="wide") # 使用进度条显示加载过程 with st.spinner("正在初始化应用..."): # 初始化应用数据 app_data = init_app_data() # 将必要的对象存储在session_state中 for key, value in app_data.items(): st.session_state[key] = value # 初始化会话状态 if "indicator_manager" not in st.session_state: st.session_state.indicator_manager = IndicatorManager() if "current_indicator_id" not in st.session_state: st.session_state.current_indicator_id = None # 加载保存的会话状态 load_session_state() # 侧边栏 with st.sidebar: # 第一个expander:指标管理 with st.expander("新建指标工具", expanded=True): # 显示清除指标链接 col1, col2, col3 = st.columns([2, 1, 1]) with col1: st.subheader("指标列表") with col2: # 新建指标按钮 if st.button("新建"): try: # 生成新指标名称 indicator_count = len(st.session_state.indicator_manager.get_all_indicators()) new_name = f"指标 {indicator_count + 1}" # 创建新指标 new_indicator = st.session_state.indicator_manager.add_indicator(name=new_name) st.session_state.current_indicator_id = new_indicator.id # 清除已生成的代码 if "generated_code" in st.session_state: del st.session_state.generated_code # 清除第二个expander的状态 st.session_state.view_indicator_detail = False save_session_state() st.rerun() except ValueError as e: st.error(str(e)) with col3: if st.button("清空", help="删除所有指标", type="secondary"): st.session_state.indicator_manager.clear_all_indicators() st.session_state.current_indicator_id = None # 清除已生成的代码 if "generated_code" in st.session_state: del st.session_state.generated_code # 清除第二个expander的状态 st.session_state.view_indicator_detail = False save_session_state() st.rerun() # 显示指标列表 indicators = st.session_state.indicator_manager.get_all_indicators() if not indicators: st.info("暂无指标,请点击\"新建指标\"按钮创建") else: for indicator in sorted(indicators, key=lambda x: x.created_at): if st.button( indicator.name, key=f"btn_{indicator.id}", use_container_width=True, type="primary" if st.session_state.current_indicator_id == indicator.id else "secondary" ): st.session_state.current_indicator_id = indicator.id # 更新显示的代码 if indicator.code: st.session_state.generated_code = indicator.code elif "generated_code" in st.session_state: # 如果当前指标没有代码,清除session中的代码 del st.session_state.generated_code # 清除第二个expander的状态 st.session_state.view_indicator_detail = False save_session_state() st.rerun() # 第二个expander:指标库浏览 with st.expander("查看指标库", expanded=False): # 设置默认文件路径 file_path = os.path.join(DATA_DIR, "code.jsonl") # 文件路径输入 #file_path = st.text_input("JSONL文件路径", value=default_file_path, key="jsonl_path") # 加载JSONL文件的函数 @st.cache_data def load_jsonl(file_path): """加载JSONL文件并返回JSON指标列表""" records = [] try: with open(file_path, "r", encoding="utf-8") as file: for line in file: if line.strip(): records.append(json.loads(line)) return records except Exception as e: st.error(f"加载文件失败: {str(e)}") return [] # 检查文件是否存在 if not os.path.exists(file_path): st.warning(f"文件不存在: {file_path}") else: # 加载JSONL文件 records = load_jsonl(file_path) if not records: st.warning("没有找到有效的指标") else: # 初始化会话状态 if "selected_index" not in st.session_state: st.session_state.selected_index = 0 # 显示指标列表 st.subheader("指标列表") for i, record in enumerate(records): if st.button(record.get("name", f"指标 {i+1}"), key=f"lib_btn_{i}", use_container_width=True): st.session_state.selected_index = i # 清除第一个expander的选中状态 st.session_state.current_indicator_id = None # 设置查看指标详情的逻辑 st.session_state.view_indicator_detail = True st.session_state.current_library_indicator = record save_session_state() st.rerun() # 主区域 if "view_indicator_detail" in st.session_state and st.session_state.view_indicator_detail and "current_library_indicator" in st.session_state: # 显示从指标库中选择的指标详细信息 library_indicator = st.session_state.current_library_indicator # 指标名称和操作按钮 col1, col2, col3 = st.columns([1, 4, 2]) with col1: st.markdown('