From 73565b26e4cd378c2f3775a67a3908bc6f5ec394 Mon Sep 17 00:00:00 2001 From: wanyaokun <12345678> Date: Fri, 30 Aug 2024 10:49:05 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=88=E5=B9=B6Dev=E5=88=86=E6=94=AF?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/.env.example | 2 +- backend/.env.xinference | 2 +- backend/app/api/routers/app.py | 508 +++++++++--------- backend/app/api/routers/request/base.py | 40 +- backend/app/api/routers/request/baseConfig.py | 120 +++-- backend/app/api/routers/request/dbOrm.py | 31 +- backend/app/api/routers/request/models.py | 7 +- .../app/api/routers/services/fileServices.py | 22 +- backend/app/api/services/file.py | 4 +- backend/app/engine/__init__.py | 7 +- backend/app/engine/generate.py | 32 +- backend/app/engine/index.py | 32 +- backend/app/engine/tools/__init__.py | 11 +- backend/config/loaders.yaml | 64 +-- backend/pyproject.toml | 9 +- backend/tests/query.py | 4 +- 16 files changed, 486 insertions(+), 409 deletions(-) diff --git a/backend/.env.example b/backend/.env.example index ec19dfb..a549405 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -80,4 +80,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath - You can install any pip package (if it exists) by running a cell with pip install. " -PRJTOJSON_URL = 'http://10.1.6.60:8092' \ No newline at end of file +PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!" \ No newline at end of file diff --git a/backend/.env.xinference b/backend/.env.xinference index a9375dd..7a00d93 100644 --- a/backend/.env.xinference +++ b/backend/.env.xinference @@ -111,4 +111,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath - You can install any pip package (if it exists) by running a cell with pip install. " -PRJTOJSON_URL = 'http://10.1.6.60:8092' \ No newline at end of file +PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!" \ No newline at end of file diff --git a/backend/app/api/routers/app.py b/backend/app/api/routers/app.py index a900501..59d5a12 100644 --- a/backend/app/api/routers/app.py +++ b/backend/app/api/routers/app.py @@ -1,7 +1,9 @@ import asyncio import json import logging +import time from typing import Dict, List, Any, Optional, AsyncGenerator +from collections import deque from aiostream import stream from fastapi import APIRouter, Request,HTTPException @@ -12,7 +14,8 @@ from llama_index.core.callbacks import CBEventType from llama_index.core.chat_engine.types import StreamingAgentChatResponse from llama_index.core.tools import ToolOutput from pydantic import BaseModel -from app.api.routers.request.base import userMng, conversations,message,parameter +from app.api.routers.request.base import userMng, conversations,message,parameter,feedback +from app.api.routers.request.baseConfig import * from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest from app.engine import get_chat_engine import uuid @@ -23,81 +26,139 @@ logger = logging.getLogger("uvicorn") api_router = r = APIRouter() v1_router = v = APIRouter() -default_conversation_id = '82e8417f-2c3b-4bb5-ab22-2ad318bbd29a' - class ChatCallbackEvent(BaseModel): - event_type: CBEventType + event_type: ChatEventType payload: Optional[Dict[str, Any]] = None - event_id: str = "" - def get_retrieval_message(self) -> dict | None: - if self.payload: - nodes = self.payload.get("nodes") - if nodes: - msg = f"根据查询检索到 {len(nodes)} 源文件" - else: - msg = f"查询检索中: '{self.payload.get('query_str')}'" - return { - "type": "events", - "data": {"title": msg}, - } - else: - return None + def get_common_param(self)-> dict: + return { + 'event': self.event_type.name, + 'conversation_id':self.payload.get("conversation_id"), + 'message_id': self.payload.get("message_id"), + 'created_at': int(time.time()), + 'task_id': self.payload.get("task_id") + } - def get_tool_message(self) -> dict | None: - func_call_args = self.payload.get("function_call") - if func_call_args is not None and "tool" in self.payload: - tool = self.payload.get("tool") - return { - "type": "events", - "data": { - "title": f"调用工具 {tool.name} ,参数: {func_call_args}", + def get_WorkflowStart_param(self) -> dict: + params = self.get_common_param() + params.update({ + 'workflow_run_id':self.payload.get('workflow_run_id'), + 'data':{ + "id": self.payload.get('workflow_run_id'), + "workflow_id": self.payload.get('workflow_id'), + "sequence_number": 1709, + "inputs": { + "sys.query": self.payload.get('query'), + "sys.files": [], + "sys.conversation_id": self.payload.get('conversation_id'), + "sys.user_id": self.payload.get('use_id') }, + "created_at": int(time.time()) } + }) + return params - def _is_output_serializable(self, output: Any) -> bool: - try: - json.dumps(output) - return True - except TypeError: - return False + def get_WorkflowFinished_param(self) -> dict: + params = self.get_common_param() + params.update({ + 'workflow_run_id':self.payload.get('workflow_run_id'), + 'data':{ + "id": self.payload.get('workflow_run_id'), + "workflow_id": self.payload.get('workflow_id'), + "sequence_number": 1709, + "status": "succeeded", + "outputs": { + "answer": self.payload.get('response') + }, + "error": '', + "elapsed_time": 36.03764106379822, + "total_tokens": 11707, + "total_steps": 10, + "created_by": { + "id": str(uuid.uuid4()), + "user": self.payload.get('use_id') + }, + "created_at": int(time.time()), + "finished_at": int(time.time()), + "files": [] + } + }) + return params + + def get_NodeStart_param(self) -> dict: + params = self.get_common_param() + params.update({ + 'workflow_run_id':self.payload.get('workflow_run_id'), + 'data':{ + "id": self.payload.get('nodeid'), + "node_id": self.payload.get('nodeid'), + "node_type": "http-request", + "title": self.payload.get('title'), + "index": self.payload.get('index'), + "predecessor_node_id": self.payload.get('predecessor_node_id'), + "inputs": '', + "created_at": 1724398751, + "extras": {} + } + }) + return params - def get_agent_tool_response(self) -> dict | None: - response = self.payload.get("response") - if response is not None: - sources = response.sources - for source in sources: - # Return the tool response here to include the toolCall information - if isinstance(source, ToolOutput): - if self._is_output_serializable(source.raw_output): - output = source.raw_output - else: - output = source.content + def get_NodeFinished_param(self) -> dict: + params = self.get_common_param() + params.update({ + 'workflow_run_id':self.payload.get('workflow_run_id'), + 'data':{ + "id": self.payload.get('nodeid'), + "node_id": self.payload.get('nodeid'), + "node_type": "http-request", + "title": self.payload.get('title'), + "index": self.payload.get('index'), + "predecessor_node_id": self.payload.get('predecessor_node_id'), + "inputs": '', + "process_data": '', + "outputs": '', + "status": "succeeded", + "error": '', + "elapsed_time": 0.10402441816404462, + "execution_metadata": '', + "created_at": 1724398751, + "finished_at": 1724398751, + "files": [] + } + }) + return params - return { - "type": "tools", - "data": { - "toolOutput": { - "output": output, - "isError": source.is_error, - }, - "toolCall": { - "id": None, # There is no tool id in the ToolOutput - "name": source.tool_name, - "input": source.raw_input, - }, - }, - } + def get_Message_param(self) -> dict: + params = self.get_common_param() + params.update({ + 'id':self.payload.get('message_id'), + 'answer':self.payload.get('answer') + }) + return params + + def get_MessageEnd_param(self) -> dict: + params = self.get_common_param() + params.update({ + 'id':self.payload.get('message_id'), + 'metadata':self.payload.get('metadata') + }) + return params - def to_response(self): + def to_response(self)-> dict|None: try: match self.event_type: - case "retrieve": - return self.get_retrieval_message() - case "function_call": - return self.get_tool_message() - case "agent_step": - return self.get_agent_tool_response() + case "workflow_started": + return self.get_WorkflowStart_param() + case "workflow_finished": + return self.get_WorkflowFinished_param() + case "node_started": + return self.get_NodeStart_param() + case 'node_finished': + return self.get_NodeFinished_param() + case 'message': + return self.get_Message_param() + case 'message_end': + return self.get_MessageEnd_param() case _: return None except Exception as e: @@ -108,19 +169,34 @@ class ChatEventCallbackHandler(BaseCallbackHandler): _aqueue: asyncio.Queue is_done: bool = False - def __init__( - self, - ): + def __init__(self,**params): """Initialize the base callback handler.""" ignored_events = [ - CBEventType.CHUNKING, - CBEventType.NODE_PARSING, - CBEventType.EMBEDDING, - CBEventType.LLM, - CBEventType.TEMPLATING, + # CBEventType.CHUNKING, + # CBEventType.NODE_PARSING, + # CBEventType.EMBEDDING, + # CBEventType.LLM, + # CBEventType.TEMPLATING, ] super().__init__(ignored_events, ignored_events) self._aqueue = asyncio.Queue() + self._response:str = '' + self._params:Dict[str,Any] = params + self._nodeStack:deque = deque() + + #添加工作流开始事件 + data:ChatRequestData = self._params['data'] + args:Dict[str,Any] = self._params['ids'] + args.update( + { + 'use_id': data.user, + 'query': data.query, + 'conversation_id': data.conversation_id + } + ) + wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args) + if wf_event.to_response() is not None: + self._aqueue.put_nowait(wf_event) def on_event_start( self, @@ -129,9 +205,23 @@ class ChatEventCallbackHandler(BaseCallbackHandler): event_id: str = "", **kwargs: Any, ) -> str: - event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload) - if event.to_response() is not None: - self._aqueue.put_nowait(event) + logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload)) + + self._nodeStack.append(event_id) + nindex = self._nodeStack.count() - 1 + args:Dict[str,Any] = self._params['ids'] + args.update( + { + 'nodeid':event_id, + 'title':event_type.name, + 'index':nindex + 1, + 'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else '' + } + ) + nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args) + if nd_event.to_response() is not None: + self._aqueue.put_nowait(nd_event) + def on_event_end( self, @@ -140,12 +230,30 @@ class ChatEventCallbackHandler(BaseCallbackHandler): event_id: str = "", **kwargs: Any, ) -> None: - event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload) - if event.to_response() is not None: - self._aqueue.put_nowait(event) + logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload)) + + #self.response = payload.get("response","") + args:Dict[str,Any] = self._params['ids'] + nodeID = self._nodeStack[-1] + if nodeID == event_id: + nindex = self._nodeStack.count() - 1 + args.update( + { + 'nodeid':event_id, + 'title':event_type.name, + 'index':nindex + 1, + 'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else '' + } + ) + nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args) + if nd_event.to_response() is not None: + self._aqueue.put_nowait(nd_event) + self._nodeStack.pop() + def start_trace(self, trace_id: Optional[str] = None) -> None: """No-op.""" + logger.info("trace_start:{}\n".format(trace_id)) def end_trace( self, @@ -153,6 +261,24 @@ class ChatEventCallbackHandler(BaseCallbackHandler): trace_map: Optional[Dict[str, List[str]]] = None, ) -> None: """No-op.""" + logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map)) + data:ChatRequestData = self._params['data'] + args:Dict[str,Any] = self._params['ids'] + args.update( + { + 'response':self._response, + 'conversation_id': data.conversation_id + } + ) + wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args) + if wf_event.to_response() is not None: + self._aqueue.put_nowait(wf_event) + + + args:Dict[str,Any] = self._params['ids'] + msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args) + if msgEnt_event.to_response() is not None: + self._aqueue.put_nowait(msgEnt_event) async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]: while not self._aqueue.empty() or not self.is_done: @@ -170,104 +296,38 @@ class IDManager: "workflow_id": str(uuid.uuid4()) } -class DifyChatResponseEvent(BaseModel): - event: str - conversation_id: str - message_id: str - created_at: int = 1724406492 - task_id: str - -class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent): - workflow_run_id:str - data:Dict[str,Any] - def __init__(self,**args): - args['data'] = { - "id": args['workflow_run_id'], - "workflow_id": args['workflow_id'], - "sequence_number": 1709, - "inputs": { - "sys.query": args['query'], - "sys.files": [], - "sys.conversation_id": args['conversation_id'], - "sys.user_id": args['use_id'] - }, - "created_at": 1724406492 - } - args['event'] = 'workflow_started' - super().__init__(**args) - -class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent): - workflow_run_id:str - data:Dict[str,Any] - def __init__(self,**args): - args['event'] = 'workflow_finished' - args['data'] = { - "id": args['workflow_run_id'], - "workflow_id": args['workflow_id'], - "sequence_number": 1709, - "status": "succeeded", - "outputs": { - "answer": args['response'] - }, - "error": '', - "elapsed_time": 36.03764106379822, - "total_tokens": 11707, - "total_steps": 10, - "created_by": { - "id": str(uuid.uuid4()), - "user": args['use_id'] - }, - "created_at": 1724406492, - "finished_at": 1724406528, - "files": [] - } - super().__init__(**args) - -class Message_DifyChatResponseEvent(DifyChatResponseEvent): - id:str - answer:str - def __init__(self,**args): - args['id'] = args['message_id'] - args['event'] = 'message' - super().__init__(**args) - -class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent): - id:str - metadata:Dict[str,Any] = {} - def __init__(self,**args): - args['id'] = args['message_id'] - args['event'] = 'message_end' - super().__init__(**args) - class ChatStreamResponse(StreamingResponse): - TEXT_PREFIX = "data:" - DATA_PREFIX = "data:" + TEXT_PREFIX = "data: " + DATA_PREFIX = "data: " + ids:Dict[str,Any] = {} + data:ChatRequestData = None @classmethod - def convert_text(cls, token: str): - # Escape newlines and double quotes to avoid breaking the stream - token = json.dumps(token) - - #return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}" - return "" + def convert_Message(cls, token: str): + params = cls.ids + params.update({ + 'answer':token, + 'conversation_id':cls.data.conversation_id + }) + event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params) + data_str = json.dumps(event.to_response()) + return f"{cls.DATA_PREFIX}{data_str}\n\n" @classmethod - def convert_data(cls, data: dict): + def convert_Event(cls, data: dict): data_str = json.dumps(data) - return f"{cls.DATA_PREFIX}{data_str}\n" - - @classmethod - def convert_event(cls, event: DifyChatResponseEvent): - data_str = json.dumps(event.dict()) - return f"{cls.DATA_PREFIX}{data_str}\n" + return f"{cls.DATA_PREFIX}{data_str}\n\n" def __init__( self, request: Request, event_handler: ChatEventCallbackHandler, response: StreamingAgentChatResponse, - data: ChatRequestData + data: ChatRequestData, + ids:Dict[str,Any] ): + ChatStreamResponse.ids = ids + ChatStreamResponse.data = data content = ChatStreamResponse.content_generator( request, event_handler, response, data ) @@ -281,41 +341,26 @@ class ChatStreamResponse(StreamingResponse): response: StreamingAgentChatResponse, data: ChatRequestData ): - ids = IDManager().createID() + # Yield the text response async def _chat_response_generator(): final_response = "" async for token in response.async_response_gen(): final_response += token - args = ids - args['answer'] = token - args['conversation_id'] = data.conversation_id - event = Message_DifyChatResponseEvent(**args) - yield ChatStreamResponse.convert_event(event) - #yield ChatStreamResponse.convert_text(token) + yield ChatStreamResponse.convert_Message(token) # 存储消息历史 message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response) # the text_generator is the leading stream, once it's finished, also finish the event stream event_handler.is_done = True - # 发送工作流结束事件 - args = ids - args['response'] = final_response - args['conversation_id'] = data.conversation_id - wf_event = Workflow_finished_DifyChatResponseEvent(**args) - yield ChatStreamResponse.convert_event(wf_event) - - msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids) - yield ChatStreamResponse.convert_event(msgEnt_event) - - + # Yield the events from the event handler async def _event_generator(): async for event in event_handler.async_event_gen(): event_response = event.to_response() if event_response is not None: - yield ChatStreamResponse.convert_data(event_response) + yield ChatStreamResponse.convert_Event(event_response) combine = stream.merge(_chat_response_generator(), _event_generator()) is_stream_started = False @@ -324,34 +369,20 @@ class ChatStreamResponse(StreamingResponse): if not is_stream_started: is_stream_started = True - # 发送工作流开始事件 - args = ids - args['use_id'] = data.user - args['query'] = data.query - args['conversation_id'] = data.conversation_id - wf_event = Workflow_started_DifyChatResponseEvent(**args) - yield ChatStreamResponse.convert_event(wf_event) - - # Stream a blank message to start the stream - # 发送一个空消息事件 - #yield ChatStreamResponse.convert_text("") - yield output if await request.is_disconnected(): break - - @v.post("/chat-messages") async def post_conversations(request: Request, data: ChatRequestData): userMng.findNoExistCreate(data.user) - data.conversation_id = default_conversation_id if data.conversation_id is None else data.conversation_id + data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4()) conversaObj = conversations() - conversationinfo = conversaObj.get(data.user, data.conversation_id) + conversationinfo = conversaObj.get(data.conversation_id) if conversationinfo is None: - conversationinfo = conversaObj.add(data.user, "新建会话", data.conversation_id) + conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话") # 生成聊天参数 last_message_content = ChatMessage.from_str(data.query) @@ -359,27 +390,36 @@ async def post_conversations(request: Request, data: ChatRequestData): params = data.inputs or {} # 获取聊天引擎对象 - chat_engine = get_chat_engine(filters=filters, params=params) + chat_engine = get_chat_engine(filters=filters, params=params,prjFlag = data.prjFlag) # 启动聊天事件监听 - event_handler = ChatEventCallbackHandler() + ids = IDManager().createID() + event_handler = ChatEventCallbackHandler(ids = ids,data = data) chat_engine.callback_manager.handlers.append(event_handler) # type: ignore # 执行异步聊天 response = await chat_engine.astream_chat(data.query) # 返回异步消息回应 - return ChatStreamResponse(request, event_handler, response, data) + return ChatStreamResponse(request, event_handler, response, data,ids) @v.get("/messages") async def query_messages(user:str, conversation_id:str): - conversation_id = default_conversation_id if conversation_id is None else conversation_id + #conversation_id = default_conversation_id if conversation_id is None else conversation_id datas = [] records = message().gets(user,conversation_id) + if records is None: + return { + "limit": 20, + "has_more": False, + "data": [] + } + for record in records: res = record.dict() + feeds = feedback().query(res['id']) res["message_files"] = [] - res["feedback"] = '' + res["feedback"] = {'rating':feeds['rating'] } if feeds != None else '' res["retriever_resources"] = [] res["created_at"] = 1723444905 res["agent_thoughts"] = [] @@ -416,7 +456,7 @@ async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]): return 'null' @v.get("/conversations") -async def query_conversations(user:str): +async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None): user_id = '' if user is None else user userMng.findNoExistCreate(user_id) @@ -430,53 +470,27 @@ async def query_conversations(user:str): async def query_parameters(user:str): params = parameter().get(user) if len(params) == 0: - params = { - "opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!", - "suggested_questions": [], - "suggested_questions_after_answer": { - "enabled": False - }, - "speech_to_text": { - "enabled": False - }, - "text_to_speech": { - "enabled": False, - "language": "", - "voice": "" - }, - "retriever_resource": { - "enabled": True - }, - "annotation_reply": { - "enabled": False - }, - "more_like_this": { - "enabled": False - }, - "user_input_form": [], - "sensitive_word_avoidance": { - "enabled": False - }, - "file_upload": { - "image": { - "enabled": False, - "number_limits": 3, - "transfer_methods": [ - "remote_url" - ] - } - }, - "system_parameters": { - "image_file_size_limit": "10" - } - } + params = BaseConfig().ParamterCfg() return params -@r.post("") +@v.post("/messages/{message_id}/feedbacks") +async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]): + if params['rating'] =='null': + feedback().delete(message_id) + else: + condition = {'id':message_id} + results = message().query(**condition) + if len(results) > 0: + result = results[0] + feedback().add(message_id=message_id,query=result['query'], + answer=result['answer'],rating=params['rating']) + +@v.post("") def upload_file(request: ChatFileUploadRequest) -> List[str]: try: logger.info("Processing file") return FileLoadService.process_file(request.base64) except Exception as e: logger.error(f"Error processing file: {e}", exc_info=True) - raise HTTPException(status_code=500, detail="Error processing file") \ No newline at end of file + raise HTTPException(status_code=500, detail="Error processing file") + diff --git a/backend/app/api/routers/request/base.py b/backend/app/api/routers/request/base.py index b7b2ec8..234323b 100644 --- a/backend/app/api/routers/request/base.py +++ b/backend/app/api/routers/request/base.py @@ -18,14 +18,14 @@ class conversations: return datas - def get(self,user_id:str,id:str = ''): - records = dbManage.query(self._tableName,user_id = user_id,id=id) + def get(self, id:str): + records = dbManage.query(self._tableName, id=id) if len(records) >0: return records[0] return None - def add(self,user_id:str,name:str,id:str = ''): - template = BaseConfig.ConversationCfg + def add(self,id:str, user_id:str, name:str): + template = BaseConfig().ConversationCfg() template['id'] = id template['user_id'] = user_id template['name'] = name @@ -111,7 +111,7 @@ class message: return datas def add(self,user_id:str,conversation_id:str,query:str,answer:str): - template = BaseConfig.MessageCfg + template = BaseConfig.MessageCfg() template['id'] = str(uuid.uuid4()) template['user_id'] = user_id template['conversation_id'] = conversation_id @@ -122,4 +122,34 @@ class message: def delete(self,user_id:str): dbManage.delete(self._tableName,user_id = user_id) + def query(self,**condition): + results = [] + records = dbManage.query(self._tableName,**condition) + for record in records: + results.append(record.dict()) + return results +class feedback: + def __init__(self) -> None: + self._tableName = 'feedbacks' + dbManage.createTable(self._tableName) + + def add(self,message_id:str,query:str,answer:str,rating:str): + record = { + 'message_id': message_id, + 'query': query, + 'answer': answer, + 'rating': rating, + } + dbManage.addRecord(self._tableName,record) + + def delete(self,message_id:str): + cond = {'message_id':message_id} + dbManage.delete(self._tableName,**cond) + + def query(self,message_id:str): + cond = {'message_id':message_id} + records = dbManage.query(self._tableName,**cond) + if len(records) > 0: + return records[0].dict() + return None \ No newline at end of file diff --git a/backend/app/api/routers/request/baseConfig.py b/backend/app/api/routers/request/baseConfig.py index 7dce858..d254d8a 100644 --- a/backend/app/api/routers/request/baseConfig.py +++ b/backend/app/api/routers/request/baseConfig.py @@ -1,62 +1,80 @@ +from pydantic import BaseModel +import os +from enum import Enum -class BaseConfig: - ParamterCfg = { - "opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!", - "suggested_questions": [], - "suggested_questions_after_answer": { - "enabled": False - }, - "speech_to_text": { - "enabled": False - }, - "text_to_speech": { - "enabled": False, - "language": "", - "voice": "" - }, - "retriever_resource": { - "enabled": True - }, - "annotation_reply": { - "enabled": False - }, - "more_like_this": { - "enabled": False - }, - "user_input_form": [], - "sensitive_word_avoidance": { - "enabled": False - }, - "file_upload": { - "image": { +class BaseConfig(BaseModel): + projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!") + + def ParamterCfg(self): + questions = os.getenv("CONVERSATION_STARTERS", "dev") + return{ + "opening_statement": self.projectInfo, + "suggested_questions": questions.split('\n'), + "suggested_questions_after_answer": { + "enabled": False + }, + "speech_to_text": { + "enabled": False + }, + "text_to_speech": { "enabled": False, - "number_limits": 3, - "transfer_methods": [ - "remote_url" - ] + "language": "", + "voice": "" + }, + "retriever_resource": { + "enabled": True + }, + "annotation_reply": { + "enabled": False + }, + "more_like_this": { + "enabled": False + }, + "user_input_form": [], + "sensitive_word_avoidance": { + "enabled": False + }, + "file_upload": { + "image": { + "enabled": False, + "number_limits": 3, + "transfer_methods": [ + "remote_url" + ] + } + }, + "system_parameters": { + "image_file_size_limit": "10" } - }, - "system_parameters": { - "image_file_size_limit": "10" } - } + + def ConversationCfg(self): + return{ + "id": "", + 'user_id':'', + "name": "", + "inputs": {}, + "status": "normal", + "introduction": self.projectInfo, + "created_at":'' + } - ConversationCfg = { - "id": "", - 'user_id':'', - "name": "", - "inputs": {}, - "status": "normal", - "introduction": ParamterCfg['opening_statement'], - "created_at":'' - } - - - MessageCfg = { + @classmethod + def MessageCfg(cls): + return { "id": "", 'user_id':'', "conversation_id": "", "inputs": {}, "query": "", "answer": "" - } \ No newline at end of file + } + + +class ChatEventType(str, Enum): + WORKFLOW_START = "workflow_started" + WORKFLOW_FINISHED = "workflow_finished" + NODE_START = "node_started" + NODE_FINISHED = "node_finished" + MESSAGE = "message" + MESSAGE_END = "message_end" \ No newline at end of file diff --git a/backend/app/api/routers/request/dbOrm.py b/backend/app/api/routers/request/dbOrm.py index 796b90c..38af99d 100644 --- a/backend/app/api/routers/request/dbOrm.py +++ b/backend/app/api/routers/request/dbOrm.py @@ -2,7 +2,7 @@ import os from typing import Dict, List, Any from pydantic import BaseModel -from sqlalchemy import create_engine, Column, String, Integer, JSON +from sqlalchemy import create_engine, Column, String, Integer, JSON,Float from sqlalchemy.engine.reflection import Inspector from sqlalchemy.orm import sessionmaker, declarative_base @@ -24,10 +24,6 @@ class ConversationOrm(Base): if 'name' in data: self.name = data['name'] - - - - class UserOrm(Base): __tablename__ = "user" @@ -51,6 +47,14 @@ class MessagesOrm(Base): query = Column(String) answer = Column(String) +class FeedBackOrm(Base): + __tablename__ = "feedbacks" + + message_id = Column(String,primary_key=True) + query = Column(String) + answer = Column(String) + rating = Column(String) + #数据结构 class ConversationModel(BaseModel): id: str @@ -61,7 +65,6 @@ class ConversationModel(BaseModel): created_at: int class Config: - #orm_mode = True from_attributes=True @classmethod @@ -73,7 +76,6 @@ class UserModel(BaseModel): createtime: str class Config: - #orm_mode = True from_attributes=True @classmethod @@ -86,7 +88,6 @@ class ParametersModel(BaseModel): value : Dict[str, Any] class Config: - #orm_mode = True from_attributes=True @classmethod @@ -101,13 +102,25 @@ class MessagesModel(BaseModel): answer : str class Config: - #orm_mode = True from_attributes=True @classmethod def orm(cls): return MessagesOrm +class FeedBackModel(BaseModel): + message_id :str + query :str + answer :str + rating :str + + class Config: + from_attributes=True + + @classmethod + def orm(cls): + return FeedBackOrm + class DBManager: def __init__(self) -> None: DATABASE_URL = os.getenv("SQLITE_DATABASE_URL") diff --git a/backend/app/api/routers/request/models.py b/backend/app/api/routers/request/models.py index d76af75..983999c 100644 --- a/backend/app/api/routers/request/models.py +++ b/backend/app/api/routers/request/models.py @@ -1,7 +1,7 @@ from typing import Dict, Any from pydantic import BaseModel - +from typing import Optional class ChatRequestData(BaseModel): inputs: Dict[str,Any] @@ -10,6 +10,9 @@ class ChatRequestData(BaseModel): response_mode: str files: Any conversation_id: str = None + prjFlag:Optional[str] = '' class ChatFileUploadRequest(BaseModel): - base64: str \ No newline at end of file + base64: str + + diff --git a/backend/app/api/routers/services/fileServices.py b/backend/app/api/routers/services/fileServices.py index d63dc47..ac5ea58 100644 --- a/backend/app/api/routers/services/fileServices.py +++ b/backend/app/api/routers/services/fileServices.py @@ -6,7 +6,7 @@ from app.settings import init_settings from app.engine.loaders import get_document_Types, get_documents,getFileCacahePath from app.engine.vectordb import get_vector_store from app.engine.generate import get_doc_store,run_pipeline,persist_storage - +import tempfile STORAGE_DIR = os.getenv("STORAGE_DIR", "storage") @@ -25,31 +25,31 @@ class FileLoadService: url = load_url, data=response1.text ) - - with open('example.zip','wb') as file: + tempFilePath:str = tempfile.gettempdir() + f"\\{str(uuid4())}.zip" + with open(tempFilePath,'wb') as file: file.write(response2.content) prjID = str(uuid4()) filePath = getFileCacahePath() + f'/Projects/{prjID}' os.makedirs(filePath) import zipfile - with zipfile.ZipFile('example.zip','r') as zip_File: + with zipfile.ZipFile(tempFilePath,'r') as zip_File: for zip_info in zip_File.infolist(): zip_info.filename = zip_info.filename.encode('cp437').decode('gbk') zip_File.extract(zip_info,filePath) - os.remove('example.zip') + os.remove(tempFilePath) return f'Projects_{prjID}' @staticmethod - def process_file(base64_content: str) -> List[str]: - docType = FileLoadService.store_and_parse_file(base64_content) + def process_file(base64_content: str) -> str: + prjFlag = FileLoadService.store_and_parse_file(base64_content) #生成向量并持久化至本地 - init_settings() - documents = get_documents(docType) + documents = get_documents(prjFlag) for doc in documents: doc.metadata["private"] = "false" - docstore = get_doc_store(docType) - vector_store = get_vector_store(docType) + docstore = get_doc_store(prjFlag) + vector_store = get_vector_store(prjFlag) _ = run_pipeline(docstore, vector_store, documents) persist_storage(docstore, vector_store) + return prjFlag diff --git a/backend/app/api/services/file.py b/backend/app/api/services/file.py index e8eb54c..a478570 100644 --- a/backend/app/api/services/file.py +++ b/backend/app/api/services/file.py @@ -87,9 +87,7 @@ class PrivateFileService: nodes = pipeline.run(documents=documents) # Add the nodes to the index and persist it - indexs = get_index() - if len(indexs) > 0: - current_index = list(indexs.values())[0] + current_index = get_index() # Insert the documents into the index if isinstance(current_index, LlamaCloudIndex): diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py index 4ccc26a..4d9b128 100644 --- a/backend/app/engine/__init__.py +++ b/backend/app/engine/__init__.py @@ -10,12 +10,11 @@ from app.engine.index import get_index from app.engine.tools import ToolFactory -def get_chat_engine(filters=None, params=None): +def get_chat_engine(filters=None, params=None,**args): system_prompt = os.getenv("SYSTEM_PROMPT") top_k = int(os.getenv("TOP_K", "3")) use_reranker = os.getenv("RERANK_ENABLED") tools = [] - # 创建SQL查询工具 # sql_query_engine = create_summary_query_engine(index) # sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine, @@ -25,9 +24,7 @@ def get_chat_engine(filters=None, params=None): #tools.append(sql_query_tool) # Add query tool if index exists - indexs = get_index() - if len(indexs) > 0: - index = list(indexs.values())[0] + index = get_index(**args) if index is not None: summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters) summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool", diff --git a/backend/app/engine/generate.py b/backend/app/engine/generate.py index 87ecfa1..1581194 100644 --- a/backend/app/engine/generate.py +++ b/backend/app/engine/generate.py @@ -5,7 +5,7 @@ load_dotenv() import logging import os -from app.engine.loaders import get_documents +from app.engine.loaders import get_document_Types, get_documents from app.engine.vectordb import get_vector_store from app.settings import init_settings from app.engine.retriever.CHBM25Retriever import CHBM25Retriever @@ -21,12 +21,13 @@ logger = logging.getLogger() STORAGE_DIR = os.getenv("STORAGE_DIR", "storage") -def get_doc_store(): +def get_doc_store(docType:str): # If the storage directory is there, load the document store from it. # If not, set up an in-memory document store since we can't load from a directory that doesn't exist. - if os.path.exists(STORAGE_DIR): - return SimpleDocumentStore.from_persist_dir(STORAGE_DIR) + storeDir = os.path.join(STORAGE_DIR,docType) + if os.path.exists(storeDir): + return SimpleDocumentStore.from_persist_dir(storeDir) else: return SimpleDocumentStore() @@ -71,19 +72,20 @@ def generate_datasource(): logger.info("Generate index for the provided data") # Get the stores and documents or create new ones - documents = get_documents() - # Set private=false to mark the document as public (required for filtering) - for doc in documents: - doc.metadata["private"] = "false" - docstore = get_doc_store() - vector_store = get_vector_store() + docTypes = get_document_Types() + for docType in docTypes: + documents = get_documents(docType) + # Set private=false to mark the document as public (required for filtering) + for doc in documents: + doc.metadata["private"] = "false" + docstore = get_doc_store(docType) + vector_store = get_vector_store(docType) - # Run the ingestion pipeline - _ = run_pipeline(docstore, vector_store, documents) + # Run the ingestion pipeline + _ = run_pipeline(docstore, vector_store, documents) - # Build the index and persist storage - persist_storage(docstore, vector_store) - persist_BMRetriever(vector_store) + # Build the index and persist storage + persist_storage(docstore, vector_store) logger.info("Finished generating the index") diff --git a/backend/app/engine/index.py b/backend/app/engine/index.py index 24f4fd1..2957b26 100644 --- a/backend/app/engine/index.py +++ b/backend/app/engine/index.py @@ -2,22 +2,20 @@ import logging from llama_index.core.indices import VectorStoreIndex from app.engine.vectordb import get_vector_store from app.engine.loaders import get_document_Types - +from typing import Dict,Any logger = logging.getLogger("uvicorn") -indexs = {} - -def get_index(params=None): - global indexs - if len(index) <= 0: - logger.info("Connecting vector store...") - docTypes = get_document_Types() - for docType in docTypes: - store = get_vector_store(docType) - # Load the index from the vector store - # If you are using a vector store that doesn't store text, - # you must load the index from both the vector store and the document store - index = VectorStoreIndex.from_vector_store(store) - logger.info("Finished load index from vector store.") - indexs[docType] = index - return indexs +def get_index(**args): + logger.info("Connecting vector store...") + prjFlags = get_document_Types() + if len(prjFlags)<=0: + return None + prjFlag = args.get('prjFlag','') + flag = prjFlags[0] if prjFlag not in prjFlags else prjFlag + store = get_vector_store(flag) + # Load the index from the vector store + # If you are using a vector store that doesn't store text, + # you must load the index from both the vector store and the document store + index = VectorStoreIndex.from_vector_store(store) + logger.info("Finished load index from vector store.") + return index diff --git a/backend/app/engine/tools/__init__.py b/backend/app/engine/tools/__init__.py index 1aced70..054308e 100644 --- a/backend/app/engine/tools/__init__.py +++ b/backend/app/engine/tools/__init__.py @@ -1,10 +1,9 @@ -import os -import yaml -import json import importlib -from cachetools import cached, LRUCache -from llama_index.core.tools.tool_spec.base import BaseToolSpec +import os + +import yaml from llama_index.core.tools.function_tool import FunctionTool +from llama_index.core.tools.tool_spec.base import BaseToolSpec class ToolType: @@ -46,7 +45,7 @@ class ToolFactory: def from_env() -> list[FunctionTool]: tools = [] if os.path.exists("config/tools.yaml"): - with open("config/tools.yaml", "r") as f: + with open("config/tools.yaml", "r", encoding='UTF-8') as f: tool_configs = yaml.safe_load(f) if tool_configs != None and len(tool_configs.items()) != 0: for tool_type, config_entries in tool_configs.items(): diff --git a/backend/config/loaders.yaml b/backend/config/loaders.yaml index b19f033..af5d2fe 100644 --- a/backend/config/loaders.yaml +++ b/backend/config/loaders.yaml @@ -3,46 +3,46 @@ file: # use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable use_llama_parse: false -#db: +db: # The configuration for the database loader, only supports MySQL and PostgreSQL databases for now. # uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db # query: The query to fetch data from the database. E.g.: SELECT * FROM table - #- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 - #enable: true # 添加 enable 字段 - #queries: - #- sql: select * from ProjectProperties; - #explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。" + - uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 + enable: true # 添加 enable 字段 + queries: + - sql: select * from ProjectProperties; + explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。" - #- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable; - #explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。" + - sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable; + explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。" - #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路'; - #explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" - #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理'; - #explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" - #- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路'; - #explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路'; + explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理'; + explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路'; + explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" - #- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee; - #explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中" + - sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee; + explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中" - #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表' - # explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa' - #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表' - #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)' - #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)' - #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - #- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)' - #explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" - #- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路' - #explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" - #- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理' - #explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路' + explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理' + explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" #web: # driver_arguments: # # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode diff --git a/backend/pyproject.toml b/backend/pyproject.toml index de1fbbb..981083d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -17,7 +17,7 @@ aiostream = "^0.6.2" llama-index = "0.10.63" cachetools = "^5.3.3" protobuf = "4.25.4" -nltk = "^3.8.2" +nltk = "^3.9.1" jieba = "^0.42.1" #arize-phoenix = "^4.12.0" @@ -35,6 +35,7 @@ chroma="^0.2.0" llama-index-vector-stores-chroma = "^0.1.10" llama-index-readers-json = "^0.1.5" llama-index-retrievers-bm25 = "^0.2.2" +llama-index-experimental = "^0.2.0" duckduckgo_search = "^6.2.6" @@ -62,6 +63,12 @@ version = "^0.8" version = "0.0.7" + +[[tool.poetry.source]] +name = "mirrors" +url = "https://pypi.tuna.tsinghua.edu.cn/simple/" +priority = "default" + [build-system] requires = [ "poetry-core" ] build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/backend/tests/query.py b/backend/tests/query.py index f45d0b5..8c82d28 100644 --- a/backend/tests/query.py +++ b/backend/tests/query.py @@ -19,9 +19,7 @@ def main(): init_settings() init_observability() - indexs = get_index() - if len(indexs) > 0: - index = list(indexs.values())[0] + index = get_index() top_k = 5 filters = generate_filters([])