From f7260da6d9e176dbfb234b2c0f9cf500bd58f171 Mon Sep 17 00:00:00 2001 From: wanyaokun <12345678> Date: Fri, 20 Sep 2024 17:34:38 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=B1=9E=E6=80=A7=E5=9B=BE?= =?UTF-8?q?=E6=A3=80=E7=B4=A2=E5=8A=9F=E8=83=BD=E5=8F=8A=E6=94=AF=E6=8C=81?= =?UTF-8?q?OpenAI=E7=BA=BF=E4=B8=8A=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/.env.example | 9 +++ backend/app/engine/generate.py | 4 +- backend/app/engine/graph/extractor.py | 74 +++++++++++-------- backend/app/engine/graph/graphStore.py | 42 +++++++++++ backend/app/engine/graph/graphTypes.py | 38 ++++++++++ backend/app/engine/graph/propertyGraph.py | 19 ++--- backend/app/engine/index.py | 4 +- .../app/engine/model/siliconCloudOpenAI.py | 19 +++++ backend/app/engine/prompt.py | 66 ++++++++++------- .../engine/retriever/graphBM25Retriever.py | 73 ++++++++++++++++++ .../engine/retriever/graphKeyWordRetriever.py | 63 ++++++++++++++++ backend/app/settings.py | 15 ++-- 12 files changed, 350 insertions(+), 76 deletions(-) create mode 100644 backend/app/engine/graph/graphStore.py create mode 100644 backend/app/engine/graph/graphTypes.py create mode 100644 backend/app/engine/model/siliconCloudOpenAI.py create mode 100644 backend/app/engine/retriever/graphBM25Retriever.py create mode 100644 backend/app/engine/retriever/graphKeyWordRetriever.py diff --git a/backend/.env.example b/backend/.env.example index c3d85dd..05200a6 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -41,6 +41,15 @@ MODEL_PROVIDER=dashscope DASHSCOPE_API_KEY=sk-221d2d202e104618a56002ce2e7dc0d0 MODEL=qwen2-math-72b-instruct +# #---------- model - openai ---------------- +# MODEL_PROVIDER=openai +# OPENAI_API_KEY=sk-hhoqttvhibirwheyponjifsqwssgxotoqlcjufkidytwxngi +# BASE_URL=https://api.siliconflow.cn/v1 +# MODEL=alibaba/Qwen1.5-110B-Chat +# LLM_TEMPERATURE=0.1 +# CONTEXT_WINDOW = 8192 +# IS_CHAT_MODEL = true +# IS_FUN_CALL_MODEL = false #---------- embedding - Xinference ---------------- diff --git a/backend/app/engine/generate.py b/backend/app/engine/generate.py index e2e7a9f..dd1ea5c 100644 --- a/backend/app/engine/generate.py +++ b/backend/app/engine/generate.py @@ -16,6 +16,7 @@ from llama_index.core.storage import StorageContext from llama_index.core.storage.docstore import SimpleDocumentStore from llama_index.core import PropertyGraphIndex from app.engine.graph.extractor import PrjGraphExtractor +from app.engine.graph.graphStore import RAGPropertyGraphStore logging.basicConfig(level=logging.INFO) logger = logging.getLogger() @@ -103,7 +104,8 @@ class PropertyGraphChache: def simplePropertyGraph(self,prjName:str,prjFlag:str,filePath:str): documents = get_documents(prjFlag) - storeContext = StorageContext.from_defaults(vector_store=get_vector_store(prjFlag)) + storeContext = StorageContext.from_defaults(vector_store=get_vector_store(prjFlag), + property_graph_store = RAGPropertyGraphStore()) index = PropertyGraphIndex( nodes =documents, kg_extractors = [PrjGraphExtractor(prjName)], diff --git a/backend/app/engine/graph/extractor.py b/backend/app/engine/graph/extractor.py index f5e8a8e..2fbe41c 100644 --- a/backend/app/engine/graph/extractor.py +++ b/backend/app/engine/graph/extractor.py @@ -1,21 +1,20 @@ import os from llama_index.core.schema import TransformComponent, BaseNode from llama_index.core.graph_stores.types import ( - EntityNode, - Relation, - Triplet, KG_NODES_KEY, KG_RELATIONS_KEY, ) -from app.engine.loaders.projectJson import ProjectJson from app.engine.loaders.markdownReader import ChunkMarkdownReader +from app.engine.graph.graphTypes import * +import uuid class PrjGraphExtractor(TransformComponent): ProjectName:str + _nodeMaps = {} + _prjID = '' def __init__(self,PrjName:str): super().__init__(ProjectName = PrjName) - - + def __call__( self, llama_nodes: list[BaseNode], **kwargs ) -> list[BaseNode]: @@ -38,9 +37,9 @@ class PrjGraphExtractor(TransformComponent): records:dict[str,list] = self._getRecordNode(llama_node) fInfos = fileName.split('_') if len(fInfos) == 1: - existing_nodes.append(EntityNode(name=fInfos[0], label=fInfos[0])) + fileID = self._add_node(existing_nodes = existing_nodes,name=fInfos[0], label=fInfos[0]) elif len(fInfos) == 2: - existing_nodes.append(EntityNode(name=fileName, label=fInfos[1])) + fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fInfos[1]) else: raise ValueError("文件名存在多个下划线") @@ -48,22 +47,20 @@ class PrjGraphExtractor(TransformComponent): for record in records: index = index + 1 rcdName = self._getRecordName(fileName,record) - existing_nodes.append(EntityNode(name=rcdName, label=rcdName,properties = record)) + rcdid = self._add_node(existing_nodes = existing_nodes,name=rcdName, label=rcdName,properties = record) existing_relations.append( - Relation( + RagRelation( label="包含", - source_id= fileName, - target_id= rcdName, - properties={}, + source_id= fileID, + target_id= rcdid ) ) existing_relations.append( - Relation( + RagRelation( label="包含", - source_id= self.ProjectName, - target_id= fileName, - properties={}, + source_id= self._prjID, + target_id= fileID ) ) @@ -76,28 +73,26 @@ class PrjGraphExtractor(TransformComponent): existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, []) existing_relations:list = llama_node.metadata.pop(KG_RELATIONS_KEY, []) records:dict[str,list] = self._getRecordNode(llama_node) - existing_nodes.append(EntityNode(name=fileName, label=fileName)) + fileID = self._add_node(existing_nodes = existing_nodes,name=fileName, label=fileName) index = 0 for record in records: index = index + 1 attName = self._getRecordName(fileName,record) - existing_nodes.append(EntityNode(name=attName, label='属性',properties = record)) + attID = self._add_node(existing_nodes = existing_nodes,name=attName, label= attName,properties = record) existing_relations.append( - Relation( + RagRelation( label="聚合", - source_id= fileName, - target_id= attName, - properties={}, + source_id= fileID, + target_id= attID ) ) existing_relations.append( - Relation( + RagRelation( label="包含", - source_id= self.ProjectName, - target_id= fileName, - properties={}, + source_id= self._prjID, + target_id= fileID ) ) @@ -118,11 +113,32 @@ class PrjGraphExtractor(TransformComponent): def _addPrjNode(self,llama_node:BaseNode): existing_nodes:list = llama_node.metadata.pop(KG_NODES_KEY, []) - existing_nodes.append(EntityNode(name=self.ProjectName, label=self.ProjectName)) + self._prjID = self._add_node(existing_nodes = existing_nodes,name=self.ProjectName,label=self.ProjectName) llama_node.metadata[KG_NODES_KEY] = existing_nodes def _getRecordName(self,fileName:str,record:dict): for name,value in record.items(): if '名称' in name: return value - raise ValueError('记录名称为空') \ No newline at end of file + raise ValueError('记录名称为空') + + def _add_node(self,existing_nodes:list,name:str,label:str,properties:dict = {}): + id:str = '' + if name in self._nodeMaps: + nodes:list[RagEntityNode] = self._nodeMaps[name] + for node in nodes: + if node.properties == properties and node.name == name and node.label == label: + id = node.id + break + + if id =='': + id = str(uuid.uuid1()) + newNode = RagEntityNode(name = name,label=label,properties = properties,uid = id) + existing_nodes.append(newNode) + if name in self._nodeMaps: + nodes:list[RagEntityNode] = self._nodeMaps[name] + nodes.append(newNode) + else: + self._nodeMaps[name] = [newNode] + + return id \ No newline at end of file diff --git a/backend/app/engine/graph/graphStore.py b/backend/app/engine/graph/graphStore.py new file mode 100644 index 0000000..60ac9ba --- /dev/null +++ b/backend/app/engine/graph/graphStore.py @@ -0,0 +1,42 @@ + +from typing import Any, List, Dict, Sequence, Tuple, Optional +from llama_index.core.graph_stores.simple_labelled import SimplePropertyGraphStore +from llama_index.core.graph_stores.types import LabelledNode,ChunkNode,LabelledPropertyGraph +from app.engine.graph.graphTypes import * + +class RAGPropertyGraphStore(SimplePropertyGraphStore): + @classmethod + def from_dict( + cls, + data: dict, + ) -> "SimplePropertyGraphStore": + """Load from dict.""" + # need to load nodes manually + node_dicts = data["nodes"] + relation_dicts = data["relations"] + + kg_nodes: Dict[str, LabelledNode] = {} + kg_relations: Dict[str, LabelledNode] = {} + for id, node_dict in node_dicts.items(): + if "name" in node_dict: + kg_nodes[id] = RagEntityNode.model_validate(node_dict) + elif "text" in node_dict: + kg_nodes[id] = ChunkNode.model_validate(node_dict) + else: + raise ValueError(f"Could not infer node type for data: {node_dict!s}") + + for id, node_dict in relation_dicts.items(): + kg_relations[id] = RagRelation.model_validate(node_dict) + + # clear the nodes, to load later + data["nodes"] = {} + data["relations"] = {} + + # load the graph + graph = LabelledPropertyGraph.model_validate(data) + + # add the node back + graph.nodes = kg_nodes + graph.relations = kg_relations + + return cls(graph) \ No newline at end of file diff --git a/backend/app/engine/graph/graphTypes.py b/backend/app/engine/graph/graphTypes.py new file mode 100644 index 0000000..6202c7b --- /dev/null +++ b/backend/app/engine/graph/graphTypes.py @@ -0,0 +1,38 @@ +from llama_index.core.graph_stores.types import ( + EntityNode, + Relation +) + +class RagEntityNode(EntityNode): + uid : str + def __str__(self) -> str: + """Return the string representation of the node.""" + if self.properties: + prop = self.properties + if 'triplet_source_id' in prop: + prop.pop('triplet_source_id') + if len(prop) > 0: + return f"{self.name} ({prop})" + return self.name + + @property + def id(self) -> str: + """Get the node id.""" + #return self.name.replace('"', " ") + return self.uid + +class RagRelation(Relation): + def __str__(self) -> str: + """Return the string representation of the relation.""" + if self.properties: + prop = self.properties + if 'triplet_source_id' in prop: + prop.pop('triplet_source_id') + if len(prop) > 0: + return f"{self.label} ({prop})" + return self.label + + @property + def id(self) -> str: + """Get the relation id.""" + return self.label \ No newline at end of file diff --git a/backend/app/engine/graph/propertyGraph.py b/backend/app/engine/graph/propertyGraph.py index cba20c6..89f868d 100644 --- a/backend/app/engine/graph/propertyGraph.py +++ b/backend/app/engine/graph/propertyGraph.py @@ -2,7 +2,7 @@ from llama_index.core.indices.property_graph import LLMSynonymRetriever,VectorCo from llama_index.core.indices.property_graph.transformations.schema_llm import * from llama_index.core import SimpleDirectoryReader from llama_index.core import settings -from llama_index.core import PropertyGraphIndex +from llama_index.core import PropertyGraphIndex,KnowledgeGraphIndex from typing import List,Tuple,Literal from app.settings import init_settings import os @@ -15,7 +15,9 @@ from util.register import * from llama_index.core.query_engine import RetrieverQueryEngine from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template from app.engine.engine import get_node_postprocessors +from app.engine.graph.graphStore import RAGPropertyGraphStore +from app.engine.retriever.graphKeyWordRetriever import GraphKeyWordRetriever class PropertyGraph: def __init__(self,prjFlag:str) -> None: self._prjFlag = prjFlag @@ -44,16 +46,15 @@ class PropertyGraph: prjCachePath = GRAPH_STORAGE_DIR + f"/{self._prjFlag}" if not os.path.exists(prjCachePath): return None - storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(self._prjFlag)) + storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(self._prjFlag), + property_graph_store = RAGPropertyGraphStore.from_persist_dir(prjCachePath)) index = load_index_from_storage(storeContext) return index def query(self,query_str:str): index = self.getPropertyGraphIndex() - synonym_retriver = LLMSynonymRetriever(index.property_graph_store, - llm=settings.Settings.llm, - max_keywords=10, - include_text=False + synonym_retriver = GraphKeyWordRetriever(index.property_graph_store, + include_text=False ) if index.property_graph_store.supports_vector_queries: vector_store = None @@ -62,7 +63,7 @@ class PropertyGraph: vector_retriver = VectorContextRetriever(index.property_graph_store, vector_store = vector_store, embed_model=settings.Settings.embed_model, - similarity_top_k=5, + similarity_top_k=10, include_text=False ) @@ -77,8 +78,8 @@ class PropertyGraph: if __name__ == "__main__": init_settings() init_observability() - graph = PropertyGraph('projects_1b20bbf4-3243-4ac3-bcf0-8a91e9157521') - graph.query('代码为XLBT的金额是') + graph = PropertyGraph('projects_0ffaf7fb-8a61-46e2-97a2-8f924e9560a7') + graph.query('工程属性表有哪些字段') diff --git a/backend/app/engine/index.py b/backend/app/engine/index.py index 176e445..3351df4 100644 --- a/backend/app/engine/index.py +++ b/backend/app/engine/index.py @@ -5,6 +5,7 @@ from typing import Dict,Any from llama_index.core import PropertyGraphIndex from llama_index.core.storage.storage_context import StorageContext from llama_index.core import load_index_from_storage +from app.engine.graph.graphStore import RAGPropertyGraphStore logger = logging.getLogger("uvicorn") @@ -33,6 +34,7 @@ def getPropertyGraphIndex(prjFlag:str): prjCachePath = GRAPH_STORAGE_DIR + f"/{prjFlag}" if not os.path.exists(prjCachePath): return None - storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(prjFlag)) + storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(prjFlag), + property_graph_store = RAGPropertyGraphStore.from_persist_dir(prjCachePath)) index = load_index_from_storage(storeContext) return index \ No newline at end of file diff --git a/backend/app/engine/model/siliconCloudOpenAI.py b/backend/app/engine/model/siliconCloudOpenAI.py new file mode 100644 index 0000000..1d68b63 --- /dev/null +++ b/backend/app/engine/model/siliconCloudOpenAI.py @@ -0,0 +1,19 @@ +from llama_index.llms.openai import OpenAI +from llama_index.core.base.llms.types import LLMMetadata +import os + +class SiliconCloudOpenAI(OpenAI): + @property + def metadata(self) -> LLMMetadata: + bIsChat = os.getenv('IS_CHAT_MODEL') + bIsFuncall = os.getenv('IS_FUN_CALL_MODEL') + bIsChat = True if bIsChat.lower() in ['true','1'] else False + bIsFuncall = True if bIsFuncall.lower() in ['true','1'] else False + + return LLMMetadata( + context_window= int(os.getenv('CONTEXT_WINDOW')), + num_output=self.max_tokens or -1, + is_chat_model=bIsChat, + is_function_calling_model=bIsFuncall, + model_name=self.model, + ) \ No newline at end of file diff --git a/backend/app/engine/prompt.py b/backend/app/engine/prompt.py index 5e3c954..1f90ebe 100644 --- a/backend/app/engine/prompt.py +++ b/backend/app/engine/prompt.py @@ -48,32 +48,46 @@ refine_template_str = ( refine_template = PromptTemplate(refine_template_str) -summary_template_str = ( - "# 角色\n" - "你是一名擅长博微造价工程数据问答的专家,可以根据电力工程文件中的内容回答用户问题。\n" - "\n" - "# 任务描述:\n" - "请仔细阅读所给的文档片段,并根据其内容回答问题。\n" - "您需要判断文档的内容是否可以回答问题,不要强行回答。如果可以回答,答案必须严格遵循文档内容,即使与事实不符。\n" - "如果答案与事实不符,直接给出答案,不要做解释。\n" - "\n" - "# 回答规则:\n" - "- 请使用与文档材料相同的语言回答问题。\n" - "- 评估文档是否含有足够信息回答问题。无关时不要回答。\n" - "- 如果问题能被回答,你的回答必须严格遵循文档内容,即使与事实不符。一定不要做多余解释。\n" - "- 如果问题能被回答,直接引用文档的相关信息保证答案准确、完整,并追求简洁。\n" - "- 当文档中只有少量信息与问题相关时,重点关注这部分信息,这种情况下一定回答。\n" - "- 当文档中信息与问题无关时,请不要额外发散回答,只需要回答为' '" - "\n" - "来自多个来源的文档片段如下,请充分理解以下参考资料内容,组织出满足用户提问的条理清晰的回复。\n" - "---------------------\n" - "{context_str}\n" - "---------------------\n" - "鉴于来自多个来源的文档片段而非先验知识,回答查询。\n" - "如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n" - "Query: {query_str}\n" - "Answer: " -) +# summary_template_str = ( +# "# 角色\n" +# "你是一名擅长博微造价工程数据问答的专家,可以根据电力工程文件中的内容回答用户问题。\n" +# "\n" +# "# 任务描述:\n" +# "请仔细阅读所给的文档片段,并根据其内容回答问题。\n" +# "您需要判断文档的内容是否可以回答问题,不要强行回答。如果可以回答,答案必须严格遵循文档内容,即使与事实不符。\n" +# "如果答案与事实不符,直接给出答案,不要做解释。\n" +# "\n" +# "# 回答规则:\n" +# "- 请使用与文档材料相同的语言回答问题。\n" +# "- 评估文档是否含有足够信息回答问题。无关时不要回答。\n" +# "- 如果问题能被回答,你的回答必须严格遵循文档内容,即使与事实不符。一定不要做多余解释。\n" +# "- 如果问题能被回答,直接引用文档的相关信息保证答案准确、完整,并追求简洁。\n" +# "- 当文档中只有少量信息与问题相关时,重点关注这部分信息,这种情况下一定回答。\n" +# "- 当文档中信息与问题无关时,请不要额外发散回答,只需要回答为' '。\n" +# "\n" +# "来自多个来源的文档片段如下,请充分理解以下参考资料内容,组织出满足用户提问的条理清晰的回复。\n" +# "---------------------\n" +# "{context_str}\n" +# "---------------------\n" +# "鉴于来自多个来源的文档片段而非先验知识,回答查询。\n" +# "如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n" +# "Query: {query_str}\n" +# "Answer: " +# ) + + +summary_template_str = """ + 你是一名擅长博微造价工程数据问答的专家,可以根据电力工程文件中的内容回答用户问题。 + 来自多个来源的文档片段如下,请充分理解以下参考资料内容,回答问题。 + --------------------- + {context_str} + --------------------- + 当你不知道答案的时候,不要编造答案,直接回答不知道,不需要解释为什么不知道。 + 问题: {query_str} + 回答: +""" + + summary_template = PromptTemplate(summary_template_str) simple_template_str = ( diff --git a/backend/app/engine/retriever/graphBM25Retriever.py b/backend/app/engine/retriever/graphBM25Retriever.py new file mode 100644 index 0000000..a9ed361 --- /dev/null +++ b/backend/app/engine/retriever/graphBM25Retriever.py @@ -0,0 +1,73 @@ +import os +from typing import Any, List, Sequence, Optional +from llama_index.core.schema import BaseNode, NodeWithScore, QueryBundle +from llama_index.core.graph_stores.types import ( + PropertyGraphStore, + KG_SOURCE_REL, + VECTOR_SOURCE_KEY, +) +from llama_index.core.indices.property_graph.sub_retrievers.base import BasePGRetriever +from llama_index.core.graph_stores.types import ( + PropertyGraphStore, + KG_SOURCE_REL +) +from app.engine.retriever.CHBM25Retriever import CHBM25Retriever + +class GraphBM25Retriever(BasePGRetriever): + def __init__( + self, + graph_store: PropertyGraphStore, + include_text: bool = True, + path_depth: int = 1, + similarity_score: Optional[float] = None, + **kwargs: Any, + ) -> None: + self._path_depth = path_depth + self._similarity_score = similarity_score + STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm") + if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0: + self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR) + super().__init__(graph_store=graph_store, include_text=include_text, **kwargs) + + async def aretrieve_from_graph( + self, query_bundle: QueryBundle + ) -> List[NodeWithScore]: + query_result:List[NodeWithScore] = self._bm25Retriever._retrieve(query_bundle.query_str) + nodes,scores = [],[] + for scoreNode in query_result: + nodes.append(scoreNode.node) + scores.append(scoreNode.score) + + kg_ids = self._get_kg_ids(nodes) + kg_nodes = await self._graph_store.aget(ids=kg_ids) + triplets = await self._graph_store.aget_rel_map( + kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + ) + new_scores = [] + for triplet in triplets: + score1 = ( + scores[kg_ids.index(triplet[0].id)] if triplet[0].id in kg_ids else 0.0 + ) + score2 = ( + scores[kg_ids.index(triplet[2].id)] if triplet[2].id in kg_ids else 0.0 + ) + new_scores.append(max(score1, score2)) + + assert len(triplets) == len(new_scores) + + if self._similarity_score: + filtered_data = [ + (triplet, score) + for triplet, score in zip(triplets, new_scores) + if score >= self._similarity_score + ] + + top_k = sorted(filtered_data, key=lambda x: x[1], reverse=True) + else: + top_k = sorted(zip(triplets, new_scores), key=lambda x: x[1], reverse=True) + + return self._get_nodes_with_score([x[0] for x in top_k], [x[1] for x in top_k]) + + def _get_kg_ids(self, kg_nodes: Sequence[BaseNode]) -> List[str]: + """Backward compatibility method to get kg_ids from kg_nodes.""" + return [node.metadata.get(VECTOR_SOURCE_KEY, node.id_) for node in kg_nodes] \ No newline at end of file diff --git a/backend/app/engine/retriever/graphKeyWordRetriever.py b/backend/app/engine/retriever/graphKeyWordRetriever.py new file mode 100644 index 0000000..305fd78 --- /dev/null +++ b/backend/app/engine/retriever/graphKeyWordRetriever.py @@ -0,0 +1,63 @@ +from typing import Any, Callable, List, Optional, Union +from llama_index.core.llms.llm import LLM +from llama_index.core.indices.property_graph.sub_retrievers.base import ( + BasePGRetriever, +) +from llama_index.core.graph_stores.types import ( + PropertyGraphStore, + KG_SOURCE_REL, +) +from llama_index.core.settings import Settings +from llama_index.core.schema import ( + NodeWithScore, + QueryBundle, +) +from llama_index.core.graph_stores.types import EntityNode + + +class GraphKeyWordRetriever(BasePGRetriever): + def __init__( + self, + graph_store: PropertyGraphStore, + include_text: bool = True, + path_depth: int = 1, + llm: Optional[LLM] = None, + **kwargs: Any, + ) -> None: + self._llm = llm or Settings.llm + self._path_depth = path_depth + super().__init__(graph_store=graph_store, include_text=include_text, **kwargs) + + def _prepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]: + kg_nodes = [] + labelNodes = self._graph_store.get() + for labelNode in labelNodes: + if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str: + kg_nodes.append(labelNode) + triplets = self._graph_store.get_rel_map( + kg_nodes, + depth=self._path_depth, + ignore_rels=[KG_SOURCE_REL], + ) + return self._get_nodes_with_score(triplets) + + async def _aprepare_matches(self,query_bundle: QueryBundle) -> List[NodeWithScore]: + kg_nodes = [] + labelNodes = await self._graph_store.aget() + for labelNode in labelNodes: + if isinstance(labelNode,EntityNode) and labelNode.name in query_bundle.query_str: + kg_nodes.append(labelNode) + triplets = await self._graph_store.aget_rel_map( + kg_nodes, + depth=self._path_depth, + ignore_rels=[KG_SOURCE_REL], + ) + return self._get_nodes_with_score(triplets) + + def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + return self._prepare_matches(query_bundle) + + async def aretrieve_from_graph( + self, query_bundle: QueryBundle + ) -> List[NodeWithScore]: + return await self._aprepare_matches(query_bundle) diff --git a/backend/app/settings.py b/backend/app/settings.py index e91a971..45d5ad6 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -89,7 +89,6 @@ class OllamaPlatform(ModelPlatform): ) return [rerank] - @register(ModelPlateCategory,'xinference') class XinferencePlatform(ModelPlatform): def model(self): @@ -123,15 +122,11 @@ class XinferencePlatform(ModelPlatform): class OpenAIPlatform(ModelPlatform): def model(self): from llama_index.core.constants import DEFAULT_TEMPERATURE - from llama_index.llms.openai import OpenAI - - max_tokens = os.getenv("LLM_MAX_TOKENS") - config = { - "model": os.getenv("MODEL"), - "temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)), - "max_tokens": int(max_tokens) if max_tokens is not None else None, - } - return OpenAI(**config) + from app.engine.model.siliconCloudOpenAI import SiliconCloudOpenAI + return SiliconCloudOpenAI(api_key= os.getenv('OPENAI_API_KEY'), + api_base= os.getenv('BASE_URL'), + model= os.getenv('MODEL'), + temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE))) def embedding(self): from llama_index.embeddings.openai import OpenAIEmbedding