优化属性图检索功能及支持OpenAI线上模型

This commit is contained in:
wanyaokun
2024-09-20 17:34:38 +08:00
parent 092f7230c1
commit f7260da6d9
12 changed files with 350 additions and 76 deletions
+10 -9
View File
@@ -2,7 +2,7 @@ from llama_index.core.indices.property_graph import LLMSynonymRetriever,VectorCo
from llama_index.core.indices.property_graph.transformations.schema_llm import *
from llama_index.core import SimpleDirectoryReader
from llama_index.core import settings
from llama_index.core import PropertyGraphIndex
from llama_index.core import PropertyGraphIndex,KnowledgeGraphIndex
from typing import List,Tuple,Literal
from app.settings import init_settings
import os
@@ -15,7 +15,9 @@ from util.register import *
from llama_index.core.query_engine import RetrieverQueryEngine
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
from app.engine.engine import get_node_postprocessors
from app.engine.graph.graphStore import RAGPropertyGraphStore
from app.engine.retriever.graphKeyWordRetriever import GraphKeyWordRetriever
class PropertyGraph:
def __init__(self,prjFlag:str) -> None:
self._prjFlag = prjFlag
@@ -44,16 +46,15 @@ class PropertyGraph:
prjCachePath = GRAPH_STORAGE_DIR + f"/{self._prjFlag}"
if not os.path.exists(prjCachePath):
return None
storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(self._prjFlag))
storeContext = StorageContext.from_defaults(persist_dir = prjCachePath,vector_store = get_vector_store(self._prjFlag),
property_graph_store = RAGPropertyGraphStore.from_persist_dir(prjCachePath))
index = load_index_from_storage(storeContext)
return index
def query(self,query_str:str):
index = self.getPropertyGraphIndex()
synonym_retriver = LLMSynonymRetriever(index.property_graph_store,
llm=settings.Settings.llm,
max_keywords=10,
include_text=False
synonym_retriver = GraphKeyWordRetriever(index.property_graph_store,
include_text=False
)
if index.property_graph_store.supports_vector_queries:
vector_store = None
@@ -62,7 +63,7 @@ class PropertyGraph:
vector_retriver = VectorContextRetriever(index.property_graph_store,
vector_store = vector_store,
embed_model=settings.Settings.embed_model,
similarity_top_k=5,
similarity_top_k=10,
include_text=False
)
@@ -77,8 +78,8 @@ class PropertyGraph:
if __name__ == "__main__":
init_settings()
init_observability()
graph = PropertyGraph('projects_1b20bbf4-3243-4ac3-bcf0-8a91e9157521')
graph.query('代码为XLBT的金额是')
graph = PropertyGraph('projects_0ffaf7fb-8a61-46e2-97a2-8f924e9560a7')
graph.query('工程属性表有哪些字段')