新增工程信息、检索的知识片段节点回传、下一轮建议问题列表

This commit is contained in:
wanyaokun
2024-08-30 16:42:36 +08:00
parent 73565b26e4
commit e7628809ad
11 changed files with 411 additions and 134 deletions
+127 -50
View File
@@ -3,7 +3,6 @@ import json
import logging
import time
from typing import Dict, List, Any, Optional, AsyncGenerator
from collections import deque
from aiostream import stream
from fastapi import APIRouter, Request,HTTPException
@@ -13,17 +12,19 @@ from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput
from llama_index.core.schema import NodeWithScore
from pydantic import BaseModel
from app.api.routers.request.base import userMng, conversations,message,parameter,feedback
from app.api.routers.request.base import userMng, conversations,message,ProjectInfo,feedback
from app.api.routers.request.baseConfig import *
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine
import uuid
from app.api.routers.services.fileServices import FileLoadService
from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService
from app.api.routers.services.suggestion import NextQuestionSuggestion
import time
logger = logging.getLogger("uvicorn")
api_router = r = APIRouter()
v1_router = v = APIRouter()
class ChatCallbackEvent(BaseModel):
@@ -32,7 +33,7 @@ class ChatCallbackEvent(BaseModel):
def get_common_param(self)-> dict:
return {
'event': self.event_type.name,
'event': self.event_type.value,
'conversation_id':self.payload.get("conversation_id"),
'message_id': self.payload.get("message_id"),
'created_at': int(time.time()),
@@ -48,7 +49,7 @@ class ChatCallbackEvent(BaseModel):
"workflow_id": self.payload.get('workflow_id'),
"sequence_number": 1709,
"inputs": {
"sys.query": self.payload.get('query'),
"sys.query": f"开始查询 {self.payload.get('query')}",
"sys.files": [],
"sys.conversation_id": self.payload.get('conversation_id'),
"sys.user_id": self.payload.get('use_id')
@@ -93,7 +94,7 @@ class ChatCallbackEvent(BaseModel):
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": self.payload.get('title'),
"title": f"正在执行事件:{self.payload.get('title')}",
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
@@ -111,7 +112,7 @@ class ChatCallbackEvent(BaseModel):
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": self.payload.get('title'),
"title": f"事件执行结束:{self.payload.get('title')}",
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
@@ -138,15 +139,54 @@ class ChatCallbackEvent(BaseModel):
def get_MessageEnd_param(self) -> dict:
params = self.get_common_param()
nodeInfos = []
source_nodes = self.payload.get('source_node')
if source_nodes is not None:
for i in range(len(source_nodes)):
source_node:NodeWithScore = source_nodes[i]
metadata:dict = source_node.node.metadata
nodeInfo = {
"position": i,
"dataset_id": metadata.get("pipeline_id"),
"dataset_name": metadata.get("file_name"),
"document_id": source_node.node_id,
"document_name": metadata.get("file_name"),
"data_source_type": "upload_file",
"segment_id": source_node.node_id,
"retriever_from": "workflow",
"score": source_node.score,
"hit_count": 1,
"word_count": 632,
"segment_position": i,
"index_node_hash": "",
"content": source_node.text
}
nodeInfos.append(nodeInfo)
params.update({
'id':self.payload.get('message_id'),
'metadata':self.payload.get('metadata')
'metadata':{
"retriever_resources":nodeInfos,
"usage":{
"prompt_tokens": 4972,
"prompt_unit_price": "0.0",
"prompt_price_unit": "0.0",
"prompt_price": "0.0",
"completion_tokens": 332,
"completion_unit_price": "0.0",
"completion_price_unit": "0.0",
"completion_price": "0.0",
"total_tokens": 5304,
"total_price": "0.0",
"currency": "USD",
"latency": 4.897703120019287
}
}
})
return params
def to_response(self)-> dict|None:
try:
match self.event_type:
match self.event_type.value:
case "workflow_started":
return self.get_WorkflowStart_param()
case "workflow_finished":
@@ -180,24 +220,18 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
self._response:str = ''
self._response: StreamingAgentChatResponse = None
self._params:Dict[str,Any] = params
self._nodeStack:deque = deque()
self._nodeStack:List[str] = []
#添加工作流开始事件
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'use_id': data.user,
'query': data.query,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
wf_event = self.makeWorkflow_startEvent()
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
def setResponse(self,response: StreamingAgentChatResponse):
self._response = response
def on_event_start(
self,
event_type: CBEventType,
@@ -208,7 +242,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
self._nodeStack.append(event_id)
nindex = self._nodeStack.count() - 1
nindex = len(self._nodeStack) - 1
args:Dict[str,Any] = self._params['ids']
args.update(
{
@@ -222,7 +256,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
def on_event_end(
self,
event_type: CBEventType,
@@ -236,7 +269,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
args:Dict[str,Any] = self._params['ids']
nodeID = self._nodeStack[-1]
if nodeID == event_id:
nindex = self._nodeStack.count() - 1
nindex = len(self._nodeStack) - 1
args.update(
{
'nodeid':event_id,
@@ -250,7 +283,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
self._aqueue.put_nowait(nd_event)
self._nodeStack.pop()
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
logger.info("trace_start:{}\n".format(trace_id))
@@ -262,21 +294,12 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
) -> None:
"""No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'response':self._response,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
wf_event = self.makeWorkflow_finishedEvent()
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
args:Dict[str,Any] = self._params['ids']
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
msgEnt_event = self.makeMessage_EndEvent()
if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event)
@@ -287,6 +310,38 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
except asyncio.TimeoutError:
pass
def makeWorkflow_startEvent(self)->ChatCallbackEvent:
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'use_id': data.user,
'query': data.query,
'conversation_id': data.conversation_id
}
)
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
def makeWorkflow_finishedEvent(self)->ChatCallbackEvent:
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'response': '',
'conversation_id': data.conversation_id
}
)
return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
def makeMessage_EndEvent(self)->ChatCallbackEvent:
args:Dict[str,Any] = self._params['ids']
if self._response is not None:
args.update({
'source_node': self._response.source_nodes
})
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
return msgEnt_event
class IDManager:
def createID(self):
return {
@@ -390,7 +445,7 @@ async def post_conversations(request: Request, data: ChatRequestData):
params = data.inputs or {}
# 获取聊天引擎对象
chat_engine = get_chat_engine(filters=filters, params=params,prjFlag = data.prjFlag)
chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听
ids = IDManager().createID()
@@ -399,7 +454,7 @@ async def post_conversations(request: Request, data: ChatRequestData):
# 执行异步聊天
response = await chat_engine.astream_chat(data.query)
event_handler.setResponse(response)
# 返回异步消息回应
return ChatStreamResponse(request, event_handler, response, data,ids)
@@ -468,29 +523,51 @@ async def query_conversations(user:str, first_id:str = None, limit:str = None, p
@v.get("/parameters")
async def query_parameters(user:str):
params = parameter().get(user)
if len(params) == 0:
params = BaseConfig().ParamterCfg()
return params
prjObj = ProjectInfo()
return BaseConfig().ParamterCfg(projectInfo = prjObj.projectNames())
@v.post("/messages/{message_id}/feedbacks")
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
if params['rating'] =='null':
if params['rating'] is None:
feedback().delete(message_id)
else:
condition = {'id':message_id}
results = message().query(**condition)
results = message().query(message_id)
if len(results) > 0:
result = results[0]
feedback().add(message_id=message_id,query=result['query'],
answer=result['answer'],rating=params['rating'])
@v.post("")
def upload_file(request: ChatFileUploadRequest) -> List[str]:
@v.post("/files/upload")
def upload_file(request: ChatFileUploadRequest):
try:
logger.info("Processing file")
return FileLoadService.process_file(request.base64)
resluts = ChatFileService.process_file(request.base64)
return {
'id':resluts.get('id'),
'name': resluts.get('name'),
'size': resluts.get('size'),
'extension':resluts.get('extension'),
'mime_type':resluts.get('mime_type'),
'created_by':str(uuid.uuid4()),
'created_at':int(time.time())
}
except Exception as e:
logger.error(f"Error processing file: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Error processing file")
@v.post("/applications")
def upload_file(request: ChatFileUploadRequest):
try:
logger.info("Processing file")
return PrjFileLoadService.process_file(request.base64)
except Exception as e:
logger.error(f"Error processing file: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Error processing file")
@v.post("/messages/{message_id}/suggested")
async def post_suggested(request: Request,message_id:str,user:str):
questions = await NextQuestionSuggestion.suggest_next_questions(message_id)
return {
"result": "success",
"data":questions
}
+33 -2
View File
@@ -2,7 +2,7 @@ from datetime import datetime
import uuid
from app.api.routers.request.baseConfig import BaseConfig
from app.api.routers.request.dbOrm import DBManager
from typing import List
dbManage = DBManager()
class conversations:
@@ -122,8 +122,9 @@ class message:
def delete(self,user_id:str):
dbManage.delete(self._tableName,user_id = user_id)
def query(self,**condition):
def query(self,id:str):
results = []
condition = {'id':id}
records = dbManage.query(self._tableName,**condition)
for record in records:
results.append(record.dict())
@@ -153,3 +154,33 @@ class feedback:
if len(records) > 0:
return records[0].dict()
return None
class ProjectInfo:
def __init__(self) -> None:
self._tableName = 'projectInfos'
dbManage.createTable(self._tableName)
def add(self,name:str,flag:str):
record = {
'prjectName': name,
'prjFlag': flag
}
dbManage.addRecord(self._tableName,record)
def projectNames(self)->List[str]:
records = dbManage.query(self._tableName)
names = []
for record in records:
data:dict = record.dict()
name = data.get('prjectName')
if name !='':
names.append(name)
return names
def prjFalg(self,name:str):
records = dbManage.query(self._tableName)
for record in records:
data:dict = record.dict()
if data.get('prjectName') == name:
return data['prjFlag']
return ''
+14 -2
View File
@@ -5,7 +5,8 @@ from enum import Enum
class BaseConfig(BaseModel):
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
def ParamterCfg(self):
def ParamterCfg(self,**args):
projectInfo = args.get('projectInfo')
questions = os.getenv("CONVERSATION_STARTERS", "dev")
return{
"opening_statement": self.projectInfo,
@@ -30,7 +31,18 @@ class BaseConfig(BaseModel):
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"user_input_form": [
{
"select": {
"variable": "projectname",
"label": "\u5de5\u7a0b\u540d\u79f0",
"type": "select",
"max_length": 48,
"required": True,
"options": [projectInfo]
}
}
],
"sensitive_word_avoidance": {
"enabled": False
},
+18
View File
@@ -55,6 +55,13 @@ class FeedBackOrm(Base):
answer = Column(String)
rating = Column(String)
class ProjectInfoOrm(Base):
__tablename__ = "projectInfos"
prjFlag = Column(String,primary_key=True)
prjectName = Column(String)
#数据结构
class ConversationModel(BaseModel):
id: str
@@ -121,6 +128,17 @@ class FeedBackModel(BaseModel):
def orm(cls):
return FeedBackOrm
class ProjectInfoModel(BaseModel):
prjectName:str
prjFlag:str
class Config:
from_attributes=True
@classmethod
def orm(cls):
return ProjectInfoOrm
class DBManager:
def __init__(self) -> None:
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
@@ -10,7 +10,6 @@ class ChatRequestData(BaseModel):
response_mode: str
files: Any
conversation_id: str = None
prjFlag:Optional[str] = ''
class ChatFileUploadRequest(BaseModel):
base64: str
@@ -1,16 +1,23 @@
import base64,os
from typing import List
import base64,os,mimetypes,requests,tempfile
from typing import List,Dict,Any
from uuid import uuid4
import requests
from app.settings import init_settings
from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath
from app.engine.vectordb import get_vector_store
from app.engine.generate import get_doc_store,run_pipeline,persist_storage
import tempfile
from llama_index.core.schema import Document
from pathlib import Path
from llama_index.core.readers.file.base import (
_try_loading_included_file_formats as get_file_loaders_map,
)
from llama_index.readers.file import FlatReader
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core import VectorStoreIndex
from app.engine.index import get_index
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
class FileLoadService:
class PrjFileLoadService:
@staticmethod
def store_and_parse_file(file_data):
prjtoJson_url = os.getenv('PRJTOJSON_URL')
@@ -20,29 +27,40 @@ class FileLoadService:
url = convert_url,
files=files
)
if response1.text is None or response1.text=='':
return None
load_url = prjtoJson_url +'/file_download'
response2 = requests.post(
url = load_url,
data=response1.text
)
tempFilePath:str = tempfile.gettempdir() + f"\\{str(uuid4())}.zip"
with open(tempFilePath,'wb') as file:
file.write(response2.content)
if response2.text is None or response2.content=='':
return None
prjID = str(uuid4())
filePath = getFileCacahePath() + f'/Projects/{prjID}'
os.makedirs(filePath)
import zipfile
with zipfile.ZipFile(tempFilePath,'r') as zip_File:
for zip_info in zip_File.infolist():
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
zip_File.extract(zip_info,filePath)
os.remove(tempFilePath)
return f'Projects_{prjID}'
try:
tempFilePath:str = tempfile.gettempdir() + f"\\{uuid4().hex}.zip"
with open(tempFilePath,'wb') as file:
file.write(response2.content)
prjID = str(uuid4())
filePath = getFileCacahePath() + f'/Projects/{prjID}'
os.makedirs(filePath)
import zipfile
with zipfile.ZipFile(tempFilePath,'r') as zip_File:
for zip_info in zip_File.infolist():
zip_info.filename = zip_info.filename.encode('cp437').decode('gbk')
zip_File.extract(zip_info,filePath)
os.remove(tempFilePath)
return f'Projects_{prjID}'
except Exception as e:
return None
@staticmethod
def process_file(base64_content: str) -> str:
prjFlag = FileLoadService.store_and_parse_file(base64_content)
prjFlag = PrjFileLoadService.store_and_parse_file(base64_content)
if prjFlag is None:
return None
#生成向量并持久化至本地
documents = get_documents(prjFlag)
for doc in documents:
@@ -53,3 +71,64 @@ class FileLoadService:
persist_storage(docstore, vector_store)
return prjFlag
class ChatFileService:
PRIVATE_STORE_PATH = os.getenv('CHAT_UPLOAD_FILECACHE','output/uploaded')
resluts:Dict[str,Any] = {}
@staticmethod
def process_file(base64_content: str) -> dict:
file_data, extension = ChatFileService.preprocess_base64_file(base64_content)
documents = ChatFileService.store_and_parse_file(file_data, extension)
pipeline = IngestionPipeline()
nodes = pipeline.run(documents=documents)
current_index = get_index()
pipeline = IngestionPipeline()
nodes = pipeline.run(documents=documents)
if current_index is None:
current_index = VectorStoreIndex(nodes=nodes)
else:
current_index.insert_nodes(nodes=nodes)
current_index.storage_context.persist(
persist_dir=os.environ.get("STORAGE_DIR", "storage")
)
return ChatFileService.resluts
@staticmethod
def preprocess_base64_file(base64_content: str) -> tuple:
header, data = base64_content.split(",", 1)
mime_type = header.split(";")[0].split(":", 1)[1]
extension = mimetypes.guess_extension(mime_type)
ChatFileService.resluts['mime_type'] = mime_type
ChatFileService.resluts['extension'] = extension
return base64.b64decode(data), extension
@staticmethod
def store_and_parse_file(file_data, extension) -> List[Document]:
os.makedirs(ChatFileService.PRIVATE_STORE_PATH, exist_ok=True)
fileID = uuid4().hex
file_name = f"{fileID}{extension}"
file_path = Path(os.path.join(ChatFileService.PRIVATE_STORE_PATH, file_name))
ChatFileService.resluts['id'] = fileID
ChatFileService.resluts['file_name'] = file_name
with open(file_path, "wb") as f:
f.write(file_data)
ChatFileService.resluts['size'] = os.path.getsize(file_path)
reader_cls = ChatFileService.default_file_loaders_map().get(extension)
if reader_cls is None:
raise ValueError(f"File extension {extension} is not supported")
reader = reader_cls()
documents = reader.load_data(file_path)
for doc in documents:
doc.metadata["file_name"] = file_name
doc.metadata["private"] = "true"
return documents
@staticmethod
def default_file_loaders_map():
default_loaders = get_file_loaders_map()
default_loaders[".txt"] = FlatReader
return default_loaders
@@ -0,0 +1,43 @@
from typing import List
from app.api.routers.request.base import message
from llama_index.core.prompts import PromptTemplate
from llama_index.core.settings import Settings
from pydantic import BaseModel
NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
"你是一个乐于助人的助手!你的任务是对用户可能会问的下一个问题给出建议。 "
"\n这是对话历史记录"
"\n---------------------\n{conversation}\n---------------------"
"考虑到对话历史记录,仅限于现在知识库已有内容, 请给我 $number_of_questions 个你接下来可能会问题的问题!"
)
N_QUESTION_TO_GENERATE = 3
class NextQuestions(BaseModel):
"""A list of questions that user might ask next"""
questions: List[str]
class NextQuestionSuggestion:
@staticmethod
async def suggest_next_questions(
message_id: str,
number_of_questions: int = N_QUESTION_TO_GENERATE,
) -> List[str]:
last_user_message = None
last_assistant_message = None
results = message().query(message_id)
if len(results) > 0:
last_user_message = results[0]['query']
last_assistant_message = results[0]['answer']
conversation: str = f"{last_user_message}\n{last_assistant_message}"
output: NextQuestions = await Settings.llm.astructured_predict(
NextQuestions,
prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
conversation=conversation,
nun_questions=number_of_questions,
)
return output.questions
return []
+18 -2
View File
@@ -8,9 +8,18 @@ from app.engine.engine import create_query_engine, create_summary_query_engine
from app.engine.index import get_index
#from app.engine.loaders.db import makeDescriptionByEngine
from app.engine.tools import ToolFactory
from app.api.routers.request.base import ProjectInfo
def getPrjFalg(params:dict=None)->str:
prjFlag = ''
if params is not None:
inputs:dict = params.get('inputs')
if inputs is not None:
prjFlag = ProjectInfo.prjFalg(inputs.get('projectname'))
return prjFlag
def get_chat_engine(filters=None, params=None,**args):
def get_chat_engine(filters=None, params:dict=None):
system_prompt = os.getenv("SYSTEM_PROMPT")
top_k = int(os.getenv("TOP_K", "3"))
use_reranker = os.getenv("RERANK_ENABLED")
@@ -24,7 +33,13 @@ def get_chat_engine(filters=None, params=None,**args):
#tools.append(sql_query_tool)
# Add query tool if index exists
index = get_index(**args)
prjFlag = ''
if params is not None:
inputs:dict = params.get('inputs')
if inputs is not None:
prjFlag = inputs.get('projectname')
index = get_index(prjFlag = getPrjFalg(params))
if index is not None:
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
@@ -57,6 +72,7 @@ def get_chat_engine(filters=None, params=None,**args):
verbose=True,
)
return agentrunner
# create the function calling worker for reasoning
# worker = FunctionCallingAgentWorker.from_tools(
# tools, verbose=True
+8 -8
View File
@@ -7,15 +7,15 @@ logger = logging.getLogger("uvicorn")
def get_index(**args):
logger.info("Connecting vector store...")
prjFlags = get_document_Types()
if len(prjFlags)<=0:
return None
prjFlag = args.get('prjFlag','')
flag = prjFlags[0] if prjFlag not in prjFlags else prjFlag
if 'prjFlag' in args:
prjFlags = get_document_Types()
if len(prjFlags)<=0:
return None
prjFlag = args.get('prjFlag','')
flag = prjFlags[0] if prjFlag not in prjFlags else prjFlag
else:
flag = ''
store = get_vector_store(flag)
# Load the index from the vector store
# If you are using a vector store that doesn't store text,
# you must load the index from both the vector store and the document store
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished load index from vector store.")
return index
+17 -15
View File
@@ -58,24 +58,26 @@ def get_document_Types():
def get_documents(docType:str):
documents = []
config = load_configs()
if config is None or len(config.items()) == 0:
return documents
return documents
for loader_type, loader_config in config.items():
logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}"
)
if loader_config.get('enable', True): # 检查 enable 字段
logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}"
)
loader_config = loader_config or []
match loader_type:
case "file":
document = get_file_documents(FileLoaderConfig(**loader_config),docType)
case "web":
document = get_web_documents(WebLoaderConfig(**loader_config))
case "db":
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
case _:
raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document)
loader_config = loader_config or []
match loader_type:
case "file":
document = get_file_documents(FileLoaderConfig(**loader_config),docType)
case "web":
document = get_web_documents(WebLoaderConfig(**loader_config))
case "db":
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
case _:
raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document)
return documents
+32 -32
View File
@@ -3,46 +3,46 @@ file:
# use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable
use_llama_parse: false
db:
#db:
# The configuration for the database loader, only supports MySQL and PostgreSQL databases for now.
# uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
# query: The query to fetch data from the database. E.g.: SELECT * FROM table
- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
enable: true # 添加 enable 字段
queries:
- sql: select * from ProjectProperties;
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
#- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
#enable: false # 添加 enable 字段
#queries:
#- sql: select * from ProjectProperties;
#explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
#- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
#explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
#- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
#explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
#- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
#explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
#- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
#explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
#- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
#explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
#- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
#explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
#- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
#explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
#- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
#explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
#web:
# driver_arguments:
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode