Files
GraphRAG/graph/graph_rag.py
T
2025-03-31 17:28:23 +08:00

81 lines
2.7 KiB
Python

from .graph_query import KnowledgeGraphQuerier
from retriever import GraphRetriever
from .generator import ResponseGenerator
from typing import Dict, Any, Optional
class GraphRAG:
def __init__(self, neo4j_uri="bolt://10.1.6.34:7687", neo4j_auth=("neo4j", "password")):
"""
初始化GraphRAG系统
参数:
neo4j_uri: Neo4j数据库URI
neo4j_auth: Neo4j认证信息
"""
# 初始化知识图谱查询器
self.graph_querier = KnowledgeGraphQuerier(uri=neo4j_uri, auth=neo4j_auth)
# 初始化检索器 - 使用自定义的embedding模型
self.retriever = GraphRetriever(self.graph_querier)
# 初始化生成器 - 使用自定义的llm模型
self.generator = ResponseGenerator()
def process_query(self, query: str, top_k: int = 5, slots: Dict[str, Any] = None, nlu_data: Dict[str, Any] = None) -> Dict[str, Any]:
"""
处理用户查询
参数:
query: 用户查询
top_k: 检索的结果数量
slots: 槽位信息,用于检索
nlu_data: 意图识别和槽位提取的结果
返回:
包含检索结果和生成回答的字典
"""
# 1. 检索相关信息
retrieved_info = self.retriever.retrieve(query, top_k=top_k, slots=slots)
# 2. 生成回答
response = self.generator.generate_response(query, retrieved_info, nlu_data)
# 3. 返回结果
return {
"query": query,
"nlu_data": nlu_data,
"retrieved_info": retrieved_info,
"response": response
}
def close(self):
"""关闭资源"""
self.graph_querier.close()
# # 示例用法
# if __name__ == "__main__":
# # 初始化GraphRAG系统
# rag = GraphRAG(
# neo4j_uri="bolt://10.1.6.34:7687",
# neo4j_auth=("neo4j", "neo4j"),
# embedding_model="shibing624/text2vec-base-chinese",
# llm_api_url="http://localhost:8000/v1/chat/completions" # 根据您的LLM API调整
# )
# try:
# # 处理查询
# query = "配网D3软件的工程量计算功能是什么?"
# result = rag.process_query(query)
# # 打印结果
# print(f"查询: {result['query']}")
# print("\n检索到的信息:")
# for info in result['retrieved_info']:
# print(f"- {info['text']} (相似度: {info.get('similarity', 'N/A')})")
# print("\n生成的回答:")
# print(result['response'])
# finally:
# # 关闭资源
# rag.close()