161 lines
5.9 KiB
Python
161 lines
5.9 KiB
Python
from dotenv import load_dotenv
|
|
# 加载.env文件
|
|
load_dotenv()
|
|
import threading
|
|
import nest_asyncio
|
|
|
|
from agentic_rag import get_agentic_rag_agent
|
|
from agno.utils.log import logger
|
|
from ui import (
|
|
initialize_ui,
|
|
show_header,
|
|
get_modul_option,
|
|
show_tabs,
|
|
)
|
|
from utils import (
|
|
add_message,
|
|
session_selector_widget,
|
|
)
|
|
import streamlit as st
|
|
from extra_streamlit_components import CookieManager
|
|
|
|
nest_asyncio.apply()
|
|
|
|
|
|
lock = threading.Lock()
|
|
|
|
def initialize_agent(model_id: str):
|
|
"""Initialize or retrieve the Agentic RAG."""
|
|
lock.acquire()
|
|
try:
|
|
if (
|
|
not "agentic_rag_agent" in st.session_state
|
|
or st.session_state.get("agentic_rag_agent") is None
|
|
or st.session_state.get("current_model") != model_id
|
|
):
|
|
logger.info(f"---*--- Creating {model_id} Agent ---*---")
|
|
agent = get_agentic_rag_agent(
|
|
model_id=model_id,
|
|
session_id=st.session_state.get("agentic_rag_agent_session_id"),
|
|
)
|
|
st.session_state["agentic_rag_agent"] = agent
|
|
st.session_state["current_model"] = model_id
|
|
else:
|
|
agent = st.session_state.get("agentic_rag_agent")
|
|
finally:
|
|
lock.release()
|
|
|
|
# Load Agent Session
|
|
try:
|
|
st.session_state["agentic_rag_agent_session_id"] = agent.load_session()
|
|
except Exception:
|
|
st.warning("无法创建Agent会话,请确认数据库是否在运行?")
|
|
return agent
|
|
|
|
def main():
|
|
initialize_ui()
|
|
#st.write("")
|
|
|
|
# 仅在首次运行时初始化 CookieManager
|
|
if "cookie_manager" not in st.session_state:
|
|
st.session_state.cookie_manager = CookieManager()
|
|
|
|
# 获取实例(后续直接使用缓存)
|
|
cookie_manager = st.session_state.cookie_manager
|
|
|
|
# 检查并设置 session_id
|
|
if (not 'agentic_rag_agent_session_id' in st.session_state
|
|
and cookie_manager.get(cookie='agentic_rag_agent_session_id') is not None):
|
|
st.session_state["agentic_rag_agent_session_id"] = cookie_manager.get(cookie='agentic_rag_agent_session_id')
|
|
|
|
model_id = get_modul_option(st.session_state["model_id"] if "model_id" in st.session_state else 0)
|
|
#model_id = show_model_selector()
|
|
# Initialize Agent
|
|
agentic_rag_agent = initialize_agent(model_id)
|
|
|
|
if 'agentic_rag_agent_session_id' in st.session_state:
|
|
cookie_manager.set('agentic_rag_agent_session_id', st.session_state.get("agentic_rag_agent_session_id"))
|
|
|
|
# Load runs from memory
|
|
agent_runs = agentic_rag_agent.memory.runs
|
|
if len(agent_runs) > 0:
|
|
logger.debug("加载历史记录")
|
|
st.session_state["messages"] = []
|
|
for _run in agent_runs:
|
|
if _run.message is not None:
|
|
add_message(_run.message.role, _run.message.content)
|
|
if _run.response is not None:
|
|
add_message("assistant", _run.response.content, _run.response.tools)
|
|
else:
|
|
logger.debug("没有找到历史记录")
|
|
st.session_state["messages"] = []
|
|
|
|
chatContainer = st.sidebar.container()
|
|
lastMsgContainer = st.sidebar.container()
|
|
|
|
# Chat input
|
|
if prompt := st.sidebar.chat_input("👋 问我任何问题!"):
|
|
add_message("user", prompt)
|
|
|
|
# Display UI
|
|
show_header()
|
|
show_tabs()
|
|
|
|
#show_chat_history(agentic_rag_agent)
|
|
|
|
with chatContainer:
|
|
for message in st.session_state["messages"]:
|
|
if message["role"] in ["user", "assistant"]:
|
|
_content = message["content"]
|
|
if _content is not None:
|
|
with st.chat_message(message["role"]):
|
|
#if "tool_calls" in message and message["tool_calls"]:
|
|
# display_tool_calls(st.empty(), message["tool_calls"])
|
|
st.markdown(_content)
|
|
|
|
with lastMsgContainer:
|
|
last_message = (st.session_state["messages"][-1] if st.session_state["messages"] else None)
|
|
if last_message and last_message.get("role") == "user":
|
|
question = last_message["content"]
|
|
#with st.chat_message("user"):
|
|
# st.markdown(question)
|
|
with st.chat_message("assistant"):
|
|
# Create container for tool calls
|
|
tool_calls_container = st.empty()
|
|
resp_container = st.empty()
|
|
with st.spinner("🤔 思考中..."):
|
|
response = ""
|
|
try:
|
|
# Run the agent and stream the response
|
|
run_response = agentic_rag_agent.run(question, stream=True)
|
|
for _resp_chunk in run_response:
|
|
# Display tool calls if available
|
|
#if _resp_chunk.tools and len(_resp_chunk.tools) > 0:
|
|
# display_tool_calls(tool_calls_container, _resp_chunk.tools)
|
|
|
|
# Display response
|
|
if _resp_chunk.content is not None:
|
|
response += _resp_chunk.content
|
|
resp_container.markdown(response)
|
|
|
|
add_message("assistant", response, agentic_rag_agent.run_response.tools)
|
|
except Exception as e:
|
|
error_message = f"对不起, 发生错误: {str(e)}"
|
|
add_message("assistant", error_message)
|
|
st.error(error_message)
|
|
|
|
|
|
####################################################################
|
|
# Session selector
|
|
####################################################################
|
|
session_selector_widget(agentic_rag_agent, model_id)
|
|
#rename_session_widget(agentic_rag_agent)
|
|
|
|
####################################################################
|
|
# About section
|
|
####################################################################
|
|
#about_widget()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|