109 lines
3.1 KiB
Python
109 lines
3.1 KiB
Python
from langchain_community.vectorstores import Neo4jVector
|
|
from llm import Embedding
|
|
import os
|
|
import ssl
|
|
|
|
# 设置SSL证书路径(解决SSL验证问题)
|
|
os.environ["SSL_CERT_FILE"] = ssl.get_default_verify_paths().cafile
|
|
|
|
|
|
# 创建一个适配器类,将我们的Embedding类适配为LangChain所期望的接口
|
|
class EmbeddingAdapter:
|
|
def __init__(self, embedding_model):
|
|
self.model = embedding_model
|
|
|
|
def embed_query(self, text):
|
|
"""LangChain期望的方法名"""
|
|
return self.model.embed(text)
|
|
|
|
def embed_documents(self, documents):
|
|
"""LangChain可能也期望这个方法"""
|
|
return [self.model.embed(doc) for doc in documents]
|
|
|
|
|
|
# 初始化 Embedding 模型
|
|
raw_embeddings = Embedding(url="http://172.20.0.145:9995/v1", api_key="xxx", model_name="bge-m3")
|
|
# 使用适配器包装
|
|
embeddings = EmbeddingAdapter(raw_embeddings)
|
|
|
|
url = "bolt://172.20.0.145:7687"
|
|
username = "neo4j"
|
|
password = "password"
|
|
|
|
# 定义要搜索的节点标签
|
|
node_labels = [
|
|
"ProjectDivisionSet",
|
|
"ProjectDivisionTree",
|
|
"ProjectDivisionItem",
|
|
# "ProjectQuantity", # 注意:这将匹配带有ProjectQuantity标签的节点
|
|
# "Quota",
|
|
# "MainMaterial",
|
|
# "Equipment",
|
|
# "MaterialOrEquipment",
|
|
# "FeeTableTemplateSet",
|
|
# "FeeTableTemplateItem",
|
|
# "FeeCollection",
|
|
# "FeeScheduleSet",
|
|
# "FeeScheduleItem",
|
|
# "Fee",
|
|
]
|
|
|
|
# 构建 Cypher 条件字符串
|
|
where_label_clause = " OR ".join([f"n:{label}" for label in node_labels])
|
|
|
|
# 自定义检索查询
|
|
retrieval_query = f"""
|
|
MATCH (n)
|
|
WHERE ({where_label_clause})
|
|
AND n.name IS NOT NULL
|
|
AND n.embedding IS NOT NULL
|
|
WITH n, reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) |
|
|
dot + n.embedding[i] * $embedding[i]) /
|
|
(sqrt(reduce(norm = 0.0, i IN range(0, size(n.embedding)-1) |
|
|
norm + n.embedding[i] * n.embedding[i])) *
|
|
sqrt(reduce(norm = 0.0, i IN range(0, size($embedding)-1) |
|
|
norm + $embedding[i] * $embedding[i]))) AS score
|
|
RETURN n.name AS text,
|
|
score,
|
|
{{
|
|
labels: labels(n),
|
|
metadata: n
|
|
}} AS metadata
|
|
ORDER BY score DESC LIMIT 5
|
|
"""
|
|
|
|
# 初始化 VectorStore,使用自定义检索查询
|
|
vectorstore = Neo4jVector.from_existing_index(
|
|
embedding=embeddings,
|
|
url=url,
|
|
username=username,
|
|
password=password,
|
|
index_name="entity_embedding_index_0",
|
|
embedding_node_property="embedding",
|
|
retrieval_query=retrieval_query,
|
|
)
|
|
|
|
|
|
def search_knowledge_graph(query, k=5):
|
|
"""搜索知识图谱中与查询最相似的节点"""
|
|
print(f"\n执行查询: '{query}'")
|
|
|
|
try:
|
|
# 使用标准查询
|
|
results = vectorstore.similarity_search_with_score(query, k=k)
|
|
|
|
print(f"找到 {len(results)} 个相关节点:")
|
|
for i, (doc, score) in enumerate(results):
|
|
print(f"{i+1}. {doc.page_content} (相似度: {score:.4f})")
|
|
if hasattr(doc, "metadata") and doc.metadata:
|
|
print(f" 元数据: {doc.metadata}")
|
|
|
|
return results
|
|
except Exception as e:
|
|
print(f"搜索出错: {e}")
|
|
return []
|
|
|
|
|
|
# 示例查询
|
|
search_knowledge_graph("土建工程", k=5)
|