diff --git a/backend/.env.example b/backend/.env.example index 37ba235..83e69c1 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -2,6 +2,7 @@ # LLAMA_CLOUD_API_KEY= SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 #SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2 +SQLITE_DATABASE_URL=sqlite:///./source.db DASHSCOPE_API_KEY=sk-02c8540e86d84b7ca0e6f4f51bac6e60 # The provider for the AI models to use. diff --git a/backend/.env.xinference b/backend/.env.xinference index 6dd566f..1dc074c 100644 --- a/backend/.env.xinference +++ b/backend/.env.xinference @@ -2,6 +2,7 @@ # LLAMA_CLOUD_API_KEY= SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 #SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2 +SQLITE_DATABASE_URL=sqlite:///./source.db # The number of similar embeddings to return when retrieving documents. TOP_K=10 diff --git a/backend/app/api/routers/app.py b/backend/app/api/routers/app.py new file mode 100644 index 0000000..cb53bd2 --- /dev/null +++ b/backend/app/api/routers/app.py @@ -0,0 +1,487 @@ +import asyncio +import json +import logging +import time +from typing import Dict, List, Any, Optional, AsyncGenerator + +from aiostream import stream +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse +from llama_index.core import BaseCallbackHandler +from llama_index.core.base.llms.types import ChatMessage +from llama_index.core.callbacks import CBEventType +from llama_index.core.chat_engine.types import StreamingAgentChatResponse +from llama_index.core.tools import ToolOutput +from pydantic import BaseModel +from app.api.routers.request.base import userMng, conversations,message,parameter +from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest +from app.engine import get_chat_engine +import uuid + +logger = logging.getLogger("uvicorn") + +api_router = r = APIRouter() +v1_router = v = APIRouter() + +class ChatCallbackEvent(BaseModel): + event_type: CBEventType + payload: Optional[Dict[str, Any]] = None + event_id: str = "" + + def get_retrieval_message(self) -> dict | None: + if self.payload: + nodes = self.payload.get("nodes") + if nodes: + msg = f"根据查询检索到 {len(nodes)} 源文件" + else: + msg = f"查询检索中: '{self.payload.get('query_str')}'" + return { + "type": "events", + "data": {"title": msg}, + } + else: + return None + + def get_tool_message(self) -> dict | None: + func_call_args = self.payload.get("function_call") + if func_call_args is not None and "tool" in self.payload: + tool = self.payload.get("tool") + return { + "type": "events", + "data": { + "title": f"调用工具 {tool.name} ,参数: {func_call_args}", + }, + } + + def _is_output_serializable(self, output: Any) -> bool: + try: + json.dumps(output) + return True + except TypeError: + return False + + def get_agent_tool_response(self) -> dict | None: + response = self.payload.get("response") + if response is not None: + sources = response.sources + for source in sources: + # Return the tool response here to include the toolCall information + if isinstance(source, ToolOutput): + if self._is_output_serializable(source.raw_output): + output = source.raw_output + else: + output = source.content + + return { + "type": "tools", + "data": { + "toolOutput": { + "output": output, + "isError": source.is_error, + }, + "toolCall": { + "id": None, # There is no tool id in the ToolOutput + "name": source.tool_name, + "input": source.raw_input, + }, + }, + } + + def to_response(self): + try: + match self.event_type: + case "retrieve": + return self.get_retrieval_message() + case "function_call": + return self.get_tool_message() + case "agent_step": + return self.get_agent_tool_response() + case _: + return None + except Exception as e: + logger.error(f"转换回应时间时发生错误,原因: {e}") + return None + +class ChatEventCallbackHandler(BaseCallbackHandler): + _aqueue: asyncio.Queue + is_done: bool = False + + def __init__( + self, + ): + """Initialize the base callback handler.""" + ignored_events = [ + # CBEventType.CHUNKING, + # CBEventType.NODE_PARSING, + # CBEventType.EMBEDDING, + # CBEventType.LLM, + # CBEventType.TEMPLATING, + ] + super().__init__(ignored_events, ignored_events) + self._aqueue = asyncio.Queue() + + def on_event_start( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> str: + logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload)) + + event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload) + if event.to_response() is not None: + self._aqueue.put_nowait(event) + + def on_event_end( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> None: + logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload)) + event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload) + if event.to_response() is not None: + self._aqueue.put_nowait(event) + + def start_trace(self, trace_id: Optional[str] = None) -> None: + """No-op.""" + logger.info("trace_start:{}\n".format(trace_id)) + + def end_trace( + self, + trace_id: Optional[str] = None, + trace_map: Optional[Dict[str, List[str]]] = None, + ) -> None: + """No-op.""" + logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map)) + + async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]: + while not self._aqueue.empty() or not self.is_done: + try: + yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1) + except asyncio.TimeoutError: + pass + +class IDManager: + def createID(self): + return { + "message_id" : str(uuid.uuid4()), + 'task_id':str(uuid.uuid4()), + 'workflow_run_id': str(uuid.uuid4()), + "workflow_id": str(uuid.uuid4()) + } + +class DifyChatResponseEvent(BaseModel): + event: str + conversation_id: str + message_id: str + created_at: int = int(time.time()) + task_id: str + +class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent): + workflow_run_id:str + data:Dict[str,Any] + def __init__(self,**args): + args['data'] = { + "id": args['workflow_run_id'], + "workflow_id": args['workflow_id'], + "sequence_number": 1709, + "inputs": { + "sys.query": args['query'], + "sys.files": [], + "sys.conversation_id": args['conversation_id'], + "sys.user_id": args['use_id'] + }, + "created_at": int(time.time()) + } + args['event'] = 'workflow_started' + super().__init__(**args) + +class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent): + workflow_run_id:str + data:Dict[str,Any] + def __init__(self,**args): + args['event'] = 'workflow_finished' + args['data'] = { + "id": args['workflow_run_id'], + "workflow_id": args['workflow_id'], + "sequence_number": 1709, + "status": "succeeded", + "outputs": { + "answer": args['response'] + }, + "error": '', + "elapsed_time": 36.03764106379822, + "total_tokens": 11707, + "total_steps": 10, + "created_by": { + "id": str(uuid.uuid4()), + "user": args['use_id'] + }, + "created_at": int(time.time()), + "finished_at": int(time.time()), + "files": [] + } + super().__init__(**args) + +class Message_DifyChatResponseEvent(DifyChatResponseEvent): + id:str + answer:str + def __init__(self,**args): + args['id'] = args['message_id'] + args['event'] = 'message' + super().__init__(**args) + +class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent): + id:str + metadata:Dict[str,Any] = {} + def __init__(self,**args): + args['id'] = args['message_id'] + args['event'] = 'message_end' + super().__init__(**args) + +class ChatStreamResponse(StreamingResponse): + TEXT_PREFIX = "data: " + DATA_PREFIX = "data: " + + @classmethod + def convert_text(cls, token: str): + # Escape newlines and double quotes to avoid breaking the stream + #token = json.dumps(token) + + #return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}" + return "\n" + + @classmethod + def convert_data(cls, data: dict): + data_str = json.dumps(data) + return f"{cls.DATA_PREFIX}{data_str}\n\n" + + @classmethod + def convert_event(cls, event: DifyChatResponseEvent): + data_str = json.dumps(event.dict()) + return f"{cls.DATA_PREFIX}{data_str}\n\n" + + def __init__( + self, + request: Request, + event_handler: ChatEventCallbackHandler, + response: StreamingAgentChatResponse, + data: ChatRequestData + ): + content = ChatStreamResponse.content_generator( + request, event_handler, response, data + ) + super().__init__(content=content) + + @classmethod + async def content_generator( + cls, + request: Request, + event_handler: ChatEventCallbackHandler, + response: StreamingAgentChatResponse, + data: ChatRequestData + ): + ids = IDManager().createID() + # Yield the text response + async def _chat_response_generator(): + final_response = "" + async for token in response.async_response_gen(): + final_response += token + args = ids + args['answer'] = token + args['conversation_id'] = data.conversation_id + event = Message_DifyChatResponseEvent(**args) + yield ChatStreamResponse.convert_event(event) + #yield ChatStreamResponse.convert_text(token) + + # 存储消息历史 + message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response) + + # the text_generator is the leading stream, once it's finished, also finish the event stream + event_handler.is_done = True + # 发送工作流结束事件 + args = ids + args['response'] = final_response + args['conversation_id'] = data.conversation_id + wf_event = Workflow_finished_DifyChatResponseEvent(**args) + yield ChatStreamResponse.convert_event(wf_event) + + msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids) + yield ChatStreamResponse.convert_event(msgEnt_event) + + + # Yield the events from the event handler + async def _event_generator(): + async for event in event_handler.async_event_gen(): + event_response = event.to_response() + if event_response is not None: + yield ChatStreamResponse.convert_text("") + + combine = stream.merge(_chat_response_generator(), _event_generator()) + is_stream_started = False + async with combine.stream() as streamer: + async for output in streamer: + if not is_stream_started: + is_stream_started = True + + # 发送工作流开始事件 + args = ids + args['use_id'] = data.user + args['query'] = data.query + args['conversation_id'] = data.conversation_id + wf_event = Workflow_started_DifyChatResponseEvent(**args) + yield ChatStreamResponse.convert_event(wf_event) + + # Stream a blank message to start the stream + # 发送一个空消息事件 + #yield ChatStreamResponse.convert_text("") + + yield output + + if await request.is_disconnected(): + break + + + +@v.post("/chat-messages") +async def post_conversations(request: Request, data: ChatRequestData): + userMng.findNoExistCreate(data.user) + data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4()) + + conversaObj = conversations() + conversationinfo = conversaObj.get(data.conversation_id) + if conversationinfo is None: + conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话") + + # 生成聊天参数 + last_message_content = ChatMessage.from_str(data.query) + filters = None + params = data.inputs or {} + + # 获取聊天引擎对象 + chat_engine = get_chat_engine(filters=filters, params=params) + + # 启动聊天事件监听 + event_handler = ChatEventCallbackHandler() + chat_engine.callback_manager.handlers.append(event_handler) # type: ignore + + # 执行异步聊天 + response = await chat_engine.astream_chat(data.query) + + # 返回异步消息回应 + return ChatStreamResponse(request, event_handler, response, data) + +@v.get("/messages") +async def query_messages(user:str, conversation_id:str): + #conversation_id = default_conversation_id if conversation_id is None else conversation_id + datas = [] + records = message().gets(user,conversation_id) + if records is None: + return { + "limit": 20, + "has_more": False, + "data": [] + } + + for record in records: + res = record.dict() + res["message_files"] = [] + res["feedback"] = '' + res["retriever_resources"] = [] + res["created_at"] = 1723444905 + res["agent_thoughts"] = [] + res["status"] = "normal" + res["error"] = '' + datas.append(res) + + return { + "limit": 20, + "has_more": False, + "data": datas + } + +@v.post("/conversations/{itemid}/name") +async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]): + consaObj = conversations() + consaObj.rename(itemid,'知识问答') + cond = { + 'id':itemid, + 'user_id':params['user'] + } + results = consaObj.query(**cond) + if len(results) > 0: + res = results[0] + return { + "id": res['id'], + "name": res['name'], + "inputs": res['inputs'], + "status": res['status'], + "introduction": res['introduction'], + "created_at": res['created_at'], + #"工程位置" + } + return 'null' + +@v.get("/conversations") +async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None): + user_id = '' if user is None else user + userMng.findNoExistCreate(user_id) + + return { + "limit": 20, + "has_more": False, + "data": conversations().gets(user_id) + } + +@v.get("/parameters") +async def query_parameters(user:str): + params = parameter().get(user) + if len(params) == 0: + params = { + "opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!", + "suggested_questions": [], + "suggested_questions_after_answer": { + "enabled": False + }, + "speech_to_text": { + "enabled": False + }, + "text_to_speech": { + "enabled": False, + "language": "", + "voice": "" + }, + "retriever_resource": { + "enabled": True + }, + "annotation_reply": { + "enabled": False + }, + "more_like_this": { + "enabled": False + }, + "user_input_form": [], + "sensitive_word_avoidance": { + "enabled": False + }, + "file_upload": { + "image": { + "enabled": False, + "number_limits": 3, + "transfer_methods": [ + "remote_url" + ] + } + }, + "system_parameters": { + "image_file_size_limit": "10" + } + } + return params + +@r.post("") +def upload_file(request: ChatFileUploadRequest) -> List[str]: + pass \ No newline at end of file diff --git a/backend/app/api/routers/request/base.py b/backend/app/api/routers/request/base.py new file mode 100644 index 0000000..bb90305 --- /dev/null +++ b/backend/app/api/routers/request/base.py @@ -0,0 +1,125 @@ +from datetime import datetime +import uuid +from app.api.routers.request.baseConfig import BaseConfig +from app.api.routers.request.dbOrm import DBManager + +dbManage = DBManager() + +class conversations: + def __init__(self) -> None: + self._tableName = 'conversations' + dbManage.createTable(self._tableName) + + def gets(self,user_id:str): + records = dbManage.query(self._tableName,user_id = user_id) + datas = [] + for record in records: + datas.append(record) + + return datas + + def get(self, id:str): + records = dbManage.query(self._tableName, id=id) + if len(records) >0: + return records[0] + return None + + def add(self,id:str, user_id:str, name:str): + template = BaseConfig.ConversationCfg + template['id'] = id + template['user_id'] = user_id + template['name'] = name + template['created_at'] = 1724399038 + dbManage.addRecord(self._tableName,template) + + def delete(self,id:str): + dbManage.delete(self._tableName,id=id) + + def rename(self,id:str,name:str): + data = {'name':name} + dbManage.update(self._tableName,data,id=id) + + def query(self,**condition): + results = [] + records = dbManage.query(self._tableName,**condition) + for record in records: + results.append(record.dict()) + return results + +class user: + def __init__(self) -> None: + self._tableName = 'user' + dbManage.createTable(self._tableName) + + def gets(self): + return dbManage.query(self._tableName) + + def get(self,id:str): + return dbManage.query(self._tableName,id = id) + + def add(self,id:str): + info = { + 'id':id, + 'createtime': datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + dbManage.addRecord(self._tableName,info) + + def delete(self,id:str): + dbManage.delete(self._tableName,id = id) + +class userMng: + userObj = user() + @classmethod + def findNoExistCreate(cls,user_id:str): + userInfo = cls.userObj.get(user_id) + if len(userInfo) == 0: + cls.userObj.add(user_id) + + def remove(cls,user_id:str): + cls.userObj.delete(user_id) + +class parameter: + def __init__(self) -> None: + self._tableName = 'parameters' + dbManage.createTable(self._tableName) + + def get(self,user_id:str): + records = dbManage.query(self._tableName,user_id = user_id) + data = {} + for record in records: + key = record['name'] + value = record['value'] + data[key] = value + return data + + def set(self,user_id:str): + dbManage.addRecord(self._tableName,{}) + + def delete(self,user_id:str): + dbManage.delete(self._tableName,user_id = user_id) + +class message: + def __init__(self) -> None: + self._tableName = 'messages' + dbManage.createTable(self._tableName) + + def gets(self,user_id:str,conversation_id:str): + records = dbManage.query(self._tableName,user_id = user_id,conversation_id = conversation_id) + datas = [] + for record in records: + datas.append(record) + return datas + + def add(self,user_id:str,conversation_id:str,query:str,answer:str): + template = BaseConfig.MessageCfg + template['id'] = str(uuid.uuid4()) + template['user_id'] = user_id + template['conversation_id'] = conversation_id + template['query'] = query + template['answer'] = answer + dbManage.addRecord(self._tableName,template) + + def delete(self,user_id:str): + dbManage.delete(self._tableName,user_id = user_id) + + diff --git a/backend/app/api/routers/request/baseConfig.py b/backend/app/api/routers/request/baseConfig.py new file mode 100644 index 0000000..7dce858 --- /dev/null +++ b/backend/app/api/routers/request/baseConfig.py @@ -0,0 +1,62 @@ + +class BaseConfig: + ParamterCfg = { + "opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!", + "suggested_questions": [], + "suggested_questions_after_answer": { + "enabled": False + }, + "speech_to_text": { + "enabled": False + }, + "text_to_speech": { + "enabled": False, + "language": "", + "voice": "" + }, + "retriever_resource": { + "enabled": True + }, + "annotation_reply": { + "enabled": False + }, + "more_like_this": { + "enabled": False + }, + "user_input_form": [], + "sensitive_word_avoidance": { + "enabled": False + }, + "file_upload": { + "image": { + "enabled": False, + "number_limits": 3, + "transfer_methods": [ + "remote_url" + ] + } + }, + "system_parameters": { + "image_file_size_limit": "10" + } + } + + ConversationCfg = { + "id": "", + 'user_id':'', + "name": "", + "inputs": {}, + "status": "normal", + "introduction": ParamterCfg['opening_statement'], + "created_at":'' + } + + + MessageCfg = { + "id": "", + 'user_id':'', + "conversation_id": "", + "inputs": {}, + "query": "", + "answer": "" + } \ No newline at end of file diff --git a/backend/app/api/routers/request/dbOrm.py b/backend/app/api/routers/request/dbOrm.py new file mode 100644 index 0000000..796b90c --- /dev/null +++ b/backend/app/api/routers/request/dbOrm.py @@ -0,0 +1,207 @@ +import os +from typing import Dict, List, Any + +from pydantic import BaseModel +from sqlalchemy import create_engine, Column, String, Integer, JSON +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.orm import sessionmaker, declarative_base + +Base = declarative_base() + +#orm类 +class ConversationOrm(Base): + __tablename__ = "conversations" + + id = Column(String, primary_key=True) + user_id = Column(String) + name = Column(String) + inputs = Column(JSON) + status = Column(String) + introduction = Column(String) + created_at = Column(Integer) + + def update(self,data:Dict[str,Any]): + if 'name' in data: + self.name = data['name'] + + + + + +class UserOrm(Base): + __tablename__ = "user" + + id = Column(String, primary_key=True) + createtime = Column(String) + +class ParametersOrm(Base): + __tablename__ = "parameters" + + user_id = Column(String,primary_key=True) + name = Column(String) + value = Column(JSON) + +class MessagesOrm(Base): + __tablename__ = "messages" + + id = Column(String,primary_key=True) + user_id = Column(String) + conversation_id = Column(String) + inputs = Column(JSON) + query = Column(String) + answer = Column(String) + +#数据结构 +class ConversationModel(BaseModel): + id: str + name: str + inputs: Dict[str, Any] + status: str + introduction: str + created_at: int + + class Config: + #orm_mode = True + from_attributes=True + + @classmethod + def orm(cls): + return ConversationOrm + +class UserModel(BaseModel): + id: str + createtime: str + + class Config: + #orm_mode = True + from_attributes=True + + @classmethod + def orm(cls): + return UserOrm + +class ParametersModel(BaseModel): + user_id : str + name : str + value : Dict[str, Any] + + class Config: + #orm_mode = True + from_attributes=True + + @classmethod + def orm(cls): + return ParametersOrm + +class MessagesModel(BaseModel): + id :str + conversation_id :str + inputs : Dict[str, Any] + query : str + answer : str + + class Config: + #orm_mode = True + from_attributes=True + + @classmethod + def orm(cls): + return MessagesOrm + +class DBManager: + def __init__(self) -> None: + DATABASE_URL = os.getenv("SQLITE_DATABASE_URL") + self._engine = create_engine(DATABASE_URL) + self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self._engine) + + def createTable(self,tableName:str): + if self._engine is None: + return + if not self.exist(tableName): + Base.metadata.tables[tableName].create(self._engine) + + def addRecord(self,tableName:str,record:Dict[str,Any]): + ormCls = self._get_orm(tableName) + if ormCls is None: + return + session = self.SessionLocal() + data = ormCls(**record) + session.add(data) + session.commit() + + def addRecords(self,tableName:str,records:List[Dict[str,Any]]): + ormCls = self._get_orm(tableName) + if ormCls is None: + return + datas = [] + session = self.SessionLocal() + for record in records: + datas.append(ormCls(**record)) + session.add(datas) + session.commit() + + def delete(self,tableName:str,**filter): + session = self.SessionLocal() + ormCls = self._get_orm(tableName) + if ormCls is None: + return + records = session.query(ormCls).filter_by(**filter).all() + if records is not None: + session.delete(records) + session.commit() + + def update(self,tableName:str,data:Dict[str,Any],**filter): + if not self.exist(tableName): + return + session = self.SessionLocal() + ormCls = self._get_orm(tableName) + if ormCls is None: + return + if len(filter) > 0: + records = session.query(ormCls).filter_by(**filter).all() + else: + records = session.query(ormCls).all() + for record in records: + if record is not None: + record.update(data) + session.commit() + + def query(self,tableName:str,**filter): + session = self.SessionLocal() + ormCls = self._get_orm(tableName) + if ormCls is None: + return + modelCls = self._get_model(ormCls) + if modelCls is None: + return + + if filter is not None: + records = session.query(ormCls).filter_by(**filter).all() + else: + records = session.query(ormCls).all() + + datas = [] + for record in records: + datas.append(modelCls.from_orm(record)) + return datas + + def exist(self,tableName:str)->bool: + if self._engine is None: + return + inspector = Inspector.from_engine(self._engine) + return inspector.has_table(tableName) + + def _get_orm(self,tableName:str): + subClss = Base.__subclasses__() + for sunCls in subClss: + if sunCls.__tablename__ == tableName: + return sunCls + return None + + def _get_model(self,orm:Any): + subClss = BaseModel.__subclasses__() + for sunCls in subClss: + if 'orm' in sunCls.__dict__ and sunCls.orm() == orm: + return sunCls + return None + diff --git a/backend/app/api/routers/request/models.py b/backend/app/api/routers/request/models.py new file mode 100644 index 0000000..d76af75 --- /dev/null +++ b/backend/app/api/routers/request/models.py @@ -0,0 +1,15 @@ + +from typing import Dict, Any +from pydantic import BaseModel + + +class ChatRequestData(BaseModel): + inputs: Dict[str,Any] + query: str + user: str + response_mode: str + files: Any + conversation_id: str = None + +class ChatFileUploadRequest(BaseModel): + base64: str \ No newline at end of file diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py index 1eaf1fe..4ee1c9c 100644 --- a/backend/app/engine/__init__.py +++ b/backend/app/engine/__init__.py @@ -4,7 +4,7 @@ from llama_index.core.agent import AgentRunner, ReActChatFormatter from llama_index.core.settings import Settings from llama_index.core.tools.query_engine import QueryEngineTool -from app.engine.engine import create_query_engine, create_summary_query_engine, create_sql_query_engine +from app.engine.engine import create_query_engine, create_summary_query_engine from app.engine.index import get_index #from app.engine.loaders.db import makeDescriptionByEngine from app.engine.tools import ToolFactory @@ -17,11 +17,11 @@ def get_chat_engine(filters=None, params=None): tools = [] # 创建SQL查询工具 - sql_query_engine = create_sql_query_engine() - sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine, - name="zjdata_query_tool", - description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具" - ) +# sql_query_engine = create_summary_query_engine(index) + # sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine, + # name="zjdata_query_tool", + # description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具" + # ) #tools.append(sql_query_tool) # Add query tool if index exists @@ -31,13 +31,19 @@ def get_chat_engine(filters=None, params=None): summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool", description="适用于任何需要进行全面总结、概括的要求。", ) - query_engine = create_query_engine(index,top_k,use_reranker,filters) + query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "COMPACT") query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool", description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。", ) + + query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "TREE_SUMMARIZE") + query_engine_tool_1 = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool_1", + description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后,且在询问工程中单位的具体数值,例如用量,费率,合计,金额等的时候建议使用“zj_query_tool_1”工具。", + ) tools.append(summary_query_tool) tools.append(query_engine_tool) + tools.append(query_engine_tool_1) # Add additional tools tools += ToolFactory.from_env() diff --git a/backend/app/engine/engine.py b/backend/app/engine/engine.py index 6cb552f..4bbd993 100644 --- a/backend/app/engine/engine.py +++ b/backend/app/engine/engine.py @@ -52,8 +52,8 @@ def get_Retriever(index,**kwargs): sql_database = None sql_obj_index = None -# Create a sql query engine -def create_sql_query_engine(top_k=3, use_reranker=False, filters=None): +# Create a summary query engine +def create_summary_query_engine(top_k=3, use_reranker=False, filters=None): global sql_obj_index global sql_database if sql_obj_index is None or sql_database is None: @@ -86,7 +86,7 @@ def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None return summary_query_engine # Create a query engine -def create_query_engine(index, top_k=3, use_reranker=False, filters=None): +def create_query_engine(index, top_k=3, use_reranker=False, filters=None, response_mode=None): # 创建向量检索查询工具 postprocess = None if use_reranker: @@ -103,6 +103,7 @@ def create_query_engine(index, top_k=3, use_reranker=False, filters=None): node_postprocessors=postprocess, use_async=True, streaming=True, + ResponseMode = response_mode ) return query_engine \ No newline at end of file diff --git a/backend/app/engine/loaders/__init__.py b/backend/app/engine/loaders/__init__.py index a220170..5e4bf7e 100644 --- a/backend/app/engine/loaders/__init__.py +++ b/backend/app/engine/loaders/__init__.py @@ -1,5 +1,4 @@ import logging - import yaml from app.engine.loaders.db import DBLoaderConfig, get_db_documents from app.engine.loaders.file import FileLoaderConfig, get_file_documents @@ -9,7 +8,7 @@ logger = logging.getLogger(__name__) def load_configs(): - with open("config/loaders.yaml") as f: + with open("config/loaders.yaml",encoding='UTF-8') as f: configs = yaml.safe_load(f) return configs @@ -17,24 +16,26 @@ def load_configs(): def get_documents(): documents = [] config = load_configs() + if config is None or len(config.items()) == 0: - return documents + return documents for loader_type, loader_config in config.items(): - logger.info( - f"Loading documents from loader: {loader_type}, config: {loader_config}" - ) + if loader_config.get('enable', True): # 检查 enable 字段 + logger.info( + f"Loading documents from loader: {loader_type}, config: {loader_config}" + ) - loader_config = loader_config or [] - match loader_type: - case "file": - document = get_file_documents(FileLoaderConfig(**loader_config)) - case "web": - document = get_web_documents(WebLoaderConfig(**loader_config)) - case "db": - document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config]) - case _: - raise ValueError(f"Invalid loader type: {loader_type}") - documents.extend(document) + loader_config = loader_config or [] + match loader_type: + case "file": + document = get_file_documents(FileLoaderConfig(**loader_config)) + case "web": + document = get_web_documents(WebLoaderConfig(**loader_config)) + case "db": + document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config]) + case _: + raise ValueError(f"Invalid loader type: {loader_type}") + documents.extend(document) - return documents + return documents \ No newline at end of file diff --git a/backend/app/engine/loaders/db.py b/backend/app/engine/loaders/db.py index d6310e2..00c0381 100644 --- a/backend/app/engine/loaders/db.py +++ b/backend/app/engine/loaders/db.py @@ -2,17 +2,14 @@ import logging from typing import Any, List, Optional from llama_index.core import SQLDatabase, Document -from llama_index.core.objects import SQLTableSchema -from llama_index.core.readers.base import BaseReader from llama_index.readers.database import DatabaseReader from pydantic import BaseModel -from sqlalchemy import create_engine -from sqlalchemy import text +from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine logger = logging.getLogger(__name__) -class CustomDatabaseReader(BaseReader): +class CustomDatabaseReader(DatabaseReader): """Simple Database reader. Concatenates each row into Document used by LlamaIndex. @@ -85,19 +82,20 @@ class CustomDatabaseReader(BaseReader): Returns: List[Document]: A list of Document objects. """ - dco_str = "" + dco_str = "" + with self.sql_database.engine.connect() as connection: if query is None: raise ValueError("A query parameter is necessary to filter the data") else: result = connection.execute(text(query)) - dco_str = ", ".join( + dco_str += ", ".join( [f"{entry}" for entry in result.keys()] - ) + ) + "\n" for item in result.fetchall(): - # fetch each item + # Fetch each item record_str = ", ".join( [f"{entry}" for col, entry in zip(result.keys(), item)] ) @@ -111,45 +109,36 @@ class CustomDatabaseReader(BaseReader): class DBLoaderConfig(BaseModel): uri: str - queries: List[str] + queries: List[dict] -def get_db_documents(configs: list[DBLoaderConfig]): +def get_db_documents(configs: List[DBLoaderConfig]) -> List[Document]: docs = [] - if len(configs) == 0 or configs[0].uri == "": + if not configs or not configs[0].uri: logger.warning( f"Failed to load database, error message: uri is empty. Return as empty document list." ) return docs metadata = { - #'file_name':'', - 'file_type':'application/booway.document.zj', - #'file_path':'', - #'file_size':'', - #'creation_date':'', - #'last_modified_date':'', + 'file_type': 'application/booway.document.zj', } - #from llama_index.readers.database import DatabaseReader for entry in configs: engine = create_engine(entry.uri) sql_database = SQLDatabase(engine) - # table_schema_objs = makeDescriptionByEngine(sql_database) - # table_node_mapping = SQLTableNodeMapping(sql_database) - # - # nodes = table_node_mapping.to_nodes(table_schema_objs) - # for node in nodes: - # node.metadata.update(metadata) - # - # docs.extend(nodes) - - queries = entry.queries or [] loader = CustomDatabaseReader(sql_database) - for query in queries: + for query_dict in entry.queries: + query = query_dict.get("sql", "") + explanation = query_dict.get("explanation", "") logger.info(f"Loading data from database with query: {query}") documents = loader.load_data(query=query) - docs.extend(documents) - return docs + # 添加解释到元数据中 + for doc in documents: + doc.metadata["explanation"] = explanation + doc.metadata.update(metadata) # 更新或添加额外的元数据 + docs.append(doc) + + return docs \ No newline at end of file diff --git a/backend/app/engine/prompt.py b/backend/app/engine/prompt.py index 101b6bf..5871562 100644 --- a/backend/app/engine/prompt.py +++ b/backend/app/engine/prompt.py @@ -5,6 +5,8 @@ text_qa_template_str = ( "你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。" "你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答," "如同直接从文件中提取的内容。\n" + "知识库中已经导入一个工程的全部数据,请你站在当前工程的角度回答用户关于工程文件的问题。\n" + "例如:询问“此工程”指当前导入的工程。询问“此工程名称”指当前导入的工程的工程名称。\n" "## 技能\n" "### 技能 1: 数据查询与提供\n" @@ -39,15 +41,19 @@ refine_template_str = ( "这是原本的问题: {query_str}\n" "我们已经提供了回答: {existing_answer}\n" "现在我们有机会改进这个回答 " - "使用以下更多上下文(仅当需要用时)\n" + "使用以下更多上下文(仅当有助于改进回答时使用)\n" + "你需要仔细的判断新的上下文的信息与原本问题必须一个字都不差,如果有一点差别,那就不能改变我现有的回答。\n" + "在判断回答是否正确的时候,你应该仔细对比新的上下文中包含的信息是否与原本的问题一字不差,如果一字不差,才能当作新的正确回答。\n" + "如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n" + "判断一下如果原回答正确,且在新的上下文仍然包含正确的回答,请将新的回答与原回答一起返回。\n" "------------\n" "{context_msg}\n" "------------\n" - "根据新的上下文, 请改进原来的回答。" - "如果新的上下文没有用, 直接返回原本的回答。\n" - "如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n" + "如果回答中已经包含有正确答案,不要返回多余的解释等信息,只返回正确答案\n" + "如果是表结构或者是数据库的相关内容,仅用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n" "改进的回答: " ) + refine_template = PromptTemplate(refine_template_str) summary_template_str = ( diff --git a/backend/app/engine/tools/__init__.py b/backend/app/engine/tools/__init__.py index 1aced70..054308e 100644 --- a/backend/app/engine/tools/__init__.py +++ b/backend/app/engine/tools/__init__.py @@ -1,10 +1,9 @@ -import os -import yaml -import json import importlib -from cachetools import cached, LRUCache -from llama_index.core.tools.tool_spec.base import BaseToolSpec +import os + +import yaml from llama_index.core.tools.function_tool import FunctionTool +from llama_index.core.tools.tool_spec.base import BaseToolSpec class ToolType: @@ -46,7 +45,7 @@ class ToolFactory: def from_env() -> list[FunctionTool]: tools = [] if os.path.exists("config/tools.yaml"): - with open("config/tools.yaml", "r") as f: + with open("config/tools.yaml", "r", encoding='UTF-8') as f: tool_configs = yaml.safe_load(f) if tool_configs != None and len(tool_configs.items()) != 0: for tool_type, config_entries in tool_configs.items(): diff --git a/backend/config/loaders.yaml b/backend/config/loaders.yaml index 66b19d9..af5d2fe 100644 --- a/backend/config/loaders.yaml +++ b/backend/config/loaders.yaml @@ -1,4 +1,5 @@ file: + enable: true # 添加 enable 字段 # use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable use_llama_parse: false @@ -7,14 +8,41 @@ db: # uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db # query: The query to fetch data from the database. E.g.: SELECT * FROM table - uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 - #- uri: mysql+pymysql://zjinfo:Y6EAjEEdSYmskA8B@110.42.234.166:3306/zjinfo -# - uri: mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2 + enable: true # 添加 enable 字段 queries: - - select * from ProjectProperties limit 30; - - select Name, Code, Amount, Amount_Total from TotalCalculateTable - - select SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 limit 50; - - select Name, Code, Rate, Amount from OtherFee + - sql: select * from ProjectProperties; + explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。" + - sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable; + explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。" + + - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路'; + explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理'; + explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路'; + explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" + + - sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee; + explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中" + + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)' + explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中" + + - sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路' + explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" + - sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理' + explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中" #web: # driver_arguments: # # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode diff --git a/backend/data/博微电力造价工程业务数据说明.docx b/backend/data/博微电力造价工程业务数据说明.docx index 425772f..670ce04 100644 Binary files a/backend/data/博微电力造价工程业务数据说明.docx and b/backend/data/博微电力造价工程业务数据说明.docx differ diff --git a/backend/data/工程造价基础知识.doc b/backend/data/工程造价基础知识.doc deleted file mode 100644 index 27d4d9f..0000000 Binary files a/backend/data/工程造价基础知识.doc and /dev/null differ diff --git a/backend/data/工程造价基础知识.docx b/backend/data/工程造价基础知识.docx new file mode 100644 index 0000000..bd7f91d Binary files /dev/null and b/backend/data/工程造价基础知识.docx differ diff --git a/backend/main.py b/backend/main.py index 0f5e9ad..dd12002 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from app.api.routers.chat import chat_router from app.api.routers.upload import file_upload_router +from app.api.routers.app import v1_router from app.settings import init_settings from app.observability import init_observability from fastapi.staticfiles import StaticFiles @@ -56,6 +57,8 @@ mount_static_files("data_output", "/api/files/output") app.include_router(chat_router, prefix="/api/chat") app.include_router(file_upload_router, prefix="/api/chat/upload") +app.include_router(v1_router, prefix="/v1") + @app.get("/") async def redirect_to_docs(): return RedirectResponse(url="/docs")