From dc43b691648908cc6da9618cad2af08ce5adc894 Mon Sep 17 00:00:00 2001 From: paituo <330435863@qq.com> Date: Mon, 7 Jul 2025 08:51:19 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0streamlt=E7=95=8C=E9=9D=A2?= =?UTF-8?q?=E5=85=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main_streamlt.py | 579 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 579 insertions(+) create mode 100644 main_streamlt.py diff --git a/main_streamlt.py b/main_streamlt.py new file mode 100644 index 0000000..410aed4 --- /dev/null +++ b/main_streamlt.py @@ -0,0 +1,579 @@ +import os +import json +import streamlit as st +import uuid +import time +from typing import Dict, List, Optional, Any +import logging +import sys +from datetime import datetime + +# 添加项目根目录到路径 +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# 配置日志 +current_file = os.path.splitext(os.path.basename(__file__))[0] +now_str = datetime.now().strftime("%Y%m%d%H%M%S") +log_filename = f"{current_file}_{now_str}.log" + +# 确保日志目录存在 +os.makedirs("logs", exist_ok=True) + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"), + logging.StreamHandler() + ], +) + +logger = logging.getLogger(current_file) + +def setup_logger(logger_name): + """ + 设置指定名称的logger,将其级别设置为WARNING并禁用传播 + :param logger_name: logger的名称 + """ + logger = logging.getLogger(logger_name) + logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别 + logger.propagate = False # 可选:禁用传播(防止被根logger处理) + return logger + + +logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"] +for name in logger_names: + setup_logger(name) + +# 延迟导入依赖,只有在需要时才加载 +@st.cache_resource +def load_dependencies(): + """ + 延迟加载依赖项,并缓存资源以提高性能 + """ + from src.config import Config + from src.document_loader import load_file + from src.multi_llm_client import MultiAPIKeyChatOpenAI + from src.user_interaction import UserInteraction + from src.dialog_manager import DialogManager + from src.code_executor import CodeExecutor + from src.neo4j_raw_retriever import Neo4jRawRetriever + from src.prompt_manager import PromptManager + from src.embedding_client import EmbeddingClient + from src.project import ProjectBuilder, ProjectToolkit + from src.project_implementation import ProjectToolkitNeo4j + + return { + "Config": Config, + "load_file": load_file, + "MultiAPIKeyChatOpenAI": MultiAPIKeyChatOpenAI, + "UserInteraction": UserInteraction, + "DialogManager": DialogManager, + "CodeExecutor": CodeExecutor, + "Neo4jRawRetriever": Neo4jRawRetriever, + "PromptManager": PromptManager, + "EmbeddingClient": EmbeddingClient, + "ProjectBuilder": ProjectBuilder, + "ProjectToolkitNeo4j": ProjectToolkitNeo4j + } + +# 后台数据结构 +class Indicator: + def __init__(self, name: str, query: str = "", id: Optional[str] = None): + self.id = id if id else str(uuid.uuid4()) + self.name = name + self.query = query + self.created_at = time.time() + self.updated_at = time.time() + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "name": self.name, + "query": self.query, + "created_at": self.created_at, + "updated_at": self.updated_at + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'Indicator': + indicator = cls( + name=data["name"], + query=data["query"], + id=data["id"] + ) + indicator.created_at = data.get("created_at", time.time()) + indicator.updated_at = data.get("updated_at", time.time()) + return indicator + + +class IndicatorManager: + def __init__(self, save_path: str = "indicator_library.json"): + self.indicators: Dict[str, Indicator] = {} + self.save_path = save_path + self.load_indicators() + + def load_indicators(self) -> None: + """从文件加载指标库""" + try: + if os.path.exists(self.save_path): + with open(self.save_path, "r", encoding="utf-8") as f: + data = json.load(f) + for indicator_data in data: + indicator = Indicator.from_dict(indicator_data) + self.indicators[indicator.id] = indicator + except Exception as e: + st.error(f"加载指标库失败: {str(e)}") + + def save_indicators(self) -> None: + """保存指标库到文件""" + try: + with open(self.save_path, "w", encoding="utf-8") as f: + json.dump([ind.to_dict() for ind in self.indicators.values()], f, ensure_ascii=False, indent=2) + except Exception as e: + st.error(f"保存指标库失败: {str(e)}") + + def add_indicator(self, name: str, query: str = "") -> Indicator: + """添加新指标""" + if not name.strip(): + raise ValueError("指标名称不能为空") + + indicator = Indicator(name=name, query=query) + self.indicators[indicator.id] = indicator + self.save_indicators() + return indicator + + def update_indicator(self, id: str, name: Optional[str] = None, query: Optional[str] = None) -> Optional[Indicator]: + """更新指标""" + if id not in self.indicators: + return None + + if name is not None and not name.strip(): + raise ValueError("指标名称不能为空") + + indicator = self.indicators[id] + if name is not None: + indicator.name = name + if query is not None: + indicator.query = query + + indicator.updated_at = time.time() + self.save_indicators() + return indicator + + def delete_indicator(self, id: str) -> bool: + """删除指标""" + if id in self.indicators: + del self.indicators[id] + self.save_indicators() + return True + return False + + def clear_all_indicators(self) -> None: + """清除所有指标""" + self.indicators.clear() + self.save_indicators() + + def get_indicator(self, id: str) -> Optional[Indicator]: + """获取指标""" + return self.indicators.get(id) + + def get_all_indicators(self) -> List[Indicator]: + """获取所有指标""" + return list(self.indicators.values()) + + +# 保存会话状态 +def ensure_cache_dir(): + """确保cache目录存在""" + cache_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "cache") + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + +@st.cache_data(ttl=3600) +def save_session_state(): + """保存会话状态到文件""" + try: + state = { + "current_indicator_id": st.session_state.get("current_indicator_id", None) + } + cache_dir = ensure_cache_dir() + state_file = os.path.join(cache_dir, "indicator_creator_state.json") + with open(state_file, "w", encoding="utf-8") as f: + json.dump(state, f, ensure_ascii=False, indent=2) + except Exception as e: + st.error(f"保存会话状态失败: {str(e)}") + +# 加载会话状态 +@st.cache_data(ttl=3600) +def load_session_state(): + """从文件加载会话状态""" + try: + cache_dir = ensure_cache_dir() + state_file = os.path.join(cache_dir, "indicator_creator_state.json") + if os.path.exists(state_file): + with open(state_file, "r", encoding="utf-8") as f: + state = json.load(f) + for key, value in state.items(): + if key not in st.session_state: + st.session_state[key] = value + except Exception as e: + st.error(f"加载会话状态失败: {str(e)}") + +# 测试指标查询 +@st.cache_data(ttl=60) # 缓存结果1分钟 +def test_indicator_query(query: str) -> str: + """测试指标查询""" + if not query.strip(): + return "查询语句为空,无法测试" + + # 模拟异步测试过程 + result = "开始测试查询...\n" + + try: + # 步骤1:通过调用user_interaction.understand(query)实现 + result += "步骤1: 正在理解查询语句...\n" + if "user_interaction" not in st.session_state: + st.session_state.user_interaction = st.session_state.dialog_manager.user_interaction + + understanding_result = st.session_state.user_interaction.understand(query) + if not understanding_result: + result += "❌ 无法理解查询语句\n" + return result + + # 步骤2:验证步骤一返回值 + result += "步骤2: 正在验证理解结果...\n" + selected_knowledge = [] + for item in understanding_result: + result += f"- 识别到实体: {item.get('name')}, 约束条件: {item.get('constraints')}\n" + selected_knowledge.append(item) + + if not selected_knowledge: + result += "❌ 没有找到相关知识\n" + return result + + # 步骤3:dialog_manager.generated_code(query, selected_knowledge) 实现 + result += "步骤3: 正在生成查询代码...\n" + if "dialog_manager" not in st.session_state: + result += "❌ 对话管理器未初始化\n" + return result + + generated_code = st.session_state.dialog_manager.generated_code(query, selected_knowledge) + if not generated_code: + result += "❌ 代码生成失败\n" + return result + + result += "✅ 代码生成成功\n" + + # 步骤4:通过调用code_executor.execute_code 实现 + result += "步骤4: 正在执行查询代码...\n" + if "code_executor" not in st.session_state: + st.session_state.code_executor = st.session_state.dialog_manager.code_executor + + execution_result = st.session_state.code_executor.execute_code(generated_code) + + if execution_result and "error" not in execution_result.lower(): + result += "✅ 查询执行成功!\n" + result += f"查询结果:\n{execution_result}" + else: + result += f"❌ 查询执行失败: {execution_result}\n" + + except Exception as e: + import traceback + error_details = traceback.format_exc() + result += f"❌ 查询测试失败: {str(e)}\n" + result += f"错误详情: {error_details}" + + return result + +@st.cache_resource +def init_app_data(): + """ + 初始化应用数据,使用缓存资源以避免重复加载 + """ + # 加载依赖 + deps = load_dependencies() + + config = deps["Config"]() + + business_structure = deps["load_file"](config.business_object_structure_path) + bowei_api_docs = deps["load_file"](config.bowei_api_docs_path) + + llm_client = deps["MultiAPIKeyChatOpenAI"](config.openai) + user_interaction = deps["UserInteraction"](llm_client.llm, business_structure) + + llm_client_coder = deps["MultiAPIKeyChatOpenAI"](config.openai_coder) + + prompt_manager = deps["PromptManager"]() + + neo4j_conf = config.neo4j_conf + embedding_conf = config.embedding + + embedding_client = deps["EmbeddingClient"](embedding_conf) + + # 创建Neo4j检索器 + knowledge_retriever = deps["Neo4jRawRetriever"](neo4j_conf) + + deps["ProjectBuilder"].register(deps["ProjectToolkitNeo4j"], knowledge_retriever.driver) + + code_executor = deps["CodeExecutor"](prompt_manager.prompts, llm_client_coder, config.max_retries) + + dialog_manager = deps["DialogManager"]( + llm_client, + business_structure, + bowei_api_docs, + code_executor, + knowledge_retriever, + prompt_manager, + ) + + # 将必要的对象存储在session_state中 + return { + "dialog_manager": dialog_manager, + "user_interaction": user_interaction, + "code_executor": code_executor, + "knowledge_retriever": knowledge_retriever + } + + +# Streamlit 应用界面 +def main(): + st.set_page_config(page_title="指标创建器", layout="wide") + + # 使用进度条显示加载过程 + with st.spinner("正在初始化应用..."): + # 初始化应用数据 + app_data = init_app_data() + + # 将必要的对象存储在session_state中 + for key, value in app_data.items(): + st.session_state[key] = value + + # 初始化会话状态 + if "indicator_manager" not in st.session_state: + st.session_state.indicator_manager = IndicatorManager() + + if "current_indicator_id" not in st.session_state: + st.session_state.current_indicator_id = None + + # 加载保存的会话状态 + load_session_state() + + # 侧边栏 + with st.sidebar: + st.title("造价工程指标管理器") + + # 第一个expander:指标管理 + with st.expander("新建指标工具", expanded=True): + + # 显示清除指标链接 + col1, col2, col3 = st.columns([2, 1, 1]) + with col1: + st.subheader("指标列表") + with col2: + # 新建指标按钮 + if st.button("新建"): + try: + # 生成新指标名称 + indicator_count = len(st.session_state.indicator_manager.get_all_indicators()) + new_name = f"指标 {indicator_count + 1}" + + # 创建新指标 + new_indicator = st.session_state.indicator_manager.add_indicator(name=new_name) + st.session_state.current_indicator_id = new_indicator.id + save_session_state() + st.rerun() + except ValueError as e: + st.error(str(e)) + with col3: + if st.button("清空", help="删除所有指标", type="secondary"): + st.session_state.indicator_manager.clear_all_indicators() + st.session_state.current_indicator_id = None + save_session_state() + st.rerun() + + # 显示指标列表 + indicators = st.session_state.indicator_manager.get_all_indicators() + if not indicators: + st.info("暂无指标,请点击\"新建指标\"按钮创建") + else: + for indicator in sorted(indicators, key=lambda x: x.created_at): + if st.button( + indicator.name, + key=f"btn_{indicator.id}", + use_container_width=True, + type="primary" if st.session_state.current_indicator_id == indicator.id else "secondary" + ): + st.session_state.current_indicator_id = indicator.id + save_session_state() + st.rerun() + + # 第二个expander:指标库浏览 + with st.expander("查看指标库", expanded=False): + + # 设置默认文件路径 + file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tests", "code.jsonl") + + # 文件路径输入 + #file_path = st.text_input("JSONL文件路径", value=default_file_path, key="jsonl_path") + + # 加载JSONL文件的函数 + @st.cache_data + def load_jsonl(file_path): + """加载JSONL文件并返回JSON指标列表""" + records = [] + try: + with open(file_path, "r", encoding="utf-8") as file: + for line in file: + if line.strip(): + records.append(json.loads(line)) + return records + except Exception as e: + st.error(f"加载文件失败: {str(e)}") + return [] + + # 检查文件是否存在 + if not os.path.exists(file_path): + st.warning(f"文件不存在: {file_path}") + else: + # 加载JSONL文件 + records = load_jsonl(file_path) + + if not records: + st.warning("没有找到有效的指标") + else: + # 初始化会话状态 + if "selected_index" not in st.session_state: + st.session_state.selected_index = 0 + + # 显示指标列表 + st.subheader("指标列表") + for i, record in enumerate(records): + if st.button(record.get("name", f"指标 {i+1}"), key=f"lib_btn_{i}", use_container_width=True): + st.session_state.selected_index = i + # 可以添加查看指标详情的逻辑 + st.session_state.view_indicator_detail = True + st.session_state.current_library_indicator = record + + # 主区域 + if "view_indicator_detail" in st.session_state and st.session_state.view_indicator_detail and "current_library_indicator" in st.session_state: + # 显示从指标库中选择的指标详细信息 + library_indicator = st.session_state.current_library_indicator + + st.header(f"指标名称: {library_indicator.get('name', '未命名指标')}") + + # 显示查询语句 + st.subheader("查询语句") + st.info(library_indicator.get('query', '无查询语句')) + + # 显示代码 + st.subheader("指标代码") + st.code(library_indicator.get('code', '无代码'), language='python') + + # 添加执行按钮 + if st.button("执行指标代码", key="execute_lib_indicator"): + with st.spinner('正在执行代码...'): + # 确保code_executor已初始化 + if "code_executor" in st.session_state: + result = st.session_state.code_executor.execute_code(library_indicator.get('code', '')) + if result and hasattr(result, 'get') and result.get('status'): + st.success("执行成功") + st.json(result.get('data', {})) + else: + st.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}") + else: + st.error("代码执行器未初始化") + + # 添加导入按钮 - 将库中指标导入到当前工作区 + if st.button("导入到我的指标", key="import_to_my_indicators"): + try: + new_indicator = st.session_state.indicator_manager.add_indicator( + name=library_indicator.get('name', '导入的指标'), + query=library_indicator.get('query', '') + ) + st.session_state.current_indicator_id = new_indicator.id + st.session_state.view_indicator_detail = False + save_session_state() + st.success(f"已导入指标: {new_indicator.name}") + st.rerun() + except Exception as e: + st.error(f"导入失败: {str(e)}") + + # 添加返回按钮 + if st.button("返回", key="back_to_indicators"): + st.session_state.view_indicator_detail = False + st.rerun() + + elif st.session_state.current_indicator_id is not None: + current_indicator = st.session_state.indicator_manager.get_indicator(st.session_state.current_indicator_id) + + if current_indicator: + # 指标名称和保存按钮 + col1, col2, col3 = st.columns([1, 5, 2]) + + with col1: + st.markdown('
指标名称:
', unsafe_allow_html=True) + + with col2: + new_name = st.text_input("", value=current_indicator.name, key="indicator_name", label_visibility="collapsed") + if new_name != current_indicator.name: + if not new_name.strip(): + st.error("指标名称不能为空") + else: + try: + st.session_state.indicator_manager.update_indicator( + current_indicator.id, name=new_name + ) + except ValueError as e: + st.error(str(e)) + + with col3: + if st.button("💾 保存指标", use_container_width=True): + if not new_name.strip(): + st.error("指标名称不能为空") + else: + try: + st.toast("指标已保存") + st.session_state.indicator_manager.save_indicators() + except Exception as e: + st.error(f"保存失败: {str(e)}") + + # 查询语句 + st.subheader("查询语句") + col1, col2 = st.columns([6, 2]) + with col1: + query = st.text_area( + "输入查询语句", + value=current_indicator.query, + height=70, + key="query_input", + placeholder="在此输入查询语句...", + label_visibility="collapsed" + ) + + if query != current_indicator.query: + st.session_state.indicator_manager.update_indicator( + current_indicator.id, query=query + ) + + with col2: + st.write("") # 添加一些空间以对齐按钮 + st.write("") + if st.button("🧪 测试", use_container_width=True): + with st.spinner("正在测试查询..."): + # 使用缓存的测试函数 + result = test_indicator_query(query) + st.session_state.test_result = result + + # 测试结果 + if "test_result" in st.session_state: + st.subheader("测试结果") + st.code(st.session_state.test_result) + else: + st.warning("找不到当前指标,可能已被删除") + else: + st.info("请在左侧创建或选择一个指标") + + +if __name__ == "__main__": + main() \ No newline at end of file