迁入项目
This commit is contained in:
@@ -0,0 +1,160 @@
|
||||
from dotenv import load_dotenv
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
import threading
|
||||
import nest_asyncio
|
||||
|
||||
from agentic_rag import get_agentic_rag_agent
|
||||
from agno.utils.log import logger
|
||||
from ui import (
|
||||
initialize_ui,
|
||||
show_header,
|
||||
get_modul_option,
|
||||
show_tabs,
|
||||
)
|
||||
from utils import (
|
||||
add_message,
|
||||
session_selector_widget,
|
||||
)
|
||||
import streamlit as st
|
||||
from extra_streamlit_components import CookieManager
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
def initialize_agent(model_id: str):
|
||||
"""Initialize or retrieve the Agentic RAG."""
|
||||
lock.acquire()
|
||||
try:
|
||||
if (
|
||||
not "agentic_rag_agent" in st.session_state
|
||||
or st.session_state.get("agentic_rag_agent") is None
|
||||
or st.session_state.get("current_model") != model_id
|
||||
):
|
||||
logger.info(f"---*--- Creating {model_id} Agent ---*---")
|
||||
agent = get_agentic_rag_agent(
|
||||
model_id=model_id,
|
||||
session_id=st.session_state.get("agentic_rag_agent_session_id"),
|
||||
)
|
||||
st.session_state["agentic_rag_agent"] = agent
|
||||
st.session_state["current_model"] = model_id
|
||||
else:
|
||||
agent = st.session_state.get("agentic_rag_agent")
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
# Load Agent Session
|
||||
try:
|
||||
st.session_state["agentic_rag_agent_session_id"] = agent.load_session()
|
||||
except Exception:
|
||||
st.warning("无法创建Agent会话,请确认数据库是否在运行?")
|
||||
return agent
|
||||
|
||||
def main():
|
||||
initialize_ui()
|
||||
#st.write("")
|
||||
|
||||
# 仅在首次运行时初始化 CookieManager
|
||||
if "cookie_manager" not in st.session_state:
|
||||
st.session_state.cookie_manager = CookieManager()
|
||||
|
||||
# 获取实例(后续直接使用缓存)
|
||||
cookie_manager = st.session_state.cookie_manager
|
||||
|
||||
# 检查并设置 session_id
|
||||
if (not 'agentic_rag_agent_session_id' in st.session_state
|
||||
and cookie_manager.get(cookie='agentic_rag_agent_session_id') is not None):
|
||||
st.session_state["agentic_rag_agent_session_id"] = cookie_manager.get(cookie='agentic_rag_agent_session_id')
|
||||
|
||||
model_id = get_modul_option(st.session_state["model_id"] if "model_id" in st.session_state else 0)
|
||||
#model_id = show_model_selector()
|
||||
# Initialize Agent
|
||||
agentic_rag_agent = initialize_agent(model_id)
|
||||
|
||||
if 'agentic_rag_agent_session_id' in st.session_state:
|
||||
cookie_manager.set('agentic_rag_agent_session_id', st.session_state.get("agentic_rag_agent_session_id"))
|
||||
|
||||
# Load runs from memory
|
||||
agent_runs = agentic_rag_agent.memory.runs
|
||||
if len(agent_runs) > 0:
|
||||
logger.debug("加载历史记录")
|
||||
st.session_state["messages"] = []
|
||||
for _run in agent_runs:
|
||||
if _run.message is not None:
|
||||
add_message(_run.message.role, _run.message.content)
|
||||
if _run.response is not None:
|
||||
add_message("assistant", _run.response.content, _run.response.tools)
|
||||
else:
|
||||
logger.debug("没有找到历史记录")
|
||||
st.session_state["messages"] = []
|
||||
|
||||
chatContainer = st.sidebar.container()
|
||||
lastMsgContainer = st.sidebar.container()
|
||||
|
||||
# Chat input
|
||||
if prompt := st.sidebar.chat_input("👋 问我任何问题!"):
|
||||
add_message("user", prompt)
|
||||
|
||||
# Display UI
|
||||
show_header()
|
||||
show_tabs()
|
||||
|
||||
#show_chat_history(agentic_rag_agent)
|
||||
|
||||
with chatContainer:
|
||||
for message in st.session_state["messages"]:
|
||||
if message["role"] in ["user", "assistant"]:
|
||||
_content = message["content"]
|
||||
if _content is not None:
|
||||
with st.chat_message(message["role"]):
|
||||
#if "tool_calls" in message and message["tool_calls"]:
|
||||
# display_tool_calls(st.empty(), message["tool_calls"])
|
||||
st.markdown(_content)
|
||||
|
||||
with lastMsgContainer:
|
||||
last_message = (st.session_state["messages"][-1] if st.session_state["messages"] else None)
|
||||
if last_message and last_message.get("role") == "user":
|
||||
question = last_message["content"]
|
||||
#with st.chat_message("user"):
|
||||
# st.markdown(question)
|
||||
with st.chat_message("assistant"):
|
||||
# Create container for tool calls
|
||||
tool_calls_container = st.empty()
|
||||
resp_container = st.empty()
|
||||
with st.spinner("🤔 思考中..."):
|
||||
response = ""
|
||||
try:
|
||||
# Run the agent and stream the response
|
||||
run_response = agentic_rag_agent.run(question, stream=True)
|
||||
for _resp_chunk in run_response:
|
||||
# Display tool calls if available
|
||||
#if _resp_chunk.tools and len(_resp_chunk.tools) > 0:
|
||||
# display_tool_calls(tool_calls_container, _resp_chunk.tools)
|
||||
|
||||
# Display response
|
||||
if _resp_chunk.content is not None:
|
||||
response += _resp_chunk.content
|
||||
resp_container.markdown(response)
|
||||
|
||||
add_message("assistant", response, agentic_rag_agent.run_response.tools)
|
||||
except Exception as e:
|
||||
error_message = f"对不起, 发生错误: {str(e)}"
|
||||
add_message("assistant", error_message)
|
||||
st.error(error_message)
|
||||
|
||||
|
||||
####################################################################
|
||||
# Session selector
|
||||
####################################################################
|
||||
session_selector_widget(agentic_rag_agent, model_id)
|
||||
#rename_session_widget(agentic_rag_agent)
|
||||
|
||||
####################################################################
|
||||
# About section
|
||||
####################################################################
|
||||
#about_widget()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user