From cfb4f4ea6a0350a95cf747433c958c4c8e655cc1 Mon Sep 17 00:00:00 2001 From: paituo <330435863@qq.com> Date: Mon, 7 Jul 2025 09:32:05 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0WEB=E5=85=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main_streamlt.py | 350 +++++++++++++++++++++++++++--------------- src/prompt_manager.py | 11 +- 2 files changed, 234 insertions(+), 127 deletions(-) diff --git a/main_streamlt.py b/main_streamlt.py index 410aed4..41c25e0 100644 --- a/main_streamlt.py +++ b/main_streamlt.py @@ -1,6 +1,6 @@ -import os -import json import streamlit as st +import json +import os import uuid import time from typing import Dict, List, Optional, Any @@ -11,19 +11,26 @@ from datetime import datetime # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# 确保必要的目录存在 +CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache") +DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") +os.makedirs(CACHE_DIR, exist_ok=True) +os.makedirs(DATA_DIR, exist_ok=True) + # 配置日志 current_file = os.path.splitext(os.path.basename(__file__))[0] now_str = datetime.now().strftime("%Y%m%d%H%M%S") log_filename = f"{current_file}_{now_str}.log" -# 确保日志目录存在 -os.makedirs("logs", exist_ok=True) +# 确保日志目录存在并设置在logs目录下 +log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs") +os.makedirs(log_dir, exist_ok=True) logging.basicConfig( level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ - logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"), + logging.FileHandler(os.path.join(log_dir, log_filename), encoding="utf-8"), logging.StreamHandler() ], ) @@ -79,10 +86,11 @@ def load_dependencies(): # 后台数据结构 class Indicator: - def __init__(self, name: str, query: str = "", id: Optional[str] = None): + def __init__(self, name: str, query: str = "", code: str = "", id: Optional[str] = None): self.id = id if id else str(uuid.uuid4()) self.name = name self.query = query + self.code = code self.created_at = time.time() self.updated_at = time.time() @@ -91,6 +99,7 @@ class Indicator: "id": self.id, "name": self.name, "query": self.query, + "code": self.code, "created_at": self.created_at, "updated_at": self.updated_at } @@ -100,6 +109,7 @@ class Indicator: indicator = cls( name=data["name"], query=data["query"], + code=data.get("code", ""), # 使用 get 方法处理可能不存在的 code 字段 id=data["id"] ) indicator.created_at = data.get("created_at", time.time()) @@ -108,7 +118,9 @@ class Indicator: class IndicatorManager: - def __init__(self, save_path: str = "indicator_library.json"): + def __init__(self, save_path: Optional[str] = None): + if save_path is None: + save_path = os.path.join(CACHE_DIR, "indicator_library.json") self.indicators: Dict[str, Indicator] = {} self.save_path = save_path self.load_indicators() @@ -143,7 +155,7 @@ class IndicatorManager: self.save_indicators() return indicator - def update_indicator(self, id: str, name: Optional[str] = None, query: Optional[str] = None) -> Optional[Indicator]: + def update_indicator(self, id: str, name: Optional[str] = None, query: Optional[str] = None, code: Optional[str] = None) -> Optional[Indicator]: """更新指标""" if id not in self.indicators: return None @@ -156,6 +168,8 @@ class IndicatorManager: indicator.name = name if query is not None: indicator.query = query + if code is not None: + indicator.code = code indicator.updated_at = time.time() self.save_indicators() @@ -184,12 +198,6 @@ class IndicatorManager: # 保存会话状态 -def ensure_cache_dir(): - """确保cache目录存在""" - cache_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "cache") - os.makedirs(cache_dir, exist_ok=True) - return cache_dir - @st.cache_data(ttl=3600) def save_session_state(): """保存会话状态到文件""" @@ -197,20 +205,19 @@ def save_session_state(): state = { "current_indicator_id": st.session_state.get("current_indicator_id", None) } - cache_dir = ensure_cache_dir() - state_file = os.path.join(cache_dir, "indicator_creator_state.json") + state_file = os.path.join(CACHE_DIR, "indicator_creator_state.json") with open(state_file, "w", encoding="utf-8") as f: json.dump(state, f, ensure_ascii=False, indent=2) except Exception as e: st.error(f"保存会话状态失败: {str(e)}") + # 加载会话状态 @st.cache_data(ttl=3600) def load_session_state(): """从文件加载会话状态""" try: - cache_dir = ensure_cache_dir() - state_file = os.path.join(cache_dir, "indicator_creator_state.json") + state_file = os.path.join(CACHE_DIR, "indicator_creator_state.json") if os.path.exists(state_file): with open(state_file, "r", encoding="utf-8") as f: state = json.load(f) @@ -220,69 +227,76 @@ def load_session_state(): except Exception as e: st.error(f"加载会话状态失败: {str(e)}") + # 测试指标查询 @st.cache_data(ttl=60) # 缓存结果1分钟 -def test_indicator_query(query: str) -> str: - """测试指标查询""" +def generate_indicator_code(query: str) -> str: + """生成指标代码""" if not query.strip(): - return "查询语句为空,无法测试" + return "查询语句为空,无法生成代码" - # 模拟异步测试过程 - result = "开始测试查询...\n" + result = "" try: # 步骤1:通过调用user_interaction.understand(query)实现 - result += "步骤1: 正在理解查询语句...\n" + result = "正在理解查询语句..." if "user_interaction" not in st.session_state: st.session_state.user_interaction = st.session_state.dialog_manager.user_interaction understanding_result = st.session_state.user_interaction.understand(query) if not understanding_result: - result += "❌ 无法理解查询语句\n" - return result + return "❌ 无法理解查询语句" # 步骤2:验证步骤一返回值 - result += "步骤2: 正在验证理解结果...\n" + result = "正在验证理解结果..." selected_knowledge = [] + entity_info = [] for item in understanding_result: - result += f"- 识别到实体: {item.get('name')}, 约束条件: {item.get('constraints')}\n" + entity_info.append(f"识别到实体: {item.get('name')}, 约束条件: {item.get('constraints')}") selected_knowledge.append(item) if not selected_knowledge: - result += "❌ 没有找到相关知识\n" - return result + return "❌ 没有找到相关知识" + + # 显示识别到的实体信息 + result = "\n".join(entity_info) # 步骤3:dialog_manager.generated_code(query, selected_knowledge) 实现 - result += "步骤3: 正在生成查询代码...\n" + result = "正在生成指标代码..." if "dialog_manager" not in st.session_state: - result += "❌ 对话管理器未初始化\n" - return result + return "❌ 对话管理器未初始化" - generated_code = st.session_state.dialog_manager.generated_code(query, selected_knowledge) - if not generated_code: - result += "❌ 代码生成失败\n" - return result - - result += "✅ 代码生成成功\n" + generated_result = st.session_state.dialog_manager.generated_code(query, selected_knowledge) - # 步骤4:通过调用code_executor.execute_code 实现 - result += "步骤4: 正在执行查询代码...\n" - if "code_executor" not in st.session_state: - st.session_state.code_executor = st.session_state.dialog_manager.code_executor + # 处理返回结构 + if not generated_result: + return "❌ 代码生成失败" - execution_result = st.session_state.code_executor.execute_code(generated_code) - - if execution_result and "error" not in execution_result.lower(): - result += "✅ 查询执行成功!\n" - result += f"查询结果:\n{execution_result}" + # 检查返回的是否是字典结构 + if isinstance(generated_result, dict): + if not generated_result.get('status', False): + return f"❌ 代码生成失败: {generated_result.get('message', '未知错误')}" + + # 提取代码内容 + code = generated_result.get('data', '') + if not code: + return "❌ 生成的代码为空" + + # 将生成的代码保存到session_state中 + st.session_state.generated_code = code else: - result += f"❌ 查询执行失败: {execution_result}\n" + # 如果直接返回的是代码字符串 + st.session_state.generated_code = generated_result + + if "current_indicator_id" in st.session_state: + st.session_state.current_query_id = st.session_state.current_indicator_id + + return "✅ 指标代码生成成功,可以在下方查看并执行" except Exception as e: import traceback error_details = traceback.format_exc() - result += f"❌ 查询测试失败: {str(e)}\n" - result += f"错误详情: {error_details}" + return f"❌ 指标代码生成失败: {str(e)}" return result @@ -338,7 +352,7 @@ def init_app_data(): # Streamlit 应用界面 def main(): - st.set_page_config(page_title="指标创建器", layout="wide") + st.set_page_config(page_title="造价工程指标管理器", layout="wide") # 使用进度条显示加载过程 with st.spinner("正在初始化应用..."): @@ -361,8 +375,7 @@ def main(): # 侧边栏 with st.sidebar: - st.title("造价工程指标管理器") - + # 第一个expander:指标管理 with st.expander("新建指标工具", expanded=True): @@ -381,6 +394,11 @@ def main(): # 创建新指标 new_indicator = st.session_state.indicator_manager.add_indicator(name=new_name) st.session_state.current_indicator_id = new_indicator.id + # 清除已生成的代码 + if "generated_code" in st.session_state: + del st.session_state.generated_code + # 清除第二个expander的状态 + st.session_state.view_indicator_detail = False save_session_state() st.rerun() except ValueError as e: @@ -389,6 +407,11 @@ def main(): if st.button("清空", help="删除所有指标", type="secondary"): st.session_state.indicator_manager.clear_all_indicators() st.session_state.current_indicator_id = None + # 清除已生成的代码 + if "generated_code" in st.session_state: + del st.session_state.generated_code + # 清除第二个expander的状态 + st.session_state.view_indicator_detail = False save_session_state() st.rerun() @@ -405,6 +428,14 @@ def main(): type="primary" if st.session_state.current_indicator_id == indicator.id else "secondary" ): st.session_state.current_indicator_id = indicator.id + # 更新显示的代码 + if indicator.code: + st.session_state.generated_code = indicator.code + elif "generated_code" in st.session_state: + # 如果当前指标没有代码,清除session中的代码 + del st.session_state.generated_code + # 清除第二个expander的状态 + st.session_state.view_indicator_detail = False save_session_state() st.rerun() @@ -412,7 +443,7 @@ def main(): with st.expander("查看指标库", expanded=False): # 设置默认文件路径 - file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tests", "code.jsonl") + file_path = os.path.join(DATA_DIR, "code.jsonl") # 文件路径输入 #file_path = st.text_input("JSONL文件路径", value=default_file_path, key="jsonl_path") @@ -451,71 +482,82 @@ def main(): for i, record in enumerate(records): if st.button(record.get("name", f"指标 {i+1}"), key=f"lib_btn_{i}", use_container_width=True): st.session_state.selected_index = i - # 可以添加查看指标详情的逻辑 + # 清除第一个expander的选中状态 + st.session_state.current_indicator_id = None + # 设置查看指标详情的逻辑 st.session_state.view_indicator_detail = True st.session_state.current_library_indicator = record + save_session_state() + st.rerun() # 主区域 if "view_indicator_detail" in st.session_state and st.session_state.view_indicator_detail and "current_library_indicator" in st.session_state: # 显示从指标库中选择的指标详细信息 library_indicator = st.session_state.current_library_indicator - st.header(f"指标名称: {library_indicator.get('name', '未命名指标')}") + # 指标名称和操作按钮 + col1, col2, col3 = st.columns([1, 4, 2]) + with col1: + st.markdown('
指标名称:
', unsafe_allow_html=True) + with col2: + st.markdown(f'
{library_indicator.get("name", "未命名指标")}
', unsafe_allow_html=True) + + + # 查询语句部分 + title_col, run_btn_col, result_col = st.columns([1, 1, 4]) + with title_col: + st.subheader("查询语句") + with result_col: + title_query_result = st.subheader("") # 显示查询语句 - st.subheader("查询语句") - st.info(library_indicator.get('query', '无查询语句')) + st.text_area( + "查询语句", + value=library_indicator.get('query', '无查询语句'), + height=70, + key="lib_query_display", + disabled=True, + label_visibility="collapsed" + ) + + # 指标代码部分 + title_col, run_btn_col, result_col = st.columns([1, 1, 4]) + with title_col: + st.subheader("指标代码") + with result_col: + title_code_result = st.subheader("") + with run_btn_col: + if st.button("执行代码", key="execute_lib_indicator", use_container_width=True): + with st.spinner('正在执行代码...'): + # 确保code_executor已初始化 + if "code_executor" in st.session_state: + result = st.session_state.code_executor.execute_code(library_indicator.get('code', '')) + if result and hasattr(result, 'get') and result.get('status'): + title_code_result.success("执行成功") + title_code_result.json(result.get('data', {})) + else: + title_code_result.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}") + else: + title_code_result.error("代码执行器未初始化") # 显示代码 - st.subheader("指标代码") st.code(library_indicator.get('code', '无代码'), language='python') - # 添加执行按钮 - if st.button("执行指标代码", key="execute_lib_indicator"): - with st.spinner('正在执行代码...'): - # 确保code_executor已初始化 - if "code_executor" in st.session_state: - result = st.session_state.code_executor.execute_code(library_indicator.get('code', '')) - if result and hasattr(result, 'get') and result.get('status'): - st.success("执行成功") - st.json(result.get('data', {})) - else: - st.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}") - else: - st.error("代码执行器未初始化") - - # 添加导入按钮 - 将库中指标导入到当前工作区 - if st.button("导入到我的指标", key="import_to_my_indicators"): - try: - new_indicator = st.session_state.indicator_manager.add_indicator( - name=library_indicator.get('name', '导入的指标'), - query=library_indicator.get('query', '') - ) - st.session_state.current_indicator_id = new_indicator.id - st.session_state.view_indicator_detail = False - save_session_state() - st.success(f"已导入指标: {new_indicator.name}") - st.rerun() - except Exception as e: - st.error(f"导入失败: {str(e)}") - - # 添加返回按钮 - if st.button("返回", key="back_to_indicators"): - st.session_state.view_indicator_detail = False - st.rerun() + # 移除导入按钮部分 + elif st.session_state.current_indicator_id is not None: current_indicator = st.session_state.indicator_manager.get_indicator(st.session_state.current_indicator_id) if current_indicator: # 指标名称和保存按钮 - col1, col2, col3 = st.columns([1, 5, 2]) + col1, col2, col3 = st.columns([1, 4, 2]) with col1: st.markdown('
指标名称:
', unsafe_allow_html=True) with col2: - new_name = st.text_input("", value=current_indicator.name, key="indicator_name", label_visibility="collapsed") + new_name = st.text_input("指标名称", value=current_indicator.name, key="indicator_name", label_visibility="collapsed") if new_name != current_indicator.name: if not new_name.strip(): st.error("指标名称不能为空") @@ -531,44 +573,108 @@ def main(): if st.button("💾 保存指标", use_container_width=True): if not new_name.strip(): st.error("指标名称不能为空") + elif not current_indicator.query.strip(): + st.error("请先输入查询语句") + elif not current_indicator.code: + st.error("请先生成指标代码") else: try: - st.toast("指标已保存") + # 保存到指标管理器 st.session_state.indicator_manager.save_indicators() + + # 同时保存到指标库 + library_path = os.path.join(DATA_DIR, "code.jsonl") + new_record = { + "name": current_indicator.name, + "query": current_indicator.query, + "code": current_indicator.code, + "created_at": time.time() + } + + # 追加到文件 + with open(library_path, "a", encoding="utf-8") as f: + f.write(json.dumps(new_record, ensure_ascii=False) + "\n") + + # 清除指标库的缓存,以便重新加载 + load_jsonl.clear() + + st.toast("指标已保存") + st.rerun() # 刷新页面以更新指标库列表 except Exception as e: st.error(f"保存失败: {str(e)}") # 查询语句 - st.subheader("查询语句") - col1, col2 = st.columns([6, 2]) - with col1: - query = st.text_area( - "输入查询语句", - value=current_indicator.query, - height=70, - key="query_input", - placeholder="在此输入查询语句...", - label_visibility="collapsed" + title_col, test_btn_col, result_col = st.columns([1, 1, 4]) + with title_col: + st.subheader("查询语句") + with result_col: + title_query_result = st.subheader("") + + # 查询输入框 + query = st.text_area( + "输入查询语句", + value=current_indicator.query, + height=70, + key="query_input", + placeholder="在此输入查询语句...", + label_visibility="collapsed" + ) + + if query != current_indicator.query: + st.session_state.indicator_manager.update_indicator( + current_indicator.id, query=query ) - - if query != current_indicator.query: - st.session_state.indicator_manager.update_indicator( - current_indicator.id, query=query - ) - with col2: - st.write("") # 添加一些空间以对齐按钮 - st.write("") - if st.button("🧪 测试", use_container_width=True): - with st.spinner("正在测试查询..."): + # 测试按钮改为生成指标代码按钮 + with test_btn_col: + if st.button("🔄 生成代码", key="generate_code_btn", use_container_width=True): + with st.spinner("正在生成指标代码..."): # 使用缓存的测试函数 - result = test_indicator_query(query) - st.session_state.test_result = result + result = generate_indicator_code(query) + + # 根据结果显示不同的状态 + if result.startswith("✅"): + title_query_result.success(result) + # 生成成功后,立即保存代码到指标对象中 + if "generated_code" in st.session_state: + st.session_state.indicator_manager.update_indicator( + current_indicator.id, + code=st.session_state.generated_code + ) + elif result.startswith("❌"): + title_query_result.error(result) + else: + # 显示实体识别结果或其他信息 + title_query_result.info(result) - # 测试结果 - if "test_result" in st.session_state: - st.subheader("测试结果") - st.code(st.session_state.test_result) + # 添加代码执行部分 + title_col, run_btn_col, result_col = st.columns([1, 1, 4]) + with title_col: + st.subheader("指标代码") + with result_col: + title_code_result = st.subheader("") + with run_btn_col: + if st.button("执行代码", key="execute_code_btn", use_container_width=True): + with st.spinner("正在执行代码..."): + if "code_executor" in st.session_state and "generated_code" in st.session_state: + result = st.session_state.code_executor.execute_code(st.session_state.generated_code) + if result and hasattr(result, 'get') and result.get('status'): + title_code_result.success("执行成功") + title_code_result.json(result.get('data', {})) + # 更新指标的代码 + st.session_state.indicator_manager.update_indicator( + current_indicator.id, + code=st.session_state.generated_code + ) + else: + title_code_result.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}") + else: + title_code_result.error("代码执行器未初始化或没有生成的代码") + + # 显示代码 + st.code(st.session_state.get("generated_code", "# 暂无代码,请输入有效的查询语句并点击测试按钮"), language="python") + + # 移除保存到指标库按钮部分 else: st.warning("找不到当前指标,可能已被删除") else: diff --git a/src/prompt_manager.py b/src/prompt_manager.py index 116d803..bbfd7bb 100644 --- a/src/prompt_manager.py +++ b/src/prompt_manager.py @@ -74,11 +74,12 @@ def project_get_calculate_function(): return result_dict # 执行规则 -- 参数必须从用户问题或上下文信息中提取 -- 禁止在代码函数范围外添加任何注释或解释或非代码内容 -- 输出代码中必须以def project_get_calculate_function() -> dict函数作为入口函数 -- 必须确保生成的代码可以直接执行,如果函数功能求取数值,project的函数返回结果为空或出错则算成功,data为0,并在message说明错误原因,代码要注意进行各类容错检查 -- ProjectToolkit 类中涉及项目划分的函数已考虑在其及其子孙项目划分下查找,所以无需生成递归子项目划分的代码 +- 参数必须从用户问题或上下文信息中提取。 +- 在代码函数内部生成功能说明,在函数外禁止生成任何注释或解释或非代码内容。 +- 输出代码中必须以def project_get_calculate_function() -> dict函数作为入口函数,该函数成功时返回的'data'通常是浮点或整型值,除非用户要求返回其他类型。 +- 必须确保生成的代码可以直接执行,代码要注意进行各类容错检查。 +- 如果project_get_calculate_function函数需要返回浮点整形数,函数中间发生错误或找不到对象也必须返回成功,data为0,并在message说明错误原因。 +- ProjectToolkit 类中涉及项目划分的函数已考虑在其及其子孙项目划分下查找,所以无需生成递归子项目划分的代码。 - 如果文本中包含范围编码格式则需要进行编码展开,如'YX2-1~7'展开为‘YX2-1/YX2-2/YX2-3/YX2-4/YX2-5/YX2-6/YX2-7’ """ )