Files
langchain_projectagent/main_streamlt.py
T
2025-07-07 08:51:19 +08:00

579 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import json
import streamlit as st
import uuid
import time
from typing import Dict, List, Optional, Any
import logging
import sys
from datetime import datetime
# 添加项目根目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 配置日志
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
# 确保日志目录存在
os.makedirs("logs", exist_ok=True)
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
def setup_logger(logger_name):
"""
设置指定名称的logger,将其级别设置为WARNING并禁用传播
:param logger_name: logger的名称
"""
logger = logging.getLogger(logger_name)
logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
logger.propagate = False # 可选:禁用传播(防止被根logger处理)
return logger
logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"]
for name in logger_names:
setup_logger(name)
# 延迟导入依赖,只有在需要时才加载
@st.cache_resource
def load_dependencies():
"""
延迟加载依赖项,并缓存资源以提高性能
"""
from src.config import Config
from src.document_loader import load_file
from src.multi_llm_client import MultiAPIKeyChatOpenAI
from src.user_interaction import UserInteraction
from src.dialog_manager import DialogManager
from src.code_executor import CodeExecutor
from src.neo4j_raw_retriever import Neo4jRawRetriever
from src.prompt_manager import PromptManager
from src.embedding_client import EmbeddingClient
from src.project import ProjectBuilder, ProjectToolkit
from src.project_implementation import ProjectToolkitNeo4j
return {
"Config": Config,
"load_file": load_file,
"MultiAPIKeyChatOpenAI": MultiAPIKeyChatOpenAI,
"UserInteraction": UserInteraction,
"DialogManager": DialogManager,
"CodeExecutor": CodeExecutor,
"Neo4jRawRetriever": Neo4jRawRetriever,
"PromptManager": PromptManager,
"EmbeddingClient": EmbeddingClient,
"ProjectBuilder": ProjectBuilder,
"ProjectToolkitNeo4j": ProjectToolkitNeo4j
}
# 后台数据结构
class Indicator:
def __init__(self, name: str, query: str = "", id: Optional[str] = None):
self.id = id if id else str(uuid.uuid4())
self.name = name
self.query = query
self.created_at = time.time()
self.updated_at = time.time()
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"query": self.query,
"created_at": self.created_at,
"updated_at": self.updated_at
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Indicator':
indicator = cls(
name=data["name"],
query=data["query"],
id=data["id"]
)
indicator.created_at = data.get("created_at", time.time())
indicator.updated_at = data.get("updated_at", time.time())
return indicator
class IndicatorManager:
def __init__(self, save_path: str = "indicator_library.json"):
self.indicators: Dict[str, Indicator] = {}
self.save_path = save_path
self.load_indicators()
def load_indicators(self) -> None:
"""从文件加载指标库"""
try:
if os.path.exists(self.save_path):
with open(self.save_path, "r", encoding="utf-8") as f:
data = json.load(f)
for indicator_data in data:
indicator = Indicator.from_dict(indicator_data)
self.indicators[indicator.id] = indicator
except Exception as e:
st.error(f"加载指标库失败: {str(e)}")
def save_indicators(self) -> None:
"""保存指标库到文件"""
try:
with open(self.save_path, "w", encoding="utf-8") as f:
json.dump([ind.to_dict() for ind in self.indicators.values()], f, ensure_ascii=False, indent=2)
except Exception as e:
st.error(f"保存指标库失败: {str(e)}")
def add_indicator(self, name: str, query: str = "") -> Indicator:
"""添加新指标"""
if not name.strip():
raise ValueError("指标名称不能为空")
indicator = Indicator(name=name, query=query)
self.indicators[indicator.id] = indicator
self.save_indicators()
return indicator
def update_indicator(self, id: str, name: Optional[str] = None, query: Optional[str] = None) -> Optional[Indicator]:
"""更新指标"""
if id not in self.indicators:
return None
if name is not None and not name.strip():
raise ValueError("指标名称不能为空")
indicator = self.indicators[id]
if name is not None:
indicator.name = name
if query is not None:
indicator.query = query
indicator.updated_at = time.time()
self.save_indicators()
return indicator
def delete_indicator(self, id: str) -> bool:
"""删除指标"""
if id in self.indicators:
del self.indicators[id]
self.save_indicators()
return True
return False
def clear_all_indicators(self) -> None:
"""清除所有指标"""
self.indicators.clear()
self.save_indicators()
def get_indicator(self, id: str) -> Optional[Indicator]:
"""获取指标"""
return self.indicators.get(id)
def get_all_indicators(self) -> List[Indicator]:
"""获取所有指标"""
return list(self.indicators.values())
# 保存会话状态
def ensure_cache_dir():
"""确保cache目录存在"""
cache_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "cache")
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
@st.cache_data(ttl=3600)
def save_session_state():
"""保存会话状态到文件"""
try:
state = {
"current_indicator_id": st.session_state.get("current_indicator_id", None)
}
cache_dir = ensure_cache_dir()
state_file = os.path.join(cache_dir, "indicator_creator_state.json")
with open(state_file, "w", encoding="utf-8") as f:
json.dump(state, f, ensure_ascii=False, indent=2)
except Exception as e:
st.error(f"保存会话状态失败: {str(e)}")
# 加载会话状态
@st.cache_data(ttl=3600)
def load_session_state():
"""从文件加载会话状态"""
try:
cache_dir = ensure_cache_dir()
state_file = os.path.join(cache_dir, "indicator_creator_state.json")
if os.path.exists(state_file):
with open(state_file, "r", encoding="utf-8") as f:
state = json.load(f)
for key, value in state.items():
if key not in st.session_state:
st.session_state[key] = value
except Exception as e:
st.error(f"加载会话状态失败: {str(e)}")
# 测试指标查询
@st.cache_data(ttl=60) # 缓存结果1分钟
def test_indicator_query(query: str) -> str:
"""测试指标查询"""
if not query.strip():
return "查询语句为空,无法测试"
# 模拟异步测试过程
result = "开始测试查询...\n"
try:
# 步骤1:通过调用user_interaction.understand(query)实现
result += "步骤1: 正在理解查询语句...\n"
if "user_interaction" not in st.session_state:
st.session_state.user_interaction = st.session_state.dialog_manager.user_interaction
understanding_result = st.session_state.user_interaction.understand(query)
if not understanding_result:
result += "❌ 无法理解查询语句\n"
return result
# 步骤2:验证步骤一返回值
result += "步骤2: 正在验证理解结果...\n"
selected_knowledge = []
for item in understanding_result:
result += f"- 识别到实体: {item.get('name')}, 约束条件: {item.get('constraints')}\n"
selected_knowledge.append(item)
if not selected_knowledge:
result += "❌ 没有找到相关知识\n"
return result
# 步骤3dialog_manager.generated_code(query, selected_knowledge) 实现
result += "步骤3: 正在生成查询代码...\n"
if "dialog_manager" not in st.session_state:
result += "❌ 对话管理器未初始化\n"
return result
generated_code = st.session_state.dialog_manager.generated_code(query, selected_knowledge)
if not generated_code:
result += "❌ 代码生成失败\n"
return result
result += "✅ 代码生成成功\n"
# 步骤4:通过调用code_executor.execute_code 实现
result += "步骤4: 正在执行查询代码...\n"
if "code_executor" not in st.session_state:
st.session_state.code_executor = st.session_state.dialog_manager.code_executor
execution_result = st.session_state.code_executor.execute_code(generated_code)
if execution_result and "error" not in execution_result.lower():
result += "✅ 查询执行成功!\n"
result += f"查询结果:\n{execution_result}"
else:
result += f"❌ 查询执行失败: {execution_result}\n"
except Exception as e:
import traceback
error_details = traceback.format_exc()
result += f"❌ 查询测试失败: {str(e)}\n"
result += f"错误详情: {error_details}"
return result
@st.cache_resource
def init_app_data():
"""
初始化应用数据,使用缓存资源以避免重复加载
"""
# 加载依赖
deps = load_dependencies()
config = deps["Config"]()
business_structure = deps["load_file"](config.business_object_structure_path)
bowei_api_docs = deps["load_file"](config.bowei_api_docs_path)
llm_client = deps["MultiAPIKeyChatOpenAI"](config.openai)
user_interaction = deps["UserInteraction"](llm_client.llm, business_structure)
llm_client_coder = deps["MultiAPIKeyChatOpenAI"](config.openai_coder)
prompt_manager = deps["PromptManager"]()
neo4j_conf = config.neo4j_conf
embedding_conf = config.embedding
embedding_client = deps["EmbeddingClient"](embedding_conf)
# 创建Neo4j检索器
knowledge_retriever = deps["Neo4jRawRetriever"](neo4j_conf)
deps["ProjectBuilder"].register(deps["ProjectToolkitNeo4j"], knowledge_retriever.driver)
code_executor = deps["CodeExecutor"](prompt_manager.prompts, llm_client_coder, config.max_retries)
dialog_manager = deps["DialogManager"](
llm_client,
business_structure,
bowei_api_docs,
code_executor,
knowledge_retriever,
prompt_manager,
)
# 将必要的对象存储在session_state中
return {
"dialog_manager": dialog_manager,
"user_interaction": user_interaction,
"code_executor": code_executor,
"knowledge_retriever": knowledge_retriever
}
# Streamlit 应用界面
def main():
st.set_page_config(page_title="指标创建器", layout="wide")
# 使用进度条显示加载过程
with st.spinner("正在初始化应用..."):
# 初始化应用数据
app_data = init_app_data()
# 将必要的对象存储在session_state中
for key, value in app_data.items():
st.session_state[key] = value
# 初始化会话状态
if "indicator_manager" not in st.session_state:
st.session_state.indicator_manager = IndicatorManager()
if "current_indicator_id" not in st.session_state:
st.session_state.current_indicator_id = None
# 加载保存的会话状态
load_session_state()
# 侧边栏
with st.sidebar:
st.title("造价工程指标管理器")
# 第一个expander:指标管理
with st.expander("新建指标工具", expanded=True):
# 显示清除指标链接
col1, col2, col3 = st.columns([2, 1, 1])
with col1:
st.subheader("指标列表")
with col2:
# 新建指标按钮
if st.button("新建"):
try:
# 生成新指标名称
indicator_count = len(st.session_state.indicator_manager.get_all_indicators())
new_name = f"指标 {indicator_count + 1}"
# 创建新指标
new_indicator = st.session_state.indicator_manager.add_indicator(name=new_name)
st.session_state.current_indicator_id = new_indicator.id
save_session_state()
st.rerun()
except ValueError as e:
st.error(str(e))
with col3:
if st.button("清空", help="删除所有指标", type="secondary"):
st.session_state.indicator_manager.clear_all_indicators()
st.session_state.current_indicator_id = None
save_session_state()
st.rerun()
# 显示指标列表
indicators = st.session_state.indicator_manager.get_all_indicators()
if not indicators:
st.info("暂无指标,请点击\"新建指标\"按钮创建")
else:
for indicator in sorted(indicators, key=lambda x: x.created_at):
if st.button(
indicator.name,
key=f"btn_{indicator.id}",
use_container_width=True,
type="primary" if st.session_state.current_indicator_id == indicator.id else "secondary"
):
st.session_state.current_indicator_id = indicator.id
save_session_state()
st.rerun()
# 第二个expander:指标库浏览
with st.expander("查看指标库", expanded=False):
# 设置默认文件路径
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tests", "code.jsonl")
# 文件路径输入
#file_path = st.text_input("JSONL文件路径", value=default_file_path, key="jsonl_path")
# 加载JSONL文件的函数
@st.cache_data
def load_jsonl(file_path):
"""加载JSONL文件并返回JSON指标列表"""
records = []
try:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
if line.strip():
records.append(json.loads(line))
return records
except Exception as e:
st.error(f"加载文件失败: {str(e)}")
return []
# 检查文件是否存在
if not os.path.exists(file_path):
st.warning(f"文件不存在: {file_path}")
else:
# 加载JSONL文件
records = load_jsonl(file_path)
if not records:
st.warning("没有找到有效的指标")
else:
# 初始化会话状态
if "selected_index" not in st.session_state:
st.session_state.selected_index = 0
# 显示指标列表
st.subheader("指标列表")
for i, record in enumerate(records):
if st.button(record.get("name", f"指标 {i+1}"), key=f"lib_btn_{i}", use_container_width=True):
st.session_state.selected_index = i
# 可以添加查看指标详情的逻辑
st.session_state.view_indicator_detail = True
st.session_state.current_library_indicator = record
# 主区域
if "view_indicator_detail" in st.session_state and st.session_state.view_indicator_detail and "current_library_indicator" in st.session_state:
# 显示从指标库中选择的指标详细信息
library_indicator = st.session_state.current_library_indicator
st.header(f"指标名称: {library_indicator.get('name', '未命名指标')}")
# 显示查询语句
st.subheader("查询语句")
st.info(library_indicator.get('query', '无查询语句'))
# 显示代码
st.subheader("指标代码")
st.code(library_indicator.get('code', '无代码'), language='python')
# 添加执行按钮
if st.button("执行指标代码", key="execute_lib_indicator"):
with st.spinner('正在执行代码...'):
# 确保code_executor已初始化
if "code_executor" in st.session_state:
result = st.session_state.code_executor.execute_code(library_indicator.get('code', ''))
if result and hasattr(result, 'get') and result.get('status'):
st.success("执行成功")
st.json(result.get('data', {}))
else:
st.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}")
else:
st.error("代码执行器未初始化")
# 添加导入按钮 - 将库中指标导入到当前工作区
if st.button("导入到我的指标", key="import_to_my_indicators"):
try:
new_indicator = st.session_state.indicator_manager.add_indicator(
name=library_indicator.get('name', '导入的指标'),
query=library_indicator.get('query', '')
)
st.session_state.current_indicator_id = new_indicator.id
st.session_state.view_indicator_detail = False
save_session_state()
st.success(f"已导入指标: {new_indicator.name}")
st.rerun()
except Exception as e:
st.error(f"导入失败: {str(e)}")
# 添加返回按钮
if st.button("返回", key="back_to_indicators"):
st.session_state.view_indicator_detail = False
st.rerun()
elif st.session_state.current_indicator_id is not None:
current_indicator = st.session_state.indicator_manager.get_indicator(st.session_state.current_indicator_id)
if current_indicator:
# 指标名称和保存按钮
col1, col2, col3 = st.columns([1, 5, 2])
with col1:
st.markdown('<div><b style="font-size: 24px;">指标名称:</b></div>', unsafe_allow_html=True)
with col2:
new_name = st.text_input("", value=current_indicator.name, key="indicator_name", label_visibility="collapsed")
if new_name != current_indicator.name:
if not new_name.strip():
st.error("指标名称不能为空")
else:
try:
st.session_state.indicator_manager.update_indicator(
current_indicator.id, name=new_name
)
except ValueError as e:
st.error(str(e))
with col3:
if st.button("💾 保存指标", use_container_width=True):
if not new_name.strip():
st.error("指标名称不能为空")
else:
try:
st.toast("指标已保存")
st.session_state.indicator_manager.save_indicators()
except Exception as e:
st.error(f"保存失败: {str(e)}")
# 查询语句
st.subheader("查询语句")
col1, col2 = st.columns([6, 2])
with col1:
query = st.text_area(
"输入查询语句",
value=current_indicator.query,
height=70,
key="query_input",
placeholder="在此输入查询语句...",
label_visibility="collapsed"
)
if query != current_indicator.query:
st.session_state.indicator_manager.update_indicator(
current_indicator.id, query=query
)
with col2:
st.write("") # 添加一些空间以对齐按钮
st.write("")
if st.button("🧪 测试", use_container_width=True):
with st.spinner("正在测试查询..."):
# 使用缓存的测试函数
result = test_indicator_query(query)
st.session_state.test_result = result
# 测试结果
if "test_result" in st.session_state:
st.subheader("测试结果")
st.code(st.session_state.test_result)
else:
st.warning("找不到当前指标,可能已被删除")
else:
st.info("请在左侧创建或选择一个指标")
if __name__ == "__main__":
main()