首次提交:上传本地文件夹
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from dotenv import load_dotenv
|
||||
from utils.llm import llm
|
||||
from utils.prompt import PROMPTS
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class EntityRelationExtractor:
|
||||
def __init__(self):
|
||||
# 使用从llm.py导入的模型
|
||||
self.llm = llm
|
||||
# 设置分隔符
|
||||
self.tuple_delimiter = "|||"
|
||||
self.record_delimiter = "\n"
|
||||
self.completion_delimiter = ""
|
||||
|
||||
def extract_from_text(self, text, entity_types=None):
|
||||
"""从文本中提取实体和关系"""
|
||||
try:
|
||||
entity_types = PROMPTS["DEFAULT_ENTITY_TYPES"]
|
||||
relationship_types = PROMPTS["DEFAULT_RELATIONSHIP_TYPES"]
|
||||
|
||||
# 构建提示词
|
||||
prompt = PROMPTS["entity_extraction"]
|
||||
prompt = prompt.replace("{tuple_delimiter}", self.tuple_delimiter)
|
||||
prompt = prompt.replace("{record_delimiter}", self.record_delimiter)
|
||||
prompt = prompt.replace("{completion_delimiter}", self.completion_delimiter)
|
||||
|
||||
# 添加实体类型和文本内容
|
||||
entity_types_str = ", ".join(entity_types)
|
||||
relationship_types_str = ", ".join(relationship_types)
|
||||
user_message = (
|
||||
f"实体类型列表: {entity_types_str}\n\n关系类型列表: {relationship_types_str}\n\n文本内容:\n{text}"
|
||||
)
|
||||
|
||||
# 调用LLM
|
||||
full_prompt = f"System: {prompt}\n\nUser: {user_message}"
|
||||
response = self.llm.generate(full_prompt)
|
||||
|
||||
# 解析结果
|
||||
return self._parse_extraction_result(response)
|
||||
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def _parse_extraction_result(self, result):
|
||||
"""解析模型返回的实体和关系结果"""
|
||||
entities = []
|
||||
relations = []
|
||||
keywords = []
|
||||
|
||||
# 按行分割结果
|
||||
lines = result.strip().split(self.record_delimiter)
|
||||
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# 移除可能的括号
|
||||
line = line.strip()
|
||||
if line.startswith("(") and line.endswith(")"):
|
||||
line = line[1:-1]
|
||||
|
||||
# 分割字段
|
||||
parts = line.split(self.tuple_delimiter)
|
||||
|
||||
if len(parts) < 3:
|
||||
continue
|
||||
|
||||
record_type = parts[0].strip('"')
|
||||
|
||||
if record_type == "entity":
|
||||
if len(parts) >= 4:
|
||||
entity = {
|
||||
"name": parts[1].strip('"'),
|
||||
"type": parts[2].strip('"'),
|
||||
"description": parts[3].strip('"'),
|
||||
}
|
||||
entities.append(entity)
|
||||
elif record_type == "relationship":
|
||||
if len(parts) >= 5:
|
||||
relation = {
|
||||
"source": parts[1].strip('"'),
|
||||
"target": parts[2].strip('"'),
|
||||
"description": parts[3].strip('"'),
|
||||
"type": parts[4].strip('"'),
|
||||
"confidence": float(parts[5]) if len(parts) > 5 else 1.0,
|
||||
}
|
||||
relations.append(relation)
|
||||
elif record_type == "content_keywords":
|
||||
keywords = [kw.strip() for kw in parts[1].strip('"').split(",")]
|
||||
|
||||
return {"entities": entities, "relations": relations, "keywords": keywords}
|
||||
|
||||
def extract_from_file(self, file_path, entity_types=None):
|
||||
"""从文件中提取实体和关系"""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
text = file.read()
|
||||
return self.extract_from_text(text, entity_types)
|
||||
except Exception as e:
|
||||
return {"error": f"读取文件时出错: {str(e)}"}
|
||||
|
||||
def extract_from_documents(self, documents, entity_types=None):
|
||||
"""从多个文档中提取实体和关系"""
|
||||
all_entities = {}
|
||||
all_relations = []
|
||||
all_keywords = set()
|
||||
|
||||
for doc in documents:
|
||||
result = self.extract_from_text(doc, entity_types)
|
||||
if "error" in result:
|
||||
continue
|
||||
|
||||
# 合并实体(避免重复)
|
||||
for entity in result.get("entities", []):
|
||||
entity_name = entity["name"]
|
||||
if entity_name not in all_entities:
|
||||
all_entities[entity_name] = entity
|
||||
|
||||
# 添加关系
|
||||
all_relations.extend(result.get("relations", []))
|
||||
|
||||
# 添加关键词
|
||||
all_keywords.update(result.get("keywords", []))
|
||||
|
||||
return {"entities": list(all_entities.values()), "relations": all_relations, "keywords": list(all_keywords)}
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
extractor = EntityRelationExtractor()
|
||||
|
||||
# 测试文本
|
||||
test_text = """
|
||||
|
||||
(电力建设计价通软件) (计价通)导入或清除电子徽标
|
||||
|
||||
# (电力建设计价通软件) (计价通)导入或清除电子徽标
|
||||
|
||||
## 使用场景
|
||||
1.单位对打印的报表有要求,需要插入显示企业电子徽标。
|
||||
|
||||
2.报表中有电子徽标,现在想要清除。
|
||||
## 功能入口
|
||||
【报表输出】界面——“导入电子徽标”按钮。
|
||||
|
||||

|
||||
## 操作步骤
|
||||
|
||||
### 导入电子徽标
|
||||
1.左侧勾选需要显示徽标的报表,点击“选择徽标”,选中需要导入的电子徽标(可识别bmp,gif,jpg格式文件),点击“打开”即可导入徽标;
|
||||
|
||||

|
||||
### **清除电子徽标**
|
||||
左侧勾选需要清除徽标的报表,点击“清除徽标”按钮,点击“确定”,即可清除徽标。
|
||||
|
||||

|
||||
|
||||
"""
|
||||
|
||||
# 提取实体和关系
|
||||
print("正在提取实体和关系...")
|
||||
result = extractor.extract_from_text(test_text)
|
||||
|
||||
# 打印结果
|
||||
print("\n提取结果:")
|
||||
if "error" in result:
|
||||
print(f"错误: {result['error']}")
|
||||
else:
|
||||
print(f"发现 {len(result.get('entities', []))} 个实体和 {len(result.get('relations', []))} 个关系")
|
||||
|
||||
# 打印实体
|
||||
print("\n实体:")
|
||||
for entity in result.get("entities", []):
|
||||
print(f"- {entity['name']} (类型: {entity['type']})")
|
||||
print(f" 描述: {entity['description']}")
|
||||
|
||||
# 打印关系
|
||||
print("\n关系:")
|
||||
for relation in result.get("relations", []):
|
||||
print(f"- {relation['source']} --[{relation['type']}]--> {relation['target']}")
|
||||
print(f" 描述: {relation['description']}")
|
||||
print(f" 置信度: {relation['confidence']}")
|
||||
|
||||
# 打印关键词
|
||||
if "keywords" in result and result["keywords"]:
|
||||
print("\n关键词:")
|
||||
print(", ".join(result["keywords"]))
|
||||
@@ -0,0 +1,125 @@
|
||||
from typing import List, Dict, Any
|
||||
# from llm import llm as llm_model
|
||||
from utils.llm import search_llm as llm_model
|
||||
from utils.prompt import RESPONSE_TEMPLATE
|
||||
import time
|
||||
import logging
|
||||
import sys
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
logger = logging.getLogger("ResponseGenerator")
|
||||
|
||||
class ResponseGenerator:
|
||||
def __init__(self):
|
||||
"""初始化响应生成器,使用自定义的llm模型"""
|
||||
logger.info("ResponseGenerator初始化完成")
|
||||
|
||||
def generate_response(self, query: str, retrieved_info: List[Dict[str, Any]], nlu_data: Dict = None) -> str:
|
||||
"""
|
||||
生成响应
|
||||
|
||||
参数:
|
||||
query: 用户查询
|
||||
retrieved_info: 检索到的信息
|
||||
nlu_data: 意图识别和槽位提取的结果
|
||||
|
||||
返回:
|
||||
生成的响应
|
||||
"""
|
||||
# 如果没有检索到信息,返回默认回答
|
||||
if not retrieved_info:
|
||||
logger.warning("没有检索到相关信息,返回默认回答")
|
||||
return "抱歉,我没有找到与您问题相关的信息。请尝试使用其他关键词或更具体的问题。"
|
||||
|
||||
# 构建提示
|
||||
logger.info("开始构建提示")
|
||||
prompt = self._build_prompt(query, retrieved_info, nlu_data)
|
||||
logger.info(f"提示构建完成,长度: {len(prompt)}")
|
||||
|
||||
# 调用自定义LLM模型生成回答
|
||||
logger.info("开始调用LLM模型生成回答")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 添加超时处理
|
||||
response = llm_model.invoke(prompt)
|
||||
logger.info(f"LLM响应完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 检查响应是否为空
|
||||
if not response or not response.strip():
|
||||
logger.error("LLM返回了空响应")
|
||||
return "抱歉,生成回答时出现了问题。请稍后再试。"
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用LLM模型时出错: {str(e)}")
|
||||
return f"抱歉,生成回答时出现了错误: {str(e)}"
|
||||
|
||||
def _build_prompt(self, query: str, retrieved_info: List[Dict[str, Any]], nlu_data: Dict = None) -> str:
|
||||
"""
|
||||
构建提示
|
||||
|
||||
参数:
|
||||
query: 用户查询
|
||||
retrieved_info: 检索到的信息
|
||||
nlu_data: 意图识别和槽位提取的结果
|
||||
|
||||
返回:
|
||||
构建的提示
|
||||
"""
|
||||
context_parts = []
|
||||
|
||||
for info in retrieved_info:
|
||||
node = info.get("node", {})
|
||||
text = info.get("text", "")
|
||||
|
||||
node_type = node.get("labels", [""])[0] if "labels" in node else ""
|
||||
name = node.get("original_name", "") or node.get("display_name", "")
|
||||
description = node.get("描述", "")
|
||||
|
||||
if node_type == "功能名称":
|
||||
context_parts.append(f"功能: {name}\n描述: {description}")
|
||||
else:
|
||||
context_parts.append(f"{node_type}: {name}")
|
||||
|
||||
context = "\n\n".join(context_parts)
|
||||
|
||||
# 添加意图和槽位信息
|
||||
intent_info = ""
|
||||
if nlu_data and '意图' in nlu_data:
|
||||
intent_data = nlu_data['意图']
|
||||
|
||||
# 获取一级意图
|
||||
first_intent = "未知"
|
||||
if '一级意图' in intent_data and 'name' in intent_data['一级意图']:
|
||||
first_intent = intent_data['一级意图']['name']
|
||||
|
||||
# 获取二级意图
|
||||
second_intent = "未知"
|
||||
if '二级意图' in intent_data and 'name' in intent_data['二级意图']:
|
||||
second_intent = intent_data['二级意图']['name']
|
||||
|
||||
intent_info = f"用户一级意图: {first_intent}\n用户二级意图: {second_intent}\n"
|
||||
|
||||
# 添加槽位信息
|
||||
if '二级意图' in intent_data and 'slot_lv2' in intent_data['二级意图']:
|
||||
slots = intent_data['二级意图']['slot_lv2']
|
||||
if slots:
|
||||
intent_info += "提取的槽位:\n"
|
||||
for slot_name, slot_value in slots.items():
|
||||
if slot_value and slot_value != '未知':
|
||||
intent_info += f"- {slot_name}: {slot_value}\n"
|
||||
|
||||
prompt = RESPONSE_TEMPLATE.format(
|
||||
intent_info=intent_info,
|
||||
context=context,
|
||||
query=query
|
||||
)
|
||||
|
||||
return prompt
|
||||
@@ -0,0 +1,280 @@
|
||||
import pandas as pd
|
||||
from neo4j import GraphDatabase
|
||||
import networkx as nx
|
||||
from pyvis.network import Network
|
||||
import networkx as nx
|
||||
|
||||
|
||||
|
||||
URI = "bolt://10.1.6.34:7687"
|
||||
AUTH = ("neo4j", "password")
|
||||
|
||||
def create_knowledge_graph(excel_file):
|
||||
|
||||
df = pd.read_excel(excel_file, engine="openpyxl")
|
||||
|
||||
driver = GraphDatabase.driver(URI, auth=AUTH)
|
||||
|
||||
def clear_database(tx):
|
||||
"""清空Neo4j数据库中的所有节点和关系"""
|
||||
tx.run("MATCH (n) DETACH DELETE n")
|
||||
print("数据库已清空!")
|
||||
|
||||
def add_node(tx, label, name, properties=None, parent_path=""):
|
||||
"""添加节点,使用父路径+名称作为唯一标识"""
|
||||
if properties is None:
|
||||
properties = {}
|
||||
|
||||
# 添加显示名称
|
||||
properties["display_name"] = name
|
||||
|
||||
# 保存原始名称
|
||||
properties["original_name"] = name
|
||||
|
||||
# 保存父路径信息
|
||||
if parent_path:
|
||||
properties["parent_path"] = parent_path
|
||||
|
||||
# 将标签也作为节点的属性存储
|
||||
properties["node_type"] = label
|
||||
|
||||
# 创建唯一标识符:父路径+名称
|
||||
unique_id = f"{parent_path}|{name}" if parent_path else name
|
||||
|
||||
# 使用唯一标识符作为name属性
|
||||
query = f"MERGE (n:{label} {{name: $unique_id}}) SET n += $properties"
|
||||
tx.run(query, unique_id=unique_id, properties=properties)
|
||||
|
||||
return unique_id
|
||||
|
||||
def add_relationship(tx, start_label, start_name, end_label, end_name, rel_type="包含"):
|
||||
"""添加关系,使用唯一标识符"""
|
||||
query = (
|
||||
f"MATCH (a:{start_label} {{name: $start_name}}), "
|
||||
f"(b:{end_label} {{name: $end_name}}) "
|
||||
f"MERGE (a)-[r:{rel_type}]->(b)"
|
||||
)
|
||||
tx.run(query, start_name=start_name, end_name=end_name)
|
||||
|
||||
try:
|
||||
with driver.session() as session:
|
||||
# 清空数据库
|
||||
session.write_transaction(clear_database)
|
||||
|
||||
# 验证数据库是否已清空
|
||||
def verify_empty_database(tx):
|
||||
result = tx.run("MATCH (n) RETURN count(n) as count")
|
||||
count = result.single()["count"]
|
||||
print(f"验证结果: 数据库中剩余节点数量 = {count}")
|
||||
return count
|
||||
|
||||
node_count = session.write_transaction(verify_empty_database)
|
||||
if node_count == 0:
|
||||
print("数据库清空成功!")
|
||||
else:
|
||||
print(f"警告: 数据库清空不完全,仍有 {node_count} 个节点!")
|
||||
|
||||
# 创建根节点:配网D3软件
|
||||
root_name = session.write_transaction(add_node, "软件", "配网D3软件", {})
|
||||
|
||||
# 批量处理数据
|
||||
batch_size = 100 # 可以根据数据量调整批次大小
|
||||
for i in range(0, len(df), batch_size):
|
||||
batch_df = df.iloc[i:i+batch_size]
|
||||
|
||||
def process_batch(tx):
|
||||
for idx, row in batch_df.iterrows():
|
||||
# 提取数据并确保是字符串类型
|
||||
module_1 = str(row.get("一级模块", "")) if not pd.isna(row.get("一级模块", "")) else ""
|
||||
module_2 = str(row.get("二级模块", "")) if not pd.isna(row.get("二级模块", "")) else ""
|
||||
module_3 = str(row.get("三级模块", "")) if not pd.isna(row.get("三级模块", "")) else ""
|
||||
module_4 = str(row.get("四级模块", "")) if not pd.isna(row.get("四级模块", "")) else ""
|
||||
function_name = str(row.get("功能名称", "")) if not pd.isna(row.get("功能名称", "")) else ""
|
||||
description = str(row.get("功能说明", "")) if not pd.isna(row.get("功能说明", "")) else ""
|
||||
|
||||
# 现在可以安全地调用strip()
|
||||
module_1 = module_1.strip()
|
||||
module_2 = module_2.strip()
|
||||
module_3 = module_3.strip()
|
||||
module_4 = module_4.strip()
|
||||
function_name = function_name.strip()
|
||||
description = description.strip()
|
||||
|
||||
# 记录最后一个非空模块的标签和名称,用于连接功能名称节点
|
||||
last_module_label = "软件"
|
||||
last_module_name = root_name
|
||||
|
||||
# 构建路径,用于创建唯一标识符
|
||||
path = "配网D3软件"
|
||||
|
||||
# 添加页面节点(一级模块)
|
||||
if module_1:
|
||||
module_1_name = add_node(tx, "页面", module_1, {}, path)
|
||||
add_relationship(tx, last_module_label, last_module_name, "页面", module_1_name)
|
||||
last_module_label = "页面"
|
||||
last_module_name = module_1_name
|
||||
path = f"{path}|{module_1}"
|
||||
|
||||
# 添加页面节点(二级模块)
|
||||
if module_2:
|
||||
module_2_name = add_node(tx, "TAB控件", module_2, {}, path)
|
||||
add_relationship(tx, last_module_label, last_module_name, "TAB控件", module_2_name)
|
||||
last_module_label = "TAB控件"
|
||||
last_module_name = module_2_name
|
||||
path = f"{path}|{module_2}"
|
||||
|
||||
# 添加TAB控件节点(三级模块)
|
||||
if module_3:
|
||||
module_3_name = add_node(tx, "分组控件", module_3, {}, path)
|
||||
add_relationship(tx, last_module_label, last_module_name, "分组控件", module_3_name)
|
||||
last_module_label = "分组控件"
|
||||
last_module_name = module_3_name
|
||||
path = f"{path}|{module_3}"
|
||||
|
||||
# 添加分组控件节点(四级模块)
|
||||
if module_4:
|
||||
module_4_name = add_node(tx, "属性控件", module_4, {}, path)
|
||||
add_relationship(tx, last_module_label, last_module_name, "属性控件", module_4_name)
|
||||
last_module_label = "属性控件"
|
||||
last_module_name = module_4_name
|
||||
path = f"{path}|{module_4}"
|
||||
|
||||
# 添加功能名称节点 - 使用路径确保唯一性
|
||||
if function_name:
|
||||
function_name_unique = add_node(
|
||||
tx, "功能名称", function_name, {"描述": description}, path
|
||||
)
|
||||
add_relationship(tx, last_module_label, last_module_name, "功能名称", function_name_unique)
|
||||
|
||||
session.write_transaction(process_batch)
|
||||
print(f"已处理 {min(i+batch_size, len(df))}/{len(df)} 条记录")
|
||||
|
||||
print("知识图谱构建完成!")
|
||||
except Exception as e:
|
||||
print(f"构建知识图谱时发生错误: {e}")
|
||||
finally:
|
||||
driver.close()
|
||||
|
||||
|
||||
def export_graph_to_html(output_file="knowledge_graph.html", limit=1000):
|
||||
"""
|
||||
将Neo4j中的知识图谱导出为交互式HTML文件
|
||||
|
||||
参数:
|
||||
output_file: 输出的HTML文件路径
|
||||
limit: 限制节点数量,防止图过大导致浏览器卡顿
|
||||
|
||||
返回:
|
||||
bool: 是否成功导出
|
||||
"""
|
||||
try:
|
||||
# 连接Neo4j数据库
|
||||
driver = GraphDatabase.driver(URI, auth=AUTH)
|
||||
|
||||
# 创建一个NetworkX图
|
||||
G = nx.DiGraph()
|
||||
|
||||
with driver.session() as session:
|
||||
# 获取所有节点 - 避免使用已弃用的id()函数
|
||||
nodes_result = session.run(
|
||||
f"MATCH (n) RETURN elementId(n) as id, labels(n) as labels, n.name as name, n.display_name as display_name, n.original_name as original_name LIMIT {limit}"
|
||||
)
|
||||
|
||||
# 节点颜色映射
|
||||
color_map = {
|
||||
"软件": "#FF5733",
|
||||
"页面": "#33FF57",
|
||||
"页面控件": "#57FF33",
|
||||
"TAB控件": "#3357FF",
|
||||
"分组控件": "#FF33A8",
|
||||
"属性控件": "#33FFF5",
|
||||
"功能名称": "#F5FF33"
|
||||
}
|
||||
|
||||
# 添加节点到图中
|
||||
node_ids = [] # 存储所有节点ID用于后续查询
|
||||
for record in nodes_result:
|
||||
node_id = record["id"]
|
||||
node_ids.append(node_id)
|
||||
node_label = record["labels"][0] if record["labels"] else "Unknown"
|
||||
node_name = record["original_name"] or record["name"] # 优先使用原始名称
|
||||
display_name = record["display_name"] if record["display_name"] else node_name
|
||||
|
||||
# 添加到NetworkX图
|
||||
G.add_node(
|
||||
node_id,
|
||||
label=display_name,
|
||||
title=f"{node_label}: {display_name}",
|
||||
color=color_map.get(node_label, "#CCCCCC")
|
||||
)
|
||||
|
||||
# 获取所有关系 - 修复查询语法,不再使用未定义的path变量
|
||||
if node_ids:
|
||||
edges_result = session.run(
|
||||
f"""
|
||||
MATCH (a)-[r]->(b)
|
||||
WHERE elementId(a) IN $node_ids AND elementId(b) IN $node_ids
|
||||
RETURN elementId(a) as source, elementId(b) as target, type(r) as type
|
||||
LIMIT {limit}
|
||||
""",
|
||||
node_ids=node_ids
|
||||
)
|
||||
|
||||
# 添加边到图中
|
||||
for record in edges_result:
|
||||
source = record["source"]
|
||||
target = record["target"]
|
||||
rel_type = record["type"]
|
||||
|
||||
# 添加到NetworkX图
|
||||
G.add_edge(source, target, title=rel_type)
|
||||
|
||||
# 创建Pyvis网络图
|
||||
net = Network(height="800px", width="100%", directed=True, notebook=False)
|
||||
|
||||
# 从NetworkX图转换
|
||||
net.from_nx(G)
|
||||
|
||||
# 设置物理布局选项
|
||||
net.set_options("""
|
||||
{
|
||||
"physics": {
|
||||
"forceAtlas2Based": {
|
||||
"gravitationalConstant": -50,
|
||||
"centralGravity": 0.01,
|
||||
"springLength": 100,
|
||||
"springConstant": 0.1
|
||||
},
|
||||
"maxVelocity": 50,
|
||||
"solver": "forceAtlas2Based",
|
||||
"timestep": 0.35,
|
||||
"stabilization": {
|
||||
"enabled": true,
|
||||
"iterations": 1000
|
||||
}
|
||||
},
|
||||
"interaction": {
|
||||
"navigationButtons": true,
|
||||
"keyboard": true
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
# 保存为HTML文件
|
||||
net.save_graph(output_file)
|
||||
print(f"知识图谱已成功导出为HTML文件: {output_file}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"导出知识图谱时发生错误: {e}")
|
||||
return False
|
||||
finally:
|
||||
if 'driver' in locals():
|
||||
driver.close()
|
||||
|
||||
|
||||
# 创建知识图谱
|
||||
create_knowledge_graph("E:\\文件\\LLM_model\\RAG\\code\\GraphRAG\\data\\博微配网工程计价通D3软件产品功能清单.xlsx")
|
||||
|
||||
# 导出为HTML文件
|
||||
# export_graph_to_html("配网D3软件知识图谱.html")
|
||||
@@ -0,0 +1,197 @@
|
||||
import os
|
||||
from neo4j import GraphDatabase
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
class KnowledgeGraphQuerier:
|
||||
def __init__(self, uri, auth):
|
||||
"""
|
||||
初始化知识图谱查询器
|
||||
|
||||
参数:
|
||||
uri: Neo4j数据库URI
|
||||
auth: Neo4j认证信息,格式为(username, password)
|
||||
"""
|
||||
self.uri = uri
|
||||
self.auth = auth
|
||||
self.driver = GraphDatabase.driver(uri, auth=auth)
|
||||
# 初始化session属性
|
||||
self.session = self.driver.session()
|
||||
print(f"已连接到Neo4j数据库: {uri}")
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
if hasattr(self, 'session') and self.session:
|
||||
self.session.close()
|
||||
if hasattr(self, 'driver') and self.driver:
|
||||
self.driver.close()
|
||||
print("已关闭Neo4j数据库连接")
|
||||
|
||||
def get_node_by_name(self, name: str, node_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
通过名称查询节点
|
||||
|
||||
参数:
|
||||
name: 节点名称
|
||||
node_type: 节点类型(可选)
|
||||
|
||||
返回:
|
||||
节点列表
|
||||
"""
|
||||
with self.driver.session() as session:
|
||||
if node_type:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
WHERE n.original_name = $name OR n.display_name = $name
|
||||
RETURN n
|
||||
"""
|
||||
else:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.original_name = $name OR n.display_name = $name
|
||||
RETURN n
|
||||
"""
|
||||
|
||||
result = session.run(query, name=name)
|
||||
nodes = []
|
||||
for record in result:
|
||||
node = record["n"]
|
||||
nodes.append(dict(node))
|
||||
|
||||
return nodes
|
||||
|
||||
def get_related_nodes(self, node_name: str, relationship_type: Optional[str] = None,
|
||||
max_depth: int = 2) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取与指定节点相关的节点
|
||||
|
||||
参数:
|
||||
node_name: 节点名称
|
||||
relationship_type: 关系类型(可选)
|
||||
max_depth: 最大深度
|
||||
|
||||
返回:
|
||||
相关节点列表
|
||||
"""
|
||||
with self.driver.session() as session:
|
||||
if relationship_type:
|
||||
query = f"""
|
||||
MATCH (n)-[r:{relationship_type}*1..{max_depth}]-(related)
|
||||
WHERE n.original_name = $name OR n.display_name = $name
|
||||
RETURN related
|
||||
"""
|
||||
else:
|
||||
query = f"""
|
||||
MATCH (n)-[r*1..{max_depth}]-(related)
|
||||
WHERE n.original_name = $name OR n.display_name = $name
|
||||
RETURN related
|
||||
"""
|
||||
|
||||
result = session.run(query, name=node_name)
|
||||
related_nodes = []
|
||||
for record in result:
|
||||
node = record["related"]
|
||||
related_nodes.append(dict(node))
|
||||
|
||||
return related_nodes
|
||||
|
||||
def search_by_keyword(self, keyword: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
通过关键词搜索知识图谱
|
||||
|
||||
参数:
|
||||
keyword: 搜索关键词
|
||||
|
||||
返回:
|
||||
匹配的节点列表
|
||||
"""
|
||||
# 使用简化的查询,不包含路径查询
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.name CONTAINS $keyword
|
||||
OR n.display_name CONTAINS $keyword
|
||||
OR n.original_name CONTAINS $keyword
|
||||
OR n.描述 CONTAINS $keyword
|
||||
RETURN n
|
||||
LIMIT 50
|
||||
"""
|
||||
|
||||
result = self.session.run(query, keyword=keyword)
|
||||
nodes = []
|
||||
|
||||
for record in result:
|
||||
node = dict(record["n"])
|
||||
node["labels"] = list(record["n"].labels)
|
||||
nodes.append(node)
|
||||
|
||||
return nodes
|
||||
|
||||
def get_path_between_nodes(self, start_name: str, end_name: str) -> List[List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取两个节点之间的路径
|
||||
|
||||
参数:
|
||||
start_name: 起始节点名称
|
||||
end_name: 结束节点名称
|
||||
|
||||
返回:
|
||||
路径列表,每个路径是节点字典的列表
|
||||
"""
|
||||
with self.driver.session() as session:
|
||||
query = """
|
||||
MATCH p = shortestPath((a)-[*]-(b))
|
||||
WHERE (a.original_name = $start_name OR a.display_name = $start_name)
|
||||
AND (b.original_name = $end_name OR b.display_name = $end_name)
|
||||
RETURN p
|
||||
"""
|
||||
|
||||
result = session.run(query, start_name=start_name, end_name=end_name)
|
||||
paths = []
|
||||
for record in result:
|
||||
path = record["p"]
|
||||
path_nodes = []
|
||||
for node in path.nodes:
|
||||
path_nodes.append(dict(node))
|
||||
paths.append(path_nodes)
|
||||
|
||||
return paths
|
||||
|
||||
def get_function_details(self, function_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取功能详情
|
||||
|
||||
参数:
|
||||
function_name: 功能名称
|
||||
|
||||
返回:
|
||||
功能详情字典,包含功能描述和路径
|
||||
"""
|
||||
with self.driver.session() as session:
|
||||
query = """
|
||||
MATCH (n:功能名称 {original_name: $name})
|
||||
RETURN n
|
||||
"""
|
||||
|
||||
result = session.run(query, name=function_name)
|
||||
record = result.single()
|
||||
if record:
|
||||
node = dict(record["n"])
|
||||
|
||||
# 获取功能的完整路径
|
||||
path_query = """
|
||||
MATCH p = (root:软件)-[*]->(func:功能名称 {original_name: $name})
|
||||
RETURN p
|
||||
"""
|
||||
path_result = session.run(path_query, name=function_name)
|
||||
path_record = path_result.single()
|
||||
|
||||
if path_record:
|
||||
path = path_record["p"]
|
||||
node_path = []
|
||||
for path_node in path.nodes:
|
||||
node_path.append(dict(path_node).get("original_name", ""))
|
||||
|
||||
node["path"] = " > ".join(filter(None, node_path))
|
||||
|
||||
return node
|
||||
|
||||
return {}
|
||||
@@ -0,0 +1,81 @@
|
||||
from .graph_query import KnowledgeGraphQuerier
|
||||
from retriever import GraphRetriever
|
||||
from .generator import ResponseGenerator
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
class GraphRAG:
|
||||
def __init__(self, neo4j_uri="bolt://10.1.6.34:7687", neo4j_auth=("neo4j", "password")):
|
||||
"""
|
||||
初始化GraphRAG系统
|
||||
|
||||
参数:
|
||||
neo4j_uri: Neo4j数据库URI
|
||||
neo4j_auth: Neo4j认证信息
|
||||
"""
|
||||
# 初始化知识图谱查询器
|
||||
self.graph_querier = KnowledgeGraphQuerier(uri=neo4j_uri, auth=neo4j_auth)
|
||||
|
||||
# 初始化检索器 - 使用自定义的embedding模型
|
||||
self.retriever = GraphRetriever(self.graph_querier)
|
||||
|
||||
# 初始化生成器 - 使用自定义的llm模型
|
||||
self.generator = ResponseGenerator()
|
||||
|
||||
def process_query(self, query: str, top_k: int = 5, slots: Dict[str, Any] = None, nlu_data: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
处理用户查询
|
||||
|
||||
参数:
|
||||
query: 用户查询
|
||||
top_k: 检索的结果数量
|
||||
slots: 槽位信息,用于检索
|
||||
nlu_data: 意图识别和槽位提取的结果
|
||||
|
||||
返回:
|
||||
包含检索结果和生成回答的字典
|
||||
"""
|
||||
# 1. 检索相关信息
|
||||
retrieved_info = self.retriever.retrieve(query, top_k=top_k, slots=slots)
|
||||
|
||||
# 2. 生成回答
|
||||
response = self.generator.generate_response(query, retrieved_info, nlu_data)
|
||||
|
||||
# 3. 返回结果
|
||||
return {
|
||||
"query": query,
|
||||
"nlu_data": nlu_data,
|
||||
"retrieved_info": retrieved_info,
|
||||
"response": response
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""关闭资源"""
|
||||
self.graph_querier.close()
|
||||
|
||||
# # 示例用法
|
||||
# if __name__ == "__main__":
|
||||
# # 初始化GraphRAG系统
|
||||
# rag = GraphRAG(
|
||||
# neo4j_uri="bolt://10.1.6.34:7687",
|
||||
# neo4j_auth=("neo4j", "neo4j"),
|
||||
# embedding_model="shibing624/text2vec-base-chinese",
|
||||
# llm_api_url="http://localhost:8000/v1/chat/completions" # 根据您的LLM API调整
|
||||
# )
|
||||
|
||||
# try:
|
||||
# # 处理查询
|
||||
# query = "配网D3软件的工程量计算功能是什么?"
|
||||
# result = rag.process_query(query)
|
||||
|
||||
# # 打印结果
|
||||
# print(f"查询: {result['query']}")
|
||||
# print("\n检索到的信息:")
|
||||
# for info in result['retrieved_info']:
|
||||
# print(f"- {info['text']} (相似度: {info.get('similarity', 'N/A')})")
|
||||
|
||||
# print("\n生成的回答:")
|
||||
# print(result['response'])
|
||||
|
||||
# finally:
|
||||
# # 关闭资源
|
||||
# rag.close()
|
||||
Reference in New Issue
Block a user