20 Commits

Author SHA1 Message Date
ly 72ddf46fc7 Merge pull request '增加新的前端子模块' (#4) from dev into main
Reviewed-on: #4
2024-08-29 10:51:50 +08:00
ly 0db159ac89 增加新的前端子模块 2024-08-29 10:48:40 +08:00
ly f57c0c84ef Merge pull request 'dev' (#3) from dev into main
Reviewed-on: #3
2024-08-29 10:13:10 +08:00
ly 131d6ef1d1 完善接口,实现对DIFY前端消息流传输的支持 2024-08-29 08:26:59 +08:00
ly 9b47e1a6e1 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-28 17:41:52 +08:00
wanyaokun 20510a937b Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-28 17:38:43 +08:00
wanyaokun a7c79df339 修改web请求接口 2024-08-28 17:35:28 +08:00
chentianrui 327bba75d5 修改了语句错误 2024-08-28 17:24:55 +08:00
chentianrui d1242d2080 修改了从数据库中查找取费表和工程量表,新加了一个树状搜索总结搜索引擎 2024-08-28 14:46:13 +08:00
ly 0f09551f5d Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-28 11:49:22 +08:00
chentianrui 8a5facb5b6 增加了判断是否使用数据库 2024-08-28 09:45:01 +08:00
chentianrui 0f7c900c1e 更改了提示词 2024-08-28 09:42:12 +08:00
chentianrui b008ad9766 更改了提示词 2024-08-28 09:39:57 +08:00
ly 56459c164e 配置文件增加UTF8编码格式支持,以免解析中文时出现问题 2024-08-28 08:04:01 +08:00
wanyaokun 07a3b2a147 修改POST和Get请求代码 2024-08-27 17:48:38 +08:00
ly b4c571cddb 增加对接DIFY前端支持功能 2024-08-27 08:43:00 +08:00
ly 7068b058e8 调整文件格式为DOCX 2024-08-27 08:40:46 +08:00
wanyaokun 33b2281b7b 修改ID为空的问题 2024-08-26 20:16:58 +08:00
ly 9ee24627c2 Merge pull request 'dev' (#2) from dev into main
Reviewed-on: #2
2024-08-23 09:37:06 +08:00
ly 88761a5d10 Merge pull request 'dev' (#1) from dev into main
Reviewed-on: #1
2024-08-22 09:41:13 +08:00
17 changed files with 628 additions and 163 deletions
+3
View File
@@ -0,0 +1,3 @@
[submodule "webapp"]
path = webapp
url = https://git.97id.com/ly/webapp.git
+458 -40
View File
@@ -1,34 +1,432 @@
import os import asyncio
from typing import Dict, List, Any, Optional, cast import json
from fastapi import APIRouter,Request import logging
from app.api.routers.request.base import userMng,conversations import time
from app.api.routers.request.models import ChatRequestData from typing import Dict, List, Any, Optional, AsyncGenerator
from aiostream import stream
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from llama_index.core import BaseCallbackHandler
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput
from pydantic import BaseModel
from app.api.routers.request.base import userMng, conversations,message,parameter
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine
import uuid
logger = logging.getLogger("uvicorn")
api_router = r = APIRouter() api_router = r = APIRouter()
v1_router = v = APIRouter() v1_router = v = APIRouter()
class ChatCallbackEvent(BaseModel):
event_type: CBEventType
payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_retrieval_message(self) -> dict | None:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"根据查询检索到 {len(nodes)} 源文件"
else:
msg = f"查询检索中: '{self.payload.get('query_str')}'"
return {
"type": "events",
"data": {"title": msg},
}
else:
return None
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
},
}
def _is_output_serializable(self, output: Any) -> bool:
try:
json.dumps(output)
return True
except TypeError:
return False
def get_agent_tool_response(self) -> dict | None:
response = self.payload.get("response")
if response is not None:
sources = response.sources
for source in sources:
# Return the tool response here to include the toolCall information
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
},
}
def to_response(self):
try:
match self.event_type:
case "retrieve":
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case _:
return None
except Exception as e:
logger.error(f"转换回应时间时发生错误,原因: {e}")
return None
class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
def __init__(
self,
):
"""Initialize the base callback handler."""
ignored_events = [
# CBEventType.CHUNKING,
# CBEventType.NODE_PARSING,
# CBEventType.EMBEDDING,
# CBEventType.LLM,
# CBEventType.TEMPLATING,
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> str:
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
logger.info("trace_start:{}\n".format(trace_id))
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
try:
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
except asyncio.TimeoutError:
pass
class IDManager:
def createID(self):
return {
"message_id" : str(uuid.uuid4()),
'task_id':str(uuid.uuid4()),
'workflow_run_id': str(uuid.uuid4()),
"workflow_id": str(uuid.uuid4())
}
class DifyChatResponseEvent(BaseModel):
event: str
conversation_id: str
message_id: str
created_at: int = int(time.time())
task_id: str
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"inputs": {
"sys.query": args['query'],
"sys.files": [],
"sys.conversation_id": args['conversation_id'],
"sys.user_id": args['use_id']
},
"created_at": int(time.time())
}
args['event'] = 'workflow_started'
super().__init__(**args)
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['event'] = 'workflow_finished'
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": args['response']
},
"error": '',
"elapsed_time": 36.03764106379822,
"total_tokens": 11707,
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": args['use_id']
},
"created_at": int(time.time()),
"finished_at": int(time.time()),
"files": []
}
super().__init__(**args)
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
answer:str
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message'
super().__init__(**args)
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
metadata:Dict[str,Any] = {}
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message_end'
super().__init__(**args)
class ChatStreamResponse(StreamingResponse):
TEXT_PREFIX = "data: "
DATA_PREFIX = "data: "
@classmethod
def convert_text(cls, token: str):
# Escape newlines and double quotes to avoid breaking the stream
#token = json.dumps(token)
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
return "\n"
@classmethod
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n\n"
@classmethod
def convert_event(cls, event: DifyChatResponseEvent):
data_str = json.dumps(event.dict())
return f"{cls.DATA_PREFIX}{data_str}\n\n"
def __init__(
self,
request: Request,
event_handler: ChatEventCallbackHandler,
response: StreamingAgentChatResponse,
data: ChatRequestData
):
content = ChatStreamResponse.content_generator(
request, event_handler, response, data
)
super().__init__(content=content)
@classmethod
async def content_generator(
cls,
request: Request,
event_handler: ChatEventCallbackHandler,
response: StreamingAgentChatResponse,
data: ChatRequestData
):
ids = IDManager().createID()
# Yield the text response
async def _chat_response_generator():
final_response = ""
async for token in response.async_response_gen():
final_response += token
args = ids
args['answer'] = token
args['conversation_id'] = data.conversation_id
event = Message_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(event)
#yield ChatStreamResponse.convert_text(token)
# 存储消息历史
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
# 发送工作流结束事件
args = ids
args['response'] = final_response
args['conversation_id'] = data.conversation_id
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids)
yield ChatStreamResponse.convert_event(msgEnt_event)
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
event_response = event.to_response()
if event_response is not None:
yield ChatStreamResponse.convert_text("")
combine = stream.merge(_chat_response_generator(), _event_generator())
is_stream_started = False
async with combine.stream() as streamer:
async for output in streamer:
if not is_stream_started:
is_stream_started = True
# 发送工作流开始事件
args = ids
args['use_id'] = data.user
args['query'] = data.query
args['conversation_id'] = data.conversation_id
wf_event = Workflow_started_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
# Stream a blank message to start the stream
# 发送一个空消息事件
#yield ChatStreamResponse.convert_text("")
yield output
if await request.is_disconnected():
break
@v.post("/chat-messages") @v.post("/chat-messages")
async def post_conversations(request: Request,data: ChatRequestData): async def post_conversations(request: Request, data: ChatRequestData):
userMng.findNoExistCreate(data.user) userMng.findNoExistCreate(data.user)
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
conversaObj = conversations() conversaObj = conversations()
conversationinfo = conversaObj.get(data.user) conversationinfo = conversaObj.get(data.conversation_id)
if conversationinfo is None: if conversationinfo is None:
conversationinfo = conversaObj.add(data.user, "新建会话") conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话")
return None # 生成聊天参数
last_message_content = ChatMessage.from_str(data.query)
filters = None
params = data.inputs or {}
# 获取聊天引擎对象
chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听
event_handler = ChatEventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
# 执行异步聊天
response = await chat_engine.astream_chat(data.query)
# 返回异步消息回应
return ChatStreamResponse(request, event_handler, response, data)
@v.get("/messages") @v.get("/messages")
async def query_messages(user:str, conversation_id:str): async def query_messages(user:str, conversation_id:str):
pass #conversation_id = default_conversation_id if conversation_id is None else conversation_id
datas = []
records = message().gets(user,conversation_id)
if records is None:
return {
"limit": 20,
"has_more": False,
"data": []
}
for record in records:
res = record.dict()
res["message_files"] = []
res["feedback"] = ''
res["retriever_resources"] = []
res["created_at"] = 1723444905
res["agent_thoughts"] = []
res["status"] = "normal"
res["error"] = ''
datas.append(res)
return {
"limit": 20,
"has_more": False,
"data": datas
}
@v.post("/conversations/{itemid}/name") @v.post("/conversations/{itemid}/name")
async def post_conversations(user:str): async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
pass consaObj = conversations()
consaObj.rename(itemid,'知识问答')
cond = {
'id':itemid,
'user_id':params['user']
}
results = consaObj.query(**cond)
if len(results) > 0:
res = results[0]
return {
"id": res['id'],
"name": res['name'],
"inputs": res['inputs'],
"status": res['status'],
"introduction": res['introduction'],
"created_at": res['created_at'],
#"工程位置"
}
return 'null'
@v.get("/conversations") @v.get("/conversations")
async def query_conversations(user:str): async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None):
user_id = '' if user is None else user user_id = '' if user is None else user
userMng.findNoExistCreate(user_id) userMng.findNoExistCreate(user_id)
@@ -38,32 +436,52 @@ async def query_conversations(user:str):
"data": conversations().gets(user_id) "data": conversations().gets(user_id)
} }
@v.get("/parameters")
async def query_parameters(user:str):
params = parameter().get(user)
if len(params) == 0:
params = {
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [],
"suggested_questions_after_answer": {
"enabled": False
},
"speech_to_text": {
"enabled": False
},
"text_to_speech": {
"enabled": False,
"language": "",
"voice": ""
},
"retriever_resource": {
"enabled": True
},
"annotation_reply": {
"enabled": False
},
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"sensitive_word_avoidance": {
"enabled": False
},
"file_upload": {
"image": {
"enabled": False,
"number_limits": 3,
"transfer_methods": [
"remote_url"
]
}
},
"system_parameters": {
"image_file_size_limit": "10"
}
}
return params
@r.get("/conversations") @r.post("")
async def query_conversations(first_id:int = None, limit:int = None, pinned:bool = None): def upload_file(request: ChatFileUploadRequest) -> List[str]:
pass pass
#meta查询
@r.get("/meta")
async def query_meta():
pass
#name查询
@r.get("/name查询")
def query_name():
with sessionlocal() as session:
name = session.query(NameOrm).first()
return Name.from_orm(name)
#parameters查询
@r.get("/parameters")
async def query_parameters():
pass
#msite查询
@r.get("/site")
async def query_site():
pass
+31 -33
View File
@@ -1,9 +1,7 @@
import os
from typing import Dict, List, Any, Optional, cast
import json
from app.api.routers.request.dbOrm import DBManager
from app.api.routers.request.baseConfig import BaseConfig
from datetime import datetime from datetime import datetime
import uuid
from app.api.routers.request.baseConfig import BaseConfig
from app.api.routers.request.dbOrm import DBManager
dbManage = DBManager() dbManage = DBManager()
@@ -20,28 +18,34 @@ class conversations:
return datas return datas
def get(self,user_id:str,id:str = ''): def get(self, id:str):
records = dbManage.query(self._tableName,user_id = user_id,id = id) records = dbManage.query(self._tableName, id=id)
if len(records) >0: if len(records) >0:
return records[0] return records[0]
return None return None
def add(self,user_id:str,name:str,id:str = ''): def add(self,id:str, user_id:str, name:str):
template = BaseConfig.ConversationCfg template = BaseConfig.ConversationCfg
template['id'] = id template['id'] = id
template['user_id'] = user_id template['user_id'] = user_id
template['name'] = name template['name'] = name
template['created_at'] = 1724399038 template['created_at'] = 1724399038
dbManage.addRecord(self._tableName,template) dbManage.addRecord(self._tableName,template)
def delete(self,id:str): def delete(self,id:str):
dbManage.delete(self._tableName,id=id) dbManage.delete(self._tableName,id=id)
def rename(self,id:str): def rename(self,id:str,name:str):
data = {'name':''} data = {'name':name}
dbManage.update(self._tableName,data,id=id) dbManage.update(self._tableName,data,id=id)
def query(self,**condition):
results = []
records = dbManage.query(self._tableName,**condition)
for record in records:
results.append(record.dict())
return results
class user: class user:
def __init__(self) -> None: def __init__(self) -> None:
self._tableName = 'user' self._tableName = 'user'
@@ -68,7 +72,7 @@ class userMng:
@classmethod @classmethod
def findNoExistCreate(cls,user_id:str): def findNoExistCreate(cls,user_id:str):
userInfo = cls.userObj.get(user_id) userInfo = cls.userObj.get(user_id)
if userInfo is None: if len(userInfo) == 0:
cls.userObj.add(user_id) cls.userObj.add(user_id)
def remove(cls,user_id:str): def remove(cls,user_id:str):
@@ -86,22 +90,7 @@ class parameter:
key = record['name'] key = record['name']
value = record['value'] value = record['value']
data[key] = value data[key] = value
return data
return {
'opening_statement':data['opening_statement'],
'suggested_questions':data['suggested_questions'],
'suggested_questions_after_answer':data['suggested_questions_after_answer'],
'speech_to_text':data['speech_to_text'],
'text_to_speech':data['text_to_speech'],
'retriever_resource':data['retriever_resource'],
'annotation_reply':data['annotation_reply'],
'more_like_this':data['more_like_this'],
'user_input_form':data['user_input_form'],
'sensitive_word_avoidance':data['sensitive_word_avoidance'],
'file_upload':data['file_upload'],
'system_parameters':data['system_parameters'],
'opening_statement':data['opening_statement'],
}
def set(self,user_id:str): def set(self,user_id:str):
dbManage.addRecord(self._tableName,{}) dbManage.addRecord(self._tableName,{})
@@ -114,14 +103,23 @@ class message:
self._tableName = 'messages' self._tableName = 'messages'
dbManage.createTable(self._tableName) dbManage.createTable(self._tableName)
def gets(self,user_id:str): def gets(self,user_id:str,conversation_id:str):
return dbManage.query(self._tableName,user_id = user_id) records = dbManage.query(self._tableName,user_id = user_id,conversation_id = conversation_id)
datas = []
for record in records:
datas.append(record)
return datas
def add(self,user_id:str): def add(self,user_id:str,conversation_id:str,query:str,answer:str):
dbManage.addRecord(self._tableName,{}) template = BaseConfig.MessageCfg
template['id'] = str(uuid.uuid4())
template['user_id'] = user_id
template['conversation_id'] = conversation_id
template['query'] = query
template['answer'] = answer
dbManage.addRecord(self._tableName,template)
def delete(self,user_id:str): def delete(self,user_id:str):
dbManage.delete(self._tableName,user_id = user_id) dbManage.delete(self._tableName,user_id = user_id)
@@ -50,3 +50,13 @@ class BaseConfig:
"introduction": ParamterCfg['opening_statement'], "introduction": ParamterCfg['opening_statement'],
"created_at":'' "created_at":''
} }
MessageCfg = {
"id": "",
'user_id':'',
"conversation_id": "",
"inputs": {},
"query": "",
"answer": ""
}
+24 -11
View File
@@ -1,11 +1,10 @@
import os import os
from typing import Dict, List, Any, Optional, cast from typing import Dict, List, Any
from fastapi import APIRouter from pydantic import BaseModel
from pydantic import BaseModel, Field from sqlalchemy import create_engine, Column, String, Integer, JSON
from sqlalchemy import create_engine, Column, String, Integer, Boolean, JSON,ForeignKey
from sqlalchemy.orm import sessionmaker, declarative_base,relationship
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import sessionmaker, declarative_base
Base = declarative_base() Base = declarative_base()
@@ -21,6 +20,14 @@ class ConversationOrm(Base):
introduction = Column(String) introduction = Column(String)
created_at = Column(Integer) created_at = Column(Integer)
def update(self,data:Dict[str,Any]):
if 'name' in data:
self.name = data['name']
class UserOrm(Base): class UserOrm(Base):
__tablename__ = "user" __tablename__ = "user"
@@ -42,7 +49,7 @@ class MessagesOrm(Base):
conversation_id = Column(String) conversation_id = Column(String)
inputs = Column(JSON) inputs = Column(JSON)
query = Column(String) query = Column(String)
answer = Column(JSON) answer = Column(String)
#数据结构 #数据结构
class ConversationModel(BaseModel): class ConversationModel(BaseModel):
@@ -91,7 +98,7 @@ class MessagesModel(BaseModel):
conversation_id :str conversation_id :str
inputs : Dict[str, Any] inputs : Dict[str, Any]
query : str query : str
answer : Dict[str, Any] answer : str
class Config: class Config:
#orm_mode = True #orm_mode = True
@@ -144,14 +151,20 @@ class DBManager:
session.commit() session.commit()
def update(self,tableName:str,data:Dict[str,Any],**filter): def update(self,tableName:str,data:Dict[str,Any],**filter):
if not self.exist(tableName):
return
session = self.SessionLocal() session = self.SessionLocal()
ormCls = self._get_orm(tableName) ormCls = self._get_orm(tableName)
if ormCls is None: if ormCls is None:
return return
record = session.query(ormCls).filter_by(**filter).first() if len(filter) > 0:
if record is not None: records = session.query(ormCls).filter_by(**filter).all()
record.update(data) else:
session.commit() records = session.query(ormCls).all()
for record in records:
if record is not None:
record.update(data)
session.commit()
def query(self,tableName:str,**filter): def query(self,tableName:str,**filter):
session = self.SessionLocal() session = self.SessionLocal()
+6 -1
View File
@@ -1,5 +1,7 @@
from typing import Dict, Any
from pydantic import BaseModel from pydantic import BaseModel
from typing import Dict, List, Any, Optional, cast
class ChatRequestData(BaseModel): class ChatRequestData(BaseModel):
inputs: Dict[str,Any] inputs: Dict[str,Any]
@@ -7,4 +9,7 @@ class ChatRequestData(BaseModel):
user: str user: str
response_mode: str response_mode: str
files: Any files: Any
conversation_id: str = None
class ChatFileUploadRequest(BaseModel):
base64: str
+7 -1
View File
@@ -31,13 +31,19 @@ def get_chat_engine(filters=None, params=None):
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool", summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
description="适用于任何需要进行全面总结、概括的要求。", description="适用于任何需要进行全面总结、概括的要求。",
) )
query_engine = create_query_engine(index,top_k,use_reranker,filters) query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "COMPACT")
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool", query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。", description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
) )
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "TREE_SUMMARIZE")
query_engine_tool_1 = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool_1",
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后,且在询问工程中单位的具体数值,例如用量,费率,合计,金额等的时候建议使用“zj_query_tool_1”工具。",
)
tools.append(summary_query_tool) tools.append(summary_query_tool)
tools.append(query_engine_tool) tools.append(query_engine_tool)
tools.append(query_engine_tool_1)
# Add additional tools # Add additional tools
tools += ToolFactory.from_env() tools += ToolFactory.from_env()
+2 -1
View File
@@ -86,7 +86,7 @@ def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None
return summary_query_engine return summary_query_engine
# Create a query engine # Create a query engine
def create_query_engine(index, top_k=3, use_reranker=False, filters=None): def create_query_engine(index, top_k=3, use_reranker=False, filters=None, response_mode=None):
# 创建向量检索查询工具 # 创建向量检索查询工具
postprocess = None postprocess = None
if use_reranker: if use_reranker:
@@ -103,6 +103,7 @@ def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
node_postprocessors=postprocess, node_postprocessors=postprocess,
use_async=True, use_async=True,
streaming=True, streaming=True,
ResponseMode = response_mode
) )
return query_engine return query_engine
+18 -17
View File
@@ -1,5 +1,4 @@
import logging import logging
import yaml import yaml
from app.engine.loaders.db import DBLoaderConfig, get_db_documents from app.engine.loaders.db import DBLoaderConfig, get_db_documents
from app.engine.loaders.file import FileLoaderConfig, get_file_documents from app.engine.loaders.file import FileLoaderConfig, get_file_documents
@@ -9,7 +8,7 @@ logger = logging.getLogger(__name__)
def load_configs(): def load_configs():
with open("config/loaders.yaml") as f: with open("config/loaders.yaml",encoding='UTF-8') as f:
configs = yaml.safe_load(f) configs = yaml.safe_load(f)
return configs return configs
@@ -17,24 +16,26 @@ def load_configs():
def get_documents(): def get_documents():
documents = [] documents = []
config = load_configs() config = load_configs()
if config is None or len(config.items()) == 0: if config is None or len(config.items()) == 0:
return documents return documents
for loader_type, loader_config in config.items(): for loader_type, loader_config in config.items():
logger.info( if loader_config.get('enable', True): # 检查 enable 字段
f"Loading documents from loader: {loader_type}, config: {loader_config}" logger.info(
) f"Loading documents from loader: {loader_type}, config: {loader_config}"
)
loader_config = loader_config or [] loader_config = loader_config or []
match loader_type: match loader_type:
case "file": case "file":
document = get_file_documents(FileLoaderConfig(**loader_config)) document = get_file_documents(FileLoaderConfig(**loader_config))
case "web": case "web":
document = get_web_documents(WebLoaderConfig(**loader_config)) document = get_web_documents(WebLoaderConfig(**loader_config))
case "db": case "db":
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config]) document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
case _: case _:
raise ValueError(f"Invalid loader type: {loader_type}") raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document) documents.extend(document)
return documents return documents
+19 -30
View File
@@ -2,17 +2,14 @@ import logging
from typing import Any, List, Optional from typing import Any, List, Optional
from llama_index.core import SQLDatabase, Document from llama_index.core import SQLDatabase, Document
from llama_index.core.objects import SQLTableSchema
from llama_index.core.readers.base import BaseReader
from llama_index.readers.database import DatabaseReader from llama_index.readers.database import DatabaseReader
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import create_engine from sqlalchemy import create_engine, text
from sqlalchemy import text
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CustomDatabaseReader(BaseReader): class CustomDatabaseReader(DatabaseReader):
"""Simple Database reader. """Simple Database reader.
Concatenates each row into Document used by LlamaIndex. Concatenates each row into Document used by LlamaIndex.
@@ -86,18 +83,19 @@ class CustomDatabaseReader(BaseReader):
List[Document]: A list of Document objects. List[Document]: A list of Document objects.
""" """
dco_str = "" dco_str = ""
with self.sql_database.engine.connect() as connection: with self.sql_database.engine.connect() as connection:
if query is None: if query is None:
raise ValueError("A query parameter is necessary to filter the data") raise ValueError("A query parameter is necessary to filter the data")
else: else:
result = connection.execute(text(query)) result = connection.execute(text(query))
dco_str = ", ".join( dco_str += ", ".join(
[f"{entry}" for entry in result.keys()] [f"{entry}" for entry in result.keys()]
) ) + "\n"
for item in result.fetchall(): for item in result.fetchall():
# fetch each item # Fetch each item
record_str = ", ".join( record_str = ", ".join(
[f"{entry}" for col, entry in zip(result.keys(), item)] [f"{entry}" for col, entry in zip(result.keys(), item)]
) )
@@ -111,45 +109,36 @@ class CustomDatabaseReader(BaseReader):
class DBLoaderConfig(BaseModel): class DBLoaderConfig(BaseModel):
uri: str uri: str
queries: List[str] queries: List[dict]
def get_db_documents(configs: list[DBLoaderConfig]): def get_db_documents(configs: List[DBLoaderConfig]) -> List[Document]:
docs = [] docs = []
if len(configs) == 0 or configs[0].uri == "": if not configs or not configs[0].uri:
logger.warning( logger.warning(
f"Failed to load database, error message: uri is empty. Return as empty document list." f"Failed to load database, error message: uri is empty. Return as empty document list."
) )
return docs return docs
metadata = { metadata = {
#'file_name':'', 'file_type': 'application/booway.document.zj',
'file_type':'application/booway.document.zj',
#'file_path':'',
#'file_size':'',
#'creation_date':'',
#'last_modified_date':'',
} }
#from llama_index.readers.database import DatabaseReader
for entry in configs: for entry in configs:
engine = create_engine(entry.uri) engine = create_engine(entry.uri)
sql_database = SQLDatabase(engine) sql_database = SQLDatabase(engine)
# table_schema_objs = makeDescriptionByEngine(sql_database)
# table_node_mapping = SQLTableNodeMapping(sql_database)
#
# nodes = table_node_mapping.to_nodes(table_schema_objs)
# for node in nodes:
# node.metadata.update(metadata)
#
# docs.extend(nodes)
queries = entry.queries or []
loader = CustomDatabaseReader(sql_database) loader = CustomDatabaseReader(sql_database)
for query in queries: for query_dict in entry.queries:
query = query_dict.get("sql", "")
explanation = query_dict.get("explanation", "")
logger.info(f"Loading data from database with query: {query}") logger.info(f"Loading data from database with query: {query}")
documents = loader.load_data(query=query) documents = loader.load_data(query=query)
docs.extend(documents) # 添加解释到元数据中
for doc in documents:
doc.metadata["explanation"] = explanation
doc.metadata.update(metadata) # 更新或添加额外的元数据
docs.append(doc)
return docs return docs
+10 -4
View File
@@ -5,6 +5,8 @@ text_qa_template_str = (
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。" "你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答," "你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
"如同直接从文件中提取的内容。\n" "如同直接从文件中提取的内容。\n"
"知识库中已经导入一个工程的全部数据,请你站在当前工程的角度回答用户关于工程文件的问题。\n"
"例如:询问“此工程”指当前导入的工程。询问“此工程名称”指当前导入的工程的工程名称。\n"
"## 技能\n" "## 技能\n"
"### 技能 1: 数据查询与提供\n" "### 技能 1: 数据查询与提供\n"
@@ -39,15 +41,19 @@ refine_template_str = (
"这是原本的问题: {query_str}\n" "这是原本的问题: {query_str}\n"
"我们已经提供了回答: {existing_answer}\n" "我们已经提供了回答: {existing_answer}\n"
"现在我们有机会改进这个回答 " "现在我们有机会改进这个回答 "
"使用以下更多上下文(仅当需要用时\n" "使用以下更多上下文(仅当有助于改进回答时使用\n"
"你需要仔细的判断新的上下文的信息与原本问题必须一个字都不差,如果有一点差别,那就不能改变我现有的回答。\n"
"在判断回答是否正确的时候,你应该仔细对比新的上下文中包含的信息是否与原本的问题一字不差,如果一字不差,才能当作新的正确回答。\n"
"如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n"
"判断一下如果原回答正确,且在新的上下文仍然包含正确的回答,请将新的回答与原回答一起返回。\n"
"------------\n" "------------\n"
"{context_msg}\n" "{context_msg}\n"
"------------\n" "------------\n"
"根据新的上下文, 请改进原来的回答。" "如果回答中已经包含有正确答案,不要返回多余的解释等信息,只返回正确答案\n"
"如果新的上下文没有用, 直接返回原本的回答\n" "如果是表结构或者是数据库的相关内容,仅用于推导问题,不需要告诉用户数据库或表结构等物理信息\n"
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
"改进的回答: " "改进的回答: "
) )
refine_template = PromptTemplate(refine_template_str) refine_template = PromptTemplate(refine_template_str)
summary_template_str = ( summary_template_str = (
+5 -6
View File
@@ -1,10 +1,9 @@
import os
import yaml
import json
import importlib import importlib
from cachetools import cached, LRUCache import os
from llama_index.core.tools.tool_spec.base import BaseToolSpec
import yaml
from llama_index.core.tools.function_tool import FunctionTool from llama_index.core.tools.function_tool import FunctionTool
from llama_index.core.tools.tool_spec.base import BaseToolSpec
class ToolType: class ToolType:
@@ -46,7 +45,7 @@ class ToolFactory:
def from_env() -> list[FunctionTool]: def from_env() -> list[FunctionTool]:
tools = [] tools = []
if os.path.exists("config/tools.yaml"): if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f: with open("config/tools.yaml", "r", encoding='UTF-8') as f:
tool_configs = yaml.safe_load(f) tool_configs = yaml.safe_load(f)
if tool_configs != None and len(tool_configs.items()) != 0: if tool_configs != None and len(tool_configs.items()) != 0:
for tool_type, config_entries in tool_configs.items(): for tool_type, config_entries in tool_configs.items():
+23 -8
View File
@@ -1,4 +1,5 @@
file: file:
enable: true # 添加 enable 字段
# use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable # use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable
use_llama_parse: false use_llama_parse: false
@@ -7,27 +8,41 @@ db:
# uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db # uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
# query: The query to fetch data from the database. E.g.: SELECT * FROM table # query: The query to fetch data from the database. E.g.: SELECT * FROM table
- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 - uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
#- uri: mysql+pymysql://zjinfo:Y6EAjEEdSYmskA8B@110.42.234.166:3306/zjinfo enable: true # 添加 enable 字段
# - uri: mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
queries: queries:
- sql: select * from ProjectProperties limit 30; - sql: select * from ProjectProperties;
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。" explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable; - sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。" explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '线路' limit 50; - sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '余物清理' limit 50;
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '拆除线路' limit 50;
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。" explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee; - sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中" explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
#web: #web:
# driver_arguments: # driver_arguments:
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode # # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode
Binary file not shown.
+2 -2
View File
@@ -12,7 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from app.api.routers.chat import chat_router from app.api.routers.chat import chat_router
from app.api.routers.upload import file_upload_router from app.api.routers.upload import file_upload_router
from app.api.routers.app import api_router,v1_router from app.api.routers.app import v1_router
from app.settings import init_settings from app.settings import init_settings
from app.observability import init_observability from app.observability import init_observability
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
@@ -56,7 +56,7 @@ mount_static_files("data", "/api/files/data")
mount_static_files("data_output", "/api/files/output") mount_static_files("data_output", "/api/files/output")
app.include_router(chat_router, prefix="/api/chat") app.include_router(chat_router, prefix="/api/chat")
app.include_router(file_upload_router, prefix="/api/chat/upload") app.include_router(file_upload_router, prefix="/api/chat/upload")
app.include_router(api_router, prefix="/api")
app.include_router(v1_router, prefix="/v1") app.include_router(v1_router, prefix="/v1")
@app.get("/") @app.get("/")
Submodule
+1
Submodule webapp added at 77dbc14a64