上传文件
This commit is contained in:
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
实现知识图谱节点name属性向量化
|
||||
"""
|
||||
|
||||
from llm import Embedding
|
||||
from neo4j import GraphDatabase
|
||||
import time
|
||||
|
||||
# 初始化 Embedding 模型
|
||||
embeddings = Embedding(url="http://172.20.0.145:9995/v1", api_key="xxx", model_name="bge-m3")
|
||||
|
||||
# Neo4j 连接信息
|
||||
url = "bolt://172.20.0.145:7687"
|
||||
username = "neo4j"
|
||||
password = "password"
|
||||
|
||||
driver = GraphDatabase.driver(url, auth=(username, password))
|
||||
|
||||
# 定义需要处理的节点类型 - 使用正确的Neo4j标签格式
|
||||
node_labels = [
|
||||
"ProjectDivisionSet",
|
||||
"ProjectDivisionTree",
|
||||
"ProjectDivisionItem",
|
||||
# 对于复合标签,在查询时使用多标签形式
|
||||
{"labels": ["ProjectQuantity", "Quota"], "display": "ProjectQuantity+Quota"},
|
||||
{"labels": ["ProjectQuantity", "MainMaterial"], "display": "ProjectQuantity+MainMaterial"},
|
||||
{"labels": ["ProjectQuantity", "Equipment"], "display": "ProjectQuantity+Equipment"},
|
||||
"MaterialOrEquipment",
|
||||
"FeeTableTemplateSet",
|
||||
"FeeTableTemplateItem",
|
||||
"FeeCollection",
|
||||
"FeeScheduleSet",
|
||||
"FeeScheduleItem",
|
||||
"Fee",
|
||||
]
|
||||
|
||||
|
||||
def create_vector_index():
|
||||
"""为每个标签创建向量索引"""
|
||||
with driver.session() as session:
|
||||
dimension = 1024 # BGE-M3模型的向量维度
|
||||
|
||||
for i, label_info in enumerate(node_labels):
|
||||
try:
|
||||
# 处理不同格式的标签
|
||||
if isinstance(label_info, dict): # 复合标签
|
||||
labels = label_info["labels"]
|
||||
display_name = label_info["display"]
|
||||
# 为每个单独的标签创建索引,而不是尝试创建复合标签的索引
|
||||
for j, single_label in enumerate(labels):
|
||||
index_name = f"entity_embedding_index_{i}_{j}"
|
||||
create_single_index(
|
||||
session, single_label, index_name, dimension, f"{display_name}的组成标签 {single_label}"
|
||||
)
|
||||
else: # 单一标签
|
||||
index_name = f"entity_embedding_index_{i}"
|
||||
create_single_index(session, label_info, index_name, dimension, label_info)
|
||||
except Exception as e:
|
||||
if isinstance(label_info, dict):
|
||||
print(f"❌ 创建复合标签 {label_info['display']} 的索引失败: {e}")
|
||||
else:
|
||||
print(f"❌ 创建标签 {label_info} 的索引失败: {e}")
|
||||
|
||||
|
||||
def create_single_index(session, label, index_name, dimension, display_name):
|
||||
"""为单个标签创建向量索引,使用最新的Neo4j语法"""
|
||||
try:
|
||||
# 检查索引是否存在
|
||||
check_index_query = """
|
||||
SHOW INDEXES
|
||||
YIELD name
|
||||
WHERE name = $index_name
|
||||
RETURN count(*) > 0 AS exists
|
||||
"""
|
||||
|
||||
result = session.run(check_index_query, index_name=index_name)
|
||||
index_exists = result.single()["exists"]
|
||||
|
||||
if not index_exists:
|
||||
print(f"正在为标签 {display_name} 创建向量索引...")
|
||||
|
||||
# 使用最新的Neo4j向量索引语法
|
||||
create_index_query = f"""
|
||||
CREATE VECTOR INDEX {index_name}
|
||||
FOR (n:{label})
|
||||
ON (n.embedding)
|
||||
OPTIONS {{indexConfig: {{`vector.dimensions`: {dimension}}}}}
|
||||
"""
|
||||
|
||||
session.run(create_index_query)
|
||||
print(f"✅ 标签 {display_name} 的向量索引创建成功")
|
||||
else:
|
||||
print(f"✅ 标签 {display_name} 的向量索引已存在,跳过创建步骤")
|
||||
except Exception as e:
|
||||
print(f"❌ 创建标签 {display_name} 的向量索引失败: {e}")
|
||||
|
||||
|
||||
def generate_and_store_embeddings():
|
||||
with driver.session() as session:
|
||||
for label_info in node_labels:
|
||||
# 处理不同格式的标签
|
||||
if isinstance(label_info, dict): # 复合标签
|
||||
labels = label_info["labels"]
|
||||
display_name = label_info["display"]
|
||||
# 为Neo4j查询构建多标签模式
|
||||
label_pattern = ":".join(labels)
|
||||
else: # 单一标签
|
||||
label_pattern = label_info
|
||||
display_name = label_info
|
||||
|
||||
print(f"\n🔍 Processing nodes of type: {display_name}")
|
||||
start_time = time.time()
|
||||
|
||||
# 查询该类型的所有节点 elementId 和 name
|
||||
query = f"""
|
||||
MATCH (n:{label_pattern})
|
||||
WHERE n.name IS NOT NULL
|
||||
RETURN elementId(n) AS id, n.name AS name
|
||||
"""
|
||||
result = session.run(query)
|
||||
|
||||
count = 0
|
||||
batch = []
|
||||
|
||||
for record in result:
|
||||
node_id = record["id"]
|
||||
name = record["name"]
|
||||
|
||||
if not name or not name.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
vector = embeddings.embed(name)
|
||||
batch.append({"id": node_id, "vector": vector})
|
||||
count += 1
|
||||
|
||||
# 批量写入,减少数据库交互次数
|
||||
if len(batch) >= 50:
|
||||
write_batch(session, batch)
|
||||
batch = []
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing node {node_id}: {e}")
|
||||
|
||||
# 写入剩余批次
|
||||
if batch:
|
||||
write_batch(session, batch)
|
||||
|
||||
duration = time.time() - start_time
|
||||
print(f"✅ Processed {count} nodes of type {display_name} in {duration:.2f} seconds.")
|
||||
|
||||
|
||||
def write_batch(session, batch):
|
||||
"""批量写入 embedding 到 Neo4j"""
|
||||
session.run(
|
||||
"""
|
||||
UNWIND $batch AS item
|
||||
MATCH (n) WHERE elementId(n) = item.id
|
||||
SET n.embedding = item.vector
|
||||
""",
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
|
||||
def verify_embeddings():
|
||||
"""验证embedding是否正确存储"""
|
||||
with driver.session() as session:
|
||||
verify_query = """
|
||||
MATCH (n)
|
||||
WHERE n.embedding IS NOT NULL
|
||||
RETURN n.name, size(n.embedding) as dimension
|
||||
LIMIT 5
|
||||
"""
|
||||
result = session.run(verify_query)
|
||||
print("\n验证embedding存储情况:")
|
||||
|
||||
records = list(result)
|
||||
if not records:
|
||||
print("❌ 未找到任何带有embedding的节点")
|
||||
else:
|
||||
for record in records:
|
||||
print(f"节点: {record['n.name']}, 向量维度: {record['dimension']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 第一步:创建向量索引
|
||||
create_vector_index()
|
||||
|
||||
# 第二步:生成并存储embeddings
|
||||
generate_and_store_embeddings()
|
||||
|
||||
# 第三步:验证结果
|
||||
verify_embeddings()
|
||||
Reference in New Issue
Block a user