优化属性图检索功能及支持OpenAI线上模型
This commit is contained in:
@@ -41,6 +41,15 @@ MODEL_PROVIDER=dashscope
|
|||||||
DASHSCOPE_API_KEY=sk-221d2d202e104618a56002ce2e7dc0d0
|
DASHSCOPE_API_KEY=sk-221d2d202e104618a56002ce2e7dc0d0
|
||||||
MODEL=qwen2-math-72b-instruct
|
MODEL=qwen2-math-72b-instruct
|
||||||
|
|
||||||
|
# #---------- model - openai ----------------
|
||||||
|
# MODEL_PROVIDER=openai
|
||||||
|
# OPENAI_API_KEY=sk-hhoqttvhibirwheyponjifsqwssgxotoqlcjufkidytwxngi
|
||||||
|
# BASE_URL=https://api.siliconflow.cn/v1
|
||||||
|
# MODEL=alibaba/Qwen1.5-110B-Chat
|
||||||
|
# LLM_TEMPERATURE=0.1
|
||||||
|
# CONTEXT_WINDOW = 8192
|
||||||
|
# IS_CHAT_MODEL = true
|
||||||
|
# IS_FUN_CALL_MODEL = false
|
||||||
|
|
||||||
|
|
||||||
#---------- embedding - Xinference ----------------
|
#---------- embedding - Xinference ----------------
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from llama_index.core.storage import StorageContext
|
|||||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||||
from llama_index.core import PropertyGraphIndex
|
from llama_index.core import PropertyGraphIndex
|
||||||
from app.engine.graph.extractor import PrjGraphExtractor
|
from app.engine.graph.extractor import PrjGraphExtractor
|
||||||
|
from app.engine.graph.graphStore import RAGPropertyGraphStore
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -103,7 +104,8 @@ class PropertyGraphChache:
|
|||||||
|
|
||||||
def simplePropertyGraph(self,prjName:str,prjFlag:str,filePath:str):
|
def simplePropertyGraph(self,prjName:str,prjFlag:str,filePath:str):
|
||||||
documents = get_documents(prjFlag)
|
documents = get_documents(prjFlag)
|
||||||
storeContext = StorageContext.from_defaults(vector_store=get_vector_store(prjFlag))
|
storeContext = StorageContext.from_defaults(vector_store=get_vector_store(prjFlag),
|
||||||
|
property_graph_store = RAGPropertyGraphStore())
|
||||||
index = PropertyGraphIndex(
|
index = PropertyGraphIndex(
|
||||||
nodes =documents,
|
nodes =documents,
|
||||||
kg_extractors = [PrjGraphExtractor(prjName)],
|
kg_extractors = [PrjGraphExtractor(prjName)],
|
||||||
|
|||||||
@@ -1,21 +1,20 @@
|
|||||||
import os
|
import os
|
||||||
from llama_index.core.schema import TransformComponent, BaseNode
|
from llama_index.core.schema import TransformComponent, BaseNode
|
||||||
from llama_index.core.graph_stores.types import (
|
from llama_index.core.graph_stores.types import (
|
||||||
EntityNode,
|
|
||||||
Relation,
|
|
||||||
Triplet,
|
|
||||||
KG_NODES_KEY,
|
KG_NODES_KEY,
|
||||||
KG_RELATIONS_KEY,
|
KG_RELATIONS_KEY,
|
||||||
)
|
)
|
||||||
from app.engine.loaders.projectJson import ProjectJson
|
|
||||||
from app.engine.loaders.markdownReader import ChunkMarkdownReader
|
from app.engine.loaders.markdownReader import ChunkMarkdownReader
|
||||||
|
from app.engine.graph.graphTypes import *
|
||||||
|
import uuid
|
||||||
|
|
||||||
class PrjGraphExtractor(TransformComponent):
|
class PrjGraphExtractor(TransformComponent):
|
||||||
ProjectName:str
|
ProjectName:str
|
||||||
|
_nodeMaps = {}
|
||||||
|
_prjID = ''
|
||||||
def __init__(self,PrjName:str):
|
def __init__(self,PrjName:str):
|
||||||
super().__init__(ProjectName = PrjName)
|
super().__init__(ProjectName = PrjName)
|
||||||
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, llama_nodes: list[BaseNode], **kwargs
|
self, llama_nodes: list[BaseNode], **kwargs
|
||||||
) -> list[BaseNode]:
|
) -> list[BaseNode]:
|
||||||
@@ -38,9 +37,9 @@ class PrjGraphExtractor(TransformComponent):
|
|||||||
records:dict[str,list] = self._getRecordNode(llama_node)
|
records:dict[str,list] = self._getRecordNode(llama_node)
|
||||||
fInfos = fileName.split('_')
|
fInfos = fileName.split('_')
|
||||||
if len(fInfos) == 1:
|
if len(fInfos) == 1:
|
||||||
existing_nodes.append(EntityNode(name=fInfos[0], label=fInfos[0]))
|
fileID = self._add_node(existing_nodes = existing_nodes,name=fInfos[0], label=fInfos[0])
|
||||||
elif len(fInfos) == 2:
|
elif len(fInfos) == 2:
|
||||||
existing_nodes.append(EntityNode(name=fileName, label=fInfos[1]))
|
fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fInfos[1])
|
||||||
else:
|
else:
|
||||||
raise ValueError("文件名存在多个下划线")
|
raise ValueError("文件名存在多个下划线")
|
||||||
|
|
||||||
@@ -48,22 +47,20 @@ class PrjGraphExtractor(TransformComponent):
|
|||||||
for record in records:
|
for record in records:
|
||||||
index = index + 1
|
index = index + 1
|
||||||
rcdName = self._getRecordName(fileName,record)
|
rcdName = self._getRecordName(fileName,record)
|
||||||
existing_nodes.append(EntityNode(name=rcdName, label=rcdName,properties = record))
|
rcdid = self._add_node(existing_nodes = existing_nodes,name=rcdName, label=rcdName,properties = record)
|
||||||
existing_relations.append(
|
existing_relations.append(
|
||||||
Relation(
|
RagRelation(
|
||||||
label="包含",
|
label="包含",
|
||||||
source_id= fileName,
|
source_id= fileID,
|
||||||
target_id= rcdName,
|
target_id= rcdid
|
||||||
properties={},
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_relations.append(
|
existing_relations.append(
|
||||||
Relation(
|
RagRelation(
|
||||||
label="包含",
|
label="包含",
|
||||||
source_id= self.ProjectName,
|
source_id= self._prjID,
|
||||||
target_id= fileName,
|
target_id= fileID
|
||||||
properties={},
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -76,28 +73,26 @@ class PrjGraphExtractor(TransformComponent):
|
|||||||
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
||||||
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
|
existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, [])
|
||||||
records:dict[str,list] = self._getRecordNode(llama_node)
|
records:dict[str,list] = self._getRecordNode(llama_node)
|
||||||
existing_nodes.append(EntityNode(name=fileName, label=fileName))
|
fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fileName)
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
for record in records:
|
for record in records:
|
||||||
index = index + 1
|
index = index + 1
|
||||||
attName = self._getRecordName(fileName,record)
|
attName = self._getRecordName(fileName,record)
|
||||||
existing_nodes.append(EntityNode(name=attName, label='属性',properties = record))
|
attID = self._add_node(existing_nodes = existing_nodes,name=attName, label= attName,properties = record)
|
||||||
existing_relations.append(
|
existing_relations.append(
|
||||||
Relation(
|
RagRelation(
|
||||||
label="聚合",
|
label="聚合",
|
||||||
source_id= fileName,
|
source_id= fileID,
|
||||||
target_id= attName,
|
target_id= attID
|
||||||
properties={},
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_relations.append(
|
existing_relations.append(
|
||||||
Relation(
|
RagRelation(
|
||||||
label="包含",
|
label="包含",
|
||||||
source_id= self.ProjectName,
|
source_id= self._prjID,
|
||||||
target_id= fileName,
|
target_id= fileID
|
||||||
properties={},
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -118,11 +113,32 @@ class PrjGraphExtractor(TransformComponent):
|
|||||||
|
|
||||||
def _addPrjNode(self,llama_node:BaseNode):
|
def _addPrjNode(self,llama_node:BaseNode):
|
||||||
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, [])
|
||||||
existing_nodes.append(EntityNode(name=self.ProjectName, label=self.ProjectName))
|
self._prjID = self._add_node(existing_nodes = existing_nodes,name=self.ProjectName,label=self.ProjectName)
|
||||||
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
llama_node.metadata[KG_NODES_KEY] = existing_nodes
|
||||||
|
|
||||||
def _getRecordName(self,fileName:str,record:dict):
|
def _getRecordName(self,fileName:str,record:dict):
|
||||||
for name,value in record.items():
|
for name,value in record.items():
|
||||||
if '名称' in name:
|
if '名称' in name:
|
||||||
return value
|
return value
|
||||||
raise ValueError('记录名称为空')
|
raise ValueError('记录名称为空')
|
||||||
|
|
||||||
|
def _add_node(self,existing_nodes:list,name:str,label:str,properties:dict = {}):
|
||||||
|
id:str = ''
|
||||||
|
if name in self._nodeMaps:
|
||||||
|
nodes:list[RagEntityNode] = self._nodeMaps[name]
|
||||||
|
for node in nodes:
|
||||||
|
if node.properties == properties and node.name == name and node.label == label:
|
||||||
|
id = node.id
|
||||||
|
break
|
||||||
|
|
||||||
|
if id =='':
|
||||||
|
id = str(uuid.uuid1())
|
||||||
|
newNode = RagEntityNode(name = name,label=label,properties = properties,uid = id)
|
||||||
|
existing_nodes.append(newNode)
|
||||||
|
if name in self._nodeMaps:
|
||||||
|
nodes:list[RagEntityNode] = self._nodeMaps[name]
|
||||||
|
nodes.append(newNode)
|
||||||
|
else:
|
||||||
|
self._nodeMaps[name] = [newNode]
|
||||||
|
|
||||||
|
return id
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
|
||||||
|
from typing import Any, List, Dict, Sequence, Tuple, Optional
|
||||||
|
from llama_index.core.graph_stores.simple_labelled import SimplePropertyGraphStore
|
||||||
|
from llama_index.core.graph_stores.types import LabelledNode,ChunkNode,LabelledPropertyGraph
|
||||||
|
from app.engine.graph.graphTypes import *
|
||||||
|
|
||||||
|
class RAGPropertyGraphStore(SimplePropertyGraphStore):
|
||||||
|
@classmethod
|
||||||
|
def from_dict(
|
||||||
|
cls,
|
||||||
|
data: dict,
|
||||||
|
) -> "SimplePropertyGraphStore":
|
||||||
|
"""Load from dict."""
|
||||||
|
# need to load nodes manually
|
||||||
|
node_dicts = data["nodes"]
|
||||||
|
relation_dicts = data["relations"]
|
||||||
|
|
||||||
|
kg_nodes: Dict[str, LabelledNode] = {}
|
||||||
|
kg_relations: Dict[str, LabelledNode] = {}
|
||||||
|
for id, node_dict in node_dicts.items():
|
||||||
|
if "name" in node_dict:
|
||||||
|
kg_nodes[id] = RagEntityNode.model_validate(node_dict)
|
||||||
|
elif "text" in node_dict:
|
||||||
|
kg_nodes[id] = ChunkNode.model_validate(node_dict)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Could not infer node type for data: {node_dict!s}")
|
||||||
|
|
||||||
|
for id, node_dict in relation_dicts.items():
|
||||||
|
kg_relations[id] = RagRelation.model_validate(node_dict)
|
||||||
|
|
||||||
|
# clear the nodes, to load later
|
||||||
|
data["nodes"] = {}
|
||||||
|
data["relations"] = {}
|
||||||
|
|
||||||
|
# load the graph
|
||||||
|
graph = LabelledPropertyGraph.model_validate(data)
|
||||||
|
|
||||||
|
# add the node back
|
||||||
|
graph.nodes = kg_nodes
|
||||||
|
graph.relations = kg_relations
|
||||||
|
|
||||||
|
return cls(graph)
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
from llama_index.core.graph_stores.types import (
|
||||||
|
EntityNode,
|
||||||
|
Relation
|
||||||
|
)
|
||||||
|
|
||||||
|
class RagEntityNode(EntityNode):
|
||||||
|
uid : str
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return the string representation of the node."""
|
||||||
|
if self.properties:
|
||||||
|
prop = self.properties
|
||||||
|
if 'triplet_source_id' in prop:
|
||||||
|
prop.pop('triplet_source_id')
|
||||||
|
if len(prop) > 0:
|
||||||
|
return f"{self.name} ({prop})"
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> str:
|
||||||
|
"""Get the node id."""
|
||||||
|
#return self.name.replace('"', " ")
|
||||||
|
return self.uid
|
||||||
|
|
||||||
|
class RagRelation(Relation):
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return the string representation of the relation."""
|
||||||
|
if self.properties:
|
||||||
|
prop = self.properties
|
||||||
|
if 'triplet_source_id' in prop:
|
||||||
|
prop.pop('triplet_source_id')
|
||||||
|
if len(prop) > 0:
|
||||||
|
return f"{self.label} ({prop})"
|
||||||
|
return self.label
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> str:
|
||||||
|
"""Get the relation id."""
|
||||||
|
return self.label
|
||||||
@@ -2,7 +2,7 @@ from llama_index.core.indices.property_graph import LLMSynonymRetriever,VectorCo
|
|||||||
from llama_index.core.indices.property_graph.transformations.schema_llm import *
|
from llama_index.core.indices.property_graph.transformations.schema_llm import *
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
from llama_index.core import settings
|
from llama_index.core import settings
|
||||||
from llama_index.core import PropertyGraphIndex
|
from llama_index.core import PropertyGraphIndex,KnowledgeGraphIndex
|
||||||
from typing import List,Tuple,Literal
|
from typing import List,Tuple,Literal
|
||||||
from app.settings import init_settings
|
from app.settings import init_settings
|
||||||
import os
|
import os
|
||||||
@@ -15,7 +15,9 @@ from util.register import *
|
|||||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||||
from app.engine.engine import get_node_postprocessors
|
from app.engine.engine import get_node_postprocessors
|
||||||
|
from app.engine.graph.graphStore import RAGPropertyGraphStore
|
||||||
|
|
||||||
|
from app.engine.retriever.graphKeyWordRetriever import GraphKeyWordRetriever
|
||||||
class PropertyGraph:
|
class PropertyGraph:
|
||||||
def __init__(self,prjFlag:str) -> None:
|
def __init__(self,prjFlag:str) -> None:
|
||||||
self._prjFlag = prjFlag
|
self._prjFlag = prjFlag
|
||||||
@@ -44,16 +46,15 @@ class PropertyGraph:
|
|||||||
prjCachePath = GRAPH_STORAGE_DIR + f"/{self._prjFlag}"
|
prjCachePath = GRAPH_STORAGE_DIR + f"/{self._prjFlag}"
|
||||||
if not os.path.exists(prjCachePath):
|
if not os.path.exists(prjCachePath):
|
||||||
return None
|
return None
|
||||||
storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(self._prjFlag))
|
storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(self._prjFlag),
|
||||||
|
property_graph_store = RAGPropertyGraphStore.from_persist_dir(prjCachePath))
|
||||||
index = load_index_from_storage(storeContext)
|
index = load_index_from_storage(storeContext)
|
||||||
return index
|
return index
|
||||||
|
|
||||||
def query(self,query_str:str):
|
def query(self,query_str:str):
|
||||||
index = self.getPropertyGraphIndex()
|
index = self.getPropertyGraphIndex()
|
||||||
synonym_retriver = LLMSynonymRetriever(index.property_graph_store,
|
synonym_retriver = GraphKeyWordRetriever(index.property_graph_store,
|
||||||
llm=settings.Settings.llm,
|
include_text=False
|
||||||
max_keywords=10,
|
|
||||||
include_text=False
|
|
||||||
)
|
)
|
||||||
if index.property_graph_store.supports_vector_queries:
|
if index.property_graph_store.supports_vector_queries:
|
||||||
vector_store = None
|
vector_store = None
|
||||||
@@ -62,7 +63,7 @@ class PropertyGraph:
|
|||||||
vector_retriver = VectorContextRetriever(index.property_graph_store,
|
vector_retriver = VectorContextRetriever(index.property_graph_store,
|
||||||
vector_store = vector_store,
|
vector_store = vector_store,
|
||||||
embed_model=settings.Settings.embed_model,
|
embed_model=settings.Settings.embed_model,
|
||||||
similarity_top_k=5,
|
similarity_top_k=10,
|
||||||
include_text=False
|
include_text=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -77,8 +78,8 @@ class PropertyGraph:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
init_settings()
|
init_settings()
|
||||||
init_observability()
|
init_observability()
|
||||||
graph = PropertyGraph('projects_1b20bbf4-3243-4ac3-bcf0-8a91e9157521')
|
graph = PropertyGraph('projects_0ffaf7fb-8a61-46e2-97a2-8f924e9560a7')
|
||||||
graph.query('代码为XLBT的金额是')
|
graph.query('工程属性表有哪些字段')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Dict,Any
|
|||||||
from llama_index.core import PropertyGraphIndex
|
from llama_index.core import PropertyGraphIndex
|
||||||
from llama_index.core.storage.storage_context import StorageContext
|
from llama_index.core.storage.storage_context import StorageContext
|
||||||
from llama_index.core import load_index_from_storage
|
from llama_index.core import load_index_from_storage
|
||||||
|
from app.engine.graph.graphStore import RAGPropertyGraphStore
|
||||||
|
|
||||||
logger = logging.getLogger("uvicorn")
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
@@ -33,6 +34,7 @@ def getPropertyGraphIndex(prjFlag:str):
|
|||||||
prjCachePath = GRAPH_STORAGE_DIR + f"/{prjFlag}"
|
prjCachePath = GRAPH_STORAGE_DIR + f"/{prjFlag}"
|
||||||
if not os.path.exists(prjCachePath):
|
if not os.path.exists(prjCachePath):
|
||||||
return None
|
return None
|
||||||
storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(prjFlag))
|
storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(prjFlag),
|
||||||
|
property_graph_store = RAGPropertyGraphStore.from_persist_dir(prjCachePath))
|
||||||
index = load_index_from_storage(storeContext)
|
index = load_index_from_storage(storeContext)
|
||||||
return index
|
return index
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
from llama_index.llms.openai import OpenAI
|
||||||
|
from llama_index.core.base.llms.types import LLMMetadata
|
||||||
|
import os
|
||||||
|
|
||||||
|
class SiliconCloudOpenAI(OpenAI):
|
||||||
|
@property
|
||||||
|
def metadata(self) -> LLMMetadata:
|
||||||
|
bIsChat = os.getenv('IS_CHAT_MODEL')
|
||||||
|
bIsFuncall = os.getenv('IS_FUN_CALL_MODEL')
|
||||||
|
bIsChat = True if bIsChat.lower() in ['true','1'] else False
|
||||||
|
bIsFuncall = True if bIsFuncall.lower() in ['true','1'] else False
|
||||||
|
|
||||||
|
return LLMMetadata(
|
||||||
|
context_window= int(os.getenv('CONTEXT_WINDOW')),
|
||||||
|
num_output=self.max_tokens or -1,
|
||||||
|
is_chat_model=bIsChat,
|
||||||
|
is_function_calling_model=bIsFuncall,
|
||||||
|
model_name=self.model,
|
||||||
|
)
|
||||||
@@ -48,32 +48,46 @@ refine_template_str = (
|
|||||||
|
|
||||||
refine_template = PromptTemplate(refine_template_str)
|
refine_template = PromptTemplate(refine_template_str)
|
||||||
|
|
||||||
summary_template_str = (
|
# summary_template_str = (
|
||||||
"# 角色\n"
|
# "# 角色\n"
|
||||||
"你是一名擅长博微造价工程数据问答的专家,可以根据电力工程文件中的内容回答用户问题。\n"
|
# "你是一名擅长博微造价工程数据问答的专家,可以根据电力工程文件中的内容回答用户问题。\n"
|
||||||
"\n"
|
# "\n"
|
||||||
"# 任务描述:\n"
|
# "# 任务描述:\n"
|
||||||
"请仔细阅读所给的文档片段,并根据其内容回答问题。\n"
|
# "请仔细阅读所给的文档片段,并根据其内容回答问题。\n"
|
||||||
"您需要判断文档的内容是否可以回答问题,不要强行回答。如果可以回答,答案必须严格遵循文档内容,即使与事实不符。\n"
|
# "您需要判断文档的内容是否可以回答问题,不要强行回答。如果可以回答,答案必须严格遵循文档内容,即使与事实不符。\n"
|
||||||
"如果答案与事实不符,直接给出答案,不要做解释。\n"
|
# "如果答案与事实不符,直接给出答案,不要做解释。\n"
|
||||||
"\n"
|
# "\n"
|
||||||
"# 回答规则:\n"
|
# "# 回答规则:\n"
|
||||||
"- 请使用与文档材料相同的语言回答问题。\n"
|
# "- 请使用与文档材料相同的语言回答问题。\n"
|
||||||
"- 评估文档是否含有足够信息回答问题。无关时不要回答。\n"
|
# "- 评估文档是否含有足够信息回答问题。无关时不要回答。\n"
|
||||||
"- 如果问题能被回答,你的回答必须严格遵循文档内容,即使与事实不符。一定不要做多余解释。\n"
|
# "- 如果问题能被回答,你的回答必须严格遵循文档内容,即使与事实不符。一定不要做多余解释。\n"
|
||||||
"- 如果问题能被回答,直接引用文档的相关信息保证答案准确、完整,并追求简洁。\n"
|
# "- 如果问题能被回答,直接引用文档的相关信息保证答案准确、完整,并追求简洁。\n"
|
||||||
"- 当文档中只有少量信息与问题相关时,重点关注这部分信息,这种情况下一定回答。\n"
|
# "- 当文档中只有少量信息与问题相关时,重点关注这部分信息,这种情况下一定回答。\n"
|
||||||
"- 当文档中信息与问题无关时,请不要额外发散回答,只需要回答为' '"
|
# "- 当文档中信息与问题无关时,请不要额外发散回答,只需要回答为' '。\n"
|
||||||
"\n"
|
# "\n"
|
||||||
"来自多个来源的文档片段如下,请充分理解以下参考资料内容,组织出满足用户提问的条理清晰的回复。\n"
|
# "来自多个来源的文档片段如下,请充分理解以下参考资料内容,组织出满足用户提问的条理清晰的回复。\n"
|
||||||
"---------------------\n"
|
# "---------------------\n"
|
||||||
"{context_str}\n"
|
# "{context_str}\n"
|
||||||
"---------------------\n"
|
# "---------------------\n"
|
||||||
"鉴于来自多个来源的文档片段而非先验知识,回答查询。\n"
|
# "鉴于来自多个来源的文档片段而非先验知识,回答查询。\n"
|
||||||
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
# "如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||||
"Query: {query_str}\n"
|
# "Query: {query_str}\n"
|
||||||
"Answer: "
|
# "Answer: "
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
summary_template_str = """
|
||||||
|
你是一名擅长博微造价工程数据问答的专家,可以根据电力工程文件中的内容回答用户问题。
|
||||||
|
来自多个来源的文档片段如下,请充分理解以下参考资料内容,回答问题。
|
||||||
|
---------------------
|
||||||
|
{context_str}
|
||||||
|
---------------------
|
||||||
|
当你不知道答案的时候,不要编造答案,直接回答不知道,不需要解释为什么不知道。
|
||||||
|
问题: {query_str}
|
||||||
|
回答:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
summary_template = PromptTemplate(summary_template_str)
|
summary_template = PromptTemplate(summary_template_str)
|
||||||
|
|
||||||
simple_template_str = (
|
simple_template_str = (
|
||||||
|
|||||||
@@ -0,0 +1,73 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, List, Sequence, Optional
|
||||||
|
from llama_index.core.schema import BaseNode, NodeWithScore, QueryBundle
|
||||||
|
from llama_index.core.graph_stores.types import (
|
||||||
|
PropertyGraphStore,
|
||||||
|
KG_SOURCE_REL,
|
||||||
|
VECTOR_SOURCE_KEY,
|
||||||
|
)
|
||||||
|
from llama_index.core.indices.property_graph.sub_retrievers.base import BasePGRetriever
|
||||||
|
from llama_index.core.graph_stores.types import (
|
||||||
|
PropertyGraphStore,
|
||||||
|
KG_SOURCE_REL
|
||||||
|
)
|
||||||
|
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||||
|
|
||||||
|
class GraphBM25Retriever(BasePGRetriever):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
graph_store: PropertyGraphStore,
|
||||||
|
include_text: bool = True,
|
||||||
|
path_depth: int = 1,
|
||||||
|
similarity_score: Optional[float] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._path_depth = path_depth
|
||||||
|
self._similarity_score = similarity_score
|
||||||
|
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||||
|
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
|
||||||
|
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
|
||||||
|
super().__init__(graph_store=graph_store, include_text=include_text, **kwargs)
|
||||||
|
|
||||||
|
async def aretrieve_from_graph(
|
||||||
|
self, query_bundle: QueryBundle
|
||||||
|
) -> List[NodeWithScore]:
|
||||||
|
query_result:List[NodeWithScore] = self._bm25Retriever._retrieve(query_bundle.query_str)
|
||||||
|
nodes,scores = [],[]
|
||||||
|
for scoreNode in query_result:
|
||||||
|
nodes.append(scoreNode.node)
|
||||||
|
scores.append(scoreNode.score)
|
||||||
|
|
||||||
|
kg_ids = self._get_kg_ids(nodes)
|
||||||
|
kg_nodes = await self._graph_store.aget(ids=kg_ids)
|
||||||
|
triplets = await self._graph_store.aget_rel_map(
|
||||||
|
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
|
||||||
|
)
|
||||||
|
new_scores = []
|
||||||
|
for triplet in triplets:
|
||||||
|
score1 = (
|
||||||
|
scores[kg_ids.index(triplet[0].id)] if triplet[0].id in kg_ids else 0.0
|
||||||
|
)
|
||||||
|
score2 = (
|
||||||
|
scores[kg_ids.index(triplet[2].id)] if triplet[2].id in kg_ids else 0.0
|
||||||
|
)
|
||||||
|
new_scores.append(max(score1, score2))
|
||||||
|
|
||||||
|
assert len(triplets) == len(new_scores)
|
||||||
|
|
||||||
|
if self._similarity_score:
|
||||||
|
filtered_data = [
|
||||||
|
(triplet, score)
|
||||||
|
for triplet, score in zip(triplets, new_scores)
|
||||||
|
if score >= self._similarity_score
|
||||||
|
]
|
||||||
|
|
||||||
|
top_k = sorted(filtered_data, key=lambda x: x[1], reverse=True)
|
||||||
|
else:
|
||||||
|
top_k = sorted(zip(triplets, new_scores), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
return self._get_nodes_with_score([x[0] for x in top_k], [x[1] for x in top_k])
|
||||||
|
|
||||||
|
def _get_kg_ids(self, kg_nodes: Sequence[BaseNode]) -> List[str]:
|
||||||
|
"""Backward compatibility method to get kg_ids from kg_nodes."""
|
||||||
|
return [node.metadata.get(VECTOR_SOURCE_KEY, node.id_) for node in kg_nodes]
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
from llama_index.core.llms.llm import LLM
|
||||||
|
from llama_index.core.indices.property_graph.sub_retrievers.base import (
|
||||||
|
BasePGRetriever,
|
||||||
|
)
|
||||||
|
from llama_index.core.graph_stores.types import (
|
||||||
|
PropertyGraphStore,
|
||||||
|
KG_SOURCE_REL,
|
||||||
|
)
|
||||||
|
from llama_index.core.settings import Settings
|
||||||
|
from llama_index.core.schema import (
|
||||||
|
NodeWithScore,
|
||||||
|
QueryBundle,
|
||||||
|
)
|
||||||
|
from llama_index.core.graph_stores.types import EntityNode
|
||||||
|
|
||||||
|
|
||||||
|
class GraphKeyWordRetriever(BasePGRetriever):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
graph_store: PropertyGraphStore,
|
||||||
|
include_text: bool = True,
|
||||||
|
path_depth: int = 1,
|
||||||
|
llm: Optional[LLM] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._llm = llm or Settings.llm
|
||||||
|
self._path_depth = path_depth
|
||||||
|
super().__init__(graph_store=graph_store, include_text=include_text, **kwargs)
|
||||||
|
|
||||||
|
def _prepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||||
|
kg_nodes = []
|
||||||
|
labelNodes = self._graph_store.get()
|
||||||
|
for labelNode in labelNodes:
|
||||||
|
if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str:
|
||||||
|
kg_nodes.append(labelNode)
|
||||||
|
triplets = self._graph_store.get_rel_map(
|
||||||
|
kg_nodes,
|
||||||
|
depth=self._path_depth,
|
||||||
|
ignore_rels=[KG_SOURCE_REL],
|
||||||
|
)
|
||||||
|
return self._get_nodes_with_score(triplets)
|
||||||
|
|
||||||
|
async def _aprepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||||
|
kg_nodes = []
|
||||||
|
labelNodes = await self._graph_store.aget()
|
||||||
|
for labelNode in labelNodes:
|
||||||
|
if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str:
|
||||||
|
kg_nodes.append(labelNode)
|
||||||
|
triplets = await self._graph_store.aget_rel_map(
|
||||||
|
kg_nodes,
|
||||||
|
depth=self._path_depth,
|
||||||
|
ignore_rels=[KG_SOURCE_REL],
|
||||||
|
)
|
||||||
|
return self._get_nodes_with_score(triplets)
|
||||||
|
|
||||||
|
def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||||
|
return self._prepare_matches(query_bundle)
|
||||||
|
|
||||||
|
async def aretrieve_from_graph(
|
||||||
|
self, query_bundle: QueryBundle
|
||||||
|
) -> List[NodeWithScore]:
|
||||||
|
return await self._aprepare_matches(query_bundle)
|
||||||
+5
-10
@@ -89,7 +89,6 @@ class OllamaPlatform(ModelPlatform):
|
|||||||
)
|
)
|
||||||
return [rerank]
|
return [rerank]
|
||||||
|
|
||||||
|
|
||||||
@register(ModelPlateCategory,'xinference')
|
@register(ModelPlateCategory,'xinference')
|
||||||
class XinferencePlatform(ModelPlatform):
|
class XinferencePlatform(ModelPlatform):
|
||||||
def model(self):
|
def model(self):
|
||||||
@@ -123,15 +122,11 @@ class XinferencePlatform(ModelPlatform):
|
|||||||
class OpenAIPlatform(ModelPlatform):
|
class OpenAIPlatform(ModelPlatform):
|
||||||
def model(self):
|
def model(self):
|
||||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
from llama_index.llms.openai import OpenAI
|
from app.engine.model.siliconCloudOpenAI import SiliconCloudOpenAI
|
||||||
|
return SiliconCloudOpenAI(api_key= os.getenv('OPENAI_API_KEY'),
|
||||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
api_base= os.getenv('BASE_URL'),
|
||||||
config = {
|
model= os.getenv('MODEL'),
|
||||||
"model": os.getenv("MODEL"),
|
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)))
|
||||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
|
||||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
|
||||||
}
|
|
||||||
return OpenAI(**config)
|
|
||||||
|
|
||||||
def embedding(self):
|
def embedding(self):
|
||||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||||
|
|||||||
Reference in New Issue
Block a user