198 lines
5.8 KiB
Python
198 lines
5.8 KiB
Python
from dotenv import load_dotenv
|
||
|
||
# 加载.env文件
|
||
load_dotenv()
|
||
|
||
from typing import Any, Dict, List, Optional, TypeVar
|
||
|
||
import streamlit as st
|
||
|
||
T = TypeVar('T')
|
||
|
||
|
||
|
||
from agentic_rag import get_agentic_rag_agent
|
||
from agno.agent import Agent
|
||
from agno.utils.log import logger
|
||
|
||
|
||
def add_message(
|
||
role: str, content: str, tool_calls: Optional[List[Dict[str, Any]]] = None
|
||
) -> None:
|
||
"""Safely add a message to the session state"""
|
||
if "messages" not in st.session_state or not isinstance(
|
||
st.session_state["messages"], list
|
||
):
|
||
st.session_state["messages"] = []
|
||
st.session_state["messages"].append(
|
||
{"role": role, "content": content, "tool_calls": tool_calls}
|
||
)
|
||
|
||
|
||
def export_chat_history():
|
||
"""Export chat history as markdown"""
|
||
if "messages" in st.session_state:
|
||
chat_text = "# Auto RAG Agent - 会话历史\n\n"
|
||
for msg in st.session_state["messages"]:
|
||
role = "🤖 Assistant" if msg["role"] == "agent" else "👤 User"
|
||
chat_text += f"### {role}\n{msg['content']}\n\n"
|
||
if msg.get("tool_calls"):
|
||
chat_text += "#### 工具调用:\n"
|
||
for tool in msg["tool_calls"]:
|
||
if isinstance(tool, dict):
|
||
tool_name = tool.get("name", "未知工具")
|
||
else:
|
||
tool_name = getattr(tool, "name", "未知工具")
|
||
chat_text += f"- {tool_name}\n"
|
||
return chat_text
|
||
return ""
|
||
|
||
|
||
def display_tool_calls(tool_calls_container, tools):
|
||
"""Display tool calls in a streamlit container with expandable sections.
|
||
|
||
Args:
|
||
tool_calls_container: Streamlit container to display the tool calls
|
||
tools: List of tool call dictionaries containing name, args, content, and metrics
|
||
"""
|
||
with tool_calls_container.container():
|
||
for tool_call in tools:
|
||
_tool_name = tool_call.get("tool_name")
|
||
_tool_args = tool_call.get("tool_args")
|
||
_content = tool_call.get("content")
|
||
_metrics = tool_call.get("metrics")
|
||
|
||
with st.sidebar.expander(f"🛠️ {_tool_name.replace('_', ' ').title()}", expanded=False):
|
||
if isinstance(_tool_args, dict) and "query" in _tool_args:
|
||
st.sidebar.code(_tool_args["query"], language="sql")
|
||
|
||
if _tool_args and _tool_args != {"query": None}:
|
||
st.sidebar.markdown("**参数:**")
|
||
st.sidebar.json(_tool_args)
|
||
|
||
if _content:
|
||
st.sidebar.markdown("**结果:**")
|
||
try:
|
||
st.sidebar.json(_content)
|
||
except Exception as e:
|
||
st.sidebar.markdown(_content)
|
||
|
||
if _metrics:
|
||
st.sidebar.markdown("**指标:**")
|
||
try:
|
||
st.sidebar.json(_metrics)
|
||
except Exception as e:
|
||
st.sidebar.markdown(_metrics)
|
||
|
||
def rename_session_widget(agent: Agent) -> None:
|
||
"""Rename the current session of the agent and save to storage"""
|
||
|
||
container = st.sidebar.container()
|
||
|
||
# Initialize session_edit_mode if needed
|
||
if "session_edit_mode" not in st.session_state:
|
||
st.session_state.session_edit_mode = False
|
||
|
||
if st.sidebar.button("✎ 重命名会话"):
|
||
st.session_state.session_edit_mode = True
|
||
st.rerun()
|
||
|
||
if st.session_state.session_edit_mode:
|
||
new_session_name = st.sidebar.text_input(
|
||
"输入新名称:",
|
||
value=agent.session_name,
|
||
key="session_name_input",
|
||
)
|
||
if st.sidebar.button("保存", type="primary"):
|
||
if new_session_name:
|
||
agent.rename_session(new_session_name)
|
||
st.session_state.session_edit_mode = False
|
||
st.rerun()
|
||
|
||
def about_widget() -> None:
|
||
"""Display an about section in the sidebar"""
|
||
st.sidebar.markdown("---")
|
||
st.sidebar.markdown("### ℹ️ 关于")
|
||
st.sidebar.markdown("""
|
||
本智能检索增强生成助手帮助您使用自然语言查询分析文档和网页内容。
|
||
|
||
构建技术:
|
||
- 🚀 Agno
|
||
- 💫 Streamlit
|
||
""")
|
||
|
||
|
||
CUSTOM_CSS = """
|
||
<style>
|
||
.sidebar .sidebar-content {
|
||
width: 200px;
|
||
margin-left: auto;
|
||
margin-right: 0;
|
||
}
|
||
.main {
|
||
margin-left: 200px;
|
||
margin-right: 0;
|
||
}
|
||
/* Main Styles */
|
||
.main-title {
|
||
text-align: center;
|
||
background: linear-gradient(45deg, #FF4B2B, #FF416C);
|
||
-webkit-background-clip: text;
|
||
-webkit-text-fill-color: transparent;
|
||
font-size: 3em;
|
||
font-weight: bold;
|
||
padding: 1em 0;
|
||
}
|
||
.subtitle {
|
||
text-align: center;
|
||
color: #666;
|
||
margin-bottom: 2em;
|
||
}
|
||
.stButton button {
|
||
width: 100%;
|
||
border-radius: 20px;
|
||
margin: 0.2em 0;
|
||
transition: all 0.3s ease;
|
||
}
|
||
.stButton button:hover {
|
||
transform: translateY(-2px);
|
||
box-shadow: 0 5px 15px rgba(0,0,0,0.1);
|
||
}
|
||
.chat-container {
|
||
border-radius: 15px;
|
||
padding: 1em;
|
||
margin: 1em 0;
|
||
background-color: #f5f5f5;
|
||
}
|
||
.tool-result {
|
||
background-color: #f8f9fa;
|
||
border-radius: 10px;
|
||
padding: 1em;
|
||
margin: 1em 0;
|
||
border-left: 4px solid #3B82F6;
|
||
}
|
||
.status-message {
|
||
padding: 1em;
|
||
border-radius: 10px;
|
||
margin: 1em 0;
|
||
}
|
||
.success-message {
|
||
background-color: #d4edda;
|
||
color: #155724;
|
||
}
|
||
.error-message {
|
||
background-color: #f8d7da;
|
||
color: #721c24;
|
||
}
|
||
/* Dark mode adjustments */
|
||
@media (prefers-color-scheme: dark) {
|
||
.chat-container {
|
||
background-color: #2b2b2b;
|
||
}
|
||
.tool-result {
|
||
background-color: #1e1e1e;
|
||
}
|
||
}
|
||
</style>
|
||
"""
|