增加streamlt界面入口

This commit is contained in:
2025-07-07 08:51:19 +08:00
parent d1c129c691
commit dc43b69164
+579
View File
@@ -0,0 +1,579 @@
import os
import json
import streamlit as st
import uuid
import time
from typing import Dict, List, Optional, Any
import logging
import sys
from datetime import datetime
# 添加项目根目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 配置日志
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
# 确保日志目录存在
os.makedirs("logs", exist_ok=True)
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
def setup_logger(logger_name):
"""
设置指定名称的logger,将其级别设置为WARNING并禁用传播
:param logger_name: logger的名称
"""
logger = logging.getLogger(logger_name)
logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
logger.propagate = False # 可选:禁用传播(防止被根logger处理)
return logger
logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"]
for name in logger_names:
setup_logger(name)
# 延迟导入依赖,只有在需要时才加载
@st.cache_resource
def load_dependencies():
"""
延迟加载依赖项,并缓存资源以提高性能
"""
from src.config import Config
from src.document_loader import load_file
from src.multi_llm_client import MultiAPIKeyChatOpenAI
from src.user_interaction import UserInteraction
from src.dialog_manager import DialogManager
from src.code_executor import CodeExecutor
from src.neo4j_raw_retriever import Neo4jRawRetriever
from src.prompt_manager import PromptManager
from src.embedding_client import EmbeddingClient
from src.project import ProjectBuilder, ProjectToolkit
from src.project_implementation import ProjectToolkitNeo4j
return {
"Config": Config,
"load_file": load_file,
"MultiAPIKeyChatOpenAI": MultiAPIKeyChatOpenAI,
"UserInteraction": UserInteraction,
"DialogManager": DialogManager,
"CodeExecutor": CodeExecutor,
"Neo4jRawRetriever": Neo4jRawRetriever,
"PromptManager": PromptManager,
"EmbeddingClient": EmbeddingClient,
"ProjectBuilder": ProjectBuilder,
"ProjectToolkitNeo4j": ProjectToolkitNeo4j
}
# 后台数据结构
class Indicator:
def __init__(self, name: str, query: str = "", id: Optional[str] = None):
self.id = id if id else str(uuid.uuid4())
self.name = name
self.query = query
self.created_at = time.time()
self.updated_at = time.time()
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"query": self.query,
"created_at": self.created_at,
"updated_at": self.updated_at
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Indicator':
indicator = cls(
name=data["name"],
query=data["query"],
id=data["id"]
)
indicator.created_at = data.get("created_at", time.time())
indicator.updated_at = data.get("updated_at", time.time())
return indicator
class IndicatorManager:
def __init__(self, save_path: str = "indicator_library.json"):
self.indicators: Dict[str, Indicator] = {}
self.save_path = save_path
self.load_indicators()
def load_indicators(self) -> None:
"""从文件加载指标库"""
try:
if os.path.exists(self.save_path):
with open(self.save_path, "r", encoding="utf-8") as f:
data = json.load(f)
for indicator_data in data:
indicator = Indicator.from_dict(indicator_data)
self.indicators[indicator.id] = indicator
except Exception as e:
st.error(f"加载指标库失败: {str(e)}")
def save_indicators(self) -> None:
"""保存指标库到文件"""
try:
with open(self.save_path, "w", encoding="utf-8") as f:
json.dump([ind.to_dict() for ind in self.indicators.values()], f, ensure_ascii=False, indent=2)
except Exception as e:
st.error(f"保存指标库失败: {str(e)}")
def add_indicator(self, name: str, query: str = "") -> Indicator:
"""添加新指标"""
if not name.strip():
raise ValueError("指标名称不能为空")
indicator = Indicator(name=name, query=query)
self.indicators[indicator.id] = indicator
self.save_indicators()
return indicator
def update_indicator(self, id: str, name: Optional[str] = None, query: Optional[str] = None) -> Optional[Indicator]:
"""更新指标"""
if id not in self.indicators:
return None
if name is not None and not name.strip():
raise ValueError("指标名称不能为空")
indicator = self.indicators[id]
if name is not None:
indicator.name = name
if query is not None:
indicator.query = query
indicator.updated_at = time.time()
self.save_indicators()
return indicator
def delete_indicator(self, id: str) -> bool:
"""删除指标"""
if id in self.indicators:
del self.indicators[id]
self.save_indicators()
return True
return False
def clear_all_indicators(self) -> None:
"""清除所有指标"""
self.indicators.clear()
self.save_indicators()
def get_indicator(self, id: str) -> Optional[Indicator]:
"""获取指标"""
return self.indicators.get(id)
def get_all_indicators(self) -> List[Indicator]:
"""获取所有指标"""
return list(self.indicators.values())
# 保存会话状态
def ensure_cache_dir():
"""确保cache目录存在"""
cache_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "cache")
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
@st.cache_data(ttl=3600)
def save_session_state():
"""保存会话状态到文件"""
try:
state = {
"current_indicator_id": st.session_state.get("current_indicator_id", None)
}
cache_dir = ensure_cache_dir()
state_file = os.path.join(cache_dir, "indicator_creator_state.json")
with open(state_file, "w", encoding="utf-8") as f:
json.dump(state, f, ensure_ascii=False, indent=2)
except Exception as e:
st.error(f"保存会话状态失败: {str(e)}")
# 加载会话状态
@st.cache_data(ttl=3600)
def load_session_state():
"""从文件加载会话状态"""
try:
cache_dir = ensure_cache_dir()
state_file = os.path.join(cache_dir, "indicator_creator_state.json")
if os.path.exists(state_file):
with open(state_file, "r", encoding="utf-8") as f:
state = json.load(f)
for key, value in state.items():
if key not in st.session_state:
st.session_state[key] = value
except Exception as e:
st.error(f"加载会话状态失败: {str(e)}")
# 测试指标查询
@st.cache_data(ttl=60) # 缓存结果1分钟
def test_indicator_query(query: str) -> str:
"""测试指标查询"""
if not query.strip():
return "查询语句为空,无法测试"
# 模拟异步测试过程
result = "开始测试查询...\n"
try:
# 步骤1:通过调用user_interaction.understand(query)实现
result += "步骤1: 正在理解查询语句...\n"
if "user_interaction" not in st.session_state:
st.session_state.user_interaction = st.session_state.dialog_manager.user_interaction
understanding_result = st.session_state.user_interaction.understand(query)
if not understanding_result:
result += "❌ 无法理解查询语句\n"
return result
# 步骤2:验证步骤一返回值
result += "步骤2: 正在验证理解结果...\n"
selected_knowledge = []
for item in understanding_result:
result += f"- 识别到实体: {item.get('name')}, 约束条件: {item.get('constraints')}\n"
selected_knowledge.append(item)
if not selected_knowledge:
result += "❌ 没有找到相关知识\n"
return result
# 步骤3dialog_manager.generated_code(query, selected_knowledge) 实现
result += "步骤3: 正在生成查询代码...\n"
if "dialog_manager" not in st.session_state:
result += "❌ 对话管理器未初始化\n"
return result
generated_code = st.session_state.dialog_manager.generated_code(query, selected_knowledge)
if not generated_code:
result += "❌ 代码生成失败\n"
return result
result += "✅ 代码生成成功\n"
# 步骤4:通过调用code_executor.execute_code 实现
result += "步骤4: 正在执行查询代码...\n"
if "code_executor" not in st.session_state:
st.session_state.code_executor = st.session_state.dialog_manager.code_executor
execution_result = st.session_state.code_executor.execute_code(generated_code)
if execution_result and "error" not in execution_result.lower():
result += "✅ 查询执行成功!\n"
result += f"查询结果:\n{execution_result}"
else:
result += f"❌ 查询执行失败: {execution_result}\n"
except Exception as e:
import traceback
error_details = traceback.format_exc()
result += f"❌ 查询测试失败: {str(e)}\n"
result += f"错误详情: {error_details}"
return result
@st.cache_resource
def init_app_data():
"""
初始化应用数据,使用缓存资源以避免重复加载
"""
# 加载依赖
deps = load_dependencies()
config = deps["Config"]()
business_structure = deps["load_file"](config.business_object_structure_path)
bowei_api_docs = deps["load_file"](config.bowei_api_docs_path)
llm_client = deps["MultiAPIKeyChatOpenAI"](config.openai)
user_interaction = deps["UserInteraction"](llm_client.llm, business_structure)
llm_client_coder = deps["MultiAPIKeyChatOpenAI"](config.openai_coder)
prompt_manager = deps["PromptManager"]()
neo4j_conf = config.neo4j_conf
embedding_conf = config.embedding
embedding_client = deps["EmbeddingClient"](embedding_conf)
# 创建Neo4j检索器
knowledge_retriever = deps["Neo4jRawRetriever"](neo4j_conf)
deps["ProjectBuilder"].register(deps["ProjectToolkitNeo4j"], knowledge_retriever.driver)
code_executor = deps["CodeExecutor"](prompt_manager.prompts, llm_client_coder, config.max_retries)
dialog_manager = deps["DialogManager"](
llm_client,
business_structure,
bowei_api_docs,
code_executor,
knowledge_retriever,
prompt_manager,
)
# 将必要的对象存储在session_state中
return {
"dialog_manager": dialog_manager,
"user_interaction": user_interaction,
"code_executor": code_executor,
"knowledge_retriever": knowledge_retriever
}
# Streamlit 应用界面
def main():
st.set_page_config(page_title="指标创建器", layout="wide")
# 使用进度条显示加载过程
with st.spinner("正在初始化应用..."):
# 初始化应用数据
app_data = init_app_data()
# 将必要的对象存储在session_state中
for key, value in app_data.items():
st.session_state[key] = value
# 初始化会话状态
if "indicator_manager" not in st.session_state:
st.session_state.indicator_manager = IndicatorManager()
if "current_indicator_id" not in st.session_state:
st.session_state.current_indicator_id = None
# 加载保存的会话状态
load_session_state()
# 侧边栏
with st.sidebar:
st.title("造价工程指标管理器")
# 第一个expander:指标管理
with st.expander("新建指标工具", expanded=True):
# 显示清除指标链接
col1, col2, col3 = st.columns([2, 1, 1])
with col1:
st.subheader("指标列表")
with col2:
# 新建指标按钮
if st.button("新建"):
try:
# 生成新指标名称
indicator_count = len(st.session_state.indicator_manager.get_all_indicators())
new_name = f"指标 {indicator_count + 1}"
# 创建新指标
new_indicator = st.session_state.indicator_manager.add_indicator(name=new_name)
st.session_state.current_indicator_id = new_indicator.id
save_session_state()
st.rerun()
except ValueError as e:
st.error(str(e))
with col3:
if st.button("清空", help="删除所有指标", type="secondary"):
st.session_state.indicator_manager.clear_all_indicators()
st.session_state.current_indicator_id = None
save_session_state()
st.rerun()
# 显示指标列表
indicators = st.session_state.indicator_manager.get_all_indicators()
if not indicators:
st.info("暂无指标,请点击\"新建指标\"按钮创建")
else:
for indicator in sorted(indicators, key=lambda x: x.created_at):
if st.button(
indicator.name,
key=f"btn_{indicator.id}",
use_container_width=True,
type="primary" if st.session_state.current_indicator_id == indicator.id else "secondary"
):
st.session_state.current_indicator_id = indicator.id
save_session_state()
st.rerun()
# 第二个expander:指标库浏览
with st.expander("查看指标库", expanded=False):
# 设置默认文件路径
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tests", "code.jsonl")
# 文件路径输入
#file_path = st.text_input("JSONL文件路径", value=default_file_path, key="jsonl_path")
# 加载JSONL文件的函数
@st.cache_data
def load_jsonl(file_path):
"""加载JSONL文件并返回JSON指标列表"""
records = []
try:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
if line.strip():
records.append(json.loads(line))
return records
except Exception as e:
st.error(f"加载文件失败: {str(e)}")
return []
# 检查文件是否存在
if not os.path.exists(file_path):
st.warning(f"文件不存在: {file_path}")
else:
# 加载JSONL文件
records = load_jsonl(file_path)
if not records:
st.warning("没有找到有效的指标")
else:
# 初始化会话状态
if "selected_index" not in st.session_state:
st.session_state.selected_index = 0
# 显示指标列表
st.subheader("指标列表")
for i, record in enumerate(records):
if st.button(record.get("name", f"指标 {i+1}"), key=f"lib_btn_{i}", use_container_width=True):
st.session_state.selected_index = i
# 可以添加查看指标详情的逻辑
st.session_state.view_indicator_detail = True
st.session_state.current_library_indicator = record
# 主区域
if "view_indicator_detail" in st.session_state and st.session_state.view_indicator_detail and "current_library_indicator" in st.session_state:
# 显示从指标库中选择的指标详细信息
library_indicator = st.session_state.current_library_indicator
st.header(f"指标名称: {library_indicator.get('name', '未命名指标')}")
# 显示查询语句
st.subheader("查询语句")
st.info(library_indicator.get('query', '无查询语句'))
# 显示代码
st.subheader("指标代码")
st.code(library_indicator.get('code', '无代码'), language='python')
# 添加执行按钮
if st.button("执行指标代码", key="execute_lib_indicator"):
with st.spinner('正在执行代码...'):
# 确保code_executor已初始化
if "code_executor" in st.session_state:
result = st.session_state.code_executor.execute_code(library_indicator.get('code', ''))
if result and hasattr(result, 'get') and result.get('status'):
st.success("执行成功")
st.json(result.get('data', {}))
else:
st.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}")
else:
st.error("代码执行器未初始化")
# 添加导入按钮 - 将库中指标导入到当前工作区
if st.button("导入到我的指标", key="import_to_my_indicators"):
try:
new_indicator = st.session_state.indicator_manager.add_indicator(
name=library_indicator.get('name', '导入的指标'),
query=library_indicator.get('query', '')
)
st.session_state.current_indicator_id = new_indicator.id
st.session_state.view_indicator_detail = False
save_session_state()
st.success(f"已导入指标: {new_indicator.name}")
st.rerun()
except Exception as e:
st.error(f"导入失败: {str(e)}")
# 添加返回按钮
if st.button("返回", key="back_to_indicators"):
st.session_state.view_indicator_detail = False
st.rerun()
elif st.session_state.current_indicator_id is not None:
current_indicator = st.session_state.indicator_manager.get_indicator(st.session_state.current_indicator_id)
if current_indicator:
# 指标名称和保存按钮
col1, col2, col3 = st.columns([1, 5, 2])
with col1:
st.markdown('<div><b style="font-size: 24px;">指标名称:</b></div>', unsafe_allow_html=True)
with col2:
new_name = st.text_input("", value=current_indicator.name, key="indicator_name", label_visibility="collapsed")
if new_name != current_indicator.name:
if not new_name.strip():
st.error("指标名称不能为空")
else:
try:
st.session_state.indicator_manager.update_indicator(
current_indicator.id, name=new_name
)
except ValueError as e:
st.error(str(e))
with col3:
if st.button("💾 保存指标", use_container_width=True):
if not new_name.strip():
st.error("指标名称不能为空")
else:
try:
st.toast("指标已保存")
st.session_state.indicator_manager.save_indicators()
except Exception as e:
st.error(f"保存失败: {str(e)}")
# 查询语句
st.subheader("查询语句")
col1, col2 = st.columns([6, 2])
with col1:
query = st.text_area(
"输入查询语句",
value=current_indicator.query,
height=70,
key="query_input",
placeholder="在此输入查询语句...",
label_visibility="collapsed"
)
if query != current_indicator.query:
st.session_state.indicator_manager.update_indicator(
current_indicator.id, query=query
)
with col2:
st.write("") # 添加一些空间以对齐按钮
st.write("")
if st.button("🧪 测试", use_container_width=True):
with st.spinner("正在测试查询..."):
# 使用缓存的测试函数
result = test_indicator_query(query)
st.session_state.test_result = result
# 测试结果
if "test_result" in st.session_state:
st.subheader("测试结果")
st.code(st.session_state.test_result)
else:
st.warning("找不到当前指标,可能已被删除")
else:
st.info("请在左侧创建或选择一个指标")
if __name__ == "__main__":
main()