22 Commits

Author SHA1 Message Date
ly 72ddf46fc7 Merge pull request '增加新的前端子模块' (#4) from dev into main
Reviewed-on: #4
2024-08-29 10:51:50 +08:00
ly 0db159ac89 增加新的前端子模块 2024-08-29 10:48:40 +08:00
ly f57c0c84ef Merge pull request 'dev' (#3) from dev into main
Reviewed-on: #3
2024-08-29 10:13:10 +08:00
ly 131d6ef1d1 完善接口,实现对DIFY前端消息流传输的支持 2024-08-29 08:26:59 +08:00
ly 9b47e1a6e1 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-28 17:41:52 +08:00
wanyaokun 20510a937b Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-28 17:38:43 +08:00
wanyaokun a7c79df339 修改web请求接口 2024-08-28 17:35:28 +08:00
chentianrui 327bba75d5 修改了语句错误 2024-08-28 17:24:55 +08:00
chentianrui d1242d2080 修改了从数据库中查找取费表和工程量表,新加了一个树状搜索总结搜索引擎 2024-08-28 14:46:13 +08:00
ly 0f09551f5d Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-28 11:49:22 +08:00
chentianrui 8a5facb5b6 增加了判断是否使用数据库 2024-08-28 09:45:01 +08:00
chentianrui 0f7c900c1e 更改了提示词 2024-08-28 09:42:12 +08:00
chentianrui b008ad9766 更改了提示词 2024-08-28 09:39:57 +08:00
ly 56459c164e 配置文件增加UTF8编码格式支持,以免解析中文时出现问题 2024-08-28 08:04:01 +08:00
wanyaokun 07a3b2a147 修改POST和Get请求代码 2024-08-27 17:48:38 +08:00
ly b4c571cddb 增加对接DIFY前端支持功能 2024-08-27 08:43:00 +08:00
ly 7068b058e8 调整文件格式为DOCX 2024-08-27 08:40:46 +08:00
wanyaokun 33b2281b7b 修改ID为空的问题 2024-08-26 20:16:58 +08:00
wanyaokun 1704b61609 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-26 19:58:57 +08:00
wanyaokun afccaf6eb5 新增Web前后端通信代码 2024-08-26 19:57:22 +08:00
ly 9ee24627c2 Merge pull request 'dev' (#2) from dev into main
Reviewed-on: #2
2024-08-23 09:37:06 +08:00
ly 88761a5d10 Merge pull request 'dev' (#1) from dev into main
Reviewed-on: #1
2024-08-22 09:41:13 +08:00
19 changed files with 992 additions and 70 deletions
+3
View File
@@ -0,0 +1,3 @@
[submodule "webapp"]
path = webapp
url = https://git.97id.com/ly/webapp.git
+1
View File
@@ -2,6 +2,7 @@
# LLAMA_CLOUD_API_KEY=
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
SQLITE_DATABASE_URL=sqlite:///./source.db
DASHSCOPE_API_KEY=sk-02c8540e86d84b7ca0e6f4f51bac6e60
# The provider for the AI models to use.
+1
View File
@@ -2,6 +2,7 @@
# LLAMA_CLOUD_API_KEY=
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
SQLITE_DATABASE_URL=sqlite:///./source.db
# The number of similar embeddings to return when retrieving documents.
TOP_K=10
+487
View File
@@ -0,0 +1,487 @@
import asyncio
import json
import logging
import time
from typing import Dict, List, Any, Optional, AsyncGenerator
from aiostream import stream
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from llama_index.core import BaseCallbackHandler
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput
from pydantic import BaseModel
from app.api.routers.request.base import userMng, conversations,message,parameter
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine
import uuid
logger = logging.getLogger("uvicorn")
api_router = r = APIRouter()
v1_router = v = APIRouter()
class ChatCallbackEvent(BaseModel):
event_type: CBEventType
payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_retrieval_message(self) -> dict | None:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"根据查询检索到 {len(nodes)} 源文件"
else:
msg = f"查询检索中: '{self.payload.get('query_str')}'"
return {
"type": "events",
"data": {"title": msg},
}
else:
return None
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
},
}
def _is_output_serializable(self, output: Any) -> bool:
try:
json.dumps(output)
return True
except TypeError:
return False
def get_agent_tool_response(self) -> dict | None:
response = self.payload.get("response")
if response is not None:
sources = response.sources
for source in sources:
# Return the tool response here to include the toolCall information
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
},
}
def to_response(self):
try:
match self.event_type:
case "retrieve":
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case _:
return None
except Exception as e:
logger.error(f"转换回应时间时发生错误,原因: {e}")
return None
class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
def __init__(
self,
):
"""Initialize the base callback handler."""
ignored_events = [
# CBEventType.CHUNKING,
# CBEventType.NODE_PARSING,
# CBEventType.EMBEDDING,
# CBEventType.LLM,
# CBEventType.TEMPLATING,
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> str:
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
logger.info("trace_start:{}\n".format(trace_id))
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
try:
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
except asyncio.TimeoutError:
pass
class IDManager:
def createID(self):
return {
"message_id" : str(uuid.uuid4()),
'task_id':str(uuid.uuid4()),
'workflow_run_id': str(uuid.uuid4()),
"workflow_id": str(uuid.uuid4())
}
class DifyChatResponseEvent(BaseModel):
event: str
conversation_id: str
message_id: str
created_at: int = int(time.time())
task_id: str
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"inputs": {
"sys.query": args['query'],
"sys.files": [],
"sys.conversation_id": args['conversation_id'],
"sys.user_id": args['use_id']
},
"created_at": int(time.time())
}
args['event'] = 'workflow_started'
super().__init__(**args)
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['event'] = 'workflow_finished'
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": args['response']
},
"error": '',
"elapsed_time": 36.03764106379822,
"total_tokens": 11707,
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": args['use_id']
},
"created_at": int(time.time()),
"finished_at": int(time.time()),
"files": []
}
super().__init__(**args)
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
answer:str
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message'
super().__init__(**args)
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
metadata:Dict[str,Any] = {}
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message_end'
super().__init__(**args)
class ChatStreamResponse(StreamingResponse):
TEXT_PREFIX = "data: "
DATA_PREFIX = "data: "
@classmethod
def convert_text(cls, token: str):
# Escape newlines and double quotes to avoid breaking the stream
#token = json.dumps(token)
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
return "\n"
@classmethod
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n\n"
@classmethod
def convert_event(cls, event: DifyChatResponseEvent):
data_str = json.dumps(event.dict())
return f"{cls.DATA_PREFIX}{data_str}\n\n"
def __init__(
self,
request: Request,
event_handler: ChatEventCallbackHandler,
response: StreamingAgentChatResponse,
data: ChatRequestData
):
content = ChatStreamResponse.content_generator(
request, event_handler, response, data
)
super().__init__(content=content)
@classmethod
async def content_generator(
cls,
request: Request,
event_handler: ChatEventCallbackHandler,
response: StreamingAgentChatResponse,
data: ChatRequestData
):
ids = IDManager().createID()
# Yield the text response
async def _chat_response_generator():
final_response = ""
async for token in response.async_response_gen():
final_response += token
args = ids
args['answer'] = token
args['conversation_id'] = data.conversation_id
event = Message_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(event)
#yield ChatStreamResponse.convert_text(token)
# 存储消息历史
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
# 发送工作流结束事件
args = ids
args['response'] = final_response
args['conversation_id'] = data.conversation_id
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids)
yield ChatStreamResponse.convert_event(msgEnt_event)
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
event_response = event.to_response()
if event_response is not None:
yield ChatStreamResponse.convert_text("")
combine = stream.merge(_chat_response_generator(), _event_generator())
is_stream_started = False
async with combine.stream() as streamer:
async for output in streamer:
if not is_stream_started:
is_stream_started = True
# 发送工作流开始事件
args = ids
args['use_id'] = data.user
args['query'] = data.query
args['conversation_id'] = data.conversation_id
wf_event = Workflow_started_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
# Stream a blank message to start the stream
# 发送一个空消息事件
#yield ChatStreamResponse.convert_text("")
yield output
if await request.is_disconnected():
break
@v.post("/chat-messages")
async def post_conversations(request: Request, data: ChatRequestData):
userMng.findNoExistCreate(data.user)
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
conversaObj = conversations()
conversationinfo = conversaObj.get(data.conversation_id)
if conversationinfo is None:
conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话")
# 生成聊天参数
last_message_content = ChatMessage.from_str(data.query)
filters = None
params = data.inputs or {}
# 获取聊天引擎对象
chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听
event_handler = ChatEventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
# 执行异步聊天
response = await chat_engine.astream_chat(data.query)
# 返回异步消息回应
return ChatStreamResponse(request, event_handler, response, data)
@v.get("/messages")
async def query_messages(user:str, conversation_id:str):
#conversation_id = default_conversation_id if conversation_id is None else conversation_id
datas = []
records = message().gets(user,conversation_id)
if records is None:
return {
"limit": 20,
"has_more": False,
"data": []
}
for record in records:
res = record.dict()
res["message_files"] = []
res["feedback"] = ''
res["retriever_resources"] = []
res["created_at"] = 1723444905
res["agent_thoughts"] = []
res["status"] = "normal"
res["error"] = ''
datas.append(res)
return {
"limit": 20,
"has_more": False,
"data": datas
}
@v.post("/conversations/{itemid}/name")
async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
consaObj = conversations()
consaObj.rename(itemid,'知识问答')
cond = {
'id':itemid,
'user_id':params['user']
}
results = consaObj.query(**cond)
if len(results) > 0:
res = results[0]
return {
"id": res['id'],
"name": res['name'],
"inputs": res['inputs'],
"status": res['status'],
"introduction": res['introduction'],
"created_at": res['created_at'],
#"工程位置"
}
return 'null'
@v.get("/conversations")
async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None):
user_id = '' if user is None else user
userMng.findNoExistCreate(user_id)
return {
"limit": 20,
"has_more": False,
"data": conversations().gets(user_id)
}
@v.get("/parameters")
async def query_parameters(user:str):
params = parameter().get(user)
if len(params) == 0:
params = {
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [],
"suggested_questions_after_answer": {
"enabled": False
},
"speech_to_text": {
"enabled": False
},
"text_to_speech": {
"enabled": False,
"language": "",
"voice": ""
},
"retriever_resource": {
"enabled": True
},
"annotation_reply": {
"enabled": False
},
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"sensitive_word_avoidance": {
"enabled": False
},
"file_upload": {
"image": {
"enabled": False,
"number_limits": 3,
"transfer_methods": [
"remote_url"
]
}
},
"system_parameters": {
"image_file_size_limit": "10"
}
}
return params
@r.post("")
def upload_file(request: ChatFileUploadRequest) -> List[str]:
pass
+125
View File
@@ -0,0 +1,125 @@
from datetime import datetime
import uuid
from app.api.routers.request.baseConfig import BaseConfig
from app.api.routers.request.dbOrm import DBManager
dbManage = DBManager()
class conversations:
def __init__(self) -> None:
self._tableName = 'conversations'
dbManage.createTable(self._tableName)
def gets(self,user_id:str):
records = dbManage.query(self._tableName,user_id = user_id)
datas = []
for record in records:
datas.append(record)
return datas
def get(self, id:str):
records = dbManage.query(self._tableName, id=id)
if len(records) >0:
return records[0]
return None
def add(self,id:str, user_id:str, name:str):
template = BaseConfig.ConversationCfg
template['id'] = id
template['user_id'] = user_id
template['name'] = name
template['created_at'] = 1724399038
dbManage.addRecord(self._tableName,template)
def delete(self,id:str):
dbManage.delete(self._tableName,id=id)
def rename(self,id:str,name:str):
data = {'name':name}
dbManage.update(self._tableName,data,id=id)
def query(self,**condition):
results = []
records = dbManage.query(self._tableName,**condition)
for record in records:
results.append(record.dict())
return results
class user:
def __init__(self) -> None:
self._tableName = 'user'
dbManage.createTable(self._tableName)
def gets(self):
return dbManage.query(self._tableName)
def get(self,id:str):
return dbManage.query(self._tableName,id = id)
def add(self,id:str):
info = {
'id':id,
'createtime': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
dbManage.addRecord(self._tableName,info)
def delete(self,id:str):
dbManage.delete(self._tableName,id = id)
class userMng:
userObj = user()
@classmethod
def findNoExistCreate(cls,user_id:str):
userInfo = cls.userObj.get(user_id)
if len(userInfo) == 0:
cls.userObj.add(user_id)
def remove(cls,user_id:str):
cls.userObj.delete(user_id)
class parameter:
def __init__(self) -> None:
self._tableName = 'parameters'
dbManage.createTable(self._tableName)
def get(self,user_id:str):
records = dbManage.query(self._tableName,user_id = user_id)
data = {}
for record in records:
key = record['name']
value = record['value']
data[key] = value
return data
def set(self,user_id:str):
dbManage.addRecord(self._tableName,{})
def delete(self,user_id:str):
dbManage.delete(self._tableName,user_id = user_id)
class message:
def __init__(self) -> None:
self._tableName = 'messages'
dbManage.createTable(self._tableName)
def gets(self,user_id:str,conversation_id:str):
records = dbManage.query(self._tableName,user_id = user_id,conversation_id = conversation_id)
datas = []
for record in records:
datas.append(record)
return datas
def add(self,user_id:str,conversation_id:str,query:str,answer:str):
template = BaseConfig.MessageCfg
template['id'] = str(uuid.uuid4())
template['user_id'] = user_id
template['conversation_id'] = conversation_id
template['query'] = query
template['answer'] = answer
dbManage.addRecord(self._tableName,template)
def delete(self,user_id:str):
dbManage.delete(self._tableName,user_id = user_id)
@@ -0,0 +1,62 @@
class BaseConfig:
ParamterCfg = {
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [],
"suggested_questions_after_answer": {
"enabled": False
},
"speech_to_text": {
"enabled": False
},
"text_to_speech": {
"enabled": False,
"language": "",
"voice": ""
},
"retriever_resource": {
"enabled": True
},
"annotation_reply": {
"enabled": False
},
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"sensitive_word_avoidance": {
"enabled": False
},
"file_upload": {
"image": {
"enabled": False,
"number_limits": 3,
"transfer_methods": [
"remote_url"
]
}
},
"system_parameters": {
"image_file_size_limit": "10"
}
}
ConversationCfg = {
"id": "",
'user_id':'',
"name": "",
"inputs": {},
"status": "normal",
"introduction": ParamterCfg['opening_statement'],
"created_at":''
}
MessageCfg = {
"id": "",
'user_id':'',
"conversation_id": "",
"inputs": {},
"query": "",
"answer": ""
}
+207
View File
@@ -0,0 +1,207 @@
import os
from typing import Dict, List, Any
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, String, Integer, JSON
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import sessionmaker, declarative_base
Base = declarative_base()
#orm类
class ConversationOrm(Base):
__tablename__ = "conversations"
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(String)
inputs = Column(JSON)
status = Column(String)
introduction = Column(String)
created_at = Column(Integer)
def update(self,data:Dict[str,Any]):
if 'name' in data:
self.name = data['name']
class UserOrm(Base):
__tablename__ = "user"
id = Column(String, primary_key=True)
createtime = Column(String)
class ParametersOrm(Base):
__tablename__ = "parameters"
user_id = Column(String,primary_key=True)
name = Column(String)
value = Column(JSON)
class MessagesOrm(Base):
__tablename__ = "messages"
id = Column(String,primary_key=True)
user_id = Column(String)
conversation_id = Column(String)
inputs = Column(JSON)
query = Column(String)
answer = Column(String)
#数据结构
class ConversationModel(BaseModel):
id: str
name: str
inputs: Dict[str, Any]
status: str
introduction: str
created_at: int
class Config:
#orm_mode = True
from_attributes=True
@classmethod
def orm(cls):
return ConversationOrm
class UserModel(BaseModel):
id: str
createtime: str
class Config:
#orm_mode = True
from_attributes=True
@classmethod
def orm(cls):
return UserOrm
class ParametersModel(BaseModel):
user_id : str
name : str
value : Dict[str, Any]
class Config:
#orm_mode = True
from_attributes=True
@classmethod
def orm(cls):
return ParametersOrm
class MessagesModel(BaseModel):
id :str
conversation_id :str
inputs : Dict[str, Any]
query : str
answer : str
class Config:
#orm_mode = True
from_attributes=True
@classmethod
def orm(cls):
return MessagesOrm
class DBManager:
def __init__(self) -> None:
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
self._engine = create_engine(DATABASE_URL)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
def createTable(self,tableName:str):
if self._engine is None:
return
if not self.exist(tableName):
Base.metadata.tables[tableName].create(self._engine)
def addRecord(self,tableName:str,record:Dict[str,Any]):
ormCls = self._get_orm(tableName)
if ormCls is None:
return
session = self.SessionLocal()
data = ormCls(**record)
session.add(data)
session.commit()
def addRecords(self,tableName:str,records:List[Dict[str,Any]]):
ormCls = self._get_orm(tableName)
if ormCls is None:
return
datas = []
session = self.SessionLocal()
for record in records:
datas.append(ormCls(**record))
session.add(datas)
session.commit()
def delete(self,tableName:str,**filter):
session = self.SessionLocal()
ormCls = self._get_orm(tableName)
if ormCls is None:
return
records = session.query(ormCls).filter_by(**filter).all()
if records is not None:
session.delete(records)
session.commit()
def update(self,tableName:str,data:Dict[str,Any],**filter):
if not self.exist(tableName):
return
session = self.SessionLocal()
ormCls = self._get_orm(tableName)
if ormCls is None:
return
if len(filter) > 0:
records = session.query(ormCls).filter_by(**filter).all()
else:
records = session.query(ormCls).all()
for record in records:
if record is not None:
record.update(data)
session.commit()
def query(self,tableName:str,**filter):
session = self.SessionLocal()
ormCls = self._get_orm(tableName)
if ormCls is None:
return
modelCls = self._get_model(ormCls)
if modelCls is None:
return
if filter is not None:
records = session.query(ormCls).filter_by(**filter).all()
else:
records = session.query(ormCls).all()
datas = []
for record in records:
datas.append(modelCls.from_orm(record))
return datas
def exist(self,tableName:str)->bool:
if self._engine is None:
return
inspector = Inspector.from_engine(self._engine)
return inspector.has_table(tableName)
def _get_orm(self,tableName:str):
subClss = Base.__subclasses__()
for sunCls in subClss:
if sunCls.__tablename__ == tableName:
return sunCls
return None
def _get_model(self,orm:Any):
subClss = BaseModel.__subclasses__()
for sunCls in subClss:
if 'orm' in sunCls.__dict__ and sunCls.orm() == orm:
return sunCls
return None
+15
View File
@@ -0,0 +1,15 @@
from typing import Dict, Any
from pydantic import BaseModel
class ChatRequestData(BaseModel):
inputs: Dict[str,Any]
query: str
user: str
response_mode: str
files: Any
conversation_id: str = None
class ChatFileUploadRequest(BaseModel):
base64: str
+7 -1
View File
@@ -31,13 +31,19 @@ def get_chat_engine(filters=None, params=None):
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
description="适用于任何需要进行全面总结、概括的要求。",
)
query_engine = create_query_engine(index,top_k,use_reranker,filters)
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "COMPACT")
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
)
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "TREE_SUMMARIZE")
query_engine_tool_1 = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool_1",
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后,且在询问工程中单位的具体数值,例如用量,费率,合计,金额等的时候建议使用“zj_query_tool_1”工具。",
)
tools.append(summary_query_tool)
tools.append(query_engine_tool)
tools.append(query_engine_tool_1)
# Add additional tools
tools += ToolFactory.from_env()
+2 -1
View File
@@ -86,7 +86,7 @@ def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None
return summary_query_engine
# Create a query engine
def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
def create_query_engine(index, top_k=3, use_reranker=False, filters=None, response_mode=None):
# 创建向量检索查询工具
postprocess = None
if use_reranker:
@@ -103,6 +103,7 @@ def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
node_postprocessors=postprocess,
use_async=True,
streaming=True,
ResponseMode = response_mode
)
return query_engine
+3 -2
View File
@@ -1,5 +1,4 @@
import logging
import yaml
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
@@ -9,7 +8,7 @@ logger = logging.getLogger(__name__)
def load_configs():
with open("config/loaders.yaml") as f:
with open("config/loaders.yaml",encoding='UTF-8') as f:
configs = yaml.safe_load(f)
return configs
@@ -17,10 +16,12 @@ def load_configs():
def get_documents():
documents = []
config = load_configs()
if config is None or len(config.items()) == 0:
return documents
for loader_type, loader_config in config.items():
if loader_config.get('enable', True): # 检查 enable 字段
logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}"
)
+18 -29
View File
@@ -2,17 +2,14 @@ import logging
from typing import Any, List, Optional
from llama_index.core import SQLDatabase, Document
from llama_index.core.objects import SQLTableSchema
from llama_index.core.readers.base import BaseReader
from llama_index.readers.database import DatabaseReader
from pydantic import BaseModel
from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
logger = logging.getLogger(__name__)
class CustomDatabaseReader(BaseReader):
class CustomDatabaseReader(DatabaseReader):
"""Simple Database reader.
Concatenates each row into Document used by LlamaIndex.
@@ -86,18 +83,19 @@ class CustomDatabaseReader(BaseReader):
List[Document]: A list of Document objects.
"""
dco_str = ""
with self.sql_database.engine.connect() as connection:
if query is None:
raise ValueError("A query parameter is necessary to filter the data")
else:
result = connection.execute(text(query))
dco_str = ", ".join(
dco_str += ", ".join(
[f"{entry}" for entry in result.keys()]
)
) + "\n"
for item in result.fetchall():
# fetch each item
# Fetch each item
record_str = ", ".join(
[f"{entry}" for col, entry in zip(result.keys(), item)]
)
@@ -111,45 +109,36 @@ class CustomDatabaseReader(BaseReader):
class DBLoaderConfig(BaseModel):
uri: str
queries: List[str]
queries: List[dict]
def get_db_documents(configs: list[DBLoaderConfig]):
def get_db_documents(configs: List[DBLoaderConfig]) -> List[Document]:
docs = []
if len(configs) == 0 or configs[0].uri == "":
if not configs or not configs[0].uri:
logger.warning(
f"Failed to load database, error message: uri is empty. Return as empty document list."
)
return docs
metadata = {
#'file_name':'',
'file_type': 'application/booway.document.zj',
#'file_path':'',
#'file_size':'',
#'creation_date':'',
#'last_modified_date':'',
}
#from llama_index.readers.database import DatabaseReader
for entry in configs:
engine = create_engine(entry.uri)
sql_database = SQLDatabase(engine)
# table_schema_objs = makeDescriptionByEngine(sql_database)
# table_node_mapping = SQLTableNodeMapping(sql_database)
#
# nodes = table_node_mapping.to_nodes(table_schema_objs)
# for node in nodes:
# node.metadata.update(metadata)
#
# docs.extend(nodes)
queries = entry.queries or []
loader = CustomDatabaseReader(sql_database)
for query in queries:
for query_dict in entry.queries:
query = query_dict.get("sql", "")
explanation = query_dict.get("explanation", "")
logger.info(f"Loading data from database with query: {query}")
documents = loader.load_data(query=query)
docs.extend(documents)
# 添加解释到元数据中
for doc in documents:
doc.metadata["explanation"] = explanation
doc.metadata.update(metadata) # 更新或添加额外的元数据
docs.append(doc)
return docs
+10 -4
View File
@@ -5,6 +5,8 @@ text_qa_template_str = (
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
"如同直接从文件中提取的内容。\n"
"知识库中已经导入一个工程的全部数据,请你站在当前工程的角度回答用户关于工程文件的问题。\n"
"例如:询问“此工程”指当前导入的工程。询问“此工程名称”指当前导入的工程的工程名称。\n"
"## 技能\n"
"### 技能 1: 数据查询与提供\n"
@@ -39,15 +41,19 @@ refine_template_str = (
"这是原本的问题: {query_str}\n"
"我们已经提供了回答: {existing_answer}\n"
"现在我们有机会改进这个回答 "
"使用以下更多上下文(仅当需要用时\n"
"使用以下更多上下文(仅当有助于改进回答时使用\n"
"你需要仔细的判断新的上下文的信息与原本问题必须一个字都不差,如果有一点差别,那就不能改变我现有的回答。\n"
"在判断回答是否正确的时候,你应该仔细对比新的上下文中包含的信息是否与原本的问题一字不差,如果一字不差,才能当作新的正确回答。\n"
"如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n"
"判断一下如果原回答正确,且在新的上下文仍然包含正确的回答,请将新的回答与原回答一起返回。\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"根据新的上下文, 请改进原来的回答。"
"如果新的上下文没有用, 直接返回原本的回答\n"
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
"如果回答中已经包含有正确答案,不要返回多余的解释等信息,只返回正确答案\n"
"如果是表结构或者是数据库的相关内容,仅用于推导问题,不需要告诉用户数据库或表结构等物理信息\n"
"改进的回答: "
)
refine_template = PromptTemplate(refine_template_str)
summary_template_str = (
+5 -6
View File
@@ -1,10 +1,9 @@
import os
import yaml
import json
import importlib
from cachetools import cached, LRUCache
from llama_index.core.tools.tool_spec.base import BaseToolSpec
import os
import yaml
from llama_index.core.tools.function_tool import FunctionTool
from llama_index.core.tools.tool_spec.base import BaseToolSpec
class ToolType:
@@ -46,7 +45,7 @@ class ToolFactory:
def from_env() -> list[FunctionTool]:
tools = []
if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f:
with open("config/tools.yaml", "r", encoding='UTF-8') as f:
tool_configs = yaml.safe_load(f)
if tool_configs != None and len(tool_configs.items()) != 0:
for tool_type, config_entries in tool_configs.items():
+23 -8
View File
@@ -1,4 +1,5 @@
file:
enable: true # 添加 enable 字段
# use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable
use_llama_parse: false
@@ -7,27 +8,41 @@ db:
# uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
# query: The query to fetch data from the database. E.g.: SELECT * FROM table
- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
#- uri: mysql+pymysql://zjinfo:Y6EAjEEdSYmskA8B@110.42.234.166:3306/zjinfo
# - uri: mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
enable: true # 添加 enable 字段
queries:
- sql: select * from ProjectProperties limit 30;
- sql: select * from ProjectProperties;
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '线路' limit 50;
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '余物清理' limit 50;
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '拆除线路' limit 50;
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
#web:
# driver_arguments:
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode
Binary file not shown.
+3
View File
@@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from app.api.routers.chat import chat_router
from app.api.routers.upload import file_upload_router
from app.api.routers.app import v1_router
from app.settings import init_settings
from app.observability import init_observability
from fastapi.staticfiles import StaticFiles
@@ -56,6 +57,8 @@ mount_static_files("data_output", "/api/files/output")
app.include_router(chat_router, prefix="/api/chat")
app.include_router(file_upload_router, prefix="/api/chat/upload")
app.include_router(v1_router, prefix="/v1")
@app.get("/")
async def redirect_to_docs():
return RedirectResponse(url="/docs")
Submodule
+1
Submodule webapp added at 77dbc14a64