新增工程信息、检索的知识片段节点回传、下一轮建议问题列表

This commit is contained in:
wanyaokun
2024-08-30 16:42:36 +08:00
parent 73565b26e4
commit e7628809ad
11 changed files with 411 additions and 134 deletions
+34 -3
View File
@@ -2,7 +2,7 @@ from datetime import datetime
import uuid
from app.api.routers.request.baseConfig import BaseConfig
from app.api.routers.request.dbOrm import DBManager
from typing import List
dbManage = DBManager()
class conversations:
@@ -122,8 +122,9 @@ class message:
def delete(self,user_id:str):
dbManage.delete(self._tableName,user_id = user_id)
def query(self,**condition):
def query(self,id:str):
results = []
condition = {'id':id}
records = dbManage.query(self._tableName,**condition)
for record in records:
results.append(record.dict())
@@ -152,4 +153,34 @@ class feedback:
records = dbManage.query(self._tableName,**cond)
if len(records) > 0:
return records[0].dict()
return None
return None
class ProjectInfo:
def __init__(self) -> None:
self._tableName = 'projectInfos'
dbManage.createTable(self._tableName)
def add(self,name:str,flag:str):
record = {
'prjectName': name,
'prjFlag': flag
}
dbManage.addRecord(self._tableName,record)
def projectNames(self)->List[str]:
records = dbManage.query(self._tableName)
names = []
for record in records:
data:dict = record.dict()
name = data.get('prjectName')
if name !='':
names.append(name)
return names
def prjFalg(self,name:str):
records = dbManage.query(self._tableName)
for record in records:
data:dict = record.dict()
if data.get('prjectName') == name:
return data['prjFlag']
return ''
+14 -2
View File
@@ -5,7 +5,8 @@ from enum import Enum
class BaseConfig(BaseModel):
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
def ParamterCfg(self):
def ParamterCfg(self,**args):
projectInfo = args.get('projectInfo')
questions = os.getenv("CONVERSATION_STARTERS", "dev")
return{
"opening_statement": self.projectInfo,
@@ -30,7 +31,18 @@ class BaseConfig(BaseModel):
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"user_input_form": [
{
"select": {
"variable": "projectname",
"label": "\u5de5\u7a0b\u540d\u79f0",
"type": "select",
"max_length": 48,
"required": True,
"options": [projectInfo]
}
}
],
"sensitive_word_avoidance": {
"enabled": False
},
+18
View File
@@ -55,6 +55,13 @@ class FeedBackOrm(Base):
answer = Column(String)
rating = Column(String)
class ProjectInfoOrm(Base):
__tablename__ = "projectInfos"
prjFlag = Column(String,primary_key=True)
prjectName = Column(String)
#数据结构
class ConversationModel(BaseModel):
id: str
@@ -121,6 +128,17 @@ class FeedBackModel(BaseModel):
def orm(cls):
return FeedBackOrm
class ProjectInfoModel(BaseModel):
prjectName:str
prjFlag:str
class Config:
from_attributes=True
@classmethod
def orm(cls):
return ProjectInfoOrm
class DBManager:
def __init__(self) -> None:
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
@@ -10,7 +10,6 @@ class ChatRequestData(BaseModel):
response_mode: str
files: Any
conversation_id: str = None
prjFlag:Optional[str] = ''
class ChatFileUploadRequest(BaseModel):
base64: str