新增工程数据
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
@@ -555,7 +556,7 @@ def upload_file(request: ChatFileUploadRequest):
|
||||
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Error processing file")
|
||||
|
||||
@v.post("/applications")
|
||||
@v.post("/project")
|
||||
def upload_file(request: ChatFileUploadRequest):
|
||||
try:
|
||||
logger.info("Processing file")
|
||||
|
||||
@@ -161,11 +161,13 @@ class ProjectInfo:
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def add(self,name:str,flag:str):
|
||||
record = {
|
||||
'prjectName': name,
|
||||
'prjFlag': flag
|
||||
}
|
||||
dbManage.addRecord(self._tableName,record)
|
||||
info = dbManage.query(self._tableName,prjFlag = flag)
|
||||
if len(info) == 0:
|
||||
record = {
|
||||
'prjectName': name,
|
||||
'prjFlag': flag
|
||||
}
|
||||
dbManage.addRecord(self._tableName,record)
|
||||
|
||||
def projectNames(self)->List[str]:
|
||||
records = dbManage.query(self._tableName)
|
||||
|
||||
@@ -39,7 +39,7 @@ class BaseConfig(BaseModel):
|
||||
"type": "select",
|
||||
"max_length": 48,
|
||||
"required": True,
|
||||
"options": [projectInfo]
|
||||
"options": projectInfo
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -13,9 +13,7 @@ from app.api.routers.request.base import ProjectInfo
|
||||
def getPrjFalg(params:dict=None)->str:
|
||||
prjFlag = ''
|
||||
if params is not None:
|
||||
inputs:dict = params.get('inputs')
|
||||
if inputs is not None:
|
||||
prjFlag = ProjectInfo.prjFalg(inputs.get('projectname'))
|
||||
prjFlag = ProjectInfo().prjFalg(params.get('projectname'))
|
||||
return prjFlag
|
||||
|
||||
|
||||
@@ -33,13 +31,7 @@ def get_chat_engine(filters=None, params:dict=None):
|
||||
#tools.append(sql_query_tool)
|
||||
|
||||
# Add query tool if index exists
|
||||
prjFlag = ''
|
||||
if params is not None:
|
||||
inputs:dict = params.get('inputs')
|
||||
if inputs is not None:
|
||||
prjFlag = inputs.get('projectname')
|
||||
|
||||
index = get_index(prjFlag = getPrjFalg(params))
|
||||
index = get_index(getPrjFalg(params))
|
||||
if index is not None:
|
||||
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
|
||||
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
||||
|
||||
@@ -5,17 +5,11 @@ from app.engine.loaders import get_document_Types
|
||||
from typing import Dict,Any
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
def get_index(**args):
|
||||
def get_index(prjFlag:str):
|
||||
if prjFlag is None or prjFlag == '':
|
||||
raise ValueError('无效的工程标识')
|
||||
logger.info("Connecting vector store...")
|
||||
if 'prjFlag' in args:
|
||||
prjFlags = get_document_Types()
|
||||
if len(prjFlags)<=0:
|
||||
return None
|
||||
prjFlag = args.get('prjFlag','')
|
||||
flag = prjFlags[0] if prjFlag not in prjFlags else prjFlag
|
||||
else:
|
||||
flag = ''
|
||||
store = get_vector_store(flag)
|
||||
store = get_vector_store(prjFlag)
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished load index from vector store.")
|
||||
return index
|
||||
|
||||
@@ -3,8 +3,10 @@ import yaml
|
||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
|
||||
from app.engine.loaders.projectJson import getProjectName
|
||||
import os
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_configs():
|
||||
@@ -55,6 +57,31 @@ def get_document_Types():
|
||||
types.append(path_difference(rootPath,curDir))
|
||||
return types
|
||||
|
||||
def getProjectInfos():
|
||||
config = load_configs()
|
||||
if config is None or len(config.items()) == 0:
|
||||
return None
|
||||
|
||||
prjDir = None
|
||||
for loader_type, loader_config in config.items():
|
||||
if loader_config.get('enable', True):
|
||||
loader_config = loader_config or []
|
||||
config = FileLoaderConfig(**loader_config)
|
||||
prjDir = config.data_dir
|
||||
break
|
||||
if prjDir is None:
|
||||
return None
|
||||
|
||||
prjInfos = []
|
||||
prjFlags = get_document_Types()
|
||||
for prjFlag in prjFlags:
|
||||
fileDir = os.path.join(config.data_dir,prjFlag.replace('_','\\'))
|
||||
prjInfo = {}
|
||||
prjInfo['flag'] = prjFlag
|
||||
prjInfo['name'] = getProjectName(fileDir)
|
||||
prjInfos.append(prjInfo)
|
||||
return prjInfos
|
||||
|
||||
def get_documents(docType:str):
|
||||
documents = []
|
||||
config = load_configs()
|
||||
@@ -80,4 +107,4 @@ def get_documents(docType:str):
|
||||
raise ValueError(f"Invalid loader type: {loader_type}")
|
||||
documents.extend(document)
|
||||
|
||||
return documents
|
||||
return documents
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
from typing import Dict,List,Any
|
||||
import json,os
|
||||
|
||||
class Record:
|
||||
def __init__(self,datas:Dict[str,Any]) -> None:
|
||||
self._datas:Dict[str,Any] = datas
|
||||
|
||||
def value(self,key:str):
|
||||
if key in self._datas:
|
||||
return self._datas.get(key)
|
||||
return ''
|
||||
|
||||
class Field:
|
||||
def __init__(self,datas:Dict[str,Any]) -> None:
|
||||
self._datas:Dict[str,Any] = datas
|
||||
|
||||
def value(self,key:str):
|
||||
if key in self._datas:
|
||||
return self._datas.get(key)
|
||||
return ''
|
||||
|
||||
class JsonTable:
|
||||
def __init__(self,filePth:str) -> None:
|
||||
self._filePth = filePth
|
||||
self._fields:Dict[str,Field] = {}
|
||||
self._records:List[Record] = []
|
||||
self._name = ''
|
||||
|
||||
def parse(self):
|
||||
with open(self._filePth, 'r',encoding='utf-8') as file:
|
||||
jsObj = json.load(file)
|
||||
data:dict = jsObj.get('table')
|
||||
self._name = data.get('name')
|
||||
Jsfields = data.get('fields')
|
||||
for jsfiled in Jsfields:
|
||||
field = Field(jsfiled)
|
||||
self._fields[field.value('name')] =field
|
||||
|
||||
JsRecords = data.get('records')
|
||||
for jsRecord in JsRecords:
|
||||
self._records.append(Record(jsRecord))
|
||||
|
||||
def records(self):
|
||||
return self._records
|
||||
|
||||
class ProjectJson:
|
||||
def __init__(self,dir:str) -> None:
|
||||
self._dir = dir
|
||||
self._tables:Dict[str,JsonTable] = {}
|
||||
|
||||
def parse(self):
|
||||
json_files = [f for f in os.listdir(self._dir) if f.endswith('.json')]
|
||||
for json_file in json_files:
|
||||
prjPath = os.path.join(self._dir, json_file)
|
||||
tb = JsonTable(prjPath)
|
||||
tb.parse()
|
||||
basename = os.path.splitext(json_file)[0]
|
||||
self._tables[basename] = tb
|
||||
|
||||
def table(self,tableName:str):
|
||||
return self._tables[tableName]
|
||||
|
||||
def getProjectName(dir:str):
|
||||
prjJson = ProjectJson(dir)
|
||||
prjJson.parse()
|
||||
tb:JsonTable = prjJson.table('工程属性')
|
||||
records = tb.records()
|
||||
for record in records:
|
||||
name = record.value('名称')
|
||||
if name == '工程名称':
|
||||
return record.value('值')
|
||||
return ''
|
||||
|
||||
+10
-4
@@ -7,6 +7,8 @@ from llama_index.llms.xinference import Xinference
|
||||
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
||||
|
||||
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||
from app.engine.loaders import getProjectInfos
|
||||
from app.api.routers.request.base import ProjectInfo
|
||||
|
||||
|
||||
def get_node_postprocessors():
|
||||
@@ -53,7 +55,6 @@ def init_settings():
|
||||
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
|
||||
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
|
||||
|
||||
|
||||
def init_ollama():
|
||||
# from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
# from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
||||
@@ -127,7 +128,6 @@ def init_dashscope():
|
||||
Settings.embed_model = DashScopeEmbedding(model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
|
||||
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY)
|
||||
|
||||
|
||||
def init_azure_openai():
|
||||
# from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
@@ -162,7 +162,6 @@ def init_azure_openai():
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def init_fastembed():
|
||||
"""
|
||||
Use Qdrant Fastembed as the local embedding provider.
|
||||
@@ -232,4 +231,11 @@ def init_mistral():
|
||||
#
|
||||
# Settings.llm = MistralAI(model=os.getenv("MODEL"))
|
||||
# Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||
pass
|
||||
pass
|
||||
|
||||
def init_ProjectInfo():
|
||||
prjObj = ProjectInfo()
|
||||
prjInfos:list[tuple] = getProjectInfos()
|
||||
for prjInfo in prjInfos:
|
||||
prjObj.add(prjInfo['name'],prjInfo['flag'])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user