This commit is contained in:
chentianrui
2024-08-28 18:12:37 +08:00
4 changed files with 107 additions and 30 deletions
+75 -6
View File
@@ -12,15 +12,14 @@ from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput from llama_index.core.tools import ToolOutput
from pydantic import BaseModel from pydantic import BaseModel
from app.api.routers.request.base import userMng, conversations,message,parameter
from app.api.routers.events import EventCallbackHandler from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.api.routers.request.base import userMng, conversations,message
from app.api.routers.request.models import ChatRequestData
from app.engine import get_chat_engine from app.engine import get_chat_engine
import uuid import uuid
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
api_router = r = APIRouter()
v1_router = v = APIRouter() v1_router = v = APIRouter()
default_conversation_id = '82e8417f-2c3b-4bb5-ab22-2ad318bbd29a' default_conversation_id = '82e8417f-2c3b-4bb5-ab22-2ad318bbd29a'
@@ -341,6 +340,8 @@ class ChatStreamResponse(StreamingResponse):
if await request.is_disconnected(): if await request.is_disconnected():
break break
@v.post("/chat-messages") @v.post("/chat-messages")
async def post_conversations(request: Request, data: ChatRequestData): async def post_conversations(request: Request, data: ChatRequestData):
userMng.findNoExistCreate(data.user) userMng.findNoExistCreate(data.user)
@@ -392,8 +393,26 @@ async def query_messages(user:str, conversation_id:str):
} }
@v.post("/conversations/{itemid}/name") @v.post("/conversations/{itemid}/name")
async def post_conversations(user:str): async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
pass consaObj = conversations()
consaObj.rename(itemid,'知识问答')
cond = {
'id':itemid,
'user_id':params['user']
}
results = consaObj.query(**cond)
if len(results) > 0:
res = results[0]
return {
"id": res['id'],
"name": res['name'],
"inputs": res['inputs'],
"status": res['status'],
"introduction": res['introduction'],
"created_at": res['created_at'],
#"工程位置"
}
return 'null'
@v.get("/conversations") @v.get("/conversations")
async def query_conversations(user:str): async def query_conversations(user:str):
@@ -405,3 +424,53 @@ async def query_conversations(user:str):
"has_more": False, "has_more": False,
"data": conversations().gets(user_id) "data": conversations().gets(user_id)
} }
@v.get("/parameters")
async def query_parameters(user:str):
params = parameter().get(user)
if len(params) == 0:
params = {
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [],
"suggested_questions_after_answer": {
"enabled": False
},
"speech_to_text": {
"enabled": False
},
"text_to_speech": {
"enabled": False,
"language": "",
"voice": ""
},
"retriever_resource": {
"enabled": True
},
"annotation_reply": {
"enabled": False
},
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"sensitive_word_avoidance": {
"enabled": False
},
"file_upload": {
"image": {
"enabled": False,
"number_limits": 3,
"transfer_methods": [
"remote_url"
]
}
},
"system_parameters": {
"image_file_size_limit": "10"
}
}
return params
@r.post("")
def upload_file(request: ChatFileUploadRequest) -> List[str]:
pass
+10 -18
View File
@@ -35,10 +35,17 @@ class conversations:
def delete(self,id:str): def delete(self,id:str):
dbManage.delete(self._tableName,id=id) dbManage.delete(self._tableName,id=id)
def rename(self,id:str): def rename(self,id:str,name:str):
data = {'name':''} data = {'name':name}
dbManage.update(self._tableName,data,id=id) dbManage.update(self._tableName,data,id=id)
def query(self,**condition):
results = []
records = dbManage.query(self._tableName,**condition)
for record in records:
results.append(record.dict())
return results
class user: class user:
def __init__(self) -> None: def __init__(self) -> None:
self._tableName = 'user' self._tableName = 'user'
@@ -83,22 +90,7 @@ class parameter:
key = record['name'] key = record['name']
value = record['value'] value = record['value']
data[key] = value data[key] = value
return data
return {
'opening_statement':data['opening_statement'],
'suggested_questions':data['suggested_questions'],
'suggested_questions_after_answer':data['suggested_questions_after_answer'],
'speech_to_text':data['speech_to_text'],
'text_to_speech':data['text_to_speech'],
'retriever_resource':data['retriever_resource'],
'annotation_reply':data['annotation_reply'],
'more_like_this':data['more_like_this'],
'user_input_form':data['user_input_form'],
'sensitive_word_avoidance':data['sensitive_word_avoidance'],
'file_upload':data['file_upload'],
'system_parameters':data['system_parameters'],
'opening_statement':data['opening_statement'],
}
def set(self,user_id:str): def set(self,user_id:str):
dbManage.addRecord(self._tableName,{}) dbManage.addRecord(self._tableName,{})
+18 -4
View File
@@ -20,6 +20,14 @@ class ConversationOrm(Base):
introduction = Column(String) introduction = Column(String)
created_at = Column(Integer) created_at = Column(Integer)
def update(self,data:Dict[str,Any]):
if 'name' in data:
self.name = data['name']
class UserOrm(Base): class UserOrm(Base):
__tablename__ = "user" __tablename__ = "user"
@@ -143,14 +151,20 @@ class DBManager:
session.commit() session.commit()
def update(self,tableName:str,data:Dict[str,Any],**filter): def update(self,tableName:str,data:Dict[str,Any],**filter):
if not self.exist(tableName):
return
session = self.SessionLocal() session = self.SessionLocal()
ormCls = self._get_orm(tableName) ormCls = self._get_orm(tableName)
if ormCls is None: if ormCls is None:
return return
record = session.query(ormCls).filter_by(**filter).first() if len(filter) > 0:
if record is not None: records = session.query(ormCls).filter_by(**filter).all()
record.update(data) else:
session.commit() records = session.query(ormCls).all()
for record in records:
if record is not None:
record.update(data)
session.commit()
def query(self,tableName:str,**filter): def query(self,tableName:str,**filter):
session = self.SessionLocal() session = self.SessionLocal()
+3 -1
View File
@@ -1,6 +1,5 @@
from typing import Dict, Any from typing import Dict, Any
from pydantic import BaseModel from pydantic import BaseModel
@@ -11,3 +10,6 @@ class ChatRequestData(BaseModel):
response_mode: str response_mode: str
files: Any files: Any
conversation_id: str = None conversation_id: str = None
class ChatFileUploadRequest(BaseModel):
base64: str