From e7628809ad51d2b0509070e945cdfafe355c5938 Mon Sep 17 00:00:00 2001 From: wanyaokun <12345678> Date: Fri, 30 Aug 2024 16:42:36 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=B7=A5=E7=A8=8B=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=E3=80=81=E6=A3=80=E7=B4=A2=E7=9A=84=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E7=89=87=E6=AE=B5=E8=8A=82=E7=82=B9=E5=9B=9E=E4=BC=A0=E3=80=81?= =?UTF-8?q?=E4=B8=8B=E4=B8=80=E8=BD=AE=E5=BB=BA=E8=AE=AE=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E5=88=97=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/routers/app.py | 179 +++++++++++++----- backend/app/api/routers/request/base.py | 37 +++- backend/app/api/routers/request/baseConfig.py | 16 +- backend/app/api/routers/request/dbOrm.py | 18 ++ backend/app/api/routers/request/models.py | 1 - .../app/api/routers/services/fileServices.py | 119 ++++++++++-- .../app/api/routers/services/suggestion.py | 43 +++++ backend/app/engine/__init__.py | 20 +- backend/app/engine/index.py | 16 +- backend/app/engine/loaders/__init__.py | 32 ++-- backend/config/loaders.yaml | 64 +++---- 11 files changed, 411 insertions(+), 134 deletions(-) create mode 100644 backend/app/api/routers/services/suggestion.py diff --git a/backend/app/api/routers/app.py b/backend/app/api/routers/app.py index 59d5a12..860d5ce 100644 --- a/backend/app/api/routers/app.py +++ b/backend/app/api/routers/app.py @@ -3,7 +3,6 @@ import json import logging import time from typing import Dict, List, Any, Optional, AsyncGenerator -from collections import deque from aiostream import stream from fastapi import APIRouter, Request,HTTPException @@ -13,17 +12,19 @@ from llama_index.core.base.llms.types import ChatMessage from llama_index.core.callbacks import CBEventType from llama_index.core.chat_engine.types import StreamingAgentChatResponse from llama_index.core.tools import ToolOutput +from llama_index.core.schema import NodeWithScore from pydantic import BaseModel -from app.api.routers.request.base import userMng, conversations,message,parameter,feedback +from app.api.routers.request.base import userMng, conversations,message,ProjectInfo,feedback from app.api.routers.request.baseConfig import * from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest from app.engine import get_chat_engine import uuid -from app.api.routers.services.fileServices import FileLoadService +from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileService +from app.api.routers.services.suggestion import NextQuestionSuggestion +import time logger = logging.getLogger("uvicorn") -api_router = r = APIRouter() v1_router = v = APIRouter() class ChatCallbackEvent(BaseModel): @@ -32,7 +33,7 @@ class ChatCallbackEvent(BaseModel): def get_common_param(self)-> dict: return { - 'event': self.event_type.name, + 'event': self.event_type.value, 'conversation_id':self.payload.get("conversation_id"), 'message_id': self.payload.get("message_id"), 'created_at': int(time.time()), @@ -48,7 +49,7 @@ class ChatCallbackEvent(BaseModel): "workflow_id": self.payload.get('workflow_id'), "sequence_number": 1709, "inputs": { - "sys.query": self.payload.get('query'), + "sys.query": f"开始查询 {self.payload.get('query')}", "sys.files": [], "sys.conversation_id": self.payload.get('conversation_id'), "sys.user_id": self.payload.get('use_id') @@ -93,7 +94,7 @@ class ChatCallbackEvent(BaseModel): "id": self.payload.get('nodeid'), "node_id": self.payload.get('nodeid'), "node_type": "http-request", - "title": self.payload.get('title'), + "title": f"正在执行事件:{self.payload.get('title')}", "index": self.payload.get('index'), "predecessor_node_id": self.payload.get('predecessor_node_id'), "inputs": '', @@ -111,7 +112,7 @@ class ChatCallbackEvent(BaseModel): "id": self.payload.get('nodeid'), "node_id": self.payload.get('nodeid'), "node_type": "http-request", - "title": self.payload.get('title'), + "title": f"事件执行结束:{self.payload.get('title')}", "index": self.payload.get('index'), "predecessor_node_id": self.payload.get('predecessor_node_id'), "inputs": '', @@ -138,15 +139,54 @@ class ChatCallbackEvent(BaseModel): def get_MessageEnd_param(self) -> dict: params = self.get_common_param() + nodeInfos = [] + source_nodes = self.payload.get('source_node') + if source_nodes is not None: + for i in range(len(source_nodes)): + source_node:NodeWithScore = source_nodes[i] + metadata:dict = source_node.node.metadata + nodeInfo = { + "position": i, + "dataset_id": metadata.get("pipeline_id"), + "dataset_name": metadata.get("file_name"), + "document_id": source_node.node_id, + "document_name": metadata.get("file_name"), + "data_source_type": "upload_file", + "segment_id": source_node.node_id, + "retriever_from": "workflow", + "score": source_node.score, + "hit_count": 1, + "word_count": 632, + "segment_position": i, + "index_node_hash": "", + "content": source_node.text + } + nodeInfos.append(nodeInfo) params.update({ 'id':self.payload.get('message_id'), - 'metadata':self.payload.get('metadata') + 'metadata':{ + "retriever_resources":nodeInfos, + "usage":{ + "prompt_tokens": 4972, + "prompt_unit_price": "0.0", + "prompt_price_unit": "0.0", + "prompt_price": "0.0", + "completion_tokens": 332, + "completion_unit_price": "0.0", + "completion_price_unit": "0.0", + "completion_price": "0.0", + "total_tokens": 5304, + "total_price": "0.0", + "currency": "USD", + "latency": 4.897703120019287 + } + } }) return params def to_response(self)-> dict|None: try: - match self.event_type: + match self.event_type.value: case "workflow_started": return self.get_WorkflowStart_param() case "workflow_finished": @@ -180,24 +220,18 @@ class ChatEventCallbackHandler(BaseCallbackHandler): ] super().__init__(ignored_events, ignored_events) self._aqueue = asyncio.Queue() - self._response:str = '' + self._response: StreamingAgentChatResponse = None self._params:Dict[str,Any] = params - self._nodeStack:deque = deque() + self._nodeStack:List[str] = [] #添加工作流开始事件 - data:ChatRequestData = self._params['data'] - args:Dict[str,Any] = self._params['ids'] - args.update( - { - 'use_id': data.user, - 'query': data.query, - 'conversation_id': data.conversation_id - } - ) - wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args) + wf_event = self.makeWorkflow_startEvent() if wf_event.to_response() is not None: self._aqueue.put_nowait(wf_event) + def setResponse(self,response: StreamingAgentChatResponse): + self._response = response + def on_event_start( self, event_type: CBEventType, @@ -208,7 +242,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler): logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload)) self._nodeStack.append(event_id) - nindex = self._nodeStack.count() - 1 + nindex = len(self._nodeStack) - 1 args:Dict[str,Any] = self._params['ids'] args.update( { @@ -222,7 +256,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler): if nd_event.to_response() is not None: self._aqueue.put_nowait(nd_event) - def on_event_end( self, event_type: CBEventType, @@ -236,7 +269,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler): args:Dict[str,Any] = self._params['ids'] nodeID = self._nodeStack[-1] if nodeID == event_id: - nindex = self._nodeStack.count() - 1 + nindex = len(self._nodeStack) - 1 args.update( { 'nodeid':event_id, @@ -250,7 +283,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler): self._aqueue.put_nowait(nd_event) self._nodeStack.pop() - def start_trace(self, trace_id: Optional[str] = None) -> None: """No-op.""" logger.info("trace_start:{}\n".format(trace_id)) @@ -262,23 +294,14 @@ class ChatEventCallbackHandler(BaseCallbackHandler): ) -> None: """No-op.""" logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map)) - data:ChatRequestData = self._params['data'] - args:Dict[str,Any] = self._params['ids'] - args.update( - { - 'response':self._response, - 'conversation_id': data.conversation_id - } - ) - wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args) + + wf_event = self.makeWorkflow_finishedEvent() if wf_event.to_response() is not None: self._aqueue.put_nowait(wf_event) - - args:Dict[str,Any] = self._params['ids'] - msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args) + msgEnt_event = self.makeMessage_EndEvent() if msgEnt_event.to_response() is not None: - self._aqueue.put_nowait(msgEnt_event) + self._aqueue.put_nowait(msgEnt_event) async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]: while not self._aqueue.empty() or not self.is_done: @@ -287,6 +310,38 @@ class ChatEventCallbackHandler(BaseCallbackHandler): except asyncio.TimeoutError: pass + def makeWorkflow_startEvent(self)->ChatCallbackEvent: + data:ChatRequestData = self._params['data'] + args:Dict[str,Any] = self._params['ids'] + args.update( + { + 'use_id': data.user, + 'query': data.query, + 'conversation_id': data.conversation_id + } + ) + return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args) + + def makeWorkflow_finishedEvent(self)->ChatCallbackEvent: + data:ChatRequestData = self._params['data'] + args:Dict[str,Any] = self._params['ids'] + args.update( + { + 'response': '', + 'conversation_id': data.conversation_id + } + ) + return ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args) + + def makeMessage_EndEvent(self)->ChatCallbackEvent: + args:Dict[str,Any] = self._params['ids'] + if self._response is not None: + args.update({ + 'source_node': self._response.source_nodes + }) + msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args) + return msgEnt_event + class IDManager: def createID(self): return { @@ -390,7 +445,7 @@ async def post_conversations(request: Request, data: ChatRequestData): params = data.inputs or {} # 获取聊天引擎对象 - chat_engine = get_chat_engine(filters=filters, params=params,prjFlag = data.prjFlag) + chat_engine = get_chat_engine(filters=filters, params=params) # 启动聊天事件监听 ids = IDManager().createID() @@ -399,7 +454,7 @@ async def post_conversations(request: Request, data: ChatRequestData): # 执行异步聊天 response = await chat_engine.astream_chat(data.query) - + event_handler.setResponse(response) # 返回异步消息回应 return ChatStreamResponse(request, event_handler, response, data,ids) @@ -468,29 +523,51 @@ async def query_conversations(user:str, first_id:str = None, limit:str = None, p @v.get("/parameters") async def query_parameters(user:str): - params = parameter().get(user) - if len(params) == 0: - params = BaseConfig().ParamterCfg() - return params + prjObj = ProjectInfo() + return BaseConfig().ParamterCfg(projectInfo = prjObj.projectNames()) @v.post("/messages/{message_id}/feedbacks") async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]): - if params['rating'] =='null': + if params['rating'] is None: feedback().delete(message_id) else: - condition = {'id':message_id} - results = message().query(**condition) + results = message().query(message_id) if len(results) > 0: result = results[0] feedback().add(message_id=message_id,query=result['query'], answer=result['answer'],rating=params['rating']) -@v.post("") -def upload_file(request: ChatFileUploadRequest) -> List[str]: +@v.post("/files/upload") +def upload_file(request: ChatFileUploadRequest): try: logger.info("Processing file") - return FileLoadService.process_file(request.base64) + resluts = ChatFileService.process_file(request.base64) + return { + 'id':resluts.get('id'), + 'name': resluts.get('name'), + 'size': resluts.get('size'), + 'extension':resluts.get('extension'), + 'mime_type':resluts.get('mime_type'), + 'created_by':str(uuid.uuid4()), + 'created_at':int(time.time()) + } except Exception as e: logger.error(f"Error processing file: {e}", exc_info=True) raise HTTPException(status_code=500, detail="Error processing file") +@v.post("/applications") +def upload_file(request: ChatFileUploadRequest): + try: + logger.info("Processing file") + return PrjFileLoadService.process_file(request.base64) + except Exception as e: + logger.error(f"Error processing file: {e}", exc_info=True) + raise HTTPException(status_code=500, detail="Error processing file") + +@v.post("/messages/{message_id}/suggested") +async def post_suggested(request: Request,message_id:str,user:str): + questions = await NextQuestionSuggestion.suggest_next_questions(message_id) + return { + "result": "success", + "data":questions + } \ No newline at end of file diff --git a/backend/app/api/routers/request/base.py b/backend/app/api/routers/request/base.py index 234323b..98c398a 100644 --- a/backend/app/api/routers/request/base.py +++ b/backend/app/api/routers/request/base.py @@ -2,7 +2,7 @@ from datetime import datetime import uuid from app.api.routers.request.baseConfig import BaseConfig from app.api.routers.request.dbOrm import DBManager - +from typing import List dbManage = DBManager() class conversations: @@ -122,8 +122,9 @@ class message: def delete(self,user_id:str): dbManage.delete(self._tableName,user_id = user_id) - def query(self,**condition): + def query(self,id:str): results = [] + condition = {'id':id} records = dbManage.query(self._tableName,**condition) for record in records: results.append(record.dict()) @@ -152,4 +153,34 @@ class feedback: records = dbManage.query(self._tableName,**cond) if len(records) > 0: return records[0].dict() - return None \ No newline at end of file + return None + +class ProjectInfo: + def __init__(self) -> None: + self._tableName = 'projectInfos' + dbManage.createTable(self._tableName) + + def add(self,name:str,flag:str): + record = { + 'prjectName': name, + 'prjFlag': flag + } + dbManage.addRecord(self._tableName,record) + + def projectNames(self)->List[str]: + records = dbManage.query(self._tableName) + names = [] + for record in records: + data:dict = record.dict() + name = data.get('prjectName') + if name !='': + names.append(name) + return names + + def prjFalg(self,name:str): + records = dbManage.query(self._tableName) + for record in records: + data:dict = record.dict() + if data.get('prjectName') == name: + return data['prjFlag'] + return '' \ No newline at end of file diff --git a/backend/app/api/routers/request/baseConfig.py b/backend/app/api/routers/request/baseConfig.py index d254d8a..53202af 100644 --- a/backend/app/api/routers/request/baseConfig.py +++ b/backend/app/api/routers/request/baseConfig.py @@ -5,7 +5,8 @@ from enum import Enum class BaseConfig(BaseModel): projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!") - def ParamterCfg(self): + def ParamterCfg(self,**args): + projectInfo = args.get('projectInfo') questions = os.getenv("CONVERSATION_STARTERS", "dev") return{ "opening_statement": self.projectInfo, @@ -30,7 +31,18 @@ class BaseConfig(BaseModel): "more_like_this": { "enabled": False }, - "user_input_form": [], + "user_input_form": [ + { + "select": { + "variable": "projectname", + "label": "\u5de5\u7a0b\u540d\u79f0", + "type": "select", + "max_length": 48, + "required": True, + "options": [projectInfo] + } + } + ], "sensitive_word_avoidance": { "enabled": False }, diff --git a/backend/app/api/routers/request/dbOrm.py b/backend/app/api/routers/request/dbOrm.py index 38af99d..71cd1cd 100644 --- a/backend/app/api/routers/request/dbOrm.py +++ b/backend/app/api/routers/request/dbOrm.py @@ -55,6 +55,13 @@ class FeedBackOrm(Base): answer = Column(String) rating = Column(String) +class ProjectInfoOrm(Base): + __tablename__ = "projectInfos" + + prjFlag = Column(String,primary_key=True) + prjectName = Column(String) + + #数据结构 class ConversationModel(BaseModel): id: str @@ -121,6 +128,17 @@ class FeedBackModel(BaseModel): def orm(cls): return FeedBackOrm +class ProjectInfoModel(BaseModel): + prjectName:str + prjFlag:str + + class Config: + from_attributes=True + + @classmethod + def orm(cls): + return ProjectInfoOrm + class DBManager: def __init__(self) -> None: DATABASE_URL = os.getenv("SQLITE_DATABASE_URL") diff --git a/backend/app/api/routers/request/models.py b/backend/app/api/routers/request/models.py index 983999c..c06204a 100644 --- a/backend/app/api/routers/request/models.py +++ b/backend/app/api/routers/request/models.py @@ -10,7 +10,6 @@ class ChatRequestData(BaseModel): response_mode: str files: Any conversation_id: str = None - prjFlag:Optional[str] = '' class ChatFileUploadRequest(BaseModel): base64: str diff --git a/backend/app/api/routers/services/fileServices.py b/backend/app/api/routers/services/fileServices.py index ac5ea58..6b5e16a 100644 --- a/backend/app/api/routers/services/fileServices.py +++ b/backend/app/api/routers/services/fileServices.py @@ -1,16 +1,23 @@ -import base64,os -from typing import List +import base64,os,mimetypes,requests,tempfile +from typing import List,Dict,Any from uuid import uuid4 -import requests from app.settings import init_settings from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath from app.engine.vectordb import get_vector_store from app.engine.generate import get_doc_store,run_pipeline,persist_storage -import tempfile +from llama_index.core.schema import Document +from pathlib import Path +from llama_index.core.readers.file.base import ( + _try_loading_included_file_formats as get_file_loaders_map, +) +from llama_index.readers.file import FlatReader +from llama_index.core.ingestion import IngestionPipeline +from llama_index.core import VectorStoreIndex +from app.engine.index import get_index STORAGE_DIR = os.getenv("STORAGE_DIR", "storage") -class FileLoadService: +class PrjFileLoadService: @staticmethod def store_and_parse_file(file_data): prjtoJson_url = os.getenv('PRJTOJSON_URL') @@ -20,29 +27,40 @@ class FileLoadService: url = convert_url, files=files ) + if response1.text is None or response1.text=='': + return None + load_url = prjtoJson_url +'/file_download' response2 = requests.post( url = load_url, data=response1.text ) - tempFilePath:str = tempfile.gettempdir() + f"\\{str(uuid4())}.zip" - with open(tempFilePath,'wb') as file: - file.write(response2.content) - - prjID = str(uuid4()) - filePath = getFileCacahePath() + f'/Projects/{prjID}' - os.makedirs(filePath) - import zipfile - with zipfile.ZipFile(tempFilePath,'r') as zip_File: - for zip_info in zip_File.infolist(): - zip_info.filename = zip_info.filename.encode('cp437').decode('gbk') - zip_File.extract(zip_info,filePath) - os.remove(tempFilePath) - return f'Projects_{prjID}' + if response2.text is None or response2.content=='': + return None + + try: + tempFilePath:str = tempfile.gettempdir() + f"\\{uuid4().hex}.zip" + with open(tempFilePath,'wb') as file: + file.write(response2.content) + prjID = str(uuid4()) + filePath = getFileCacahePath() + f'/Projects/{prjID}' + os.makedirs(filePath) + import zipfile + with zipfile.ZipFile(tempFilePath,'r') as zip_File: + for zip_info in zip_File.infolist(): + zip_info.filename = zip_info.filename.encode('cp437').decode('gbk') + zip_File.extract(zip_info,filePath) + os.remove(tempFilePath) + return f'Projects_{prjID}' + except Exception as e: + return None + @staticmethod def process_file(base64_content: str) -> str: - prjFlag = FileLoadService.store_and_parse_file(base64_content) + prjFlag = PrjFileLoadService.store_and_parse_file(base64_content) + if prjFlag is None: + return None #生成向量并持久化至本地 documents = get_documents(prjFlag) for doc in documents: @@ -53,3 +71,64 @@ class FileLoadService: persist_storage(docstore, vector_store) return prjFlag +class ChatFileService: + PRIVATE_STORE_PATH = os.getenv('CHAT_UPLOAD_FILECACHE','output/uploaded') + resluts:Dict[str,Any] = {} + + @staticmethod + def process_file(base64_content: str) -> dict: + file_data, extension = ChatFileService.preprocess_base64_file(base64_content) + documents = ChatFileService.store_and_parse_file(file_data, extension) + + pipeline = IngestionPipeline() + nodes = pipeline.run(documents=documents) + current_index = get_index() + pipeline = IngestionPipeline() + nodes = pipeline.run(documents=documents) + if current_index is None: + current_index = VectorStoreIndex(nodes=nodes) + else: + current_index.insert_nodes(nodes=nodes) + current_index.storage_context.persist( + persist_dir=os.environ.get("STORAGE_DIR", "storage") + ) + + return ChatFileService.resluts + + @staticmethod + def preprocess_base64_file(base64_content: str) -> tuple: + header, data = base64_content.split(",", 1) + mime_type = header.split(";")[0].split(":", 1)[1] + extension = mimetypes.guess_extension(mime_type) + ChatFileService.resluts['mime_type'] = mime_type + ChatFileService.resluts['extension'] = extension + return base64.b64decode(data), extension + + @staticmethod + def store_and_parse_file(file_data, extension) -> List[Document]: + os.makedirs(ChatFileService.PRIVATE_STORE_PATH, exist_ok=True) + fileID = uuid4().hex + file_name = f"{fileID}{extension}" + file_path = Path(os.path.join(ChatFileService.PRIVATE_STORE_PATH, file_name)) + ChatFileService.resluts['id'] = fileID + ChatFileService.resluts['file_name'] = file_name + + with open(file_path, "wb") as f: + f.write(file_data) + + ChatFileService.resluts['size'] = os.path.getsize(file_path) + reader_cls = ChatFileService.default_file_loaders_map().get(extension) + if reader_cls is None: + raise ValueError(f"File extension {extension} is not supported") + reader = reader_cls() + documents = reader.load_data(file_path) + for doc in documents: + doc.metadata["file_name"] = file_name + doc.metadata["private"] = "true" + return documents + + @staticmethod + def default_file_loaders_map(): + default_loaders = get_file_loaders_map() + default_loaders[".txt"] = FlatReader + return default_loaders \ No newline at end of file diff --git a/backend/app/api/routers/services/suggestion.py b/backend/app/api/routers/services/suggestion.py new file mode 100644 index 0000000..b372b4b --- /dev/null +++ b/backend/app/api/routers/services/suggestion.py @@ -0,0 +1,43 @@ +from typing import List + +from app.api.routers.request.base import message +from llama_index.core.prompts import PromptTemplate +from llama_index.core.settings import Settings +from pydantic import BaseModel + +NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate( + "你是一个乐于助人的助手!你的任务是对用户可能会问的下一个问题给出建议。 " + "\n这是对话历史记录" + "\n---------------------\n{conversation}\n---------------------" + "考虑到对话历史记录,仅限于现在知识库已有内容, 请给我 $number_of_questions 个你接下来可能会问题的问题!" +) +N_QUESTION_TO_GENERATE = 3 + + +class NextQuestions(BaseModel): + """A list of questions that user might ask next""" + + questions: List[str] + + +class NextQuestionSuggestion: + @staticmethod + async def suggest_next_questions( + message_id: str, + number_of_questions: int = N_QUESTION_TO_GENERATE, + ) -> List[str]: + last_user_message = None + last_assistant_message = None + results = message().query(message_id) + if len(results) > 0: + last_user_message = results[0]['query'] + last_assistant_message = results[0]['answer'] + conversation: str = f"{last_user_message}\n{last_assistant_message}" + output: NextQuestions = await Settings.llm.astructured_predict( + NextQuestions, + prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT, + conversation=conversation, + nun_questions=number_of_questions, + ) + return output.questions + return [] \ No newline at end of file diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py index 4d9b128..0a6cf31 100644 --- a/backend/app/engine/__init__.py +++ b/backend/app/engine/__init__.py @@ -8,9 +8,18 @@ from app.engine.engine import create_query_engine, create_summary_query_engine from app.engine.index import get_index #from app.engine.loaders.db import makeDescriptionByEngine from app.engine.tools import ToolFactory +from app.api.routers.request.base import ProjectInfo + +def getPrjFalg(params:dict=None)->str: + prjFlag = '' + if params is not None: + inputs:dict = params.get('inputs') + if inputs is not None: + prjFlag = ProjectInfo.prjFalg(inputs.get('projectname')) + return prjFlag -def get_chat_engine(filters=None, params=None,**args): +def get_chat_engine(filters=None, params:dict=None): system_prompt = os.getenv("SYSTEM_PROMPT") top_k = int(os.getenv("TOP_K", "3")) use_reranker = os.getenv("RERANK_ENABLED") @@ -24,7 +33,13 @@ def get_chat_engine(filters=None, params=None,**args): #tools.append(sql_query_tool) # Add query tool if index exists - index = get_index(**args) + prjFlag = '' + if params is not None: + inputs:dict = params.get('inputs') + if inputs is not None: + prjFlag = inputs.get('projectname') + + index = get_index(prjFlag = getPrjFalg(params)) if index is not None: summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters) summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool", @@ -57,6 +72,7 @@ def get_chat_engine(filters=None, params=None,**args): verbose=True, ) return agentrunner + # create the function calling worker for reasoning # worker = FunctionCallingAgentWorker.from_tools( # tools, verbose=True diff --git a/backend/app/engine/index.py b/backend/app/engine/index.py index 2957b26..e64dc25 100644 --- a/backend/app/engine/index.py +++ b/backend/app/engine/index.py @@ -7,15 +7,15 @@ logger = logging.getLogger("uvicorn") def get_index(**args): logger.info("Connecting vector store...") - prjFlags = get_document_Types() - if len(prjFlags)<=0: - return None - prjFlag = args.get('prjFlag','') - flag = prjFlags[0] if prjFlag not in prjFlags else prjFlag + if 'prjFlag' in args: + prjFlags = get_document_Types() + if len(prjFlags)<=0: + return None + prjFlag = args.get('prjFlag','') + flag = prjFlags[0] if prjFlag not in prjFlags else prjFlag + else: + flag = '' store = get_vector_store(flag) - # Load the index from the vector store - # If you are using a vector store that doesn't store text, - # you must load the index from both the vector store and the document store index = VectorStoreIndex.from_vector_store(store) logger.info("Finished load index from vector store.") return index diff --git a/backend/app/engine/loaders/__init__.py b/backend/app/engine/loaders/__init__.py index 3155028..6596419 100644 --- a/backend/app/engine/loaders/__init__.py +++ b/backend/app/engine/loaders/__init__.py @@ -58,24 +58,26 @@ def get_document_Types(): def get_documents(docType:str): documents = [] config = load_configs() + if config is None or len(config.items()) == 0: - return documents + return documents for loader_type, loader_config in config.items(): - logger.info( - f"Loading documents from loader: {loader_type}, config: {loader_config}" - ) + if loader_config.get('enable', True): # 检查 enable 字段 + logger.info( + f"Loading documents from loader: {loader_type}, config: {loader_config}" + ) - loader_config = loader_config or [] - match loader_type: - case "file": - document = get_file_documents(FileLoaderConfig(**loader_config),docType) - case "web": - document = get_web_documents(WebLoaderConfig(**loader_config)) - case "db": - document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config]) - case _: - raise ValueError(f"Invalid loader type: {loader_type}") - documents.extend(document) + loader_config = loader_config or [] + match loader_type: + case "file": + document = get_file_documents(FileLoaderConfig(**loader_config),docType) + case "web": + document = get_web_documents(WebLoaderConfig(**loader_config)) + case "db": + document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config]) + case _: + raise ValueError(f"Invalid loader type: {loader_type}") + documents.extend(document) return documents \ No newline at end of file diff --git a/backend/config/loaders.yaml b/backend/config/loaders.yaml index af5d2fe..9070ff0 100644 --- a/backend/config/loaders.yaml +++ b/backend/config/loaders.yaml @@ -3,46 +3,46 @@ file: # use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable use_llama_parse: false -db: +#db: # The configuration for the database loader, only supports MySQL and PostgreSQL databases for now. # uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db # query: The query to fetch data from the database. E.g.: SELECT * FROM table - - uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 - enable: true # 添加 enable 字段 - queries: - - sql: select * from ProjectProperties; - explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。" + #- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 + #enable: false # 添加 enable 字段 + #queries: + #- sql: select * from ProjectProperties; + #explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。" - - sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable; - explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。" + #- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable; + #explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。" - - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路'; - explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" - - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理'; - explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" - - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路'; - explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路'; + #explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理'; + #explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路'; + #explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" - - sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee; - explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中" + #- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee; + #explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)' - explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表' + #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa' + #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表' + #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)' + #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)' + #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)' + #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路' - explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" - - sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理' - explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路' + #explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" + #- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理' + #explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" #web: # driver_arguments: # # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode