新增工程数据

This commit is contained in:
wanyaokun
2024-08-30 19:12:40 +08:00
parent c4088fe963
commit ced3199550
37 changed files with 157 additions and 35 deletions
+2 -10
View File
@@ -13,9 +13,7 @@ from app.api.routers.request.base import ProjectInfo
def getPrjFalg(params:dict=None)->str:
prjFlag = ''
if params is not None:
inputs:dict = params.get('inputs')
if inputs is not None:
prjFlag = ProjectInfo.prjFalg(inputs.get('projectname'))
prjFlag = ProjectInfo().prjFalg(params.get('projectname'))
return prjFlag
@@ -33,13 +31,7 @@ def get_chat_engine(filters=None, params:dict=None):
#tools.append(sql_query_tool)
# Add query tool if index exists
prjFlag = ''
if params is not None:
inputs:dict = params.get('inputs')
if inputs is not None:
prjFlag = inputs.get('projectname')
index = get_index(prjFlag = getPrjFalg(params))
index = get_index(getPrjFalg(params))
if index is not None:
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
+4 -10
View File
@@ -5,17 +5,11 @@ from app.engine.loaders import get_document_Types
from typing import Dict,Any
logger = logging.getLogger("uvicorn")
def get_index(**args):
def get_index(prjFlag:str):
if prjFlag is None or prjFlag == '':
raise ValueError('无效的工程标识')
logger.info("Connecting vector store...")
if 'prjFlag' in args:
prjFlags = get_document_Types()
if len(prjFlags)<=0:
return None
prjFlag = args.get('prjFlag','')
flag = prjFlags[0] if prjFlag not in prjFlags else prjFlag
else:
flag = ''
store = get_vector_store(flag)
store = get_vector_store(prjFlag)
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished load index from vector store.")
return index
+28 -1
View File
@@ -3,8 +3,10 @@ import yaml
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
from app.engine.loaders.projectJson import getProjectName
import os
logger = logging.getLogger(__name__)
def load_configs():
@@ -55,6 +57,31 @@ def get_document_Types():
types.append(path_difference(rootPath,curDir))
return types
def getProjectInfos():
config = load_configs()
if config is None or len(config.items()) == 0:
return None
prjDir = None
for loader_type, loader_config in config.items():
if loader_config.get('enable', True):
loader_config = loader_config or []
config = FileLoaderConfig(**loader_config)
prjDir = config.data_dir
break
if prjDir is None:
return None
prjInfos = []
prjFlags = get_document_Types()
for prjFlag in prjFlags:
fileDir = os.path.join(config.data_dir,prjFlag.replace('_','\\'))
prjInfo = {}
prjInfo['flag'] = prjFlag
prjInfo['name'] = getProjectName(fileDir)
prjInfos.append(prjInfo)
return prjInfos
def get_documents(docType:str):
documents = []
config = load_configs()
@@ -80,4 +107,4 @@ def get_documents(docType:str):
raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document)
return documents
return documents
+73
View File
@@ -0,0 +1,73 @@
from typing import Dict,List,Any
import json,os
class Record:
def __init__(self,datas:Dict[str,Any]) -> None:
self._datas:Dict[str,Any] = datas
def value(self,key:str):
if key in self._datas:
return self._datas.get(key)
return ''
class Field:
def __init__(self,datas:Dict[str,Any]) -> None:
self._datas:Dict[str,Any] = datas
def value(self,key:str):
if key in self._datas:
return self._datas.get(key)
return ''
class JsonTable:
def __init__(self,filePth:str) -> None:
self._filePth = filePth
self._fields:Dict[str,Field] = {}
self._records:List[Record] = []
self._name = ''
def parse(self):
with open(self._filePth, 'r',encoding='utf-8') as file:
jsObj = json.load(file)
data:dict = jsObj.get('table')
self._name = data.get('name')
Jsfields = data.get('fields')
for jsfiled in Jsfields:
field = Field(jsfiled)
self._fields[field.value('name')] =field
JsRecords = data.get('records')
for jsRecord in JsRecords:
self._records.append(Record(jsRecord))
def records(self):
return self._records
class ProjectJson:
def __init__(self,dir:str) -> None:
self._dir = dir
self._tables:Dict[str,JsonTable] = {}
def parse(self):
json_files = [f for f in os.listdir(self._dir) if f.endswith('.json')]
for json_file in json_files:
prjPath = os.path.join(self._dir, json_file)
tb = JsonTable(prjPath)
tb.parse()
basename = os.path.splitext(json_file)[0]
self._tables[basename] = tb
def table(self,tableName:str):
return self._tables[tableName]
def getProjectName(dir:str):
prjJson = ProjectJson(dir)
prjJson.parse()
tb:JsonTable = prjJson.table('工程属性')
records = tb.records()
for record in records:
name = record.value('名称')
if name == '工程名称':
return record.value('')
return ''