优化属性图检索功能及支持OpenAI线上模型
This commit is contained in:
@@ -1,21 +1,20 @@
|
||||
import os
|
||||
from llama_index.core.schema import TransformComponent, BaseNode
|
||||
from llama_index.core.graph_stores.types import (
|
||||
EntityNode,
|
||||
Relation,
|
||||
Triplet,
|
||||
KG_NODES_KEY,
|
||||
KG_RELATIONS_KEY,
|
||||
)
|
||||
from app.engine.loaders.projectJson import ProjectJson
|
||||
from app.engine.loaders.markdownReader import ChunkMarkdownReader
|
||||
from app.engine.graph.graphTypes import *
|
||||
import uuid
|
||||
|
||||
class PrjGraphExtractor(TransformComponent):
|
||||
ProjectName:str
|
||||
_nodeMaps = {}
|
||||
_prjID = ''
|
||||
def __init__(self,PrjName:str):
|
||||
super().__init__(ProjectName = PrjName)
|
||||
|
||||
|
||||
|
||||
def __call__(
|
||||
self, llama_nodes: list[BaseNode], **kwargs
|
||||
) -> list[BaseNode]:
|
||||
@@ -38,9 +37,9 @@ class PrjGraphExtractor(TransformComponent):
|
||||
records:dict[str,list] = self._getRecordNode(llama_node)
|
||||
fInfos = fileName.split('_')
|
||||
if len(fInfos) == 1:
|
||||
existing_nodes.append(EntityNode(name=fInfos[0], label=fInfos[0]))
|
||||
fileID = self._add_node(existing_nodes = existing_nodes,name=fInfos[0], label=fInfos[0])
|
||||
elif len(fInfos) == 2:
|
||||
existing_nodes.append(EntityNode(name=fileName, label=fInfos[1]))
|
||||
fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fInfos[1])
|
||||
else:
|
||||
raise ValueError("文件名存在多个下划线")
|
||||
|
||||
@@ -48,22 +47,20 @@ class PrjGraphExtractor(TransformComponent):
|
||||
for record in records:
|
||||
index = index + 1
|
||||
rcdName = self._getRecordName(fileName,record)
|
||||
existing_nodes.append(EntityNode(name=rcdName, label=rcdName,properties = record))
|
||||
rcdid = self._add_node(existing_nodes = existing_nodes,name=rcdName, label=rcdName,properties = record)
|
||||
existing_relations.append(
|
||||
Relation(
|
||||
RagRelation(
|
||||
label="包含",
|
||||
source_id= fileName,
|
||||
target_id= rcdName,
|
||||
properties={},
|
||||
source_id= fileID,
|
||||
target_id= rcdid
|
||||
)
|
||||
)
|
||||
|
||||
existing_relations.append(
|
||||
Relation(
|
||||
RagRelation(
|
||||
label="包含",
|
||||
source_id= self.ProjectName,
|
||||
target_id= fileName,
|
||||
properties={},
|
||||
source_id= self._prjID,
|
||||
target_id= fileID
|
||||
)
|
||||
)
|
||||
|
||||
@@ -76,28 +73,26 @@ class PrjGraphExtractor(TransformComponent):
|
||||
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
||||
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
|
||||
records:dict[str,list] = self._getRecordNode(llama_node)
|
||||
existing_nodes.append(EntityNode(name=fileName, label=fileName))
|
||||
fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fileName)
|
||||
|
||||
index = 0
|
||||
for record in records:
|
||||
index = index + 1
|
||||
attName = self._getRecordName(fileName,record)
|
||||
existing_nodes.append(EntityNode(name=attName, label='属性',properties = record))
|
||||
attID = self._add_node(existing_nodes = existing_nodes,name=attName, label= attName,properties = record)
|
||||
existing_relations.append(
|
||||
Relation(
|
||||
RagRelation(
|
||||
label="聚合",
|
||||
source_id= fileName,
|
||||
target_id= attName,
|
||||
properties={},
|
||||
source_id= fileID,
|
||||
target_id= attID
|
||||
)
|
||||
)
|
||||
|
||||
existing_relations.append(
|
||||
Relation(
|
||||
RagRelation(
|
||||
label="包含",
|
||||
source_id= self.ProjectName,
|
||||
target_id= fileName,
|
||||
properties={},
|
||||
source_id= self._prjID,
|
||||
target_id= fileID
|
||||
)
|
||||
)
|
||||
|
||||
@@ -118,11 +113,32 @@ class PrjGraphExtractor(TransformComponent):
|
||||
|
||||
def _addPrjNode(self,llama_node:BaseNode):
|
||||
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
||||
existing_nodes.append(EntityNode(name=self.ProjectName, label=self.ProjectName))
|
||||
self._prjID = self._add_node(existing_nodes = existing_nodes,name=self.ProjectName,label=self.ProjectName)
|
||||
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
||||
|
||||
def _getRecordName(self,fileName:str,record:dict):
|
||||
for name,value in record.items():
|
||||
if '名称' in name:
|
||||
return value
|
||||
raise ValueError('记录名称为空')
|
||||
raise ValueError('记录名称为空')
|
||||
|
||||
def _add_node(self,existing_nodes:list,name:str,label:str,properties:dict = {}):
|
||||
id:str = ''
|
||||
if name in self._nodeMaps:
|
||||
nodes:list[RagEntityNode] = self._nodeMaps[name]
|
||||
for node in nodes:
|
||||
if node.properties == properties and node.name == name and node.label == label:
|
||||
id = node.id
|
||||
break
|
||||
|
||||
if id =='':
|
||||
id = str(uuid.uuid1())
|
||||
newNode = RagEntityNode(name = name,label=label,properties = properties,uid = id)
|
||||
existing_nodes.append(newNode)
|
||||
if name in self._nodeMaps:
|
||||
nodes:list[RagEntityNode] = self._nodeMaps[name]
|
||||
nodes.append(newNode)
|
||||
else:
|
||||
self._nodeMaps[name] = [newNode]
|
||||
|
||||
return id
|
||||
@@ -0,0 +1,42 @@
|
||||
|
||||
from typing import Any, List, Dict, Sequence, Tuple, Optional
|
||||
from llama_index.core.graph_stores.simple_labelled import SimplePropertyGraphStore
|
||||
from llama_index.core.graph_stores.types import LabelledNode,ChunkNode,LabelledPropertyGraph
|
||||
from app.engine.graph.graphTypes import *
|
||||
|
||||
class RAGPropertyGraphStore(SimplePropertyGraphStore):
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: dict,
|
||||
) -> "SimplePropertyGraphStore":
|
||||
"""Load from dict."""
|
||||
# need to load nodes manually
|
||||
node_dicts = data["nodes"]
|
||||
relation_dicts = data["relations"]
|
||||
|
||||
kg_nodes: Dict[str, LabelledNode] = {}
|
||||
kg_relations: Dict[str, LabelledNode] = {}
|
||||
for id, node_dict in node_dicts.items():
|
||||
if "name" in node_dict:
|
||||
kg_nodes[id] = RagEntityNode.model_validate(node_dict)
|
||||
elif "text" in node_dict:
|
||||
kg_nodes[id] = ChunkNode.model_validate(node_dict)
|
||||
else:
|
||||
raise ValueError(f"Could not infer node type for data: {node_dict!s}")
|
||||
|
||||
for id, node_dict in relation_dicts.items():
|
||||
kg_relations[id] = RagRelation.model_validate(node_dict)
|
||||
|
||||
# clear the nodes, to load later
|
||||
data["nodes"] = {}
|
||||
data["relations"] = {}
|
||||
|
||||
# load the graph
|
||||
graph = LabelledPropertyGraph.model_validate(data)
|
||||
|
||||
# add the node back
|
||||
graph.nodes = kg_nodes
|
||||
graph.relations = kg_relations
|
||||
|
||||
return cls(graph)
|
||||
@@ -0,0 +1,38 @@
|
||||
from llama_index.core.graph_stores.types import (
|
||||
EntityNode,
|
||||
Relation
|
||||
)
|
||||
|
||||
class RagEntityNode(EntityNode):
|
||||
uid : str
|
||||
def __str__(self) -> str:
|
||||
"""Return the string representation of the node."""
|
||||
if self.properties:
|
||||
prop = self.properties
|
||||
if 'triplet_source_id' in prop:
|
||||
prop.pop('triplet_source_id')
|
||||
if len(prop) > 0:
|
||||
return f"{self.name} ({prop})"
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""Get the node id."""
|
||||
#return self.name.replace('"', " ")
|
||||
return self.uid
|
||||
|
||||
class RagRelation(Relation):
|
||||
def __str__(self) -> str:
|
||||
"""Return the string representation of the relation."""
|
||||
if self.properties:
|
||||
prop = self.properties
|
||||
if 'triplet_source_id' in prop:
|
||||
prop.pop('triplet_source_id')
|
||||
if len(prop) > 0:
|
||||
return f"{self.label} ({prop})"
|
||||
return self.label
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""Get the relation id."""
|
||||
return self.label
|
||||
@@ -2,7 +2,7 @@ from llama_index.core.indices.property_graph import LLMSynonymRetriever,VectorCo
|
||||
from llama_index.core.indices.property_graph.transformations.schema_llm import *
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core import settings
|
||||
from llama_index.core import PropertyGraphIndex
|
||||
from llama_index.core import PropertyGraphIndex,KnowledgeGraphIndex
|
||||
from typing import List,Tuple,Literal
|
||||
from app.settings import init_settings
|
||||
import os
|
||||
@@ -15,7 +15,9 @@ from util.register import *
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||
from app.engine.engine import get_node_postprocessors
|
||||
from app.engine.graph.graphStore import RAGPropertyGraphStore
|
||||
|
||||
from app.engine.retriever.graphKeyWordRetriever import GraphKeyWordRetriever
|
||||
class PropertyGraph:
|
||||
def __init__(self,prjFlag:str) -> None:
|
||||
self._prjFlag = prjFlag
|
||||
@@ -44,16 +46,15 @@ class PropertyGraph:
|
||||
prjCachePath = GRAPH_STORAGE_DIR + f"/{self._prjFlag}"
|
||||
if not os.path.exists(prjCachePath):
|
||||
return None
|
||||
storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(self._prjFlag))
|
||||
storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(self._prjFlag),
|
||||
property_graph_store = RAGPropertyGraphStore.from_persist_dir(prjCachePath))
|
||||
index = load_index_from_storage(storeContext)
|
||||
return index
|
||||
|
||||
def query(self,query_str:str):
|
||||
index = self.getPropertyGraphIndex()
|
||||
synonym_retriver = LLMSynonymRetriever(index.property_graph_store,
|
||||
llm=settings.Settings.llm,
|
||||
max_keywords=10,
|
||||
include_text=False
|
||||
synonym_retriver = GraphKeyWordRetriever(index.property_graph_store,
|
||||
include_text=False
|
||||
)
|
||||
if index.property_graph_store.supports_vector_queries:
|
||||
vector_store = None
|
||||
@@ -62,7 +63,7 @@ class PropertyGraph:
|
||||
vector_retriver = VectorContextRetriever(index.property_graph_store,
|
||||
vector_store = vector_store,
|
||||
embed_model=settings.Settings.embed_model,
|
||||
similarity_top_k=5,
|
||||
similarity_top_k=10,
|
||||
include_text=False
|
||||
)
|
||||
|
||||
@@ -77,8 +78,8 @@ class PropertyGraph:
|
||||
if __name__ == "__main__":
|
||||
init_settings()
|
||||
init_observability()
|
||||
graph = PropertyGraph('projects_1b20bbf4-3243-4ac3-bcf0-8a91e9157521')
|
||||
graph.query('代码为XLBT的金额是')
|
||||
graph = PropertyGraph('projects_0ffaf7fb-8a61-46e2-97a2-8f924e9560a7')
|
||||
graph.query('工程属性表有哪些字段')
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user