Files
agno_agentic_rag/utils.py
T
2025-04-08 11:38:01 +08:00

239 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from dotenv import load_dotenv
# 加载.env文件
load_dotenv()
from typing import Any, Dict, List, Optional, TypeVar
import streamlit as st
T = TypeVar('T')
from agentic_rag import get_agentic_rag_agent
from agno.agent import Agent
from agno.utils.log import logger
def add_message(
role: str, content: str, tool_calls: Optional[List[Dict[str, Any]]] = None
) -> None:
"""Safely add a message to the session state"""
if "messages" not in st.session_state or not isinstance(
st.session_state["messages"], list
):
st.session_state["messages"] = []
st.session_state["messages"].append(
{"role": role, "content": content, "tool_calls": tool_calls}
)
def export_chat_history():
"""Export chat history as markdown"""
if "messages" in st.session_state:
chat_text = "# Auto RAG Agent - 会话历史\n\n"
for msg in st.session_state["messages"]:
role = "🤖 Assistant" if msg["role"] == "agent" else "👤 User"
chat_text += f"### {role}\n{msg['content']}\n\n"
if msg.get("tool_calls"):
chat_text += "#### 工具调用:\n"
for tool in msg["tool_calls"]:
if isinstance(tool, dict):
tool_name = tool.get("name", "未知工具")
else:
tool_name = getattr(tool, "name", "未知工具")
chat_text += f"- {tool_name}\n"
return chat_text
return ""
def display_tool_calls(tool_calls_container, tools):
"""Display tool calls in a streamlit container with expandable sections.
Args:
tool_calls_container: Streamlit container to display the tool calls
tools: List of tool call dictionaries containing name, args, content, and metrics
"""
with tool_calls_container.container():
for tool_call in tools:
_tool_name = tool_call.get("tool_name")
_tool_args = tool_call.get("tool_args")
_content = tool_call.get("content")
_metrics = tool_call.get("metrics")
with st.sidebar.expander(f"🛠️ {_tool_name.replace('_', ' ').title()}", expanded=False):
if isinstance(_tool_args, dict) and "query" in _tool_args:
st.sidebar.code(_tool_args["query"], language="sql")
if _tool_args and _tool_args != {"query": None}:
st.sidebar.markdown("**参数:**")
st.sidebar.json(_tool_args)
if _content:
st.sidebar.markdown("**结果:**")
try:
st.sidebar.json(_content)
except Exception as e:
st.sidebar.markdown(_content)
if _metrics:
st.sidebar.markdown("**指标:**")
try:
st.sidebar.json(_metrics)
except Exception as e:
st.sidebar.markdown(_metrics)
def rename_session_widget(agent: Agent) -> None:
"""Rename the current session of the agent and save to storage"""
container = st.sidebar.container()
# Initialize session_edit_mode if needed
if "session_edit_mode" not in st.session_state:
st.session_state.session_edit_mode = False
if st.sidebar.button("✎ 重命名会话"):
st.session_state.session_edit_mode = True
st.rerun()
if st.session_state.session_edit_mode:
new_session_name = st.sidebar.text_input(
"输入新名称:",
value=agent.session_name,
key="session_name_input",
)
if st.sidebar.button("保存", type="primary"):
if new_session_name:
agent.rename_session(new_session_name)
st.session_state.session_edit_mode = False
st.rerun()
def session_selector_widget(agent: Agent, model_id: str) -> None:
"""Display a session selector in the sidebar"""
if agent.storage:
agent_sessions = agent.storage.get_all_sessions()
# Get session names if available, otherwise use IDs
session_options = []
for session in agent_sessions:
session_id = session.session_id
session_name = (
session.session_data.get("session_name", None)
if session.session_data
else None
)
display_name = session_name if session_name else session_id
session_options.append({"id": session_id, "display": display_name})
# Display session selector
#selected_session = st.sidebar.selectbox(
# "会话",
# options=[s["display"] for s in session_options],
# key="session_selector",
#)
# Find the selected session ID
#selected_session_id = next(
# s["id"] for s in session_options if s["display"] == selected_session
#)
if len(session_options) > 0:
selected_session_id = session_options[0]["id"]
if st.session_state["agentic_rag_agent_session_id"] != selected_session_id:
logger.info(
f"---*--- Loading {model_id} run: {selected_session_id} ---*---"
)
st.session_state["agentic_rag_agent"] = get_agentic_rag_agent(
model_id=model_id,
session_id=selected_session_id,
)
st.rerun()
def about_widget() -> None:
"""Display an about section in the sidebar"""
st.sidebar.markdown("---")
st.sidebar.markdown("### ️ 关于")
st.sidebar.markdown("""
本智能检索增强生成助手帮助您使用自然语言查询分析文档和网页内容。
构建技术:
- 🚀 Agno
- 💫 Streamlit
""")
CUSTOM_CSS = """
<style>
.sidebar .sidebar-content {
width: 200px;
margin-left: auto;
margin-right: 0;
}
.main {
margin-left: 200px;
margin-right: 0;
}
/* Main Styles */
.main-title {
text-align: center;
background: linear-gradient(45deg, #FF4B2B, #FF416C);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 3em;
font-weight: bold;
padding: 1em 0;
}
.subtitle {
text-align: center;
color: #666;
margin-bottom: 2em;
}
.stButton button {
width: 100%;
border-radius: 20px;
margin: 0.2em 0;
transition: all 0.3s ease;
}
.stButton button:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(0,0,0,0.1);
}
.chat-container {
border-radius: 15px;
padding: 1em;
margin: 1em 0;
background-color: #f5f5f5;
}
.tool-result {
background-color: #f8f9fa;
border-radius: 10px;
padding: 1em;
margin: 1em 0;
border-left: 4px solid #3B82F6;
}
.status-message {
padding: 1em;
border-radius: 10px;
margin: 1em 0;
}
.success-message {
background-color: #d4edda;
color: #155724;
}
.error-message {
background-color: #f8d7da;
color: #721c24;
}
/* Dark mode adjustments */
@media (prefers-color-scheme: dark) {
.chat-container {
background-color: #2b2b2b;
}
.tool-result {
background-color: #1e1e1e;
}
}
</style>
"""