239 lines
5.9 KiB
Python
239 lines
5.9 KiB
Python
import os
|
|
from typing import Dict, List, Any
|
|
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import create_engine, Column, String, Integer, JSON,Float
|
|
from sqlalchemy.engine.reflection import Inspector
|
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
|
|
|
Base = declarative_base()
|
|
|
|
#orm类
|
|
class ConversationOrm(Base):
|
|
__tablename__ = "conversations"
|
|
|
|
id = Column(String, primary_key=True)
|
|
user_id = Column(String)
|
|
name = Column(String)
|
|
inputs = Column(JSON)
|
|
status = Column(String)
|
|
introduction = Column(String)
|
|
created_at = Column(Integer)
|
|
|
|
def update(self,data:Dict[str,Any]):
|
|
if 'name' in data:
|
|
self.name = data['name']
|
|
|
|
class UserOrm(Base):
|
|
__tablename__ = "user"
|
|
|
|
id = Column(String, primary_key=True)
|
|
createtime = Column(String)
|
|
|
|
class ParametersOrm(Base):
|
|
__tablename__ = "parameters"
|
|
|
|
user_id = Column(String,primary_key=True)
|
|
name = Column(String)
|
|
value = Column(JSON)
|
|
|
|
class MessagesOrm(Base):
|
|
__tablename__ = "messages"
|
|
|
|
id = Column(String,primary_key=True)
|
|
user_id = Column(String)
|
|
conversation_id = Column(String)
|
|
inputs = Column(JSON)
|
|
query = Column(String)
|
|
answer = Column(String)
|
|
|
|
class FeedBackOrm(Base):
|
|
__tablename__ = "feedbacks"
|
|
|
|
message_id = Column(String,primary_key=True)
|
|
query = Column(String)
|
|
answer = Column(String)
|
|
rating = Column(String)
|
|
|
|
class ProjectInfoOrm(Base):
|
|
__tablename__ = "projectInfos"
|
|
|
|
prjFlag = Column(String,primary_key=True)
|
|
prjectName = Column(String)
|
|
|
|
|
|
#数据结构
|
|
class ConversationModel(BaseModel):
|
|
id: str
|
|
name: str
|
|
inputs: Dict[str, Any]
|
|
status: str
|
|
introduction: str
|
|
created_at: int
|
|
|
|
class Config:
|
|
from_attributes=True
|
|
|
|
@classmethod
|
|
def orm(cls):
|
|
return ConversationOrm
|
|
|
|
class UserModel(BaseModel):
|
|
id: str
|
|
createtime: str
|
|
|
|
class Config:
|
|
from_attributes=True
|
|
|
|
@classmethod
|
|
def orm(cls):
|
|
return UserOrm
|
|
|
|
class ParametersModel(BaseModel):
|
|
user_id : str
|
|
name : str
|
|
value : Dict[str, Any]
|
|
|
|
class Config:
|
|
from_attributes=True
|
|
|
|
@classmethod
|
|
def orm(cls):
|
|
return ParametersOrm
|
|
|
|
class MessagesModel(BaseModel):
|
|
id :str
|
|
conversation_id :str
|
|
inputs : Dict[str, Any]
|
|
query : str
|
|
answer : str
|
|
|
|
class Config:
|
|
from_attributes=True
|
|
|
|
@classmethod
|
|
def orm(cls):
|
|
return MessagesOrm
|
|
|
|
class FeedBackModel(BaseModel):
|
|
message_id :str
|
|
query :str
|
|
answer :str
|
|
rating :str
|
|
|
|
class Config:
|
|
from_attributes=True
|
|
|
|
@classmethod
|
|
def orm(cls):
|
|
return FeedBackOrm
|
|
|
|
class ProjectInfoModel(BaseModel):
|
|
prjectName:str
|
|
prjFlag:str
|
|
|
|
class Config:
|
|
from_attributes=True
|
|
|
|
@classmethod
|
|
def orm(cls):
|
|
return ProjectInfoOrm
|
|
|
|
class DBManager:
|
|
def __init__(self) -> None:
|
|
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
|
|
self._engine = create_engine(DATABASE_URL)
|
|
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
|
|
|
|
def createTable(self,tableName:str):
|
|
if self._engine is None:
|
|
return
|
|
if not self.exist(tableName):
|
|
Base.metadata.tables[tableName].create(self._engine)
|
|
|
|
def addRecord(self,tableName:str,record:Dict[str,Any]):
|
|
ormCls = self._get_orm(tableName)
|
|
if ormCls is None:
|
|
return
|
|
session = self.SessionLocal()
|
|
data = ormCls(**record)
|
|
session.add(data)
|
|
session.commit()
|
|
|
|
def addRecords(self,tableName:str,records:List[Dict[str,Any]]):
|
|
ormCls = self._get_orm(tableName)
|
|
if ormCls is None:
|
|
return
|
|
datas = []
|
|
session = self.SessionLocal()
|
|
for record in records:
|
|
datas.append(ormCls(**record))
|
|
session.add(datas)
|
|
session.commit()
|
|
|
|
def delete(self,tableName:str,**filter):
|
|
session = self.SessionLocal()
|
|
ormCls = self._get_orm(tableName)
|
|
if ormCls is None:
|
|
return
|
|
records = session.query(ormCls).filter_by(**filter).all()
|
|
if records is not None:
|
|
session.delete(records)
|
|
session.commit()
|
|
|
|
def update(self,tableName:str,data:Dict[str,Any],**filter):
|
|
if not self.exist(tableName):
|
|
return
|
|
session = self.SessionLocal()
|
|
ormCls = self._get_orm(tableName)
|
|
if ormCls is None:
|
|
return
|
|
if len(filter) > 0:
|
|
records = session.query(ormCls).filter_by(**filter).all()
|
|
else:
|
|
records = session.query(ormCls).all()
|
|
for record in records:
|
|
if record is not None:
|
|
record.update(data)
|
|
session.commit()
|
|
|
|
def query(self,tableName:str,**filter):
|
|
session = self.SessionLocal()
|
|
ormCls = self._get_orm(tableName)
|
|
if ormCls is None:
|
|
return
|
|
modelCls = self._get_model(ormCls)
|
|
if modelCls is None:
|
|
return
|
|
|
|
if filter is not None:
|
|
records = session.query(ormCls).filter_by(**filter).all()
|
|
else:
|
|
records = session.query(ormCls).all()
|
|
|
|
datas = []
|
|
for record in records:
|
|
datas.append(modelCls.from_orm(record))
|
|
return datas
|
|
|
|
def exist(self,tableName:str)->bool:
|
|
if self._engine is None:
|
|
return
|
|
inspector = Inspector.from_engine(self._engine)
|
|
return inspector.has_table(tableName)
|
|
|
|
def _get_orm(self,tableName:str):
|
|
subClss = Base.__subclasses__()
|
|
for sunCls in subClss:
|
|
if sunCls.__tablename__ == tableName:
|
|
return sunCls
|
|
return None
|
|
|
|
def _get_model(self,orm:Any):
|
|
subClss = BaseModel.__subclasses__()
|
|
for sunCls in subClss:
|
|
if 'orm' in sunCls.__dict__ and sunCls.orm() == orm:
|
|
return sunCls
|
|
return None
|
|
|