from .graph_query import KnowledgeGraphQuerier from retriever import GraphRetriever from .generator import ResponseGenerator from typing import Dict, Any, Optional class GraphRAG: def __init__(self, neo4j_uri="bolt://10.1.6.34:7687", neo4j_auth=("neo4j", "password")): """ 初始化GraphRAG系统 参数: neo4j_uri: Neo4j数据库URI neo4j_auth: Neo4j认证信息 """ # 初始化知识图谱查询器 self.graph_querier = KnowledgeGraphQuerier(uri=neo4j_uri, auth=neo4j_auth) # 初始化检索器 - 使用自定义的embedding模型 self.retriever = GraphRetriever(self.graph_querier) # 初始化生成器 - 使用自定义的llm模型 self.generator = ResponseGenerator() def process_query(self, query: str, top_k: int = 5, slots: Dict[str, Any] = None, nlu_data: Dict[str, Any] = None) -> Dict[str, Any]: """ 处理用户查询 参数: query: 用户查询 top_k: 检索的结果数量 slots: 槽位信息,用于检索 nlu_data: 意图识别和槽位提取的结果 返回: 包含检索结果和生成回答的字典 """ # 1. 检索相关信息 retrieved_info = self.retriever.retrieve(query, top_k=top_k, slots=slots) # 2. 生成回答 response = self.generator.generate_response(query, retrieved_info, nlu_data) # 3. 返回结果 return { "query": query, "nlu_data": nlu_data, "retrieved_info": retrieved_info, "response": response } def close(self): """关闭资源""" self.graph_querier.close() # # 示例用法 # if __name__ == "__main__": # # 初始化GraphRAG系统 # rag = GraphRAG( # neo4j_uri="bolt://10.1.6.34:7687", # neo4j_auth=("neo4j", "neo4j"), # embedding_model="shibing624/text2vec-base-chinese", # llm_api_url="http://localhost:8000/v1/chat/completions" # 根据您的LLM API调整 # ) # try: # # 处理查询 # query = "配网D3软件的工程量计算功能是什么?" # result = rag.process_query(query) # # 打印结果 # print(f"查询: {result['query']}") # print("\n检索到的信息:") # for info in result['retrieved_info']: # print(f"- {info['text']} (相似度: {info.get('similarity', 'N/A')})") # print("\n生成的回答:") # print(result['response']) # finally: # # 关闭资源 # rag.close()