Files
langchain_projectagent/main_streamlt.py
T
2025-07-07 09:32:05 +08:00

685 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import streamlit as st
import json
import os
import uuid
import time
from typing import Dict, List, Optional, Any
import logging
import sys
from datetime import datetime
# 添加项目根目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 确保必要的目录存在
CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache")
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)
# 配置日志
current_file = os.path.splitext(os.path.basename(__file__))[0]
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"{current_file}_{now_str}.log"
# 确保日志目录存在并设置在logs目录下
log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
os.makedirs(log_dir, exist_ok=True)
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join(log_dir, log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger(current_file)
def setup_logger(logger_name):
"""
设置指定名称的logger,将其级别设置为WARNING并禁用传播
:param logger_name: logger的名称
"""
logger = logging.getLogger(logger_name)
logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
logger.propagate = False # 可选:禁用传播(防止被根logger处理)
return logger
logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"]
for name in logger_names:
setup_logger(name)
# 延迟导入依赖,只有在需要时才加载
@st.cache_resource
def load_dependencies():
"""
延迟加载依赖项,并缓存资源以提高性能
"""
from src.config import Config
from src.document_loader import load_file
from src.multi_llm_client import MultiAPIKeyChatOpenAI
from src.user_interaction import UserInteraction
from src.dialog_manager import DialogManager
from src.code_executor import CodeExecutor
from src.neo4j_raw_retriever import Neo4jRawRetriever
from src.prompt_manager import PromptManager
from src.embedding_client import EmbeddingClient
from src.project import ProjectBuilder, ProjectToolkit
from src.project_implementation import ProjectToolkitNeo4j
return {
"Config": Config,
"load_file": load_file,
"MultiAPIKeyChatOpenAI": MultiAPIKeyChatOpenAI,
"UserInteraction": UserInteraction,
"DialogManager": DialogManager,
"CodeExecutor": CodeExecutor,
"Neo4jRawRetriever": Neo4jRawRetriever,
"PromptManager": PromptManager,
"EmbeddingClient": EmbeddingClient,
"ProjectBuilder": ProjectBuilder,
"ProjectToolkitNeo4j": ProjectToolkitNeo4j
}
# 后台数据结构
class Indicator:
def __init__(self, name: str, query: str = "", code: str = "", id: Optional[str] = None):
self.id = id if id else str(uuid.uuid4())
self.name = name
self.query = query
self.code = code
self.created_at = time.time()
self.updated_at = time.time()
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"query": self.query,
"code": self.code,
"created_at": self.created_at,
"updated_at": self.updated_at
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Indicator':
indicator = cls(
name=data["name"],
query=data["query"],
code=data.get("code", ""), # 使用 get 方法处理可能不存在的 code 字段
id=data["id"]
)
indicator.created_at = data.get("created_at", time.time())
indicator.updated_at = data.get("updated_at", time.time())
return indicator
class IndicatorManager:
def __init__(self, save_path: Optional[str] = None):
if save_path is None:
save_path = os.path.join(CACHE_DIR, "indicator_library.json")
self.indicators: Dict[str, Indicator] = {}
self.save_path = save_path
self.load_indicators()
def load_indicators(self) -> None:
"""从文件加载指标库"""
try:
if os.path.exists(self.save_path):
with open(self.save_path, "r", encoding="utf-8") as f:
data = json.load(f)
for indicator_data in data:
indicator = Indicator.from_dict(indicator_data)
self.indicators[indicator.id] = indicator
except Exception as e:
st.error(f"加载指标库失败: {str(e)}")
def save_indicators(self) -> None:
"""保存指标库到文件"""
try:
with open(self.save_path, "w", encoding="utf-8") as f:
json.dump([ind.to_dict() for ind in self.indicators.values()], f, ensure_ascii=False, indent=2)
except Exception as e:
st.error(f"保存指标库失败: {str(e)}")
def add_indicator(self, name: str, query: str = "") -> Indicator:
"""添加新指标"""
if not name.strip():
raise ValueError("指标名称不能为空")
indicator = Indicator(name=name, query=query)
self.indicators[indicator.id] = indicator
self.save_indicators()
return indicator
def update_indicator(self, id: str, name: Optional[str] = None, query: Optional[str] = None, code: Optional[str] = None) -> Optional[Indicator]:
"""更新指标"""
if id not in self.indicators:
return None
if name is not None and not name.strip():
raise ValueError("指标名称不能为空")
indicator = self.indicators[id]
if name is not None:
indicator.name = name
if query is not None:
indicator.query = query
if code is not None:
indicator.code = code
indicator.updated_at = time.time()
self.save_indicators()
return indicator
def delete_indicator(self, id: str) -> bool:
"""删除指标"""
if id in self.indicators:
del self.indicators[id]
self.save_indicators()
return True
return False
def clear_all_indicators(self) -> None:
"""清除所有指标"""
self.indicators.clear()
self.save_indicators()
def get_indicator(self, id: str) -> Optional[Indicator]:
"""获取指标"""
return self.indicators.get(id)
def get_all_indicators(self) -> List[Indicator]:
"""获取所有指标"""
return list(self.indicators.values())
# 保存会话状态
@st.cache_data(ttl=3600)
def save_session_state():
"""保存会话状态到文件"""
try:
state = {
"current_indicator_id": st.session_state.get("current_indicator_id", None)
}
state_file = os.path.join(CACHE_DIR, "indicator_creator_state.json")
with open(state_file, "w", encoding="utf-8") as f:
json.dump(state, f, ensure_ascii=False, indent=2)
except Exception as e:
st.error(f"保存会话状态失败: {str(e)}")
# 加载会话状态
@st.cache_data(ttl=3600)
def load_session_state():
"""从文件加载会话状态"""
try:
state_file = os.path.join(CACHE_DIR, "indicator_creator_state.json")
if os.path.exists(state_file):
with open(state_file, "r", encoding="utf-8") as f:
state = json.load(f)
for key, value in state.items():
if key not in st.session_state:
st.session_state[key] = value
except Exception as e:
st.error(f"加载会话状态失败: {str(e)}")
# 测试指标查询
@st.cache_data(ttl=60) # 缓存结果1分钟
def generate_indicator_code(query: str) -> str:
"""生成指标代码"""
if not query.strip():
return "查询语句为空,无法生成代码"
result = ""
try:
# 步骤1:通过调用user_interaction.understand(query)实现
result = "正在理解查询语句..."
if "user_interaction" not in st.session_state:
st.session_state.user_interaction = st.session_state.dialog_manager.user_interaction
understanding_result = st.session_state.user_interaction.understand(query)
if not understanding_result:
return "❌ 无法理解查询语句"
# 步骤2:验证步骤一返回值
result = "正在验证理解结果..."
selected_knowledge = []
entity_info = []
for item in understanding_result:
entity_info.append(f"识别到实体: {item.get('name')}, 约束条件: {item.get('constraints')}")
selected_knowledge.append(item)
if not selected_knowledge:
return "❌ 没有找到相关知识"
# 显示识别到的实体信息
result = "\n".join(entity_info)
# 步骤3dialog_manager.generated_code(query, selected_knowledge) 实现
result = "正在生成指标代码..."
if "dialog_manager" not in st.session_state:
return "❌ 对话管理器未初始化"
generated_result = st.session_state.dialog_manager.generated_code(query, selected_knowledge)
# 处理返回结构
if not generated_result:
return "❌ 代码生成失败"
# 检查返回的是否是字典结构
if isinstance(generated_result, dict):
if not generated_result.get('status', False):
return f"❌ 代码生成失败: {generated_result.get('message', '未知错误')}"
# 提取代码内容
code = generated_result.get('data', '')
if not code:
return "❌ 生成的代码为空"
# 将生成的代码保存到session_state中
st.session_state.generated_code = code
else:
# 如果直接返回的是代码字符串
st.session_state.generated_code = generated_result
if "current_indicator_id" in st.session_state:
st.session_state.current_query_id = st.session_state.current_indicator_id
return "✅ 指标代码生成成功,可以在下方查看并执行"
except Exception as e:
import traceback
error_details = traceback.format_exc()
return f"❌ 指标代码生成失败: {str(e)}"
return result
@st.cache_resource
def init_app_data():
"""
初始化应用数据,使用缓存资源以避免重复加载
"""
# 加载依赖
deps = load_dependencies()
config = deps["Config"]()
business_structure = deps["load_file"](config.business_object_structure_path)
bowei_api_docs = deps["load_file"](config.bowei_api_docs_path)
llm_client = deps["MultiAPIKeyChatOpenAI"](config.openai)
user_interaction = deps["UserInteraction"](llm_client.llm, business_structure)
llm_client_coder = deps["MultiAPIKeyChatOpenAI"](config.openai_coder)
prompt_manager = deps["PromptManager"]()
neo4j_conf = config.neo4j_conf
embedding_conf = config.embedding
embedding_client = deps["EmbeddingClient"](embedding_conf)
# 创建Neo4j检索器
knowledge_retriever = deps["Neo4jRawRetriever"](neo4j_conf)
deps["ProjectBuilder"].register(deps["ProjectToolkitNeo4j"], knowledge_retriever.driver)
code_executor = deps["CodeExecutor"](prompt_manager.prompts, llm_client_coder, config.max_retries)
dialog_manager = deps["DialogManager"](
llm_client,
business_structure,
bowei_api_docs,
code_executor,
knowledge_retriever,
prompt_manager,
)
# 将必要的对象存储在session_state中
return {
"dialog_manager": dialog_manager,
"user_interaction": user_interaction,
"code_executor": code_executor,
"knowledge_retriever": knowledge_retriever
}
# Streamlit 应用界面
def main():
st.set_page_config(page_title="造价工程指标管理器", layout="wide")
# 使用进度条显示加载过程
with st.spinner("正在初始化应用..."):
# 初始化应用数据
app_data = init_app_data()
# 将必要的对象存储在session_state中
for key, value in app_data.items():
st.session_state[key] = value
# 初始化会话状态
if "indicator_manager" not in st.session_state:
st.session_state.indicator_manager = IndicatorManager()
if "current_indicator_id" not in st.session_state:
st.session_state.current_indicator_id = None
# 加载保存的会话状态
load_session_state()
# 侧边栏
with st.sidebar:
# 第一个expander:指标管理
with st.expander("新建指标工具", expanded=True):
# 显示清除指标链接
col1, col2, col3 = st.columns([2, 1, 1])
with col1:
st.subheader("指标列表")
with col2:
# 新建指标按钮
if st.button("新建"):
try:
# 生成新指标名称
indicator_count = len(st.session_state.indicator_manager.get_all_indicators())
new_name = f"指标 {indicator_count + 1}"
# 创建新指标
new_indicator = st.session_state.indicator_manager.add_indicator(name=new_name)
st.session_state.current_indicator_id = new_indicator.id
# 清除已生成的代码
if "generated_code" in st.session_state:
del st.session_state.generated_code
# 清除第二个expander的状态
st.session_state.view_indicator_detail = False
save_session_state()
st.rerun()
except ValueError as e:
st.error(str(e))
with col3:
if st.button("清空", help="删除所有指标", type="secondary"):
st.session_state.indicator_manager.clear_all_indicators()
st.session_state.current_indicator_id = None
# 清除已生成的代码
if "generated_code" in st.session_state:
del st.session_state.generated_code
# 清除第二个expander的状态
st.session_state.view_indicator_detail = False
save_session_state()
st.rerun()
# 显示指标列表
indicators = st.session_state.indicator_manager.get_all_indicators()
if not indicators:
st.info("暂无指标,请点击\"新建指标\"按钮创建")
else:
for indicator in sorted(indicators, key=lambda x: x.created_at):
if st.button(
indicator.name,
key=f"btn_{indicator.id}",
use_container_width=True,
type="primary" if st.session_state.current_indicator_id == indicator.id else "secondary"
):
st.session_state.current_indicator_id = indicator.id
# 更新显示的代码
if indicator.code:
st.session_state.generated_code = indicator.code
elif "generated_code" in st.session_state:
# 如果当前指标没有代码,清除session中的代码
del st.session_state.generated_code
# 清除第二个expander的状态
st.session_state.view_indicator_detail = False
save_session_state()
st.rerun()
# 第二个expander:指标库浏览
with st.expander("查看指标库", expanded=False):
# 设置默认文件路径
file_path = os.path.join(DATA_DIR, "code.jsonl")
# 文件路径输入
#file_path = st.text_input("JSONL文件路径", value=default_file_path, key="jsonl_path")
# 加载JSONL文件的函数
@st.cache_data
def load_jsonl(file_path):
"""加载JSONL文件并返回JSON指标列表"""
records = []
try:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
if line.strip():
records.append(json.loads(line))
return records
except Exception as e:
st.error(f"加载文件失败: {str(e)}")
return []
# 检查文件是否存在
if not os.path.exists(file_path):
st.warning(f"文件不存在: {file_path}")
else:
# 加载JSONL文件
records = load_jsonl(file_path)
if not records:
st.warning("没有找到有效的指标")
else:
# 初始化会话状态
if "selected_index" not in st.session_state:
st.session_state.selected_index = 0
# 显示指标列表
st.subheader("指标列表")
for i, record in enumerate(records):
if st.button(record.get("name", f"指标 {i+1}"), key=f"lib_btn_{i}", use_container_width=True):
st.session_state.selected_index = i
# 清除第一个expander的选中状态
st.session_state.current_indicator_id = None
# 设置查看指标详情的逻辑
st.session_state.view_indicator_detail = True
st.session_state.current_library_indicator = record
save_session_state()
st.rerun()
# 主区域
if "view_indicator_detail" in st.session_state and st.session_state.view_indicator_detail and "current_library_indicator" in st.session_state:
# 显示从指标库中选择的指标详细信息
library_indicator = st.session_state.current_library_indicator
# 指标名称和操作按钮
col1, col2, col3 = st.columns([1, 4, 2])
with col1:
st.markdown('<div><b style="font-size: 24px;">指标名称:</b></div>', unsafe_allow_html=True)
with col2:
st.markdown(f'<div style="font-size: 20px; padding-top: 4px;">{library_indicator.get("name", "未命名指标")}</div>', unsafe_allow_html=True)
# 查询语句部分
title_col, run_btn_col, result_col = st.columns([1, 1, 4])
with title_col:
st.subheader("查询语句")
with result_col:
title_query_result = st.subheader("")
# 显示查询语句
st.text_area(
"查询语句",
value=library_indicator.get('query', '无查询语句'),
height=70,
key="lib_query_display",
disabled=True,
label_visibility="collapsed"
)
# 指标代码部分
title_col, run_btn_col, result_col = st.columns([1, 1, 4])
with title_col:
st.subheader("指标代码")
with result_col:
title_code_result = st.subheader("")
with run_btn_col:
if st.button("执行代码", key="execute_lib_indicator", use_container_width=True):
with st.spinner('正在执行代码...'):
# 确保code_executor已初始化
if "code_executor" in st.session_state:
result = st.session_state.code_executor.execute_code(library_indicator.get('code', ''))
if result and hasattr(result, 'get') and result.get('status'):
title_code_result.success("执行成功")
title_code_result.json(result.get('data', {}))
else:
title_code_result.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}")
else:
title_code_result.error("代码执行器未初始化")
# 显示代码
st.code(library_indicator.get('code', '无代码'), language='python')
# 移除导入按钮部分
elif st.session_state.current_indicator_id is not None:
current_indicator = st.session_state.indicator_manager.get_indicator(st.session_state.current_indicator_id)
if current_indicator:
# 指标名称和保存按钮
col1, col2, col3 = st.columns([1, 4, 2])
with col1:
st.markdown('<div><b style="font-size: 24px;">指标名称:</b></div>', unsafe_allow_html=True)
with col2:
new_name = st.text_input("指标名称", value=current_indicator.name, key="indicator_name", label_visibility="collapsed")
if new_name != current_indicator.name:
if not new_name.strip():
st.error("指标名称不能为空")
else:
try:
st.session_state.indicator_manager.update_indicator(
current_indicator.id, name=new_name
)
except ValueError as e:
st.error(str(e))
with col3:
if st.button("💾 保存指标", use_container_width=True):
if not new_name.strip():
st.error("指标名称不能为空")
elif not current_indicator.query.strip():
st.error("请先输入查询语句")
elif not current_indicator.code:
st.error("请先生成指标代码")
else:
try:
# 保存到指标管理器
st.session_state.indicator_manager.save_indicators()
# 同时保存到指标库
library_path = os.path.join(DATA_DIR, "code.jsonl")
new_record = {
"name": current_indicator.name,
"query": current_indicator.query,
"code": current_indicator.code,
"created_at": time.time()
}
# 追加到文件
with open(library_path, "a", encoding="utf-8") as f:
f.write(json.dumps(new_record, ensure_ascii=False) + "\n")
# 清除指标库的缓存,以便重新加载
load_jsonl.clear()
st.toast("指标已保存")
st.rerun() # 刷新页面以更新指标库列表
except Exception as e:
st.error(f"保存失败: {str(e)}")
# 查询语句
title_col, test_btn_col, result_col = st.columns([1, 1, 4])
with title_col:
st.subheader("查询语句")
with result_col:
title_query_result = st.subheader("")
# 查询输入框
query = st.text_area(
"输入查询语句",
value=current_indicator.query,
height=70,
key="query_input",
placeholder="在此输入查询语句...",
label_visibility="collapsed"
)
if query != current_indicator.query:
st.session_state.indicator_manager.update_indicator(
current_indicator.id, query=query
)
# 测试按钮改为生成指标代码按钮
with test_btn_col:
if st.button("🔄 生成代码", key="generate_code_btn", use_container_width=True):
with st.spinner("正在生成指标代码..."):
# 使用缓存的测试函数
result = generate_indicator_code(query)
# 根据结果显示不同的状态
if result.startswith("✅"):
title_query_result.success(result)
# 生成成功后,立即保存代码到指标对象中
if "generated_code" in st.session_state:
st.session_state.indicator_manager.update_indicator(
current_indicator.id,
code=st.session_state.generated_code
)
elif result.startswith("❌"):
title_query_result.error(result)
else:
# 显示实体识别结果或其他信息
title_query_result.info(result)
# 添加代码执行部分
title_col, run_btn_col, result_col = st.columns([1, 1, 4])
with title_col:
st.subheader("指标代码")
with result_col:
title_code_result = st.subheader("")
with run_btn_col:
if st.button("执行代码", key="execute_code_btn", use_container_width=True):
with st.spinner("正在执行代码..."):
if "code_executor" in st.session_state and "generated_code" in st.session_state:
result = st.session_state.code_executor.execute_code(st.session_state.generated_code)
if result and hasattr(result, 'get') and result.get('status'):
title_code_result.success("执行成功")
title_code_result.json(result.get('data', {}))
# 更新指标的代码
st.session_state.indicator_manager.update_indicator(
current_indicator.id,
code=st.session_state.generated_code
)
else:
title_code_result.error(f"执行失败: {result.get('message', '未知错误') if hasattr(result, 'get') else str(result)}")
else:
title_code_result.error("代码执行器未初始化或没有生成的代码")
# 显示代码
st.code(st.session_state.get("generated_code", "# 暂无代码,请输入有效的查询语句并点击测试按钮"), language="python")
# 移除保存到指标库按钮部分
else:
st.warning("找不到当前指标,可能已被删除")
else:
st.info("请在左侧创建或选择一个指标")
if __name__ == "__main__":
main()