207 lines
7.6 KiB
Python
207 lines
7.6 KiB
Python
from typing import Optional
|
|
|
|
from dotenv import load_dotenv
|
|
# 加载.env文件
|
|
load_dotenv()
|
|
import threading
|
|
import nest_asyncio
|
|
|
|
from agentic_rag import get_agentic_rag_agent, get_workflow
|
|
from agno.utils.log import logger
|
|
from ui import (
|
|
initialize_ui,
|
|
show_header,
|
|
get_modul_option,
|
|
show_tabs,
|
|
)
|
|
from utils import (
|
|
add_message,
|
|
)
|
|
import streamlit as st
|
|
from extra_streamlit_components import CookieManager
|
|
|
|
nest_asyncio.apply()
|
|
|
|
|
|
lock = threading.Lock()
|
|
|
|
def initialize_agent(model_id: str, session_id: Optional[str] = None):
|
|
"""Initialize or retrieve the Agentic RAG."""
|
|
lock.acquire()
|
|
|
|
if session_id is None:
|
|
session_id = st.session_state.get("agentic_rag_agent_session_id")
|
|
|
|
try:
|
|
if (
|
|
not "agentic_rag_agent" in st.session_state
|
|
or st.session_state.get("agentic_rag_agent") is None
|
|
or st.session_state.get("current_model") != model_id
|
|
):
|
|
logger.info(f"---*--- Creating {model_id} Agent ---*---")
|
|
#agent = get_agentic_rag_agent(
|
|
# model_id=model_id,
|
|
# session_id=session_id,
|
|
#)
|
|
agent = get_workflow(
|
|
model_id=model_id,
|
|
session_id=session_id,
|
|
)
|
|
st.session_state["agentic_rag_agent"] = agent
|
|
st.session_state["current_model"] = model_id
|
|
else:
|
|
agent = st.session_state.get("agentic_rag_agent")
|
|
finally:
|
|
lock.release()
|
|
|
|
# Load Agent Session
|
|
try:
|
|
st.session_state["agentic_rag_agent_session_id"] = agent.load_session()
|
|
except Exception:
|
|
st.warning("无法创建Agent会话,请确认数据库是否在运行?")
|
|
return agent
|
|
|
|
def main():
|
|
initialize_ui()
|
|
#st.write("")
|
|
|
|
# 仅在首次运行时初始化 CookieManager
|
|
if "cookie_manager" not in st.session_state:
|
|
st.session_state.cookie_manager = CookieManager()
|
|
|
|
# 获取实例(后续直接使用缓存)
|
|
cookie_manager = st.session_state.cookie_manager
|
|
|
|
# 检查并设置 session_id
|
|
if (not 'agentic_rag_agent_session_id' in st.session_state
|
|
and cookie_manager.get(cookie='agentic_rag_agent_session_id') is not None):
|
|
st.session_state["agentic_rag_agent_session_id"] = cookie_manager.get(cookie='agentic_rag_agent_session_id')
|
|
|
|
model_id = get_modul_option(st.session_state["model_id"] if "model_id" in st.session_state else 0)
|
|
#model_id = show_model_selector()
|
|
# Initialize Agent
|
|
agentic_rag_agent = initialize_agent(model_id)
|
|
|
|
if 'agentic_rag_agent_session_id' in st.session_state:
|
|
cookie_manager.set('agentic_rag_agent_session_id', st.session_state.get("agentic_rag_agent_session_id"))
|
|
|
|
# Load runs from memory
|
|
agent_runs = agentic_rag_agent.memory.runs
|
|
if len(agent_runs) > 0:
|
|
logger.debug("加载历史记录")
|
|
st.session_state["messages"] = []
|
|
for _run in agent_runs:
|
|
if _run.message is not None:
|
|
add_message(_run.message.role, _run.message.content)
|
|
if _run.response is not None:
|
|
add_message("assistant", _run.response.content, _run.response.tools)
|
|
else:
|
|
logger.debug("没有找到历史记录")
|
|
st.session_state["messages"] = []
|
|
|
|
chatContainer = st.sidebar.container()
|
|
lastMsgContainer = st.sidebar.container()
|
|
|
|
# Chat input
|
|
if prompt := st.sidebar.chat_input("👋 问我任何问题!"):
|
|
add_message("user", prompt)
|
|
|
|
# Display UI
|
|
show_header()
|
|
show_tabs()
|
|
|
|
#show_chat_history(agentic_rag_agent)
|
|
|
|
with chatContainer:
|
|
for message in st.session_state["messages"]:
|
|
if message["role"] in ["user", "assistant"]:
|
|
_content = message["content"]
|
|
if _content is not None:
|
|
with st.chat_message(message["role"]):
|
|
#if "tool_calls" in message and message["tool_calls"]:
|
|
# display_tool_calls(st.empty(), message["tool_calls"])
|
|
st.markdown(_content)
|
|
|
|
with lastMsgContainer:
|
|
last_message = (st.session_state["messages"][-1] if st.session_state["messages"] else None)
|
|
if last_message and last_message.get("role") == "user":
|
|
question = last_message["content"]
|
|
#with st.chat_message("user"):
|
|
# st.markdown(question)
|
|
with st.chat_message("assistant"):
|
|
# Create container for tool calls
|
|
tool_calls_container = st.empty()
|
|
resp_container = st.empty()
|
|
with st.spinner("🤔 思考中..."):
|
|
response = ""
|
|
try:
|
|
# Run the agent and stream the response
|
|
run_response = agentic_rag_agent.run(message=question, stream=False)
|
|
for _resp_chunk in run_response:
|
|
# Display tool calls if available
|
|
#if _resp_chunk.tools and len(_resp_chunk.tools) > 0:
|
|
# display_tool_calls(tool_calls_container, _resp_chunk.tools)
|
|
|
|
# Display response
|
|
if _resp_chunk.content is not None:
|
|
response += _resp_chunk.content
|
|
resp_container.markdown(response)
|
|
|
|
add_message("assistant", response, agentic_rag_agent.run_response.tools)
|
|
except Exception as e:
|
|
error_message = f"对不起, 发生错误: {str(e)}"
|
|
add_message("assistant", error_message)
|
|
st.error(error_message)
|
|
|
|
|
|
####################################################################
|
|
# Session selector
|
|
####################################################################
|
|
if agentic_rag_agent.storage:
|
|
agent_sessions = agentic_rag_agent.storage.get_all_sessions()
|
|
# Get session names if available, otherwise use IDs
|
|
session_options = []
|
|
for session in agent_sessions:
|
|
session_id = session.session_id
|
|
session_name = (
|
|
session.session_data.get("session_name", None)
|
|
if session.session_data
|
|
else None
|
|
)
|
|
display_name = session_name if session_name else session_id
|
|
session_options.append({"id": session_id, "display": display_name})
|
|
|
|
# Display session selector
|
|
# selected_session = st.sidebar.selectbox(
|
|
# "会话",
|
|
# options=[s["display"] for s in session_options],
|
|
# key="session_selector",
|
|
# )
|
|
# Find the selected session ID
|
|
# selected_session_id = next(
|
|
# s["id"] for s in session_options if s["display"] == selected_session
|
|
# )
|
|
if len(session_options) > 0:
|
|
selected_session_id = session_options[0]["id"]
|
|
|
|
if ('agentic_rag_agent_session_id' in st.session_state and
|
|
selected_session_id is not None and
|
|
st.session_state["agentic_rag_agent_session_id"] != selected_session_id):
|
|
logger.info(
|
|
f"---*--- Loading {model_id} run: {selected_session_id} ---*---"
|
|
)
|
|
st.session_state["agentic_rag_agent"] = initialize_agent(
|
|
model_id=model_id,
|
|
session_id=selected_session_id,
|
|
)
|
|
st.rerun()
|
|
#rename_session_widget(agentic_rag_agent)
|
|
|
|
####################################################################
|
|
# About section
|
|
####################################################################
|
|
#about_widget()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|