This commit is contained in:
2024-08-28 17:41:52 +08:00
8 changed files with 137 additions and 32 deletions
+75 -6
View File
@@ -12,15 +12,14 @@ from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput
from pydantic import BaseModel
from app.api.routers.events import EventCallbackHandler
from app.api.routers.request.base import userMng, conversations,message
from app.api.routers.request.models import ChatRequestData
from app.api.routers.request.base import userMng, conversations,message,parameter
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine
import uuid
logger = logging.getLogger("uvicorn")
api_router = r = APIRouter()
v1_router = v = APIRouter()
default_conversation_id = '82e8417f-2c3b-4bb5-ab22-2ad318bbd29a'
@@ -341,6 +340,8 @@ class ChatStreamResponse(StreamingResponse):
if await request.is_disconnected():
break
@v.post("/chat-messages")
async def post_conversations(request: Request, data: ChatRequestData):
userMng.findNoExistCreate(data.user)
@@ -392,8 +393,26 @@ async def query_messages(user:str, conversation_id:str):
}
@v.post("/conversations/{itemid}/name")
async def post_conversations(user:str):
pass
async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
consaObj = conversations()
consaObj.rename(itemid,'知识问答')
cond = {
'id':itemid,
'user_id':params['user']
}
results = consaObj.query(**cond)
if len(results) > 0:
res = results[0]
return {
"id": res['id'],
"name": res['name'],
"inputs": res['inputs'],
"status": res['status'],
"introduction": res['introduction'],
"created_at": res['created_at'],
#"工程位置"
}
return 'null'
@v.get("/conversations")
async def query_conversations(user:str):
@@ -405,3 +424,53 @@ async def query_conversations(user:str):
"has_more": False,
"data": conversations().gets(user_id)
}
@v.get("/parameters")
async def query_parameters(user:str):
params = parameter().get(user)
if len(params) == 0:
params = {
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [],
"suggested_questions_after_answer": {
"enabled": False
},
"speech_to_text": {
"enabled": False
},
"text_to_speech": {
"enabled": False,
"language": "",
"voice": ""
},
"retriever_resource": {
"enabled": True
},
"annotation_reply": {
"enabled": False
},
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"sensitive_word_avoidance": {
"enabled": False
},
"file_upload": {
"image": {
"enabled": False,
"number_limits": 3,
"transfer_methods": [
"remote_url"
]
}
},
"system_parameters": {
"image_file_size_limit": "10"
}
}
return params
@r.post("")
def upload_file(request: ChatFileUploadRequest) -> List[str]:
pass
+11 -19
View File
@@ -35,10 +35,17 @@ class conversations:
def delete(self,id:str):
dbManage.delete(self._tableName,id=id)
def rename(self,id:str):
data = {'name':''}
def rename(self,id:str,name:str):
data = {'name':name}
dbManage.update(self._tableName,data,id=id)
def query(self,**condition):
results = []
records = dbManage.query(self._tableName,**condition)
for record in records:
results.append(record.dict())
return results
class user:
def __init__(self) -> None:
self._tableName = 'user'
@@ -83,23 +90,8 @@ class parameter:
key = record['name']
value = record['value']
data[key] = value
return {
'opening_statement':data['opening_statement'],
'suggested_questions':data['suggested_questions'],
'suggested_questions_after_answer':data['suggested_questions_after_answer'],
'speech_to_text':data['speech_to_text'],
'text_to_speech':data['text_to_speech'],
'retriever_resource':data['retriever_resource'],
'annotation_reply':data['annotation_reply'],
'more_like_this':data['more_like_this'],
'user_input_form':data['user_input_form'],
'sensitive_word_avoidance':data['sensitive_word_avoidance'],
'file_upload':data['file_upload'],
'system_parameters':data['system_parameters'],
'opening_statement':data['opening_statement'],
}
return data
def set(self,user_id:str):
dbManage.addRecord(self._tableName,{})
+18 -4
View File
@@ -20,6 +20,14 @@ class ConversationOrm(Base):
introduction = Column(String)
created_at = Column(Integer)
def update(self,data:Dict[str,Any]):
if 'name' in data:
self.name = data['name']
class UserOrm(Base):
__tablename__ = "user"
@@ -143,14 +151,20 @@ class DBManager:
session.commit()
def update(self,tableName:str,data:Dict[str,Any],**filter):
if not self.exist(tableName):
return
session = self.SessionLocal()
ormCls = self._get_orm(tableName)
if ormCls is None:
return
record = session.query(ormCls).filter_by(**filter).first()
if record is not None:
record.update(data)
session.commit()
if len(filter) > 0:
records = session.query(ormCls).filter_by(**filter).all()
else:
records = session.query(ormCls).all()
for record in records:
if record is not None:
record.update(data)
session.commit()
def query(self,tableName:str,**filter):
session = self.SessionLocal()
+3 -1
View File
@@ -1,6 +1,5 @@
from typing import Dict, Any
from pydantic import BaseModel
@@ -11,3 +10,6 @@ class ChatRequestData(BaseModel):
response_mode: str
files: Any
conversation_id: str = None
class ChatFileUploadRequest(BaseModel):
base64: str
+7 -1
View File
@@ -31,13 +31,19 @@ def get_chat_engine(filters=None, params=None):
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
description="适用于任何需要进行全面总结、概括的要求。",
)
query_engine = create_query_engine(index,top_k,use_reranker,filters)
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "COMPACT")
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
)
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "TREE_SUMMARIZE")
query_engine_tool_1 = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool_1",
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后,且在询问工程中单位的具体数值,例如用量,费率,合计,金额等的时候建议使用“zj_query_tool_1”工具。",
)
tools.append(summary_query_tool)
tools.append(query_engine_tool)
tools.append(query_engine_tool_1)
# Add additional tools
tools += ToolFactory.from_env()
+2 -1
View File
@@ -86,7 +86,7 @@ def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None
return summary_query_engine
# Create a query engine
def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
def create_query_engine(index, top_k=3, use_reranker=False, filters=None, response_mode=None):
# 创建向量检索查询工具
postprocess = None
if use_reranker:
@@ -103,6 +103,7 @@ def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
node_postprocessors=postprocess,
use_async=True,
streaming=True,
ResponseMode = response_mode
)
return query_engine
+21
View File
@@ -12,16 +12,37 @@ db:
queries:
- sql: select * from ProjectProperties;
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
#web:
# driver_arguments:
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode