193 lines
6.5 KiB
Python
193 lines
6.5 KiB
Python
"""
|
|
实现知识图谱节点name属性向量化
|
|
"""
|
|
|
|
from llm import Embedding
|
|
from neo4j import GraphDatabase
|
|
import time
|
|
|
|
# 初始化 Embedding 模型
|
|
embeddings = Embedding(url="http://172.20.0.145:9995/v1", api_key="xxx", model_name="bge-m3")
|
|
|
|
# Neo4j 连接信息
|
|
url = "bolt://172.20.0.145:7687"
|
|
username = "neo4j"
|
|
password = "password"
|
|
|
|
driver = GraphDatabase.driver(url, auth=(username, password))
|
|
|
|
# 定义需要处理的节点类型 - 使用正确的Neo4j标签格式
|
|
node_labels = [
|
|
"ProjectDivisionSet",
|
|
"ProjectDivisionTree",
|
|
"ProjectDivisionItem",
|
|
# 对于复合标签,在查询时使用多标签形式
|
|
{"labels": ["ProjectQuantity", "Quota"], "display": "ProjectQuantity+Quota"},
|
|
{"labels": ["ProjectQuantity", "MainMaterial"], "display": "ProjectQuantity+MainMaterial"},
|
|
{"labels": ["ProjectQuantity", "Equipment"], "display": "ProjectQuantity+Equipment"},
|
|
"MaterialOrEquipment",
|
|
"FeeTableTemplateSet",
|
|
"FeeTableTemplateItem",
|
|
"FeeCollection",
|
|
"FeeScheduleSet",
|
|
"FeeScheduleItem",
|
|
"Fee",
|
|
]
|
|
|
|
|
|
def create_vector_index():
|
|
"""为每个标签创建向量索引"""
|
|
with driver.session() as session:
|
|
dimension = 1024 # BGE-M3模型的向量维度
|
|
|
|
for i, label_info in enumerate(node_labels):
|
|
try:
|
|
# 处理不同格式的标签
|
|
if isinstance(label_info, dict): # 复合标签
|
|
labels = label_info["labels"]
|
|
display_name = label_info["display"]
|
|
# 为每个单独的标签创建索引,而不是尝试创建复合标签的索引
|
|
for j, single_label in enumerate(labels):
|
|
index_name = f"entity_embedding_index_{i}_{j}"
|
|
create_single_index(
|
|
session, single_label, index_name, dimension, f"{display_name}的组成标签 {single_label}"
|
|
)
|
|
else: # 单一标签
|
|
index_name = f"entity_embedding_index_{i}"
|
|
create_single_index(session, label_info, index_name, dimension, label_info)
|
|
except Exception as e:
|
|
if isinstance(label_info, dict):
|
|
print(f"❌ 创建复合标签 {label_info['display']} 的索引失败: {e}")
|
|
else:
|
|
print(f"❌ 创建标签 {label_info} 的索引失败: {e}")
|
|
|
|
|
|
def create_single_index(session, label, index_name, dimension, display_name):
|
|
"""为单个标签创建向量索引,使用最新的Neo4j语法"""
|
|
try:
|
|
# 检查索引是否存在
|
|
check_index_query = """
|
|
SHOW INDEXES
|
|
YIELD name
|
|
WHERE name = $index_name
|
|
RETURN count(*) > 0 AS exists
|
|
"""
|
|
|
|
result = session.run(check_index_query, index_name=index_name)
|
|
index_exists = result.single()["exists"]
|
|
|
|
if not index_exists:
|
|
print(f"正在为标签 {display_name} 创建向量索引...")
|
|
|
|
# 使用最新的Neo4j向量索引语法
|
|
create_index_query = f"""
|
|
CREATE VECTOR INDEX {index_name}
|
|
FOR (n:{label})
|
|
ON (n.embedding)
|
|
OPTIONS {{indexConfig: {{`vector.dimensions`: {dimension}}}}}
|
|
"""
|
|
|
|
session.run(create_index_query)
|
|
print(f"✅ 标签 {display_name} 的向量索引创建成功")
|
|
else:
|
|
print(f"✅ 标签 {display_name} 的向量索引已存在,跳过创建步骤")
|
|
except Exception as e:
|
|
print(f"❌ 创建标签 {display_name} 的向量索引失败: {e}")
|
|
|
|
|
|
def generate_and_store_embeddings():
|
|
with driver.session() as session:
|
|
for label_info in node_labels:
|
|
# 处理不同格式的标签
|
|
if isinstance(label_info, dict): # 复合标签
|
|
labels = label_info["labels"]
|
|
display_name = label_info["display"]
|
|
# 为Neo4j查询构建多标签模式
|
|
label_pattern = ":".join(labels)
|
|
else: # 单一标签
|
|
label_pattern = label_info
|
|
display_name = label_info
|
|
|
|
print(f"\n🔍 Processing nodes of type: {display_name}")
|
|
start_time = time.time()
|
|
|
|
# 查询该类型的所有节点 elementId 和 name
|
|
query = f"""
|
|
MATCH (n:{label_pattern})
|
|
WHERE n.name IS NOT NULL
|
|
RETURN elementId(n) AS id, n.name AS name
|
|
"""
|
|
result = session.run(query)
|
|
|
|
count = 0
|
|
batch = []
|
|
|
|
for record in result:
|
|
node_id = record["id"]
|
|
name = record["name"]
|
|
|
|
if not name or not name.strip():
|
|
continue
|
|
|
|
try:
|
|
vector = embeddings.embed(name)
|
|
batch.append({"id": node_id, "vector": vector})
|
|
count += 1
|
|
|
|
# 批量写入,减少数据库交互次数
|
|
if len(batch) >= 50:
|
|
write_batch(session, batch)
|
|
batch = []
|
|
except Exception as e:
|
|
print(f"❌ Error processing node {node_id}: {e}")
|
|
|
|
# 写入剩余批次
|
|
if batch:
|
|
write_batch(session, batch)
|
|
|
|
duration = time.time() - start_time
|
|
print(f"✅ Processed {count} nodes of type {display_name} in {duration:.2f} seconds.")
|
|
|
|
|
|
def write_batch(session, batch):
|
|
"""批量写入 embedding 到 Neo4j"""
|
|
session.run(
|
|
"""
|
|
UNWIND $batch AS item
|
|
MATCH (n) WHERE elementId(n) = item.id
|
|
SET n.embedding = item.vector
|
|
""",
|
|
batch=batch,
|
|
)
|
|
|
|
|
|
def verify_embeddings():
|
|
"""验证embedding是否正确存储"""
|
|
with driver.session() as session:
|
|
verify_query = """
|
|
MATCH (n)
|
|
WHERE n.embedding IS NOT NULL
|
|
RETURN n.name, size(n.embedding) as dimension
|
|
LIMIT 5
|
|
"""
|
|
result = session.run(verify_query)
|
|
print("\n验证embedding存储情况:")
|
|
|
|
records = list(result)
|
|
if not records:
|
|
print("❌ 未找到任何带有embedding的节点")
|
|
else:
|
|
for record in records:
|
|
print(f"节点: {record['n.name']}, 向量维度: {record['dimension']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 第一步:创建向量索引
|
|
create_vector_index()
|
|
|
|
# 第二步:生成并存储embeddings
|
|
generate_and_store_embeddings()
|
|
|
|
# 第三步:验证结果
|
|
verify_embeddings()
|