完善整个示例,并增加查询测试模式

This commit is contained in:
2025-04-10 07:18:52 +08:00
parent 211db332c0
commit 6356f53ef7
4 changed files with 238 additions and 105 deletions
+13 -10
View File
@@ -1,3 +1,4 @@
import re
from typing import Optional
from dotenv import load_dotenv
@@ -32,21 +33,22 @@ def initialize_agent(model_id: str, session_id: Optional[str] = None):
if session_id is None:
session_id = st.session_state.get("agentic_rag_agent_session_id")
agent = st.session_state.get("agentic_rag_agent") if "agentic_rag_agent" in st.session_state else None
try:
if (
not "agentic_rag_agent" in st.session_state
or st.session_state.get("agentic_rag_agent") is None
or st.session_state.get("current_model") != model_id
if (st.session_state.get("current_model") != model_id
or agent is None
or agent.session_id != session_id
):
logger.info(f"---*--- Creating {model_id} Agent ---*---")
#agent = get_agentic_rag_agent(
# model_id=model_id,
# session_id=session_id,
#)
agent = get_workflow(
agent = get_agentic_rag_agent(
model_id=model_id,
session_id=session_id,
)
#agent = get_workflow(
# model_id=model_id,
# session_id=session_id,
#)
st.session_state["agentic_rag_agent"] = agent
st.session_state["current_model"] = model_id
else:
@@ -120,6 +122,7 @@ def main():
with st.chat_message(message["role"]):
#if "tool_calls" in message and message["tool_calls"]:
# display_tool_calls(st.empty(), message["tool_calls"])
_content = re.sub(r'<context>\s*{.*?}\s*</context>', '', _content, flags=re.DOTALL)
st.markdown(_content)
with lastMsgContainer:
@@ -136,7 +139,7 @@ def main():
response = ""
try:
# Run the agent and stream the response
run_response = agentic_rag_agent.run(message=question, stream=False)
run_response = agentic_rag_agent.run(message=question, stream=True)
for _resp_chunk in run_response:
# Display tool calls if available
#if _resp_chunk.tools and len(_resp_chunk.tools) > 0: