This commit is contained in:
chentianrui
2024-08-30 18:40:32 +08:00
7 changed files with 362 additions and 294 deletions
+1
View File
@@ -80,3 +80,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
- You can install any pip package (if it exists) by running a cell with pip install. - You can install any pip package (if it exists) by running a cell with pip install.
" "
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
+1
View File
@@ -111,3 +111,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
- You can install any pip package (if it exists) by running a cell with pip install. - You can install any pip package (if it exists) by running a cell with pip install.
" "
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
+232 -229
View File
@@ -3,6 +3,7 @@ import json
import logging import logging
import time import time
from typing import Dict, List, Any, Optional, AsyncGenerator from typing import Dict, List, Any, Optional, AsyncGenerator
from collections import deque
from aiostream import stream from aiostream import stream
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
@@ -13,7 +14,8 @@ from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput from llama_index.core.tools import ToolOutput
from pydantic import BaseModel from pydantic import BaseModel
from app.api.routers.request.base import userMng, conversations,message,parameter from app.api.routers.request.base import userMng, conversations,message,parameter,feedback
from app.api.routers.request.baseConfig import *
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine from app.engine import get_chat_engine
import uuid import uuid
@@ -24,78 +26,138 @@ api_router = r = APIRouter()
v1_router = v = APIRouter() v1_router = v = APIRouter()
class ChatCallbackEvent(BaseModel): class ChatCallbackEvent(BaseModel):
event_type: CBEventType event_type: ChatEventType
payload: Optional[Dict[str, Any]] = None payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_retrieval_message(self) -> dict | None: def get_common_param(self)-> dict:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"根据查询检索到 {len(nodes)} 源文件"
else:
msg = f"查询检索中: '{self.payload.get('query_str')}'"
return { return {
"type": "events", 'event': self.event_type.name,
"data": {"title": msg}, 'conversation_id':self.payload.get("conversation_id"),
} 'message_id': self.payload.get("message_id"),
else: 'created_at': int(time.time()),
return None 'task_id': self.payload.get("task_id")
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
},
} }
def _is_output_serializable(self, output: Any) -> bool: def get_WorkflowStart_param(self) -> dict:
try: params = self.get_common_param()
json.dumps(output) params.update({
return True 'workflow_run_id':self.payload.get('workflow_run_id'),
except TypeError: 'data':{
return False "id": self.payload.get('workflow_run_id'),
"workflow_id": self.payload.get('workflow_id'),
def get_agent_tool_response(self) -> dict | None: "sequence_number": 1709,
response = self.payload.get("response") "inputs": {
if response is not None: "sys.query": self.payload.get('query'),
sources = response.sources "sys.files": [],
for source in sources: "sys.conversation_id": self.payload.get('conversation_id'),
# Return the tool response here to include the toolCall information "sys.user_id": self.payload.get('use_id')
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
}, },
"created_at": int(time.time())
} }
})
return params
def to_response(self): def get_WorkflowFinished_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('workflow_run_id'),
"workflow_id": self.payload.get('workflow_id'),
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": self.payload.get('response')
},
"error": '',
"elapsed_time": 36.03764106379822,
"total_tokens": 11707,
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": self.payload.get('use_id')
},
"created_at": int(time.time()),
"finished_at": int(time.time()),
"files": []
}
})
return params
def get_NodeStart_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": self.payload.get('title'),
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
"created_at": 1724398751,
"extras": {}
}
})
return params
def get_NodeFinished_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": self.payload.get('title'),
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
"process_data": '',
"outputs": '',
"status": "succeeded",
"error": '',
"elapsed_time": 0.10402441816404462,
"execution_metadata": '',
"created_at": 1724398751,
"finished_at": 1724398751,
"files": []
}
})
return params
def get_Message_param(self) -> dict:
params = self.get_common_param()
params.update({
'id':self.payload.get('message_id'),
'answer':self.payload.get('answer')
})
return params
def get_MessageEnd_param(self) -> dict:
params = self.get_common_param()
params.update({
'id':self.payload.get('message_id'),
'metadata':self.payload.get('metadata')
})
return params
def to_response(self)-> dict|None:
try: try:
match self.event_type: match self.event_type:
case "retrieve": case "workflow_started":
return self.get_retrieval_message() return self.get_WorkflowStart_param()
case "function_call": case "workflow_finished":
return self.get_tool_message() return self.get_WorkflowFinished_param()
case "agent_step": case "node_started":
return self.get_agent_tool_response() return self.get_NodeStart_param()
case 'node_finished':
return self.get_NodeFinished_param()
case 'message':
return self.get_Message_param()
case 'message_end':
return self.get_MessageEnd_param()
case _: case _:
return None return None
except Exception as e: except Exception as e:
@@ -106,9 +168,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue _aqueue: asyncio.Queue
is_done: bool = False is_done: bool = False
def __init__( def __init__(self,**params):
self,
):
"""Initialize the base callback handler.""" """Initialize the base callback handler."""
ignored_events = [ ignored_events = [
# CBEventType.CHUNKING, # CBEventType.CHUNKING,
@@ -119,6 +179,23 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
] ]
super().__init__(ignored_events, ignored_events) super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue() self._aqueue = asyncio.Queue()
self._response:str = ''
self._params:Dict[str,Any] = params
self._nodeStack:deque = deque()
#添加工作流开始事件
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'use_id': data.user,
'query': data.query,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
def on_event_start( def on_event_start(
self, self,
@@ -129,9 +206,21 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
) -> str: ) -> str:
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload)) logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload) self._nodeStack.append(event_id)
if event.to_response() is not None: nindex = self._nodeStack.count() - 1
self._aqueue.put_nowait(event) args:Dict[str,Any] = self._params['ids']
args.update(
{
'nodeid':event_id,
'title':event_type.name,
'index':nindex + 1,
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
}
)
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args)
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
def on_event_end( def on_event_end(
self, self,
@@ -141,9 +230,25 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload)) logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None: #self.response = payload.get("response","")
self._aqueue.put_nowait(event) args:Dict[str,Any] = self._params['ids']
nodeID = self._nodeStack[-1]
if nodeID == event_id:
nindex = self._nodeStack.count() - 1
args.update(
{
'nodeid':event_id,
'title':event_type.name,
'index':nindex + 1,
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
}
)
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args)
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
self._nodeStack.pop()
def start_trace(self, trace_id: Optional[str] = None) -> None: def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op.""" """No-op."""
@@ -156,6 +261,23 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""No-op.""" """No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map)) logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'response':self._response,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
args:Dict[str,Any] = self._params['ids']
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event)
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]: async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done: while not self._aqueue.empty() or not self.is_done:
@@ -173,95 +295,26 @@ class IDManager:
"workflow_id": str(uuid.uuid4()) "workflow_id": str(uuid.uuid4())
} }
class DifyChatResponseEvent(BaseModel):
event: str
conversation_id: str
message_id: str
created_at: int = int(time.time())
task_id: str
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"inputs": {
"sys.query": args['query'],
"sys.files": [],
"sys.conversation_id": args['conversation_id'],
"sys.user_id": args['use_id']
},
"created_at": int(time.time())
}
args['event'] = 'workflow_started'
super().__init__(**args)
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['event'] = 'workflow_finished'
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": args['response']
},
"error": '',
"elapsed_time": 36.03764106379822,
"total_tokens": 11707,
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": args['use_id']
},
"created_at": int(time.time()),
"finished_at": int(time.time()),
"files": []
}
super().__init__(**args)
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
answer:str
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message'
super().__init__(**args)
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
metadata:Dict[str,Any] = {}
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message_end'
super().__init__(**args)
class ChatStreamResponse(StreamingResponse): class ChatStreamResponse(StreamingResponse):
TEXT_PREFIX = "data: " TEXT_PREFIX = "data: "
DATA_PREFIX = "data: " DATA_PREFIX = "data: "
ids:Dict[str,Any] = {}
data:ChatRequestData = None
@classmethod @classmethod
def convert_text(cls, token: str): def convert_Message(cls, token: str):
# Escape newlines and double quotes to avoid breaking the stream params = cls.ids
#token = json.dumps(token) params.update({
'answer':token,
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}" 'conversation_id':cls.data.conversation_id
return "\n" })
event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params)
@classmethod data_str = json.dumps(event.to_response())
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n\n" return f"{cls.DATA_PREFIX}{data_str}\n\n"
@classmethod @classmethod
def convert_event(cls, event: DifyChatResponseEvent): def convert_Event(cls, data: dict):
data_str = json.dumps(event.dict()) data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n\n" return f"{cls.DATA_PREFIX}{data_str}\n\n"
def __init__( def __init__(
@@ -269,8 +322,11 @@ class ChatStreamResponse(StreamingResponse):
request: Request, request: Request,
event_handler: ChatEventCallbackHandler, event_handler: ChatEventCallbackHandler,
response: StreamingAgentChatResponse, response: StreamingAgentChatResponse,
data: ChatRequestData data: ChatRequestData,
ids:Dict[str,Any]
): ):
ChatStreamResponse.ids = ids
ChatStreamResponse.data = data
content = ChatStreamResponse.content_generator( content = ChatStreamResponse.content_generator(
request, event_handler, response, data request, event_handler, response, data
) )
@@ -284,41 +340,26 @@ class ChatStreamResponse(StreamingResponse):
response: StreamingAgentChatResponse, response: StreamingAgentChatResponse,
data: ChatRequestData data: ChatRequestData
): ):
ids = IDManager().createID()
# Yield the text response # Yield the text response
async def _chat_response_generator(): async def _chat_response_generator():
final_response = "" final_response = ""
async for token in response.async_response_gen(): async for token in response.async_response_gen():
final_response += token final_response += token
args = ids yield ChatStreamResponse.convert_Message(token)
args['answer'] = token
args['conversation_id'] = data.conversation_id
event = Message_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(event)
#yield ChatStreamResponse.convert_text(token)
# 存储消息历史 # 存储消息历史
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response) message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
# the text_generator is the leading stream, once it's finished, also finish the event stream # the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True event_handler.is_done = True
# 发送工作流结束事件
args = ids
args['response'] = final_response
args['conversation_id'] = data.conversation_id
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids)
yield ChatStreamResponse.convert_event(msgEnt_event)
# Yield the events from the event handler # Yield the events from the event handler
async def _event_generator(): async def _event_generator():
async for event in event_handler.async_event_gen(): async for event in event_handler.async_event_gen():
event_response = event.to_response() event_response = event.to_response()
if event_response is not None: if event_response is not None:
yield ChatStreamResponse.convert_text("") yield ChatStreamResponse.convert_Event(event_response)
combine = stream.merge(_chat_response_generator(), _event_generator()) combine = stream.merge(_chat_response_generator(), _event_generator())
is_stream_started = False is_stream_started = False
@@ -327,25 +368,11 @@ class ChatStreamResponse(StreamingResponse):
if not is_stream_started: if not is_stream_started:
is_stream_started = True is_stream_started = True
# 发送工作流开始事件
args = ids
args['use_id'] = data.user
args['query'] = data.query
args['conversation_id'] = data.conversation_id
wf_event = Workflow_started_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
# Stream a blank message to start the stream
# 发送一个空消息事件
#yield ChatStreamResponse.convert_text("")
yield output yield output
if await request.is_disconnected(): if await request.is_disconnected():
break break
@v.post("/chat-messages") @v.post("/chat-messages")
async def post_conversations(request: Request, data: ChatRequestData): async def post_conversations(request: Request, data: ChatRequestData):
userMng.findNoExistCreate(data.user) userMng.findNoExistCreate(data.user)
@@ -365,14 +392,15 @@ async def post_conversations(request: Request, data: ChatRequestData):
chat_engine = get_chat_engine(filters=filters, params=params) chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听 # 启动聊天事件监听
event_handler = ChatEventCallbackHandler() ids = IDManager().createID()
event_handler = ChatEventCallbackHandler(ids = ids,data = data)
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
# 执行异步聊天 # 执行异步聊天
response = await chat_engine.astream_chat(data.query) response = await chat_engine.astream_chat(data.query)
# 返回异步消息回应 # 返回异步消息回应
return ChatStreamResponse(request, event_handler, response, data) return ChatStreamResponse(request, event_handler, response, data,ids)
@v.get("/messages") @v.get("/messages")
async def query_messages(user:str, conversation_id:str): async def query_messages(user:str, conversation_id:str):
@@ -388,8 +416,9 @@ async def query_messages(user:str, conversation_id:str):
for record in records: for record in records:
res = record.dict() res = record.dict()
feeds = feedback().query(res['id'])
res["message_files"] = [] res["message_files"] = []
res["feedback"] = '' res["feedback"] = {'rating':feeds['rating'] } if feeds != None else ''
res["retriever_resources"] = [] res["retriever_resources"] = []
res["created_at"] = 1723444905 res["created_at"] = 1723444905
res["agent_thoughts"] = [] res["agent_thoughts"] = []
@@ -440,48 +469,22 @@ async def query_conversations(user:str, first_id:str = None, limit:str = None, p
async def query_parameters(user:str): async def query_parameters(user:str):
params = parameter().get(user) params = parameter().get(user)
if len(params) == 0: if len(params) == 0:
params = { params = BaseConfig().ParamterCfg()
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [],
"suggested_questions_after_answer": {
"enabled": False
},
"speech_to_text": {
"enabled": False
},
"text_to_speech": {
"enabled": False,
"language": "",
"voice": ""
},
"retriever_resource": {
"enabled": True
},
"annotation_reply": {
"enabled": False
},
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"sensitive_word_avoidance": {
"enabled": False
},
"file_upload": {
"image": {
"enabled": False,
"number_limits": 3,
"transfer_methods": [
"remote_url"
]
}
},
"system_parameters": {
"image_file_size_limit": "10"
}
}
return params return params
@v.post("/messages/{message_id}/feedbacks")
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
if params['rating'] =='null':
feedback().delete(message_id)
else:
condition = {'id':message_id}
results = message().query(**condition)
if len(results) > 0:
result = results[0]
feedback().add(message_id=message_id,query=result['query'],
answer=result['answer'],rating=params['rating'])
@r.post("") @r.post("")
def upload_file(request: ChatFileUploadRequest) -> List[str]: def upload_file(request: ChatFileUploadRequest) -> List[str]:
pass pass
+32 -2
View File
@@ -25,7 +25,7 @@ class conversations:
return None return None
def add(self,id:str, user_id:str, name:str): def add(self,id:str, user_id:str, name:str):
template = BaseConfig.ConversationCfg template = BaseConfig().ConversationCfg()
template['id'] = id template['id'] = id
template['user_id'] = user_id template['user_id'] = user_id
template['name'] = name template['name'] = name
@@ -111,7 +111,7 @@ class message:
return datas return datas
def add(self,user_id:str,conversation_id:str,query:str,answer:str): def add(self,user_id:str,conversation_id:str,query:str,answer:str):
template = BaseConfig.MessageCfg template = BaseConfig.MessageCfg()
template['id'] = str(uuid.uuid4()) template['id'] = str(uuid.uuid4())
template['user_id'] = user_id template['user_id'] = user_id
template['conversation_id'] = conversation_id template['conversation_id'] = conversation_id
@@ -122,4 +122,34 @@ class message:
def delete(self,user_id:str): def delete(self,user_id:str):
dbManage.delete(self._tableName,user_id = user_id) dbManage.delete(self._tableName,user_id = user_id)
def query(self,**condition):
results = []
records = dbManage.query(self._tableName,**condition)
for record in records:
results.append(record.dict())
return results
class feedback:
def __init__(self) -> None:
self._tableName = 'feedbacks'
dbManage.createTable(self._tableName)
def add(self,message_id:str,query:str,answer:str,rating:str):
record = {
'message_id': message_id,
'query': query,
'answer': answer,
'rating': rating,
}
dbManage.addRecord(self._tableName,record)
def delete(self,message_id:str):
cond = {'message_id':message_id}
dbManage.delete(self._tableName,**cond)
def query(self,message_id:str):
cond = {'message_id':message_id}
records = dbManage.query(self._tableName,**cond)
if len(records) > 0:
return records[0].dict()
return None
+26 -8
View File
@@ -1,8 +1,15 @@
from pydantic import BaseModel
import os
from enum import Enum
class BaseConfig: class BaseConfig(BaseModel):
ParamterCfg = { projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [], def ParamterCfg(self):
questions = os.getenv("CONVERSATION_STARTERS", "dev")
return{
"opening_statement": self.projectInfo,
"suggested_questions": questions.split('\n'),
"suggested_questions_after_answer": { "suggested_questions_after_answer": {
"enabled": False "enabled": False
}, },
@@ -41,18 +48,20 @@ class BaseConfig:
} }
} }
ConversationCfg = { def ConversationCfg(self):
return{
"id": "", "id": "",
'user_id':'', 'user_id':'',
"name": "", "name": "",
"inputs": {}, "inputs": {},
"status": "normal", "status": "normal",
"introduction": ParamterCfg['opening_statement'], "introduction": self.projectInfo,
"created_at":'' "created_at":''
} }
@classmethod
MessageCfg = { def MessageCfg(cls):
return {
"id": "", "id": "",
'user_id':'', 'user_id':'',
"conversation_id": "", "conversation_id": "",
@@ -60,3 +69,12 @@ class BaseConfig:
"query": "", "query": "",
"answer": "" "answer": ""
} }
class ChatEventType(str, Enum):
WORKFLOW_START = "workflow_started"
WORKFLOW_FINISHED = "workflow_finished"
NODE_START = "node_started"
NODE_FINISHED = "node_finished"
MESSAGE = "message"
MESSAGE_END = "message_end"
+22 -9
View File
@@ -2,7 +2,7 @@ import os
from typing import Dict, List, Any from typing import Dict, List, Any
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import create_engine, Column, String, Integer, JSON from sqlalchemy import create_engine, Column, String, Integer, JSON,Float
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.orm import sessionmaker, declarative_base
@@ -24,10 +24,6 @@ class ConversationOrm(Base):
if 'name' in data: if 'name' in data:
self.name = data['name'] self.name = data['name']
class UserOrm(Base): class UserOrm(Base):
__tablename__ = "user" __tablename__ = "user"
@@ -51,6 +47,14 @@ class MessagesOrm(Base):
query = Column(String) query = Column(String)
answer = Column(String) answer = Column(String)
class FeedBackOrm(Base):
__tablename__ = "feedbacks"
message_id = Column(String,primary_key=True)
query = Column(String)
answer = Column(String)
rating = Column(String)
#数据结构 #数据结构
class ConversationModel(BaseModel): class ConversationModel(BaseModel):
id: str id: str
@@ -61,7 +65,6 @@ class ConversationModel(BaseModel):
created_at: int created_at: int
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
@@ -73,7 +76,6 @@ class UserModel(BaseModel):
createtime: str createtime: str
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
@@ -86,7 +88,6 @@ class ParametersModel(BaseModel):
value : Dict[str, Any] value : Dict[str, Any]
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
@@ -101,13 +102,25 @@ class MessagesModel(BaseModel):
answer : str answer : str
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
def orm(cls): def orm(cls):
return MessagesOrm return MessagesOrm
class FeedBackModel(BaseModel):
message_id :str
query :str
answer :str
rating :str
class Config:
from_attributes=True
@classmethod
def orm(cls):
return FeedBackOrm
class DBManager: class DBManager:
def __init__(self) -> None: def __init__(self) -> None:
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL") DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
+3 -1
View File
@@ -1,7 +1,7 @@
from typing import Dict, Any from typing import Dict, Any
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional
class ChatRequestData(BaseModel): class ChatRequestData(BaseModel):
inputs: Dict[str,Any] inputs: Dict[str,Any]
@@ -13,3 +13,5 @@ class ChatRequestData(BaseModel):
class ChatFileUploadRequest(BaseModel): class ChatFileUploadRequest(BaseModel):
base64: str base64: str