Files
KG_generation/KG_vectorization.py
T
chentianrui 9609bb67b4 上传文件
2025-08-01 15:31:56 +08:00

193 lines
6.5 KiB
Python

"""
实现知识图谱节点name属性向量化
"""
from llm import Embedding
from neo4j import GraphDatabase
import time
# 初始化 Embedding 模型
embeddings = Embedding(url="http://172.20.0.145:9995/v1", api_key="xxx", model_name="bge-m3")
# Neo4j 连接信息
url = "bolt://172.20.0.145:7687"
username = "neo4j"
password = "password"
driver = GraphDatabase.driver(url, auth=(username, password))
# 定义需要处理的节点类型 - 使用正确的Neo4j标签格式
node_labels = [
"ProjectDivisionSet",
"ProjectDivisionTree",
"ProjectDivisionItem",
# 对于复合标签,在查询时使用多标签形式
{"labels": ["ProjectQuantity", "Quota"], "display": "ProjectQuantity+Quota"},
{"labels": ["ProjectQuantity", "MainMaterial"], "display": "ProjectQuantity+MainMaterial"},
{"labels": ["ProjectQuantity", "Equipment"], "display": "ProjectQuantity+Equipment"},
"MaterialOrEquipment",
"FeeTableTemplateSet",
"FeeTableTemplateItem",
"FeeCollection",
"FeeScheduleSet",
"FeeScheduleItem",
"Fee",
]
def create_vector_index():
"""为每个标签创建向量索引"""
with driver.session() as session:
dimension = 1024 # BGE-M3模型的向量维度
for i, label_info in enumerate(node_labels):
try:
# 处理不同格式的标签
if isinstance(label_info, dict): # 复合标签
labels = label_info["labels"]
display_name = label_info["display"]
# 为每个单独的标签创建索引,而不是尝试创建复合标签的索引
for j, single_label in enumerate(labels):
index_name = f"entity_embedding_index_{i}_{j}"
create_single_index(
session, single_label, index_name, dimension, f"{display_name}的组成标签 {single_label}"
)
else: # 单一标签
index_name = f"entity_embedding_index_{i}"
create_single_index(session, label_info, index_name, dimension, label_info)
except Exception as e:
if isinstance(label_info, dict):
print(f"❌ 创建复合标签 {label_info['display']} 的索引失败: {e}")
else:
print(f"❌ 创建标签 {label_info} 的索引失败: {e}")
def create_single_index(session, label, index_name, dimension, display_name):
"""为单个标签创建向量索引,使用最新的Neo4j语法"""
try:
# 检查索引是否存在
check_index_query = """
SHOW INDEXES
YIELD name
WHERE name = $index_name
RETURN count(*) > 0 AS exists
"""
result = session.run(check_index_query, index_name=index_name)
index_exists = result.single()["exists"]
if not index_exists:
print(f"正在为标签 {display_name} 创建向量索引...")
# 使用最新的Neo4j向量索引语法
create_index_query = f"""
CREATE VECTOR INDEX {index_name}
FOR (n:{label})
ON (n.embedding)
OPTIONS {{indexConfig: {{`vector.dimensions`: {dimension}}}}}
"""
session.run(create_index_query)
print(f"✅ 标签 {display_name} 的向量索引创建成功")
else:
print(f"✅ 标签 {display_name} 的向量索引已存在,跳过创建步骤")
except Exception as e:
print(f"❌ 创建标签 {display_name} 的向量索引失败: {e}")
def generate_and_store_embeddings():
with driver.session() as session:
for label_info in node_labels:
# 处理不同格式的标签
if isinstance(label_info, dict): # 复合标签
labels = label_info["labels"]
display_name = label_info["display"]
# 为Neo4j查询构建多标签模式
label_pattern = ":".join(labels)
else: # 单一标签
label_pattern = label_info
display_name = label_info
print(f"\n🔍 Processing nodes of type: {display_name}")
start_time = time.time()
# 查询该类型的所有节点 elementId 和 name
query = f"""
MATCH (n:{label_pattern})
WHERE n.name IS NOT NULL
RETURN elementId(n) AS id, n.name AS name
"""
result = session.run(query)
count = 0
batch = []
for record in result:
node_id = record["id"]
name = record["name"]
if not name or not name.strip():
continue
try:
vector = embeddings.embed(name)
batch.append({"id": node_id, "vector": vector})
count += 1
# 批量写入,减少数据库交互次数
if len(batch) >= 50:
write_batch(session, batch)
batch = []
except Exception as e:
print(f"❌ Error processing node {node_id}: {e}")
# 写入剩余批次
if batch:
write_batch(session, batch)
duration = time.time() - start_time
print(f"✅ Processed {count} nodes of type {display_name} in {duration:.2f} seconds.")
def write_batch(session, batch):
"""批量写入 embedding 到 Neo4j"""
session.run(
"""
UNWIND $batch AS item
MATCH (n) WHERE elementId(n) = item.id
SET n.embedding = item.vector
""",
batch=batch,
)
def verify_embeddings():
"""验证embedding是否正确存储"""
with driver.session() as session:
verify_query = """
MATCH (n)
WHERE n.embedding IS NOT NULL
RETURN n.name, size(n.embedding) as dimension
LIMIT 5
"""
result = session.run(verify_query)
print("\n验证embedding存储情况:")
records = list(result)
if not records:
print("❌ 未找到任何带有embedding的节点")
else:
for record in records:
print(f"节点: {record['n.name']}, 向量维度: {record['dimension']}")
if __name__ == "__main__":
# 第一步:创建向量索引
create_vector_index()
# 第二步:生成并存储embeddings
generate_and_store_embeddings()
# 第三步:验证结果
verify_embeddings()