首次提交:上传本地文件夹
This commit is contained in:
@@ -0,0 +1,81 @@
|
||||
from .graph_query import KnowledgeGraphQuerier
|
||||
from retriever import GraphRetriever
|
||||
from .generator import ResponseGenerator
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
class GraphRAG:
|
||||
def __init__(self, neo4j_uri="bolt://10.1.6.34:7687", neo4j_auth=("neo4j", "password")):
|
||||
"""
|
||||
初始化GraphRAG系统
|
||||
|
||||
参数:
|
||||
neo4j_uri: Neo4j数据库URI
|
||||
neo4j_auth: Neo4j认证信息
|
||||
"""
|
||||
# 初始化知识图谱查询器
|
||||
self.graph_querier = KnowledgeGraphQuerier(uri=neo4j_uri, auth=neo4j_auth)
|
||||
|
||||
# 初始化检索器 - 使用自定义的embedding模型
|
||||
self.retriever = GraphRetriever(self.graph_querier)
|
||||
|
||||
# 初始化生成器 - 使用自定义的llm模型
|
||||
self.generator = ResponseGenerator()
|
||||
|
||||
def process_query(self, query: str, top_k: int = 5, slots: Dict[str, Any] = None, nlu_data: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
处理用户查询
|
||||
|
||||
参数:
|
||||
query: 用户查询
|
||||
top_k: 检索的结果数量
|
||||
slots: 槽位信息,用于检索
|
||||
nlu_data: 意图识别和槽位提取的结果
|
||||
|
||||
返回:
|
||||
包含检索结果和生成回答的字典
|
||||
"""
|
||||
# 1. 检索相关信息
|
||||
retrieved_info = self.retriever.retrieve(query, top_k=top_k, slots=slots)
|
||||
|
||||
# 2. 生成回答
|
||||
response = self.generator.generate_response(query, retrieved_info, nlu_data)
|
||||
|
||||
# 3. 返回结果
|
||||
return {
|
||||
"query": query,
|
||||
"nlu_data": nlu_data,
|
||||
"retrieved_info": retrieved_info,
|
||||
"response": response
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""关闭资源"""
|
||||
self.graph_querier.close()
|
||||
|
||||
# # 示例用法
|
||||
# if __name__ == "__main__":
|
||||
# # 初始化GraphRAG系统
|
||||
# rag = GraphRAG(
|
||||
# neo4j_uri="bolt://10.1.6.34:7687",
|
||||
# neo4j_auth=("neo4j", "neo4j"),
|
||||
# embedding_model="shibing624/text2vec-base-chinese",
|
||||
# llm_api_url="http://localhost:8000/v1/chat/completions" # 根据您的LLM API调整
|
||||
# )
|
||||
|
||||
# try:
|
||||
# # 处理查询
|
||||
# query = "配网D3软件的工程量计算功能是什么?"
|
||||
# result = rag.process_query(query)
|
||||
|
||||
# # 打印结果
|
||||
# print(f"查询: {result['query']}")
|
||||
# print("\n检索到的信息:")
|
||||
# for info in result['retrieved_info']:
|
||||
# print(f"- {info['text']} (相似度: {info.get('similarity', 'N/A')})")
|
||||
|
||||
# print("\n生成的回答:")
|
||||
# print(result['response'])
|
||||
|
||||
# finally:
|
||||
# # 关闭资源
|
||||
# rag.close()
|
||||
Reference in New Issue
Block a user