优化属性图检索功能及支持OpenAI线上模型

This commit is contained in:
wanyaokun
2024-09-20 17:34:38 +08:00
parent 092f7230c1
commit f7260da6d9
12 changed files with 350 additions and 76 deletions
+45 -29
View File
@@ -1,21 +1,20 @@
import os
from llama_index.core.schema import TransformComponent, BaseNode
from llama_index.core.graph_stores.types import (
EntityNode,
Relation,
Triplet,
KG_NODES_KEY,
KG_RELATIONS_KEY,
)
from app.engine.loaders.projectJson import ProjectJson
from app.engine.loaders.markdownReader import ChunkMarkdownReader
from app.engine.graph.graphTypes import *
import uuid
class PrjGraphExtractor(TransformComponent):
ProjectName:str
_nodeMaps = {}
_prjID = ''
def __init__(self,PrjName:str):
super().__init__(ProjectName = PrjName)
def __call__(
self, llama_nodes: list[BaseNode], **kwargs
) -> list[BaseNode]:
@@ -38,9 +37,9 @@ class PrjGraphExtractor(TransformComponent):
records:dict[str,list] = self._getRecordNode(llama_node)
fInfos = fileName.split('_')
if len(fInfos) == 1:
existing_nodes.append(EntityNode(name=fInfos[0], label=fInfos[0]))
fileID = self._add_node(existing_nodes = existing_nodes,name=fInfos[0], label=fInfos[0])
elif len(fInfos) == 2:
existing_nodes.append(EntityNode(name=fileName, label=fInfos[1]))
fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fInfos[1])
else:
raise ValueError("文件名存在多个下划线")
@@ -48,22 +47,20 @@ class PrjGraphExtractor(TransformComponent):
for record in records:
index = index + 1
rcdName = self._getRecordName(fileName,record)
existing_nodes.append(EntityNode(name=rcdName, label=rcdName,properties = record))
rcdid = self._add_node(existing_nodes = existing_nodes,name=rcdName, label=rcdName,properties = record)
existing_relations.append(
Relation(
RagRelation(
label="包含",
source_id= fileName,
target_id= rcdName,
properties={},
source_id= fileID,
target_id= rcdid
)
)
existing_relations.append(
Relation(
RagRelation(
label="包含",
source_id= self.ProjectName,
target_id= fileName,
properties={},
source_id= self._prjID,
target_id= fileID
)
)
@@ -76,28 +73,26 @@ class PrjGraphExtractor(TransformComponent):
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
records:dict[str,list] = self._getRecordNode(llama_node)
existing_nodes.append(EntityNode(name=fileName, label=fileName))
fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fileName)
index = 0
for record in records:
index = index + 1
attName = self._getRecordName(fileName,record)
existing_nodes.append(EntityNode(name=attName, label='属性',properties = record))
attID = self._add_node(existing_nodes = existing_nodes,name=attName, label= attName,properties = record)
existing_relations.append(
Relation(
RagRelation(
label="聚合",
source_id= fileName,
target_id= attName,
properties={},
source_id= fileID,
target_id= attID
)
)
existing_relations.append(
Relation(
RagRelation(
label="包含",
source_id= self.ProjectName,
target_id= fileName,
properties={},
source_id= self._prjID,
target_id= fileID
)
)
@@ -118,11 +113,32 @@ class PrjGraphExtractor(TransformComponent):
def _addPrjNode(self,llama_node:BaseNode):
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
existing_nodes.append(EntityNode(name=self.ProjectName, label=self.ProjectName))
self._prjID = self._add_node(existing_nodes = existing_nodes,name=self.ProjectName,label=self.ProjectName)
llama_node.metadata[KG_NODES_KEY] = existing_nodes
def _getRecordName(self,fileName:str,record:dict):
for name,value in record.items():
if '名称' in name:
return value
raise ValueError('记录名称为空')
raise ValueError('记录名称为空')
def _add_node(self,existing_nodes:list,name:str,label:str,properties:dict = {}):
id:str = ''
if name in self._nodeMaps:
nodes:list[RagEntityNode] = self._nodeMaps[name]
for node in nodes:
if node.properties == properties and node.name == name and node.label == label:
id = node.id
break
if id =='':
id = str(uuid.uuid1())
newNode = RagEntityNode(name = name,label=label,properties = properties,uid = id)
existing_nodes.append(newNode)
if name in self._nodeMaps:
nodes:list[RagEntityNode] = self._nodeMaps[name]
nodes.append(newNode)
else:
self._nodeMaps[name] = [newNode]
return id