迁入项目
@@ -0,0 +1,36 @@
|
||||
.yarn
|
||||
.env
|
||||
*.env
|
||||
.idea
|
||||
.venv
|
||||
.vscode
|
||||
__pycache__
|
||||
wiki_doc.json
|
||||
Data/
|
||||
|
||||
**/node_modules/**
|
||||
.DS_Store
|
||||
*.tsbuildinfo
|
||||
|
||||
dist
|
||||
.turbo
|
||||
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
lerna-debug.log*
|
||||
.pnpm-debug.log*
|
||||
|
||||
# Rush temporary files
|
||||
common/deploy/
|
||||
common/temp/
|
||||
common/autoinstallers/*/.npmrc
|
||||
**/.rush/temp/
|
||||
*.lock
|
||||
*.log
|
||||
*.chunks.jsonl
|
||||
|
||||
# mise
|
||||
mise.toml
|
||||
@@ -1,3 +1,96 @@
|
||||
# agno_agentic_rag
|
||||
# Agentic RAG Agent
|
||||
|
||||
**Agentic RAG Agent** is a chat application that combines models with retrieval-augmented generation.
|
||||
It allows users to ask questions based on custom knowledge bases, documents, and web data, retrieve context-aware answers, and maintain chat history across sessions.
|
||||
|
||||
> Note: Fork and clone this repository if needed
|
||||
|
||||
### 1. Create a virtual environment
|
||||
|
||||
```shell
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
### 2. Install dependencies
|
||||
|
||||
```shell
|
||||
pip install -r cookbook/examples/apps/agentic_rag/requirements.txt
|
||||
```
|
||||
|
||||
### 3. Configure API Keys
|
||||
|
||||
Required:
|
||||
```bash
|
||||
export OPENAI_API_KEY=your_openai_key_here
|
||||
```
|
||||
|
||||
Optional (for additional models):
|
||||
```bash
|
||||
export ANTHROPIC_API_KEY=your_anthropic_key_here
|
||||
export GOOGLE_API_KEY=your_google_key_here
|
||||
export GROQ_API_KEY=your_groq_key_here
|
||||
```
|
||||
|
||||
### 4. Run PgVector
|
||||
|
||||
> Install [docker desktop](https://docs.docker.com/desktop/install/mac-install/) first.
|
||||
|
||||
- Run using a helper script
|
||||
|
||||
```shell
|
||||
./cookbook/scripts/run_pgvector.sh
|
||||
```
|
||||
|
||||
- OR run using the docker run command
|
||||
|
||||
```shell
|
||||
docker run -d \
|
||||
-e POSTGRES_DB=ai \
|
||||
-e POSTGRES_USER=ai \
|
||||
-e POSTGRES_PASSWORD=ai \
|
||||
-e PGDATA=/var/lib/postgresql/data/pgdata \
|
||||
-v pgvolume:/var/lib/postgresql/data \
|
||||
-p 5532:5432 \
|
||||
--name pgvector \
|
||||
agnohq/pgvector:16
|
||||
```
|
||||
|
||||
### 5. Run Agentic RAG App
|
||||
|
||||
```shell
|
||||
streamlit run cookbook/examples/apps/agentic_rag/app.py
|
||||
```
|
||||
|
||||
## 🔧 Customization
|
||||
|
||||
### Model Selection
|
||||
|
||||
The application supports multiple model providers:
|
||||
- OpenAI (o3-min[requirements.in](requirements.in)i, gpt-4o)
|
||||
- Anthropic (claude-3-5-sonnet)
|
||||
- Google (gemini-2.0-flash-exp)
|
||||
- Groq (llama-3.3-70b-versatile)
|
||||
|
||||
### How to Use
|
||||
- Open [localhost:8501](http://localhost:8501) in your browser.
|
||||
- Upload documents or provide URLs (websites, csv, txt, and PDFs) to build a knowledge base.
|
||||
- Enter questions in the chat interface and get context-aware answers.
|
||||
- The app can also answer question using duckduckgo search without any external documents added.
|
||||
|
||||
### Troubleshooting
|
||||
- **Docker Connection Refused**: Ensure `pgvector` containers are running (`docker ps`).
|
||||
- **OpenAI API Errors**: Verify that the `OPENAI_API_KEY` is set and valid.
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
For more detailed information:
|
||||
- [Agno Documentation](https://docs.agno.com)
|
||||
- [Streamlit Documentation](https://docs.streamlit.io)
|
||||
|
||||
## 🤝 Support
|
||||
|
||||
Need help? Join our [Discord community](https://agno.link/discord)
|
||||
|
||||
|
||||
|
||||
agno_agentic_rag
|
||||
@@ -0,0 +1,239 @@
|
||||
"""🤖 Agentic RAG Agent - 您的AI知识助手!
|
||||
|
||||
这个高级示例展示了如何构建一个复杂的RAG(检索增强生成)系统,
|
||||
利用向量搜索和LLMs从任何知识库中提供深入见解。
|
||||
|
||||
该代理可以:
|
||||
- 处理和理解来自多个来源的文档(PDF、网站、文本文件)
|
||||
- 使用向量嵌入构建可搜索的知识库
|
||||
- 跨会话维护对话上下文和记忆
|
||||
- 为其响应提供相关引用和来源
|
||||
- 生成摘要并提取关键见解
|
||||
- 回答后续问题和澄清
|
||||
|
||||
可以尝试的示例查询:
|
||||
- "本文档的关键要点是什么?"
|
||||
- "你能总结主要论点和支持证据吗?"
|
||||
- "有哪些重要地统计数据和发现?"
|
||||
- "这与[主题X]有什么关系?"
|
||||
- "这个分析有哪些局限性或空白?"
|
||||
- "你能更详细地解释[概念X]吗?"
|
||||
- "其他来源支持或反驳这些主张吗?"
|
||||
|
||||
该代理使用:
|
||||
- 向量相似性搜索进行相关文档检索
|
||||
- 对话记忆用于上下文响应
|
||||
- 引用跟踪用于来源归属
|
||||
- 动态知识库更新
|
||||
|
||||
查看README了解如何运行应用程序。
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
from agno.document.chunking.document import DocumentChunking
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from agno.agent import Agent, AgentMemory
|
||||
from agno.embedder.openai import OpenAIEmbedder
|
||||
from agno.knowledge import AgentKnowledge
|
||||
from agno.memory.classifier import MemoryClassifier
|
||||
from agno.memory.db.sqlite import SqliteMemoryDb
|
||||
from agno.memory.manager import MemoryManager
|
||||
from agno.memory.summarizer import MemorySummarizer
|
||||
from agno.models.openai import OpenAIChat
|
||||
from agno.storage.json import JsonStorage
|
||||
from agno.vectordb.lancedb import LanceDb
|
||||
from agno.vectordb.search import SearchType
|
||||
from agno.document.reader.json_reader import JSONReader
|
||||
from agno.document.reader.csv_reader import CSVReader
|
||||
from agno.document.reader.pdf_reader import PDFReader
|
||||
from agno.document.reader.text_reader import TextReader
|
||||
|
||||
#db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai"
|
||||
|
||||
api_key = os.getenv("API_KEY")
|
||||
embedding_model = os.getenv("EMBEDDING_MODEL")
|
||||
embedding_baseUrl = os.getenv("EMBEDDING_BASE_URL")
|
||||
model_baseUrl = os.getenv("MODEL_BASE_URL")
|
||||
|
||||
cwd = Path(__file__).parent.resolve()
|
||||
tmp = cwd.joinpath("tmp")
|
||||
if not tmp.exists():
|
||||
tmp.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
work_context = ""
|
||||
|
||||
def get_sofeware_work_context() -> str:
|
||||
"""返回当前用户使用软件时所处环境."""
|
||||
global work_context
|
||||
return work_context
|
||||
|
||||
def set_sofeware_work_context(context : str):
|
||||
global work_context
|
||||
work_context = context
|
||||
|
||||
def get_reader(file_type: str):
|
||||
"""Return appropriate reader based on file type."""
|
||||
readers = {
|
||||
"pdf": PDFReader(),
|
||||
"csv": CSVReader(),
|
||||
"txt": TextReader(),
|
||||
"md": TextReader(),
|
||||
"json": JSONReader(),
|
||||
}
|
||||
return readers.get(file_type.lower(), None)
|
||||
|
||||
def get_model_by_provider(provider: str, model_name: str):
|
||||
"""根据提供商获取对应的模型实例"""
|
||||
if provider == "openai":
|
||||
model = OpenAIChat(id=model_name, base_url=model_baseUrl, api_key=api_key)
|
||||
model.role_map = {
|
||||
"system": "system",
|
||||
"user": "user",
|
||||
"assistant": "assistant",
|
||||
"tool": "tool",
|
||||
"model": "assistant",
|
||||
}
|
||||
return model
|
||||
# elif provider == "google":
|
||||
# return Gemini(id=model_name)
|
||||
# elif provider == "anthropic":
|
||||
# return Claude(id=model_name)
|
||||
# elif provider == "groq":
|
||||
# return Groq(id=model_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model provider: {provider}")
|
||||
|
||||
def initialize_memory(model) -> AgentMemory:
|
||||
"""初始化并返回配置好的AgentMemory实例"""
|
||||
return AgentMemory(
|
||||
db=SqliteMemoryDb(
|
||||
table_name="agent_memory",
|
||||
db_file=os.getenv("MEMORY_DB_FILE", "tmp/agent_memory.db"),
|
||||
), # 在Sqlite中持久化记忆
|
||||
classifier=MemoryClassifier(model=model),
|
||||
summarizer=MemorySummarizer(model=model),
|
||||
manager=MemoryManager(model=model),
|
||||
create_user_memories=True, # 存储用户偏好
|
||||
#create_session_summary=True, # 存储对话摘要
|
||||
)
|
||||
|
||||
def initialize_vector_db() -> LanceDb:
|
||||
"""初始化并返回配置好的LanceDb实例"""
|
||||
return LanceDb(
|
||||
table_name="recipes",
|
||||
uri=os.getenv("VECTOR_DB_PATH", "tmp/lancedb"),
|
||||
search_type=SearchType.hybrid,
|
||||
embedder=OpenAIEmbedder(id=embedding_model, base_url=embedding_baseUrl, api_key=api_key)
|
||||
)
|
||||
|
||||
def initialize_knowledge_base() -> AgentKnowledge:
|
||||
"""初始化并返回配置好的AgentKnowledge实例"""
|
||||
return AgentKnowledge(
|
||||
vector_db=initialize_vector_db(),
|
||||
num_documents=3, # 检索3个最相关的文档
|
||||
chunking_strategy=DocumentChunking(
|
||||
chunk_size=500,
|
||||
overlap=50,
|
||||
), # 固定大小分块
|
||||
optimize_on=1000, # 每1000条数据进行向量优化
|
||||
reader=TextReader(), # 默认文本读取器
|
||||
)
|
||||
|
||||
|
||||
def get_agentic_rag_agent(
|
||||
model_id: str = "openai:gpt-4o",
|
||||
user_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
debug_mode: bool = True,
|
||||
) -> Agent:
|
||||
"""获取一个带有记忆功能的Agentic RAG代理。"""
|
||||
# 解析模型提供商和名称
|
||||
provider, model_name = model_id.split(":")
|
||||
model = get_model_by_provider(provider, model_name)
|
||||
|
||||
# 初始化记忆系统
|
||||
memory = initialize_memory(model)
|
||||
|
||||
# 初始化知识库
|
||||
knowledge_base = initialize_knowledge_base()
|
||||
|
||||
description="""
|
||||
你是一个智能助手,专门为[博微配网计价通D3软件]提供使用支持。你的任务是帮助用户理解和使用这个复杂的配电网工程造价软件系统。
|
||||
|
||||
软件特点
|
||||
1.多页面架构:软件由多个功能页面组成
|
||||
2.复杂控件布局:每个页面包含多种控件(如列表控件、TAB控件、按钮等)
|
||||
3.业务对象丰富:涉及"取费表"、"项目划分"、"工程量"等多种业务对象
|
||||
4.操作多样:支持"添加"、"修改"、"删除"、"导入"、"导出"等多种操作
|
||||
"""
|
||||
|
||||
instructions="""
|
||||
1. 理解用户问题
|
||||
用户正在使用软件过程中遇到问题,向您请求帮助
|
||||
用户所处环境如下:
|
||||
{sofeware_work_context}
|
||||
只从用户问题识别中提到的业务对象(如"如何设置取费费率"→"取费表")
|
||||
只从用户问题识别业务对象的属性字段(如"如何设置取费费率"→"费率")
|
||||
只从用户问题识别用户想要执行的操作(如"如何设置取费费率"→"设置")
|
||||
判断问题类型(功能入口、操作步骤、错误处理等)
|
||||
2. 改写问题
|
||||
将用户问题改写为包含以下要素的标准查询:
|
||||
[问题类型] : [操作类型] + [业务对象] + [属性]
|
||||
[属性]只有用户问题中明确包含才改写,否则为未知。
|
||||
[问题类型]、[操作类型]、[业务对象]为必须输入,如果缺少任一个都需追问用户补全才能进入下一步。
|
||||
示例:
|
||||
原始问题:"如何设置取费费率?"
|
||||
改写后:"操作步骤 : 设置 - 取费 - 费率"
|
||||
3.搜索知识库
|
||||
必须始终使用工具 search_knowledge_base 来搜索知识库
|
||||
在回应前彻底分析所有返回的文档
|
||||
如果返回多个文档,需连贯地综合信息
|
||||
4. 上下文管理:
|
||||
使用工具 get_chat_history 保持对话连续性
|
||||
相关时引用之前的交互
|
||||
记录用户偏好和之前的澄清
|
||||
5. 结果呈现要求
|
||||
以 makedown 格式输出,注意换行和排版
|
||||
避免使用'根据我的知识'或'取决于信息'等模糊表述
|
||||
7. 特殊情况处理
|
||||
如果问题不明确,可以反问请求澄清
|
||||
如果知识库搜索无结果,则直接明确回复不知道
|
||||
对于错误提示,先解释含义再直接回复无法解决
|
||||
"""
|
||||
|
||||
# 创建代理
|
||||
agentic_rag_agent: Agent = Agent(
|
||||
name="博微软件AI助手",
|
||||
session_id=session_id, # 跟踪会话ID以实现持久对话
|
||||
user_id=user_id,
|
||||
model=model,
|
||||
storage=JsonStorage(dir_path=os.getenv("SESSION_STORAGE_PATH", "tmp/agent_sessions_json")), # 持久化会话数据
|
||||
memory=memory, # 为代理添加记忆功能
|
||||
knowledge=knowledge_base, # 添加知识库
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
context={"sofeware_work_context": get_sofeware_work_context},
|
||||
add_context=True,
|
||||
search_knowledge=True, # 此设置赋予模型搜索知识库信息的工具
|
||||
read_chat_history=True, # 此设置赋予模型获取聊天历史的工具
|
||||
#tools=[DuckDuckGoTools()],
|
||||
markdown=True, # 此设置告诉模型以markdown格式格式化消息
|
||||
# add_chat_history_to_messages=True,
|
||||
show_tool_calls=True,
|
||||
add_history_to_messages=True, # 将聊天历史添加到消息中
|
||||
add_datetime_to_instructions=True,
|
||||
add_name_to_instructions=True,
|
||||
debug_mode=debug_mode,
|
||||
read_tool_call_history=True,
|
||||
num_history_responses=3,
|
||||
save_response_to_file=str(tmp.joinpath("{message}.md")),
|
||||
)
|
||||
|
||||
return agentic_rag_agent
|
||||
@@ -0,0 +1,160 @@
|
||||
from dotenv import load_dotenv
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
import threading
|
||||
import nest_asyncio
|
||||
|
||||
from agentic_rag import get_agentic_rag_agent
|
||||
from agno.utils.log import logger
|
||||
from ui import (
|
||||
initialize_ui,
|
||||
show_header,
|
||||
get_modul_option,
|
||||
show_tabs,
|
||||
)
|
||||
from utils import (
|
||||
add_message,
|
||||
session_selector_widget,
|
||||
)
|
||||
import streamlit as st
|
||||
from extra_streamlit_components import CookieManager
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
def initialize_agent(model_id: str):
|
||||
"""Initialize or retrieve the Agentic RAG."""
|
||||
lock.acquire()
|
||||
try:
|
||||
if (
|
||||
not "agentic_rag_agent" in st.session_state
|
||||
or st.session_state.get("agentic_rag_agent") is None
|
||||
or st.session_state.get("current_model") != model_id
|
||||
):
|
||||
logger.info(f"---*--- Creating {model_id} Agent ---*---")
|
||||
agent = get_agentic_rag_agent(
|
||||
model_id=model_id,
|
||||
session_id=st.session_state.get("agentic_rag_agent_session_id"),
|
||||
)
|
||||
st.session_state["agentic_rag_agent"] = agent
|
||||
st.session_state["current_model"] = model_id
|
||||
else:
|
||||
agent = st.session_state.get("agentic_rag_agent")
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
# Load Agent Session
|
||||
try:
|
||||
st.session_state["agentic_rag_agent_session_id"] = agent.load_session()
|
||||
except Exception:
|
||||
st.warning("无法创建Agent会话,请确认数据库是否在运行?")
|
||||
return agent
|
||||
|
||||
def main():
|
||||
initialize_ui()
|
||||
#st.write("")
|
||||
|
||||
# 仅在首次运行时初始化 CookieManager
|
||||
if "cookie_manager" not in st.session_state:
|
||||
st.session_state.cookie_manager = CookieManager()
|
||||
|
||||
# 获取实例(后续直接使用缓存)
|
||||
cookie_manager = st.session_state.cookie_manager
|
||||
|
||||
# 检查并设置 session_id
|
||||
if (not 'agentic_rag_agent_session_id' in st.session_state
|
||||
and cookie_manager.get(cookie='agentic_rag_agent_session_id') is not None):
|
||||
st.session_state["agentic_rag_agent_session_id"] = cookie_manager.get(cookie='agentic_rag_agent_session_id')
|
||||
|
||||
model_id = get_modul_option(st.session_state["model_id"] if "model_id" in st.session_state else 0)
|
||||
#model_id = show_model_selector()
|
||||
# Initialize Agent
|
||||
agentic_rag_agent = initialize_agent(model_id)
|
||||
|
||||
if 'agentic_rag_agent_session_id' in st.session_state:
|
||||
cookie_manager.set('agentic_rag_agent_session_id', st.session_state.get("agentic_rag_agent_session_id"))
|
||||
|
||||
# Load runs from memory
|
||||
agent_runs = agentic_rag_agent.memory.runs
|
||||
if len(agent_runs) > 0:
|
||||
logger.debug("加载历史记录")
|
||||
st.session_state["messages"] = []
|
||||
for _run in agent_runs:
|
||||
if _run.message is not None:
|
||||
add_message(_run.message.role, _run.message.content)
|
||||
if _run.response is not None:
|
||||
add_message("assistant", _run.response.content, _run.response.tools)
|
||||
else:
|
||||
logger.debug("没有找到历史记录")
|
||||
st.session_state["messages"] = []
|
||||
|
||||
chatContainer = st.sidebar.container()
|
||||
lastMsgContainer = st.sidebar.container()
|
||||
|
||||
# Chat input
|
||||
if prompt := st.sidebar.chat_input("👋 问我任何问题!"):
|
||||
add_message("user", prompt)
|
||||
|
||||
# Display UI
|
||||
show_header()
|
||||
show_tabs()
|
||||
|
||||
#show_chat_history(agentic_rag_agent)
|
||||
|
||||
with chatContainer:
|
||||
for message in st.session_state["messages"]:
|
||||
if message["role"] in ["user", "assistant"]:
|
||||
_content = message["content"]
|
||||
if _content is not None:
|
||||
with st.chat_message(message["role"]):
|
||||
#if "tool_calls" in message and message["tool_calls"]:
|
||||
# display_tool_calls(st.empty(), message["tool_calls"])
|
||||
st.markdown(_content)
|
||||
|
||||
with lastMsgContainer:
|
||||
last_message = (st.session_state["messages"][-1] if st.session_state["messages"] else None)
|
||||
if last_message and last_message.get("role") == "user":
|
||||
question = last_message["content"]
|
||||
#with st.chat_message("user"):
|
||||
# st.markdown(question)
|
||||
with st.chat_message("assistant"):
|
||||
# Create container for tool calls
|
||||
tool_calls_container = st.empty()
|
||||
resp_container = st.empty()
|
||||
with st.spinner("🤔 思考中..."):
|
||||
response = ""
|
||||
try:
|
||||
# Run the agent and stream the response
|
||||
run_response = agentic_rag_agent.run(question, stream=True)
|
||||
for _resp_chunk in run_response:
|
||||
# Display tool calls if available
|
||||
#if _resp_chunk.tools and len(_resp_chunk.tools) > 0:
|
||||
# display_tool_calls(tool_calls_container, _resp_chunk.tools)
|
||||
|
||||
# Display response
|
||||
if _resp_chunk.content is not None:
|
||||
response += _resp_chunk.content
|
||||
resp_container.markdown(response)
|
||||
|
||||
add_message("assistant", response, agentic_rag_agent.run_response.tools)
|
||||
except Exception as e:
|
||||
error_message = f"对不起, 发生错误: {str(e)}"
|
||||
add_message("assistant", error_message)
|
||||
st.error(error_message)
|
||||
|
||||
|
||||
####################################################################
|
||||
# Session selector
|
||||
####################################################################
|
||||
session_selector_widget(agentic_rag_agent, model_id)
|
||||
#rename_session_widget(agentic_rag_agent)
|
||||
|
||||
####################################################################
|
||||
# About section
|
||||
####################################################################
|
||||
#about_widget()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,121 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, override
|
||||
|
||||
import httpx
|
||||
from agno.document import Document
|
||||
from agno.reranker.base import Reranker
|
||||
from agno.utils.log import logger
|
||||
import requests
|
||||
from openai import OpenAIError, Omit
|
||||
|
||||
|
||||
class CustomReranker(Reranker):
|
||||
model: str = "BAAI/bge-reranker-v2-m3"
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[Union[str, httpx.URL]] = None
|
||||
top_n: Optional[int] = None
|
||||
return_documents: Optional[bool] = None
|
||||
max_chunks_per_doc: Optional[int] = 1024
|
||||
overlap_tokens: Optional[int] = 80
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
model: str | None = None,
|
||||
top_n: int | None = None,
|
||||
return_documents: bool | None = None,
|
||||
max_chunks_per_doc:bool | None = None,
|
||||
overlap_tokens: int | None = None,
|
||||
):
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise OpenAIError(
|
||||
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
|
||||
)
|
||||
self.api_key = api_key
|
||||
|
||||
if base_url is None:
|
||||
base_url = os.environ.get("OPENAI_BASE_URL")
|
||||
if base_url is None:
|
||||
base_url = f"https://api.openai.com/v1"
|
||||
self.base_url = base_url
|
||||
self.model = model or self.model
|
||||
self.return_documents = return_documents or self.return_documents
|
||||
self.top_n = top_n or self.top_n
|
||||
self.overlap_tokens = overlap_tokens or self.overlap_tokens
|
||||
self.max_chunks_per_doc = max_chunks_per_doc or self.max_chunks_per_doc
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
@override
|
||||
def auth_headers(self) -> dict[str, str]:
|
||||
api_key = self.api_key
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
@property
|
||||
@override
|
||||
def default_headers(self) -> dict[str, str | Omit]:
|
||||
return {
|
||||
**super().default_headers,
|
||||
"X-Stainless-Async": "false",
|
||||
"OpenAI-Organization": self.organization if self.organization is not None else Omit(),
|
||||
"OpenAI-Project": self.project if self.project is not None else Omit(),
|
||||
**self._custom_headers,
|
||||
}
|
||||
def _rerank(self, query: str, documents: List[Document]) -> List[Document]:
|
||||
# Validate input documents and top_n
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
top_n = self.top_n
|
||||
if top_n and not (0 < top_n):
|
||||
logger.warning(f"top_n should be a positive integer, got {self.top_n}, setting top_n to None")
|
||||
top_n = None
|
||||
|
||||
compressed_docs: list[Document] = []
|
||||
_docs = [doc.content for doc in documents]
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"documents": _docs,
|
||||
"top_n": top_n,
|
||||
"return_documents": self.return_documents,
|
||||
"max_chunks_per_doc": self.max_chunks_per_doc,
|
||||
"overlap_tokens": self.overlap_tokens,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/v1/rerank"
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
print(response.text)
|
||||
#response = self.client.rerank(query=query, documents=_docs, model=self.model)
|
||||
for r in response.results:
|
||||
doc = documents[r.index]
|
||||
doc.reranking_score = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
# Order by relevance score
|
||||
compressed_docs.sort(
|
||||
key=lambda x: x.reranking_score if x.reranking_score is not None else float("-inf"),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Limit to top_n if specified
|
||||
if top_n:
|
||||
compressed_docs = compressed_docs[:top_n]
|
||||
|
||||
return compressed_docs
|
||||
|
||||
def rerank(self, query: str, documents: List[Document]) -> List[Document]:
|
||||
try:
|
||||
return self._rerank(query=query, documents=documents)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reranking documents: {e}. Returning original documents")
|
||||
return documents
|
||||
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
############################################################################
|
||||
# Generate requirements.txt from requirements.in
|
||||
############################################################################
|
||||
|
||||
echo "Generating requirements.txt"
|
||||
|
||||
CURR_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
UV_CUSTOM_COMPILE_COMMAND="./generate_requirements.sh" \
|
||||
uv pip compile ${CURR_DIR}/requirements.in --no-cache --upgrade -o ${CURR_DIR}/requirements.txt
|
||||
@@ -0,0 +1,39 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from agno.document import Document
|
||||
from agno.utils.log import logger
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from agentic_rag import initialize_knowledge_base, get_reader
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
def main():
|
||||
print("Hello from agno-agentic-rag!")
|
||||
# 从.env加载知识库来源目录并初始化知识库
|
||||
load_knowledge = os.getenv("LOAD_KNOWLEDGE", "false").lower() == "true"
|
||||
knowledge_source_dir = os.getenv("KNOWLEDGE_SOURCE_DIR")
|
||||
if load_knowledge and knowledge_source_dir and os.path.exists(knowledge_source_dir):
|
||||
# 初始化知识库
|
||||
knowledge_base = initialize_knowledge_base()
|
||||
|
||||
logger.info(f"加载知识库: {knowledge_source_dir}")
|
||||
for root, _, files in os.walk(knowledge_source_dir):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
file_ext = os.path.splitext(file)[1][1:] # 获取文件扩展名
|
||||
reader = get_reader(file_ext)
|
||||
if reader:
|
||||
try:
|
||||
filePath = Path(file_path)
|
||||
docs: List[Document] = reader.read(filePath)
|
||||
knowledge_base.load_documents(docs, upsert=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法加载文档 {file_path}: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,35 @@
|
||||
[project]
|
||||
name = "agno-agentic-rag"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"agno>=1.2.8",
|
||||
"distro>=1.9.0",
|
||||
"lxml>=5.3.1",
|
||||
"lancedb>=0.21.2",
|
||||
"nest-asyncio>=1.6.0",
|
||||
"streamlit>=1.44.1",
|
||||
"openai",
|
||||
"extra-streamlit-components>=0.1.71",
|
||||
"sqlalchemy>=2.0.38",
|
||||
"websockets>=14.2",
|
||||
"tqdm>=4.67.1",
|
||||
"google-auth>=2.38.0",
|
||||
"anthropic>=0.45.2",
|
||||
"primp>=0.12.1",
|
||||
"groq>=0.18.0",
|
||||
"aiofiles",
|
||||
"pypdf",
|
||||
"beautifulsoup4",
|
||||
"tantivy>=0.22.2",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
url = "https://pypi.python.org/simple/"
|
||||
default = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
#url = "https://pypi.python.org/simple/"
|
||||
url = "https://mirrors.aliyun.com/pypi/simple"
|
||||
default = false
|
||||
@@ -0,0 +1,9 @@
|
||||
agno
|
||||
anthropic
|
||||
duckduckgo_search
|
||||
google-genai
|
||||
groq
|
||||
nest_asyncio
|
||||
openai
|
||||
sqlalchemy
|
||||
streamlit
|
||||
@@ -0,0 +1,223 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# ./generate_requirements.sh
|
||||
agno==1.1.1
|
||||
# via -r cookbook/examples/apps/agentic_rag/requirements.in
|
||||
altair==5.5.0
|
||||
# via streamlit
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anthropic>=0.45.2
|
||||
# via -r cookbook/examples/apps/agentic_rag/requirements.in
|
||||
anyio==4.8.0
|
||||
# via
|
||||
# anthropic
|
||||
# groq
|
||||
# httpx
|
||||
# openai
|
||||
attrs==25.1.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
blinker==1.9.0
|
||||
# via streamlit
|
||||
cachetools==5.5.1
|
||||
# via
|
||||
# google-auth
|
||||
# streamlit
|
||||
certifi==2025.1.31
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
charset-normalizer==3.4.1
|
||||
# via requests
|
||||
click==8.1.8
|
||||
# via
|
||||
# duckduckgo-search
|
||||
# streamlit
|
||||
# typer
|
||||
distro>=1.9.0
|
||||
# via
|
||||
# anthropic
|
||||
# groq
|
||||
# openai
|
||||
docstring-parser==0.16
|
||||
# via agno
|
||||
duckduckgo-search>=7.3.2
|
||||
# via -r cookbook/examples/apps/agentic_rag/requirements.in
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.44
|
||||
# via
|
||||
# agno
|
||||
# streamlit
|
||||
google-auth>=2.38.0
|
||||
# via google-genai
|
||||
google-genai>=1.2.0
|
||||
# via -r cookbook/examples/apps/agentic_rag/requirements.in
|
||||
groq>=0.18.0
|
||||
# via -r cookbook/examples/apps/agentic_rag/requirements.in
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
httpcore==1.0.7
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# agno
|
||||
# anthropic
|
||||
# groq
|
||||
# openai
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
jinja2==3.1.5
|
||||
# via
|
||||
# altair
|
||||
# pydeck
|
||||
jiter>=0.8.2
|
||||
# via
|
||||
# anthropic
|
||||
# openai
|
||||
jsonschema==4.23.0
|
||||
# via altair
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
lxml>=5.3.1
|
||||
# via duckduckgo-search
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.2
|
||||
# via jinja2
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
narwhals==1.26.0
|
||||
# via altair
|
||||
nest-asyncio==1.6.0
|
||||
# via -r cookbook/examples/apps/agentic_rag/requirements.in
|
||||
numpy==2.2.3
|
||||
# via
|
||||
# pandas
|
||||
# pydeck
|
||||
# streamlit
|
||||
openai>=1.63.0
|
||||
# via -r cookbook/examples/apps/agentic_rag/requirements.in
|
||||
packaging==24.2
|
||||
# via
|
||||
# altair
|
||||
# streamlit
|
||||
pandas==2.2.3
|
||||
# via streamlit
|
||||
pillow==11.1.0
|
||||
# via streamlit
|
||||
primp>=0.12.1
|
||||
# via duckduckgo-search
|
||||
protobuf==5.29.3
|
||||
# via streamlit
|
||||
pyarrow==19.0.0
|
||||
# via streamlit
|
||||
pyasn1>=0.6.1
|
||||
# via
|
||||
# pyasn1-modules
|
||||
# rsa
|
||||
pyasn1-modules>=0.4.1
|
||||
# via google-auth
|
||||
pydantic==2.10.6
|
||||
# via
|
||||
# agno
|
||||
# anthropic
|
||||
# google-genai
|
||||
# groq
|
||||
# openai
|
||||
# pydantic-settings
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.7.1
|
||||
# via agno
|
||||
pydeck==0.9.1
|
||||
# via streamlit
|
||||
pygments==2.19.1
|
||||
# via rich
|
||||
python-dateutil==2.9.0.post0
|
||||
# via pandas
|
||||
python-dotenv==1.0.1
|
||||
# via
|
||||
# agno
|
||||
# pydantic-settings
|
||||
python-multipart==0.0.20
|
||||
# via agno
|
||||
pytz==2025.1
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
# via agno
|
||||
referencing==0.36.2
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
requests==2.32.3
|
||||
# via
|
||||
# google-genai
|
||||
# streamlit
|
||||
rich==13.9.4
|
||||
# via
|
||||
# agno
|
||||
# streamlit
|
||||
# typer
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
rsa>=4.9
|
||||
# via google-auth
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anthropic
|
||||
# anyio
|
||||
# groq
|
||||
# openai
|
||||
sqlalchemy>=2.0.38
|
||||
# via -r cookbook/examples/apps/agentic_rag/requirements.in
|
||||
streamlit==1.42.0
|
||||
# via -r cookbook/examples/apps/agentic_rag/requirements.in
|
||||
extra-streamlit-components>=0.1.71
|
||||
# via -r cookbook/examples/apps/agentic_rag/requirements.in
|
||||
tenacity==9.0.0
|
||||
# via streamlit
|
||||
toml==0.10.2
|
||||
# via streamlit
|
||||
tomli==2.2.1
|
||||
# via agno
|
||||
tornado==6.4.2
|
||||
# via streamlit
|
||||
tqdm>=4.67.1
|
||||
# via openai
|
||||
typer==0.15.1
|
||||
# via agno
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# agno
|
||||
# altair
|
||||
# anthropic
|
||||
# anyio
|
||||
# google-genai
|
||||
# groq
|
||||
# openai
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# sqlalchemy
|
||||
# streamlit
|
||||
# typer
|
||||
tzdata==2025.1
|
||||
# via pandas
|
||||
urllib3==2.3.0
|
||||
# via requests
|
||||
websockets>=14.2
|
||||
# via google-genai
|
||||
|
After Width: | Height: | Size: 62 KiB |
|
After Width: | Height: | Size: 126 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 73 KiB |
|
After Width: | Height: | Size: 174 KiB |
|
After Width: | Height: | Size: 83 KiB |
|
After Width: | Height: | Size: 52 KiB |
|
After Width: | Height: | Size: 64 KiB |
@@ -0,0 +1,219 @@
|
||||
from typing import List
|
||||
|
||||
import streamlit as st
|
||||
from agno.agent import Agent
|
||||
from agno.utils.log import logger
|
||||
|
||||
from agentic_rag import work_context, set_sofeware_work_context
|
||||
from utils import (
|
||||
CUSTOM_CSS,
|
||||
add_message,
|
||||
export_chat_history,
|
||||
)
|
||||
|
||||
|
||||
def initialize_ui():
|
||||
"""Initialize Streamlit UI configuration"""
|
||||
st.set_page_config(
|
||||
page_title="智能检索增强生成",
|
||||
page_icon="💎",
|
||||
layout="wide",
|
||||
initial_sidebar_state="expanded",
|
||||
)
|
||||
st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
|
||||
|
||||
def show_header():
|
||||
"""Display application header"""
|
||||
st.markdown("<h1 class='main-title'>博微配网计价通D3软件 </h1>", unsafe_allow_html=True)
|
||||
#st.markdown(
|
||||
# "<p class='subtitle'>由Agno驱动的智能研究助手</p>",
|
||||
# unsafe_allow_html=True,
|
||||
#)
|
||||
|
||||
def set_current_page(page : str):
|
||||
set_sofeware_work_context({
|
||||
"软件": "博微配网计价通D3软件",
|
||||
"工程文件": "广州配网造价工程",
|
||||
"已打开页面": ["工程信息", "取费设置", "组合件", "工程量", "材机分析", "工程费用", "报表输出"],
|
||||
"当前页面": page,
|
||||
})
|
||||
|
||||
def show_tabs():
|
||||
"""Display tabs in main content area"""
|
||||
tabs = [
|
||||
{"name": "主页", "icon": "💬", "image": "static/images/主页.png"},
|
||||
{"name": "工程信息", "icon": "📚", "image": "static/images/工程信息.png"},
|
||||
{"name": "取费设置", "icon": "⚙️", "image": "static/images/取费设置.png"},
|
||||
{"name": "组合件", "icon": "⚙️", "image": "static/images/组合件.png"},
|
||||
{"name": "工程量", "icon": "⚙️", "image": "static/images/工程量.png"},
|
||||
{"name": "材机分析", "icon": "⚙️", "image": "static/images/材机分析.png"},
|
||||
{"name": "工程费用", "icon": "⚙️", "image": "static/images/工程费用.png"},
|
||||
{"name": "报表输出", "icon": "⚙️", "image": "static/images/报表输出.png"},
|
||||
]
|
||||
|
||||
tabNames = ["主页", "工程信息", "取费设置", "组合件", "工程量", "材机分析", "工程费用", "报表输出"]
|
||||
selected_tab = st.radio("页面导航:", options=tabNames, index=0, horizontal=True)
|
||||
if selected_tab in tabNames:
|
||||
set_current_page(selected_tab)
|
||||
st.image(f"static/images/{selected_tab}.png")
|
||||
#tab_objects = st.tabs([f"{tab['name']}" for tab in tabs])
|
||||
|
||||
# with tab_objects[0]:
|
||||
# st.image(tabs[0]["image"])
|
||||
# with tab_objects[1]:
|
||||
# st.image(tabs[1]["image"])
|
||||
# with tab_objects[2]:
|
||||
# st.image(tabs[2]["image"])
|
||||
# with tab_objects[3]:
|
||||
# st.image(tabs[3]["image"])
|
||||
# with tab_objects[4]:
|
||||
# st.image(tabs[4]["image"])
|
||||
# with tab_objects[5]:
|
||||
# st.image(tabs[5]["image"])
|
||||
# with tab_objects[6]:
|
||||
# st.image(tabs[6]["image"])
|
||||
# with tab_objects[7]:
|
||||
# st.image(tabs[7]["image"])
|
||||
|
||||
# for index, value in enumerate(tabs):
|
||||
# with tab_objects[index]:
|
||||
# st.image(value["image"])
|
||||
|
||||
# for index, value in enumerate(tabs):
|
||||
# if tabs[index]["name"] == selected_tab:
|
||||
# set_current_page(value["name"])
|
||||
|
||||
model_options = {
|
||||
"Qwen2.5-72B": "openai:Qwen/Qwen2.5-72B-Instruct",
|
||||
"o3-mini": "openai:o3-mini",
|
||||
"gpt-4o": "openai:gpt-4o",
|
||||
"gemini-2.0-flash-exp": "google:gemini-2.0-flash-exp",
|
||||
"claude-3-5-sonnet": "anthropic:claude-3-5-sonnet-20241022",
|
||||
"llama-3.3-70b": "groq:llama-3.3-70b-versatile",
|
||||
}
|
||||
|
||||
def get_modul_option(id: int = 0) -> str:
|
||||
"""Return the selected module option"""
|
||||
return model_options[list(model_options.keys())[id]]
|
||||
|
||||
def show_model_selector(index: int = 0) -> str:
|
||||
"""Display model selection dialog"""
|
||||
selected_model = st.selectbox(
|
||||
"选择模型",
|
||||
options=list(model_options.keys()),
|
||||
index=index,
|
||||
key="model_selector",
|
||||
)
|
||||
session.set("model_index", selected_model)
|
||||
return model_options[selected_model]
|
||||
|
||||
def show_dialog_components(agent: Agent):
|
||||
"""Display all sidebar components in a dialog"""
|
||||
with st.expander("⚙️ 设置", expanded=False):
|
||||
st.markdown("#### 📚 文档管理")
|
||||
input_url = st.text_input("添加URL到知识库")
|
||||
if (
|
||||
input_url and not prompt and not st.session_state.knowledge_base_initialized
|
||||
): # Only load if KB not initialized
|
||||
if input_url not in st.session_state.loaded_urls:
|
||||
alert = st.sidebar.info("Processing URLs...", icon="ℹ️")
|
||||
if input_url.lower().endswith(".pdf"):
|
||||
try:
|
||||
# Download PDF to temporary file
|
||||
response = requests.get(input_url, stream=True, verify=False)
|
||||
response.raise_for_status()
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".pdf", delete=False
|
||||
) as tmp_file:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
tmp_file.write(chunk)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
reader = PDFReader()
|
||||
docs: List[Document] = reader.read(tmp_path)
|
||||
|
||||
# Clean up temporary file
|
||||
os.unlink(tmp_path)
|
||||
except Exception as e:
|
||||
st.sidebar.error(f"Error processing PDF: {str(e)}")
|
||||
docs = []
|
||||
else:
|
||||
scraper = WebsiteReader(max_links=2, max_depth=1)
|
||||
docs: List[Document] = scraper.read(input_url)
|
||||
|
||||
if docs:
|
||||
agentic_rag_agent.knowledge.load_documents(docs, upsert=True)
|
||||
st.session_state.loaded_urls.add(input_url)
|
||||
st.sidebar.success("URL已添加到知识库")
|
||||
else:
|
||||
st.sidebar.error("无法处理提供的URL")
|
||||
alert.empty()
|
||||
else:
|
||||
st.sidebar.info("URL已加载到知识库")
|
||||
|
||||
# 修正缩进,使其与上下文一致
|
||||
uploaded_file = st.sidebar.file_uploader(
|
||||
"添加文档(.pdf,.csv,.json,.md或.txt)", key="file_upload"
|
||||
)
|
||||
if (
|
||||
uploaded_file and not prompt and not st.session_state.knowledge_base_initialized
|
||||
): # Only load if KB not initialized
|
||||
file_identifier = f"{uploaded_file.name}_{uploaded_file.size}"
|
||||
if file_identifier not in st.session_state.loaded_files:
|
||||
alert = st.sidebar.info("正在处理文档...", icon="ℹ️")
|
||||
file_type = uploaded_file.name.split(".")[-1].lower()
|
||||
reader = get_reader(file_type)
|
||||
if reader:
|
||||
docs = reader.read(uploaded_file)
|
||||
agentic_rag_agent.knowledge.load_documents(docs, upsert=True)
|
||||
st.session_state.loaded_files.add(file_identifier)
|
||||
st.sidebar.success(f"{uploaded_file.name}已添加到知识库")
|
||||
st.session_state.knowledge_base_initialized = True
|
||||
alert.empty()
|
||||
else:
|
||||
st.sidebar.info(f"{uploaded_file.name}已加载到知识库")
|
||||
|
||||
if st.sidebar.button("清空知识库"):
|
||||
agentic_rag_agent.knowledge.vector_db.delete()
|
||||
st.session_state.loaded_urls.clear()
|
||||
st.session_state.loaded_files.clear()
|
||||
st.session_state.knowledge_base_initialized = False # Reset initialization flag
|
||||
st.sidebar.success("知识库已清空")
|
||||
|
||||
def show_sample_questions():
|
||||
"""Display sample questions section"""
|
||||
st.markdown("#### ❓ 示例问题")
|
||||
if st.button("📝 总结"):
|
||||
add_message(
|
||||
"user",
|
||||
"你能总结一下当前知识库中的内容吗(使用`search_knowledge_base`工具)?",
|
||||
)
|
||||
|
||||
def show_utility_buttons():
|
||||
"""Display utility buttons section"""
|
||||
st.markdown("#### 🛠️ 工具")
|
||||
col1, col2 = st.columns([1, 1])
|
||||
with col1:
|
||||
if st.button("🔄 新对话", use_container_width=True):
|
||||
restart_agent()
|
||||
with col2:
|
||||
if st.download_button(
|
||||
"💾 导出会话",
|
||||
export_chat_history(),
|
||||
file_name="rag_chat_history.md",
|
||||
mime="text/markdown",
|
||||
use_container_width=True,
|
||||
):
|
||||
st.sidebar.success("会话历史已导出!")
|
||||
|
||||
def show_chat_history(agent: Agent):
|
||||
pass
|
||||
|
||||
def restart_agent():
|
||||
"""Reset the agent and clear chat history"""
|
||||
logger.debug("---*--- Restarting agent ---*---")
|
||||
st.session_state["agentic_rag_agent"] = None
|
||||
st.session_state["agentic_rag_agent_session_id"] = None
|
||||
st.session_state["messages"] = []
|
||||
st.rerun()
|
||||
@@ -0,0 +1,238 @@
|
||||
from dotenv import load_dotenv
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypeVar
|
||||
|
||||
import streamlit as st
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
|
||||
from agentic_rag import get_agentic_rag_agent
|
||||
from agno.agent import Agent
|
||||
from agno.utils.log import logger
|
||||
|
||||
|
||||
def add_message(
|
||||
role: str, content: str, tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
) -> None:
|
||||
"""Safely add a message to the session state"""
|
||||
if "messages" not in st.session_state or not isinstance(
|
||||
st.session_state["messages"], list
|
||||
):
|
||||
st.session_state["messages"] = []
|
||||
st.session_state["messages"].append(
|
||||
{"role": role, "content": content, "tool_calls": tool_calls}
|
||||
)
|
||||
|
||||
|
||||
def export_chat_history():
|
||||
"""Export chat history as markdown"""
|
||||
if "messages" in st.session_state:
|
||||
chat_text = "# Auto RAG Agent - 会话历史\n\n"
|
||||
for msg in st.session_state["messages"]:
|
||||
role = "🤖 Assistant" if msg["role"] == "agent" else "👤 User"
|
||||
chat_text += f"### {role}\n{msg['content']}\n\n"
|
||||
if msg.get("tool_calls"):
|
||||
chat_text += "#### 工具调用:\n"
|
||||
for tool in msg["tool_calls"]:
|
||||
if isinstance(tool, dict):
|
||||
tool_name = tool.get("name", "未知工具")
|
||||
else:
|
||||
tool_name = getattr(tool, "name", "未知工具")
|
||||
chat_text += f"- {tool_name}\n"
|
||||
return chat_text
|
||||
return ""
|
||||
|
||||
|
||||
def display_tool_calls(tool_calls_container, tools):
|
||||
"""Display tool calls in a streamlit container with expandable sections.
|
||||
|
||||
Args:
|
||||
tool_calls_container: Streamlit container to display the tool calls
|
||||
tools: List of tool call dictionaries containing name, args, content, and metrics
|
||||
"""
|
||||
with tool_calls_container.container():
|
||||
for tool_call in tools:
|
||||
_tool_name = tool_call.get("tool_name")
|
||||
_tool_args = tool_call.get("tool_args")
|
||||
_content = tool_call.get("content")
|
||||
_metrics = tool_call.get("metrics")
|
||||
|
||||
with st.sidebar.expander(f"🛠️ {_tool_name.replace('_', ' ').title()}", expanded=False):
|
||||
if isinstance(_tool_args, dict) and "query" in _tool_args:
|
||||
st.sidebar.code(_tool_args["query"], language="sql")
|
||||
|
||||
if _tool_args and _tool_args != {"query": None}:
|
||||
st.sidebar.markdown("**参数:**")
|
||||
st.sidebar.json(_tool_args)
|
||||
|
||||
if _content:
|
||||
st.sidebar.markdown("**结果:**")
|
||||
try:
|
||||
st.sidebar.json(_content)
|
||||
except Exception as e:
|
||||
st.sidebar.markdown(_content)
|
||||
|
||||
if _metrics:
|
||||
st.sidebar.markdown("**指标:**")
|
||||
try:
|
||||
st.sidebar.json(_metrics)
|
||||
except Exception as e:
|
||||
st.sidebar.markdown(_metrics)
|
||||
|
||||
def rename_session_widget(agent: Agent) -> None:
|
||||
"""Rename the current session of the agent and save to storage"""
|
||||
|
||||
container = st.sidebar.container()
|
||||
|
||||
# Initialize session_edit_mode if needed
|
||||
if "session_edit_mode" not in st.session_state:
|
||||
st.session_state.session_edit_mode = False
|
||||
|
||||
if st.sidebar.button("✎ 重命名会话"):
|
||||
st.session_state.session_edit_mode = True
|
||||
st.rerun()
|
||||
|
||||
if st.session_state.session_edit_mode:
|
||||
new_session_name = st.sidebar.text_input(
|
||||
"输入新名称:",
|
||||
value=agent.session_name,
|
||||
key="session_name_input",
|
||||
)
|
||||
if st.sidebar.button("保存", type="primary"):
|
||||
if new_session_name:
|
||||
agent.rename_session(new_session_name)
|
||||
st.session_state.session_edit_mode = False
|
||||
st.rerun()
|
||||
|
||||
|
||||
def session_selector_widget(agent: Agent, model_id: str) -> None:
|
||||
"""Display a session selector in the sidebar"""
|
||||
|
||||
if agent.storage:
|
||||
agent_sessions = agent.storage.get_all_sessions()
|
||||
# Get session names if available, otherwise use IDs
|
||||
session_options = []
|
||||
for session in agent_sessions:
|
||||
session_id = session.session_id
|
||||
session_name = (
|
||||
session.session_data.get("session_name", None)
|
||||
if session.session_data
|
||||
else None
|
||||
)
|
||||
display_name = session_name if session_name else session_id
|
||||
session_options.append({"id": session_id, "display": display_name})
|
||||
|
||||
# Display session selector
|
||||
#selected_session = st.sidebar.selectbox(
|
||||
# "会话",
|
||||
# options=[s["display"] for s in session_options],
|
||||
# key="session_selector",
|
||||
#)
|
||||
# Find the selected session ID
|
||||
#selected_session_id = next(
|
||||
# s["id"] for s in session_options if s["display"] == selected_session
|
||||
#)
|
||||
if len(session_options) > 0:
|
||||
selected_session_id = session_options[0]["id"]
|
||||
|
||||
if st.session_state["agentic_rag_agent_session_id"] != selected_session_id:
|
||||
logger.info(
|
||||
f"---*--- Loading {model_id} run: {selected_session_id} ---*---"
|
||||
)
|
||||
st.session_state["agentic_rag_agent"] = get_agentic_rag_agent(
|
||||
model_id=model_id,
|
||||
session_id=selected_session_id,
|
||||
)
|
||||
st.rerun()
|
||||
|
||||
|
||||
def about_widget() -> None:
|
||||
"""Display an about section in the sidebar"""
|
||||
st.sidebar.markdown("---")
|
||||
st.sidebar.markdown("### ℹ️ 关于")
|
||||
st.sidebar.markdown("""
|
||||
本智能检索增强生成助手帮助您使用自然语言查询分析文档和网页内容。
|
||||
|
||||
构建技术:
|
||||
- 🚀 Agno
|
||||
- 💫 Streamlit
|
||||
""")
|
||||
|
||||
|
||||
CUSTOM_CSS = """
|
||||
<style>
|
||||
.sidebar .sidebar-content {
|
||||
width: 200px;
|
||||
margin-left: auto;
|
||||
margin-right: 0;
|
||||
}
|
||||
.main {
|
||||
margin-left: 200px;
|
||||
margin-right: 0;
|
||||
}
|
||||
/* Main Styles */
|
||||
.main-title {
|
||||
text-align: center;
|
||||
background: linear-gradient(45deg, #FF4B2B, #FF416C);
|
||||
-webkit-background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
font-size: 3em;
|
||||
font-weight: bold;
|
||||
padding: 1em 0;
|
||||
}
|
||||
.subtitle {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
margin-bottom: 2em;
|
||||
}
|
||||
.stButton button {
|
||||
width: 100%;
|
||||
border-radius: 20px;
|
||||
margin: 0.2em 0;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
.stButton button:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(0,0,0,0.1);
|
||||
}
|
||||
.chat-container {
|
||||
border-radius: 15px;
|
||||
padding: 1em;
|
||||
margin: 1em 0;
|
||||
background-color: #f5f5f5;
|
||||
}
|
||||
.tool-result {
|
||||
background-color: #f8f9fa;
|
||||
border-radius: 10px;
|
||||
padding: 1em;
|
||||
margin: 1em 0;
|
||||
border-left: 4px solid #3B82F6;
|
||||
}
|
||||
.status-message {
|
||||
padding: 1em;
|
||||
border-radius: 10px;
|
||||
margin: 1em 0;
|
||||
}
|
||||
.success-message {
|
||||
background-color: #d4edda;
|
||||
color: #155724;
|
||||
}
|
||||
.error-message {
|
||||
background-color: #f8d7da;
|
||||
color: #721c24;
|
||||
}
|
||||
/* Dark mode adjustments */
|
||||
@media (prefers-color-scheme: dark) {
|
||||
.chat-container {
|
||||
background-color: #2b2b2b;
|
||||
}
|
||||
.tool-result {
|
||||
background-color: #1e1e1e;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
"""
|
||||