Files
zjdataai-app/backend/app/api/routers/request/dbOrm.py
T
2024-09-02 08:49:32 +08:00

240 lines
5.9 KiB
Python

import os
from typing import Dict, List, Any
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, String, Integer, JSON,Float
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import sessionmaker, declarative_base
Base = declarative_base()
#orm类
class ConversationOrm(Base):
__tablename__ = "conversations"
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(String)
inputs = Column(JSON)
status = Column(String)
introduction = Column(String)
created_at = Column(Integer)
def update(self,data:Dict[str,Any]):
if 'name' in data:
self.name = data['name']
class UserOrm(Base):
__tablename__ = "user"
id = Column(String, primary_key=True)
createtime = Column(String)
class ParametersOrm(Base):
__tablename__ = "parameters"
user_id = Column(String,primary_key=True)
name = Column(String)
value = Column(JSON)
class MessagesOrm(Base):
__tablename__ = "messages"
id = Column(String,primary_key=True)
user_id = Column(String)
conversation_id = Column(String)
inputs = Column(JSON)
query = Column(String)
answer = Column(String)
class FeedBackOrm(Base):
__tablename__ = "feedbacks"
message_id = Column(String,primary_key=True)
query = Column(String)
answer = Column(String)
rating = Column(String)
class ProjectInfoOrm(Base):
__tablename__ = "projectInfos"
prjFlag = Column(String,primary_key=True)
prjectName = Column(String)
#数据结构
class ConversationModel(BaseModel):
id: str
name: str
inputs: Dict[str, Any]
status: str
introduction: str
created_at: int
class Config:
from_attributes=True
@classmethod
def orm(cls):
return ConversationOrm
class UserModel(BaseModel):
id: str
createtime: str
class Config:
from_attributes=True
@classmethod
def orm(cls):
return UserOrm
class ParametersModel(BaseModel):
user_id : str
name : str
value : Dict[str, Any]
class Config:
from_attributes=True
@classmethod
def orm(cls):
return ParametersOrm
class MessagesModel(BaseModel):
id :str
conversation_id :str
inputs : Dict[str, Any]
query : str
answer : str
class Config:
from_attributes=True
@classmethod
def orm(cls):
return MessagesOrm
class FeedBackModel(BaseModel):
message_id :str
query :str
answer :str
rating :str
class Config:
from_attributes=True
@classmethod
def orm(cls):
return FeedBackOrm
class ProjectInfoModel(BaseModel):
prjectName:str
prjFlag:str
class Config:
from_attributes=True
@classmethod
def orm(cls):
return ProjectInfoOrm
class DBManager:
def __init__(self) -> None:
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
self._engine = create_engine(DATABASE_URL)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
def createTable(self,tableName:str):
if self._engine is None:
return
if not self.exist(tableName):
Base.metadata.tables[tableName].create(self._engine)
def addRecord(self,tableName:str,record:Dict[str,Any]):
ormCls = self._get_orm(tableName)
if ormCls is None:
return
session = self.SessionLocal()
data = ormCls(**record)
session.add(data)
session.commit()
def addRecords(self,tableName:str,records:List[Dict[str,Any]]):
ormCls = self._get_orm(tableName)
if ormCls is None:
return
datas = []
session = self.SessionLocal()
for record in records:
datas.append(ormCls(**record))
session.add(datas)
session.commit()
def delete(self,tableName:str,**filter):
session = self.SessionLocal()
ormCls = self._get_orm(tableName)
if ormCls is None:
return
records = session.query(ormCls).filter_by(**filter).all()
if records is not None:
for record in records:
session.delete(record)
session.commit()
def update(self,tableName:str,data:Dict[str,Any],**filter):
if not self.exist(tableName):
return
session = self.SessionLocal()
ormCls = self._get_orm(tableName)
if ormCls is None:
return
if len(filter) > 0:
records = session.query(ormCls).filter_by(**filter).all()
else:
records = session.query(ormCls).all()
for record in records:
if record is not None:
record.update(data)
session.commit()
def query(self,tableName:str,**filter):
session = self.SessionLocal()
ormCls = self._get_orm(tableName)
if ormCls is None:
return
modelCls = self._get_model(ormCls)
if modelCls is None:
return
if filter is not None:
records = session.query(ormCls).filter_by(**filter).all()
else:
records = session.query(ormCls).all()
datas = []
for record in records:
datas.append(modelCls.from_orm(record))
return datas
def exist(self,tableName:str)->bool:
if self._engine is None:
return
inspector = Inspector.from_engine(self._engine)
return inspector.has_table(tableName)
def _get_orm(self,tableName:str):
subClss = Base.__subclasses__()
for sunCls in subClss:
if sunCls.__tablename__ == tableName:
return sunCls
return None
def _get_model(self,orm:Any):
subClss = BaseModel.__subclasses__()
for sunCls in subClss:
if 'orm' in sunCls.__dict__ and sunCls.orm() == orm:
return sunCls
return None