半成品,为了保存记录,请勿使用。

This commit is contained in:
2025-04-09 20:44:26 +08:00
parent b6b697efdb
commit 211db332c0
3 changed files with 320 additions and 52 deletions
+53 -7
View File
@@ -1,10 +1,12 @@
from typing import Optional
from dotenv import load_dotenv
# 加载.env文件
load_dotenv()
import threading
import nest_asyncio
from agentic_rag import get_agentic_rag_agent
from agentic_rag import get_agentic_rag_agent, get_workflow
from agno.utils.log import logger
from ui import (
initialize_ui,
@@ -14,7 +16,6 @@ from ui import (
)
from utils import (
add_message,
session_selector_widget,
)
import streamlit as st
from extra_streamlit_components import CookieManager
@@ -24,9 +25,13 @@ nest_asyncio.apply()
lock = threading.Lock()
def initialize_agent(model_id: str):
def initialize_agent(model_id: str, session_id: Optional[str] = None):
"""Initialize or retrieve the Agentic RAG."""
lock.acquire()
if session_id is None:
session_id = st.session_state.get("agentic_rag_agent_session_id")
try:
if (
not "agentic_rag_agent" in st.session_state
@@ -34,9 +39,13 @@ def initialize_agent(model_id: str):
or st.session_state.get("current_model") != model_id
):
logger.info(f"---*--- Creating {model_id} Agent ---*---")
agent = get_agentic_rag_agent(
#agent = get_agentic_rag_agent(
# model_id=model_id,
# session_id=session_id,
#)
agent = get_workflow(
model_id=model_id,
session_id=st.session_state.get("agentic_rag_agent_session_id"),
session_id=session_id,
)
st.session_state["agentic_rag_agent"] = agent
st.session_state["current_model"] = model_id
@@ -127,7 +136,7 @@ def main():
response = ""
try:
# Run the agent and stream the response
run_response = agentic_rag_agent.run(question, stream=True)
run_response = agentic_rag_agent.run(message=question, stream=False)
for _resp_chunk in run_response:
# Display tool calls if available
#if _resp_chunk.tools and len(_resp_chunk.tools) > 0:
@@ -148,7 +157,44 @@ def main():
####################################################################
# Session selector
####################################################################
session_selector_widget(agentic_rag_agent, model_id)
if agentic_rag_agent.storage:
agent_sessions = agentic_rag_agent.storage.get_all_sessions()
# Get session names if available, otherwise use IDs
session_options = []
for session in agent_sessions:
session_id = session.session_id
session_name = (
session.session_data.get("session_name", None)
if session.session_data
else None
)
display_name = session_name if session_name else session_id
session_options.append({"id": session_id, "display": display_name})
# Display session selector
# selected_session = st.sidebar.selectbox(
# "会话",
# options=[s["display"] for s in session_options],
# key="session_selector",
# )
# Find the selected session ID
# selected_session_id = next(
# s["id"] for s in session_options if s["display"] == selected_session
# )
if len(session_options) > 0:
selected_session_id = session_options[0]["id"]
if ('agentic_rag_agent_session_id' in st.session_state and
selected_session_id is not None and
st.session_state["agentic_rag_agent_session_id"] != selected_session_id):
logger.info(
f"---*--- Loading {model_id} run: {selected_session_id} ---*---"
)
st.session_state["agentic_rag_agent"] = initialize_agent(
model_id=model_id,
session_id=selected_session_id,
)
st.rerun()
#rename_session_widget(agentic_rag_agent)
####################################################################