243 lines
9.2 KiB
Python
243 lines
9.2 KiB
Python
"""🤖 Agentic RAG Agent - 您的AI知识助手!
|
|
|
|
这个高级示例展示了如何构建一个复杂的RAG(检索增强生成)系统,
|
|
利用向量搜索和LLMs从任何知识库中提供深入见解。
|
|
|
|
该代理可以:
|
|
- 处理和理解来自多个来源的文档(PDF、网站、文本文件)
|
|
- 使用向量嵌入构建可搜索的知识库
|
|
- 跨会话维护对话上下文和记忆
|
|
- 为其响应提供相关引用和来源
|
|
- 生成摘要并提取关键见解
|
|
- 回答后续问题和澄清
|
|
|
|
可以尝试的示例查询:
|
|
- "本文档的关键要点是什么?"
|
|
- "你能总结主要论点和支持证据吗?"
|
|
- "有哪些重要地统计数据和发现?"
|
|
- "这与[主题X]有什么关系?"
|
|
- "这个分析有哪些局限性或空白?"
|
|
- "你能更详细地解释[概念X]吗?"
|
|
- "其他来源支持或反驳这些主张吗?"
|
|
|
|
该代理使用:
|
|
- 向量相似性搜索进行相关文档检索
|
|
- 对话记忆用于上下文响应
|
|
- 引用跟踪用于来源归属
|
|
- 动态知识库更新
|
|
|
|
查看README了解如何运行应用程序。
|
|
"""
|
|
from pathlib import Path
|
|
|
|
from agno.document.chunking.document import DocumentChunking
|
|
from agno.models.deepseek import DeepSeek
|
|
from dotenv import load_dotenv
|
|
|
|
# 加载.env文件
|
|
load_dotenv()
|
|
import os
|
|
|
|
from typing import Optional
|
|
|
|
from agno.agent import Agent, AgentMemory
|
|
from agno.embedder.openai import OpenAIEmbedder
|
|
from agno.knowledge import AgentKnowledge
|
|
from agno.memory.classifier import MemoryClassifier
|
|
from agno.memory.db.sqlite import SqliteMemoryDb
|
|
from agno.memory.manager import MemoryManager
|
|
from agno.memory.summarizer import MemorySummarizer
|
|
from agno.models.openai import OpenAIChat
|
|
from agno.storage.json import JsonStorage
|
|
from agno.vectordb.lancedb import LanceDb
|
|
from agno.vectordb.search import SearchType
|
|
from agno.document.reader.json_reader import JSONReader
|
|
from agno.document.reader.csv_reader import CSVReader
|
|
from agno.document.reader.pdf_reader import PDFReader
|
|
from agno.document.reader.text_reader import TextReader
|
|
|
|
#db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai"
|
|
|
|
api_key = os.getenv("API_KEY")
|
|
embedding_model = os.getenv("EMBEDDING_MODEL")
|
|
embedding_baseUrl = os.getenv("EMBEDDING_BASE_URL")
|
|
model_baseUrl = os.getenv("MODEL_BASE_URL")
|
|
|
|
cwd = Path(__file__).parent.resolve()
|
|
tmp = cwd.joinpath("tmp")
|
|
if not tmp.exists():
|
|
tmp.mkdir(exist_ok=True, parents=True)
|
|
|
|
work_context = ""
|
|
|
|
def get_sofeware_work_context() -> str:
|
|
"""返回当前用户使用软件时所处环境."""
|
|
global work_context
|
|
return work_context
|
|
|
|
def set_sofeware_work_context(context : str):
|
|
global work_context
|
|
work_context = context
|
|
|
|
def get_reader(file_type: str):
|
|
"""Return appropriate reader based on file type."""
|
|
readers = {
|
|
"pdf": PDFReader(),
|
|
"csv": CSVReader(),
|
|
"txt": TextReader(),
|
|
"md": TextReader(),
|
|
"json": JSONReader(),
|
|
}
|
|
return readers.get(file_type.lower(), None)
|
|
|
|
def get_model_by_provider(provider: str, model_name: str):
|
|
"""根据提供商获取对应的模型实例"""
|
|
if provider == "openai":
|
|
model = OpenAIChat(id=model_name, base_url=model_baseUrl, api_key=api_key)
|
|
model.role_map = {
|
|
"system": "system",
|
|
"user": "user",
|
|
"assistant": "assistant",
|
|
"tool": "tool",
|
|
"model": "assistant",
|
|
}
|
|
return model
|
|
# elif provider == "google":
|
|
# return Gemini(id=model_name)
|
|
# elif provider == "anthropic":
|
|
# return Claude(id=model_name)
|
|
# elif provider == "groq":
|
|
# return Groq(id=model_name)
|
|
elif provider == "deepseek":
|
|
return DeepSeek(id=model_name, base_url=model_baseUrl, api_key=api_key)
|
|
else:
|
|
raise ValueError(f"Unsupported model provider: {provider}")
|
|
|
|
def initialize_memory(model) -> AgentMemory:
|
|
"""初始化并返回配置好的AgentMemory实例"""
|
|
return AgentMemory(
|
|
db=SqliteMemoryDb(
|
|
table_name="agent_memory",
|
|
db_file=os.getenv("MEMORY_DB_FILE", "tmp/agent_memory.db"),
|
|
), # 在Sqlite中持久化记忆
|
|
classifier=MemoryClassifier(model=model),
|
|
summarizer=MemorySummarizer(model=model),
|
|
manager=MemoryManager(model=model),
|
|
create_user_memories=True, # 存储用户偏好
|
|
#create_session_summary=True, # 存储对话摘要
|
|
)
|
|
|
|
def initialize_vector_db() -> LanceDb:
|
|
"""初始化并返回配置好的LanceDb实例"""
|
|
return LanceDb(
|
|
table_name="recipes",
|
|
uri=os.getenv("VECTOR_DB_PATH", "tmp/lancedb"),
|
|
search_type=SearchType.hybrid,
|
|
embedder=OpenAIEmbedder(id=embedding_model, base_url=embedding_baseUrl, api_key=api_key)
|
|
)
|
|
|
|
def initialize_knowledge_base() -> AgentKnowledge:
|
|
"""初始化并返回配置好的AgentKnowledge实例"""
|
|
return AgentKnowledge(
|
|
vector_db=initialize_vector_db(),
|
|
num_documents=3, # 检索3个最相关的文档
|
|
chunking_strategy=DocumentChunking(
|
|
chunk_size=500,
|
|
overlap=50,
|
|
), # 固定大小分块
|
|
optimize_on=1000, # 每1000条数据进行向量优化
|
|
reader=TextReader(), # 默认文本读取器
|
|
)
|
|
|
|
|
|
def get_agentic_rag_agent(
|
|
model_id: str = "openai:gpt-4o",
|
|
user_id: Optional[str] = None,
|
|
session_id: Optional[str] = None,
|
|
debug_mode: bool = True,
|
|
) -> Agent:
|
|
"""获取一个带有记忆功能的Agentic RAG代理。"""
|
|
# 解析模型提供商和名称
|
|
provider, model_name = model_id.split(":")
|
|
model = get_model_by_provider(provider, model_name)
|
|
|
|
# 初始化记忆系统
|
|
memory = initialize_memory(model)
|
|
|
|
# 初始化知识库
|
|
knowledge_base = initialize_knowledge_base()
|
|
|
|
description="""
|
|
你是一个智能助手,专门为[博微配网计价通D3软件]提供使用支持。你的任务是帮助用户理解和使用这个复杂的配电网工程造价软件系统。
|
|
|
|
软件特点
|
|
1.多页面架构:软件由多个功能页面组成
|
|
2.复杂控件布局:每个页面包含多种控件(如列表控件、TAB控件、按钮等)
|
|
3.业务对象丰富:涉及"取费表"、"项目划分"、"工程量"等多种业务对象
|
|
4.操作多样:支持"添加"、"修改"、"删除"、"导入"、"导出"等多种操作
|
|
"""
|
|
|
|
instructions="""
|
|
1. 理解用户问题
|
|
用户正在使用软件过程中遇到问题,向您请求帮助
|
|
用户所处环境如下:
|
|
{sofeware_work_context}
|
|
只从用户问题识别中提到的业务对象(如"如何设置取费费率"→"取费表")
|
|
只从用户问题识别业务对象的属性字段(如"如何设置取费费率"→"费率")
|
|
只从用户问题识别用户想要执行的操作(如"如何设置取费费率"→"设置")
|
|
判断问题类型(功能入口、操作步骤、错误处理等)
|
|
2. 改写问题
|
|
将用户问题改写为包含以下要素的标准查询:
|
|
[问题类型] : [操作类型] + [业务对象] + [属性]
|
|
[属性]只有用户问题中明确包含才改写,否则为未知。
|
|
[问题类型]、[操作类型]、[业务对象]为必须输入,如果缺少任一个都需追问用户补全才能进入下一步。
|
|
示例:
|
|
原始问题:"如何设置取费费率?"
|
|
改写后:"操作步骤 : 设置 - 取费 - 费率"
|
|
3.搜索知识库
|
|
必须始终使用工具 search_knowledge_base 来搜索知识库
|
|
在回应前彻底分析所有返回的文档
|
|
如果返回多个文档,需连贯地综合信息
|
|
4. 上下文管理:
|
|
使用工具 get_chat_history 保持对话连续性
|
|
相关时引用之前的交互
|
|
记录用户偏好和之前的澄清
|
|
5. 结果呈现要求
|
|
以 makedown 格式输出,注意换行和排版
|
|
避免使用'根据我的知识'或'取决于信息'等模糊表述
|
|
7. 特殊情况处理
|
|
如果问题不明确,可以反问请求澄清
|
|
如果知识库搜索无结果,则直接明确回复不知道
|
|
对于错误提示,先解释含义再直接回复无法解决
|
|
"""
|
|
|
|
# 创建代理
|
|
agentic_rag_agent: Agent = Agent(
|
|
name="博微软件AI助手",
|
|
session_id=session_id, # 跟踪会话ID以实现持久对话
|
|
user_id=user_id,
|
|
model=model,
|
|
storage=JsonStorage(dir_path=os.getenv("SESSION_STORAGE_PATH", "tmp/agent_sessions_json")), # 持久化会话数据
|
|
memory=memory, # 为代理添加记忆功能
|
|
knowledge=knowledge_base, # 添加知识库
|
|
description=description,
|
|
instructions=instructions,
|
|
context={"sofeware_work_context": get_sofeware_work_context},
|
|
add_context=True,
|
|
search_knowledge=True, # 此设置赋予模型搜索知识库信息的工具
|
|
read_chat_history=True, # 此设置赋予模型获取聊天历史的工具
|
|
#tools=[DuckDuckGoTools()],
|
|
markdown=True, # 此设置告诉模型以markdown格式格式化消息
|
|
# add_chat_history_to_messages=True,
|
|
show_tool_calls=True,
|
|
add_history_to_messages=True, # 将聊天历史添加到消息中
|
|
add_datetime_to_instructions=True,
|
|
add_name_to_instructions=True,
|
|
debug_mode=debug_mode,
|
|
read_tool_call_history=True,
|
|
num_history_responses=3,
|
|
save_response_to_file=str(tmp.joinpath("{message}.md")),
|
|
)
|
|
|
|
return agentic_rag_agent
|