新增属性图谱
This commit is contained in:
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
from llama_index.core.schema import TransformComponent, BaseNode
|
||||
from llama_index.core.graph_stores.types import (
|
||||
EntityNode,
|
||||
Relation,
|
||||
Triplet,
|
||||
KG_NODES_KEY,
|
||||
KG_RELATIONS_KEY,
|
||||
)
|
||||
from app.engine.loaders.projectJson import ProjectJson
|
||||
from app.engine.loaders.markdownReader import ChunkMarkdownReader
|
||||
|
||||
class PrjGraphExtractor(TransformComponent):
|
||||
ProjectName:str
|
||||
def __init__(self,PrjName:str):
|
||||
super().__init__(ProjectName = PrjName)
|
||||
|
||||
|
||||
def __call__(
|
||||
self, llama_nodes: list[BaseNode], **kwargs
|
||||
) -> list[BaseNode]:
|
||||
if len(llama_nodes) > 0:
|
||||
self._addPrjNode(llama_nodes[0])
|
||||
|
||||
for llama_node in llama_nodes:
|
||||
fileName = self._getFileName(llama_node)
|
||||
if fileName == '工程属性':
|
||||
self._dealAttributeNode(llama_node)
|
||||
else:
|
||||
self._dealCommonNode(llama_node)
|
||||
return llama_nodes
|
||||
|
||||
def _dealCommonNode(self,llama_node:BaseNode):
|
||||
fileName = self._getFileName(llama_node)
|
||||
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
||||
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
|
||||
|
||||
records:dict[str,list] = self._getRecordNode(llama_node)
|
||||
fInfos = fileName.split('_')
|
||||
if len(fInfos) == 1:
|
||||
existing_nodes.append(EntityNode(name=fInfos[0], label=fInfos[0]))
|
||||
elif len(fInfos) == 2:
|
||||
existing_nodes.append(EntityNode(name=fileName, label=fInfos[1]))
|
||||
else:
|
||||
raise ValueError("文件名存在多个下划线")
|
||||
|
||||
index = 0
|
||||
for record in records:
|
||||
index = index + 1
|
||||
rcdName = self._getRecordName(fileName,record)
|
||||
existing_nodes.append(EntityNode(name=rcdName, label=rcdName,properties = record))
|
||||
existing_relations.append(
|
||||
Relation(
|
||||
label="包含",
|
||||
source_id= fileName,
|
||||
target_id= rcdName,
|
||||
properties={},
|
||||
)
|
||||
)
|
||||
|
||||
existing_relations.append(
|
||||
Relation(
|
||||
label="包含",
|
||||
source_id= self.ProjectName,
|
||||
target_id= fileName,
|
||||
properties={},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
||||
llama_node.metadata[KG_RELATIONS_KEY] = existing_relations
|
||||
|
||||
def _dealAttributeNode(self,llama_node:BaseNode):
|
||||
fileName = self._getFileName(llama_node)
|
||||
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
||||
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
|
||||
records:dict[str,list] = self._getRecordNode(llama_node)
|
||||
existing_nodes.append(EntityNode(name=fileName, label=fileName))
|
||||
|
||||
index = 0
|
||||
for record in records:
|
||||
index = index + 1
|
||||
attName = self._getRecordName(fileName,record)
|
||||
existing_nodes.append(EntityNode(name=attName, label='属性',properties = record))
|
||||
existing_relations.append(
|
||||
Relation(
|
||||
label="聚合",
|
||||
source_id= fileName,
|
||||
target_id= attName,
|
||||
properties={},
|
||||
)
|
||||
)
|
||||
|
||||
existing_relations.append(
|
||||
Relation(
|
||||
label="包含",
|
||||
source_id= self.ProjectName,
|
||||
target_id= fileName,
|
||||
properties={},
|
||||
)
|
||||
)
|
||||
|
||||
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
||||
llama_node.metadata[KG_RELATIONS_KEY] = existing_relations
|
||||
|
||||
def _getRecordNode(self,llama_node:BaseNode):
|
||||
content = llama_node.get_content()
|
||||
rd = ChunkMarkdownReader()
|
||||
rd.markdown_to_tups(content)
|
||||
records = rd.records()
|
||||
return records
|
||||
|
||||
def _getFileName(self,llama_node:BaseNode):
|
||||
meta = llama_node.metadata
|
||||
fileName:str = os.path.splitext(meta['file_name'])[0]
|
||||
return fileName
|
||||
|
||||
def _addPrjNode(self,llama_node:BaseNode):
|
||||
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
||||
existing_nodes.append(EntityNode(name=self.ProjectName, label=self.ProjectName))
|
||||
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
||||
|
||||
def _getRecordName(self,fileName:str,record:dict):
|
||||
for name,value in record.items():
|
||||
if '名称' in name:
|
||||
return value
|
||||
raise ValueError('记录名称为空')
|
||||
@@ -0,0 +1,87 @@
|
||||
from llama_index.core.indices.property_graph import LLMSynonymRetriever,VectorContextRetriever,PGRetriever
|
||||
from llama_index.core.indices.property_graph.transformations.schema_llm import *
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core import settings
|
||||
from llama_index.core import PropertyGraphIndex
|
||||
from typing import List,Tuple,Literal
|
||||
from app.settings import init_settings
|
||||
import os
|
||||
from llama_index.core.storage.storage_context import StorageContext
|
||||
from llama_index.core import load_index_from_storage
|
||||
from app.observability import init_observability
|
||||
from app.engine.vectordb import get_Neo4j_Graph_Store
|
||||
from llama_index.core.response_synthesizers import ResponseMode
|
||||
from util.register import *
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||
from app.engine.engine import get_node_postprocessors
|
||||
|
||||
class PropertyGraph:
|
||||
def __init__(self,prjFlag:str) -> None:
|
||||
self._prjFlag = prjFlag
|
||||
|
||||
def create_query_engine(self,retriever):
|
||||
postprocess = get_node_postprocessors()
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever = retriever,
|
||||
text_qa_template=text_qa_template,
|
||||
refine_template=refine_template,
|
||||
summary_template = summary_template,
|
||||
simple_template = simple_template,
|
||||
node_postprocessors=postprocess,
|
||||
use_async=True,
|
||||
streaming=False,
|
||||
response_mode = ResponseMode.TREE_SUMMARIZE
|
||||
)
|
||||
return query_engine
|
||||
|
||||
def getPropertyGraphIndex(self):
|
||||
GRAPH_STORE_TYPE = os.getenv("GRAPH_STORE_TYPE", "")
|
||||
if GRAPH_STORE_TYPE == 'neo4j':
|
||||
index = PropertyGraphIndex.from_existing(property_graph_store= get_Neo4j_Graph_Store(self._prjFlag))
|
||||
else:
|
||||
GRAPH_STORAGE_DIR = os.getenv("GRAPH_STORAGE_DIR", "storage_graph")
|
||||
prjCachePath = GRAPH_STORAGE_DIR + f"/{self._prjFlag}"
|
||||
if not os.path.exists(prjCachePath):
|
||||
return None
|
||||
storeContext = StorageContext.from_defaults(persist_dir = prjCachePath)
|
||||
index = load_index_from_storage(storeContext)
|
||||
return index
|
||||
|
||||
def query(self,query_str:str):
|
||||
index = self.getPropertyGraphIndex()
|
||||
synonym_retriver = LLMSynonymRetriever(index.property_graph_store,
|
||||
llm=settings.Settings.llm,
|
||||
max_keywords=10,
|
||||
include_text=False
|
||||
)
|
||||
if index.property_graph_store.supports_vector_queries:
|
||||
vector_store = None
|
||||
else:
|
||||
vector_store = index.vector_store
|
||||
vector_retriver = VectorContextRetriever(index.property_graph_store,
|
||||
vector_store = vector_store,
|
||||
embed_model=settings.Settings.embed_model,
|
||||
similarity_top_k=5,
|
||||
include_text=False
|
||||
)
|
||||
|
||||
retriever = index.as_retriever(sub_retrievers=[synonym_retriver,vector_retriver])
|
||||
query_engine = self.create_query_engine(retriever)
|
||||
|
||||
response = query_engine.query(query_str)
|
||||
print(response)
|
||||
return str(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_settings()
|
||||
init_observability()
|
||||
# graph = PropertyGraph('projects_1b20bbf4-3243-4ac3-bcf0-8a91e9157521')
|
||||
# graph.query('代码为XLBT的金额是')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user