81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
from .graph_query import KnowledgeGraphQuerier
|
|
from retriever import GraphRetriever
|
|
from .generator import ResponseGenerator
|
|
from typing import Dict, Any, Optional
|
|
|
|
class GraphRAG:
|
|
def __init__(self, neo4j_uri="bolt://10.1.6.34:7687", neo4j_auth=("neo4j", "password")):
|
|
"""
|
|
初始化GraphRAG系统
|
|
|
|
参数:
|
|
neo4j_uri: Neo4j数据库URI
|
|
neo4j_auth: Neo4j认证信息
|
|
"""
|
|
# 初始化知识图谱查询器
|
|
self.graph_querier = KnowledgeGraphQuerier(uri=neo4j_uri, auth=neo4j_auth)
|
|
|
|
# 初始化检索器 - 使用自定义的embedding模型
|
|
self.retriever = GraphRetriever(self.graph_querier)
|
|
|
|
# 初始化生成器 - 使用自定义的llm模型
|
|
self.generator = ResponseGenerator()
|
|
|
|
def process_query(self, query: str, top_k: int = 5, slots: Dict[str, Any] = None, nlu_data: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
"""
|
|
处理用户查询
|
|
|
|
参数:
|
|
query: 用户查询
|
|
top_k: 检索的结果数量
|
|
slots: 槽位信息,用于检索
|
|
nlu_data: 意图识别和槽位提取的结果
|
|
|
|
返回:
|
|
包含检索结果和生成回答的字典
|
|
"""
|
|
# 1. 检索相关信息
|
|
retrieved_info = self.retriever.retrieve(query, top_k=top_k, slots=slots)
|
|
|
|
# 2. 生成回答
|
|
response = self.generator.generate_response(query, retrieved_info, nlu_data)
|
|
|
|
# 3. 返回结果
|
|
return {
|
|
"query": query,
|
|
"nlu_data": nlu_data,
|
|
"retrieved_info": retrieved_info,
|
|
"response": response
|
|
}
|
|
|
|
def close(self):
|
|
"""关闭资源"""
|
|
self.graph_querier.close()
|
|
|
|
# # 示例用法
|
|
# if __name__ == "__main__":
|
|
# # 初始化GraphRAG系统
|
|
# rag = GraphRAG(
|
|
# neo4j_uri="bolt://10.1.6.34:7687",
|
|
# neo4j_auth=("neo4j", "neo4j"),
|
|
# embedding_model="shibing624/text2vec-base-chinese",
|
|
# llm_api_url="http://localhost:8000/v1/chat/completions" # 根据您的LLM API调整
|
|
# )
|
|
|
|
# try:
|
|
# # 处理查询
|
|
# query = "配网D3软件的工程量计算功能是什么?"
|
|
# result = rag.process_query(query)
|
|
|
|
# # 打印结果
|
|
# print(f"查询: {result['query']}")
|
|
# print("\n检索到的信息:")
|
|
# for info in result['retrieved_info']:
|
|
# print(f"- {info['text']} (相似度: {info.get('similarity', 'N/A')})")
|
|
|
|
# print("\n生成的回答:")
|
|
# print(result['response'])
|
|
|
|
# finally:
|
|
# # 关闭资源
|
|
# rag.close() |