Compare commits
22 Commits
b052d373f1
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 72ddf46fc7 | |||
| 0db159ac89 | |||
| f57c0c84ef | |||
| 131d6ef1d1 | |||
| 9b47e1a6e1 | |||
| 20510a937b | |||
| a7c79df339 | |||
| 327bba75d5 | |||
| d1242d2080 | |||
| 0f09551f5d | |||
| 8a5facb5b6 | |||
| 0f7c900c1e | |||
| b008ad9766 | |||
| 56459c164e | |||
| 07a3b2a147 | |||
| b4c571cddb | |||
| 7068b058e8 | |||
| 33b2281b7b | |||
| 1704b61609 | |||
| afccaf6eb5 | |||
| 9ee24627c2 | |||
| 88761a5d10 |
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "webapp"]
|
||||||
|
path = webapp
|
||||||
|
url = https://git.97id.com/ly/webapp.git
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
# LLAMA_CLOUD_API_KEY=
|
# LLAMA_CLOUD_API_KEY=
|
||||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||||
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||||
|
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||||
|
|
||||||
DASHSCOPE_API_KEY=sk-02c8540e86d84b7ca0e6f4f51bac6e60
|
DASHSCOPE_API_KEY=sk-02c8540e86d84b7ca0e6f4f51bac6e60
|
||||||
# The provider for the AI models to use.
|
# The provider for the AI models to use.
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# LLAMA_CLOUD_API_KEY=
|
# LLAMA_CLOUD_API_KEY=
|
||||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||||
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||||
|
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||||
|
|
||||||
# The number of similar embeddings to return when retrieving documents.
|
# The number of similar embeddings to return when retrieving documents.
|
||||||
TOP_K=10
|
TOP_K=10
|
||||||
|
|||||||
@@ -0,0 +1,487 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||||
|
|
||||||
|
from aiostream import stream
|
||||||
|
from fastapi import APIRouter, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from llama_index.core import BaseCallbackHandler
|
||||||
|
from llama_index.core.base.llms.types import ChatMessage
|
||||||
|
from llama_index.core.callbacks import CBEventType
|
||||||
|
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||||
|
from llama_index.core.tools import ToolOutput
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from app.api.routers.request.base import userMng, conversations,message,parameter
|
||||||
|
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
|
||||||
|
from app.engine import get_chat_engine
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
|
api_router = r = APIRouter()
|
||||||
|
v1_router = v = APIRouter()
|
||||||
|
|
||||||
|
class ChatCallbackEvent(BaseModel):
|
||||||
|
event_type: CBEventType
|
||||||
|
payload: Optional[Dict[str, Any]] = None
|
||||||
|
event_id: str = ""
|
||||||
|
|
||||||
|
def get_retrieval_message(self) -> dict | None:
|
||||||
|
if self.payload:
|
||||||
|
nodes = self.payload.get("nodes")
|
||||||
|
if nodes:
|
||||||
|
msg = f"根据查询检索到 {len(nodes)} 源文件"
|
||||||
|
else:
|
||||||
|
msg = f"查询检索中: '{self.payload.get('query_str')}'"
|
||||||
|
return {
|
||||||
|
"type": "events",
|
||||||
|
"data": {"title": msg},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_tool_message(self) -> dict | None:
|
||||||
|
func_call_args = self.payload.get("function_call")
|
||||||
|
if func_call_args is not None and "tool" in self.payload:
|
||||||
|
tool = self.payload.get("tool")
|
||||||
|
return {
|
||||||
|
"type": "events",
|
||||||
|
"data": {
|
||||||
|
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _is_output_serializable(self, output: Any) -> bool:
|
||||||
|
try:
|
||||||
|
json.dumps(output)
|
||||||
|
return True
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_agent_tool_response(self) -> dict | None:
|
||||||
|
response = self.payload.get("response")
|
||||||
|
if response is not None:
|
||||||
|
sources = response.sources
|
||||||
|
for source in sources:
|
||||||
|
# Return the tool response here to include the toolCall information
|
||||||
|
if isinstance(source, ToolOutput):
|
||||||
|
if self._is_output_serializable(source.raw_output):
|
||||||
|
output = source.raw_output
|
||||||
|
else:
|
||||||
|
output = source.content
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "tools",
|
||||||
|
"data": {
|
||||||
|
"toolOutput": {
|
||||||
|
"output": output,
|
||||||
|
"isError": source.is_error,
|
||||||
|
},
|
||||||
|
"toolCall": {
|
||||||
|
"id": None, # There is no tool id in the ToolOutput
|
||||||
|
"name": source.tool_name,
|
||||||
|
"input": source.raw_input,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_response(self):
|
||||||
|
try:
|
||||||
|
match self.event_type:
|
||||||
|
case "retrieve":
|
||||||
|
return self.get_retrieval_message()
|
||||||
|
case "function_call":
|
||||||
|
return self.get_tool_message()
|
||||||
|
case "agent_step":
|
||||||
|
return self.get_agent_tool_response()
|
||||||
|
case _:
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||||
|
_aqueue: asyncio.Queue
|
||||||
|
is_done: bool = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""Initialize the base callback handler."""
|
||||||
|
ignored_events = [
|
||||||
|
# CBEventType.CHUNKING,
|
||||||
|
# CBEventType.NODE_PARSING,
|
||||||
|
# CBEventType.EMBEDDING,
|
||||||
|
# CBEventType.LLM,
|
||||||
|
# CBEventType.TEMPLATING,
|
||||||
|
]
|
||||||
|
super().__init__(ignored_events, ignored_events)
|
||||||
|
self._aqueue = asyncio.Queue()
|
||||||
|
|
||||||
|
def on_event_start(
|
||||||
|
self,
|
||||||
|
event_type: CBEventType,
|
||||||
|
payload: Optional[Dict[str, Any]] = None,
|
||||||
|
event_id: str = "",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||||
|
|
||||||
|
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||||
|
if event.to_response() is not None:
|
||||||
|
self._aqueue.put_nowait(event)
|
||||||
|
|
||||||
|
def on_event_end(
|
||||||
|
self,
|
||||||
|
event_type: CBEventType,
|
||||||
|
payload: Optional[Dict[str, Any]] = None,
|
||||||
|
event_id: str = "",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||||
|
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||||
|
if event.to_response() is not None:
|
||||||
|
self._aqueue.put_nowait(event)
|
||||||
|
|
||||||
|
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||||
|
"""No-op."""
|
||||||
|
logger.info("trace_start:{}\n".format(trace_id))
|
||||||
|
|
||||||
|
def end_trace(
|
||||||
|
self,
|
||||||
|
trace_id: Optional[str] = None,
|
||||||
|
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""No-op."""
|
||||||
|
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||||
|
|
||||||
|
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
||||||
|
while not self._aqueue.empty() or not self.is_done:
|
||||||
|
try:
|
||||||
|
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class IDManager:
|
||||||
|
def createID(self):
|
||||||
|
return {
|
||||||
|
"message_id" : str(uuid.uuid4()),
|
||||||
|
'task_id':str(uuid.uuid4()),
|
||||||
|
'workflow_run_id': str(uuid.uuid4()),
|
||||||
|
"workflow_id": str(uuid.uuid4())
|
||||||
|
}
|
||||||
|
|
||||||
|
class DifyChatResponseEvent(BaseModel):
|
||||||
|
event: str
|
||||||
|
conversation_id: str
|
||||||
|
message_id: str
|
||||||
|
created_at: int = int(time.time())
|
||||||
|
task_id: str
|
||||||
|
|
||||||
|
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||||
|
workflow_run_id:str
|
||||||
|
data:Dict[str,Any]
|
||||||
|
def __init__(self,**args):
|
||||||
|
args['data'] = {
|
||||||
|
"id": args['workflow_run_id'],
|
||||||
|
"workflow_id": args['workflow_id'],
|
||||||
|
"sequence_number": 1709,
|
||||||
|
"inputs": {
|
||||||
|
"sys.query": args['query'],
|
||||||
|
"sys.files": [],
|
||||||
|
"sys.conversation_id": args['conversation_id'],
|
||||||
|
"sys.user_id": args['use_id']
|
||||||
|
},
|
||||||
|
"created_at": int(time.time())
|
||||||
|
}
|
||||||
|
args['event'] = 'workflow_started'
|
||||||
|
super().__init__(**args)
|
||||||
|
|
||||||
|
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||||
|
workflow_run_id:str
|
||||||
|
data:Dict[str,Any]
|
||||||
|
def __init__(self,**args):
|
||||||
|
args['event'] = 'workflow_finished'
|
||||||
|
args['data'] = {
|
||||||
|
"id": args['workflow_run_id'],
|
||||||
|
"workflow_id": args['workflow_id'],
|
||||||
|
"sequence_number": 1709,
|
||||||
|
"status": "succeeded",
|
||||||
|
"outputs": {
|
||||||
|
"answer": args['response']
|
||||||
|
},
|
||||||
|
"error": '',
|
||||||
|
"elapsed_time": 36.03764106379822,
|
||||||
|
"total_tokens": 11707,
|
||||||
|
"total_steps": 10,
|
||||||
|
"created_by": {
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"user": args['use_id']
|
||||||
|
},
|
||||||
|
"created_at": int(time.time()),
|
||||||
|
"finished_at": int(time.time()),
|
||||||
|
"files": []
|
||||||
|
}
|
||||||
|
super().__init__(**args)
|
||||||
|
|
||||||
|
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||||
|
id:str
|
||||||
|
answer:str
|
||||||
|
def __init__(self,**args):
|
||||||
|
args['id'] = args['message_id']
|
||||||
|
args['event'] = 'message'
|
||||||
|
super().__init__(**args)
|
||||||
|
|
||||||
|
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||||
|
id:str
|
||||||
|
metadata:Dict[str,Any] = {}
|
||||||
|
def __init__(self,**args):
|
||||||
|
args['id'] = args['message_id']
|
||||||
|
args['event'] = 'message_end'
|
||||||
|
super().__init__(**args)
|
||||||
|
|
||||||
|
class ChatStreamResponse(StreamingResponse):
|
||||||
|
TEXT_PREFIX = "data: "
|
||||||
|
DATA_PREFIX = "data: "
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_text(cls, token: str):
|
||||||
|
# Escape newlines and double quotes to avoid breaking the stream
|
||||||
|
#token = json.dumps(token)
|
||||||
|
|
||||||
|
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
|
||||||
|
return "\n"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_data(cls, data: dict):
|
||||||
|
data_str = json.dumps(data)
|
||||||
|
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_event(cls, event: DifyChatResponseEvent):
|
||||||
|
data_str = json.dumps(event.dict())
|
||||||
|
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
event_handler: ChatEventCallbackHandler,
|
||||||
|
response: StreamingAgentChatResponse,
|
||||||
|
data: ChatRequestData
|
||||||
|
):
|
||||||
|
content = ChatStreamResponse.content_generator(
|
||||||
|
request, event_handler, response, data
|
||||||
|
)
|
||||||
|
super().__init__(content=content)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def content_generator(
|
||||||
|
cls,
|
||||||
|
request: Request,
|
||||||
|
event_handler: ChatEventCallbackHandler,
|
||||||
|
response: StreamingAgentChatResponse,
|
||||||
|
data: ChatRequestData
|
||||||
|
):
|
||||||
|
ids = IDManager().createID()
|
||||||
|
# Yield the text response
|
||||||
|
async def _chat_response_generator():
|
||||||
|
final_response = ""
|
||||||
|
async for token in response.async_response_gen():
|
||||||
|
final_response += token
|
||||||
|
args = ids
|
||||||
|
args['answer'] = token
|
||||||
|
args['conversation_id'] = data.conversation_id
|
||||||
|
event = Message_DifyChatResponseEvent(**args)
|
||||||
|
yield ChatStreamResponse.convert_event(event)
|
||||||
|
#yield ChatStreamResponse.convert_text(token)
|
||||||
|
|
||||||
|
# 存储消息历史
|
||||||
|
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
|
||||||
|
|
||||||
|
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||||
|
event_handler.is_done = True
|
||||||
|
# 发送工作流结束事件
|
||||||
|
args = ids
|
||||||
|
args['response'] = final_response
|
||||||
|
args['conversation_id'] = data.conversation_id
|
||||||
|
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
|
||||||
|
yield ChatStreamResponse.convert_event(wf_event)
|
||||||
|
|
||||||
|
msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids)
|
||||||
|
yield ChatStreamResponse.convert_event(msgEnt_event)
|
||||||
|
|
||||||
|
|
||||||
|
# Yield the events from the event handler
|
||||||
|
async def _event_generator():
|
||||||
|
async for event in event_handler.async_event_gen():
|
||||||
|
event_response = event.to_response()
|
||||||
|
if event_response is not None:
|
||||||
|
yield ChatStreamResponse.convert_text("")
|
||||||
|
|
||||||
|
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||||
|
is_stream_started = False
|
||||||
|
async with combine.stream() as streamer:
|
||||||
|
async for output in streamer:
|
||||||
|
if not is_stream_started:
|
||||||
|
is_stream_started = True
|
||||||
|
|
||||||
|
# 发送工作流开始事件
|
||||||
|
args = ids
|
||||||
|
args['use_id'] = data.user
|
||||||
|
args['query'] = data.query
|
||||||
|
args['conversation_id'] = data.conversation_id
|
||||||
|
wf_event = Workflow_started_DifyChatResponseEvent(**args)
|
||||||
|
yield ChatStreamResponse.convert_event(wf_event)
|
||||||
|
|
||||||
|
# Stream a blank message to start the stream
|
||||||
|
# 发送一个空消息事件
|
||||||
|
#yield ChatStreamResponse.convert_text("")
|
||||||
|
|
||||||
|
yield output
|
||||||
|
|
||||||
|
if await request.is_disconnected():
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@v.post("/chat-messages")
|
||||||
|
async def post_conversations(request: Request, data: ChatRequestData):
|
||||||
|
userMng.findNoExistCreate(data.user)
|
||||||
|
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
|
||||||
|
|
||||||
|
conversaObj = conversations()
|
||||||
|
conversationinfo = conversaObj.get(data.conversation_id)
|
||||||
|
if conversationinfo is None:
|
||||||
|
conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话")
|
||||||
|
|
||||||
|
# 生成聊天参数
|
||||||
|
last_message_content = ChatMessage.from_str(data.query)
|
||||||
|
filters = None
|
||||||
|
params = data.inputs or {}
|
||||||
|
|
||||||
|
# 获取聊天引擎对象
|
||||||
|
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||||
|
|
||||||
|
# 启动聊天事件监听
|
||||||
|
event_handler = ChatEventCallbackHandler()
|
||||||
|
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||||
|
|
||||||
|
# 执行异步聊天
|
||||||
|
response = await chat_engine.astream_chat(data.query)
|
||||||
|
|
||||||
|
# 返回异步消息回应
|
||||||
|
return ChatStreamResponse(request, event_handler, response, data)
|
||||||
|
|
||||||
|
@v.get("/messages")
|
||||||
|
async def query_messages(user:str, conversation_id:str):
|
||||||
|
#conversation_id = default_conversation_id if conversation_id is None else conversation_id
|
||||||
|
datas = []
|
||||||
|
records = message().gets(user,conversation_id)
|
||||||
|
if records is None:
|
||||||
|
return {
|
||||||
|
"limit": 20,
|
||||||
|
"has_more": False,
|
||||||
|
"data": []
|
||||||
|
}
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
res = record.dict()
|
||||||
|
res["message_files"] = []
|
||||||
|
res["feedback"] = ''
|
||||||
|
res["retriever_resources"] = []
|
||||||
|
res["created_at"] = 1723444905
|
||||||
|
res["agent_thoughts"] = []
|
||||||
|
res["status"] = "normal"
|
||||||
|
res["error"] = ''
|
||||||
|
datas.append(res)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"limit": 20,
|
||||||
|
"has_more": False,
|
||||||
|
"data": datas
|
||||||
|
}
|
||||||
|
|
||||||
|
@v.post("/conversations/{itemid}/name")
|
||||||
|
async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
|
||||||
|
consaObj = conversations()
|
||||||
|
consaObj.rename(itemid,'知识问答')
|
||||||
|
cond = {
|
||||||
|
'id':itemid,
|
||||||
|
'user_id':params['user']
|
||||||
|
}
|
||||||
|
results = consaObj.query(**cond)
|
||||||
|
if len(results) > 0:
|
||||||
|
res = results[0]
|
||||||
|
return {
|
||||||
|
"id": res['id'],
|
||||||
|
"name": res['name'],
|
||||||
|
"inputs": res['inputs'],
|
||||||
|
"status": res['status'],
|
||||||
|
"introduction": res['introduction'],
|
||||||
|
"created_at": res['created_at'],
|
||||||
|
#"工程位置"
|
||||||
|
}
|
||||||
|
return 'null'
|
||||||
|
|
||||||
|
@v.get("/conversations")
|
||||||
|
async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None):
|
||||||
|
user_id = '' if user is None else user
|
||||||
|
userMng.findNoExistCreate(user_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"limit": 20,
|
||||||
|
"has_more": False,
|
||||||
|
"data": conversations().gets(user_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
@v.get("/parameters")
|
||||||
|
async def query_parameters(user:str):
|
||||||
|
params = parameter().get(user)
|
||||||
|
if len(params) == 0:
|
||||||
|
params = {
|
||||||
|
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
|
||||||
|
"suggested_questions": [],
|
||||||
|
"suggested_questions_after_answer": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"speech_to_text": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"text_to_speech": {
|
||||||
|
"enabled": False,
|
||||||
|
"language": "",
|
||||||
|
"voice": ""
|
||||||
|
},
|
||||||
|
"retriever_resource": {
|
||||||
|
"enabled": True
|
||||||
|
},
|
||||||
|
"annotation_reply": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"more_like_this": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"user_input_form": [],
|
||||||
|
"sensitive_word_avoidance": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"file_upload": {
|
||||||
|
"image": {
|
||||||
|
"enabled": False,
|
||||||
|
"number_limits": 3,
|
||||||
|
"transfer_methods": [
|
||||||
|
"remote_url"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"system_parameters": {
|
||||||
|
"image_file_size_limit": "10"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
|
||||||
|
@r.post("")
|
||||||
|
def upload_file(request: ChatFileUploadRequest) -> List[str]:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
import uuid
|
||||||
|
from app.api.routers.request.baseConfig import BaseConfig
|
||||||
|
from app.api.routers.request.dbOrm import DBManager
|
||||||
|
|
||||||
|
dbManage = DBManager()
|
||||||
|
|
||||||
|
class conversations:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._tableName = 'conversations'
|
||||||
|
dbManage.createTable(self._tableName)
|
||||||
|
|
||||||
|
def gets(self,user_id:str):
|
||||||
|
records = dbManage.query(self._tableName,user_id = user_id)
|
||||||
|
datas = []
|
||||||
|
for record in records:
|
||||||
|
datas.append(record)
|
||||||
|
|
||||||
|
return datas
|
||||||
|
|
||||||
|
def get(self, id:str):
|
||||||
|
records = dbManage.query(self._tableName, id=id)
|
||||||
|
if len(records) >0:
|
||||||
|
return records[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add(self,id:str, user_id:str, name:str):
|
||||||
|
template = BaseConfig.ConversationCfg
|
||||||
|
template['id'] = id
|
||||||
|
template['user_id'] = user_id
|
||||||
|
template['name'] = name
|
||||||
|
template['created_at'] = 1724399038
|
||||||
|
dbManage.addRecord(self._tableName,template)
|
||||||
|
|
||||||
|
def delete(self,id:str):
|
||||||
|
dbManage.delete(self._tableName,id=id)
|
||||||
|
|
||||||
|
def rename(self,id:str,name:str):
|
||||||
|
data = {'name':name}
|
||||||
|
dbManage.update(self._tableName,data,id=id)
|
||||||
|
|
||||||
|
def query(self,**condition):
|
||||||
|
results = []
|
||||||
|
records = dbManage.query(self._tableName,**condition)
|
||||||
|
for record in records:
|
||||||
|
results.append(record.dict())
|
||||||
|
return results
|
||||||
|
|
||||||
|
class user:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._tableName = 'user'
|
||||||
|
dbManage.createTable(self._tableName)
|
||||||
|
|
||||||
|
def gets(self):
|
||||||
|
return dbManage.query(self._tableName)
|
||||||
|
|
||||||
|
def get(self,id:str):
|
||||||
|
return dbManage.query(self._tableName,id = id)
|
||||||
|
|
||||||
|
def add(self,id:str):
|
||||||
|
info = {
|
||||||
|
'id':id,
|
||||||
|
'createtime': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
}
|
||||||
|
dbManage.addRecord(self._tableName,info)
|
||||||
|
|
||||||
|
def delete(self,id:str):
|
||||||
|
dbManage.delete(self._tableName,id = id)
|
||||||
|
|
||||||
|
class userMng:
|
||||||
|
userObj = user()
|
||||||
|
@classmethod
|
||||||
|
def findNoExistCreate(cls,user_id:str):
|
||||||
|
userInfo = cls.userObj.get(user_id)
|
||||||
|
if len(userInfo) == 0:
|
||||||
|
cls.userObj.add(user_id)
|
||||||
|
|
||||||
|
def remove(cls,user_id:str):
|
||||||
|
cls.userObj.delete(user_id)
|
||||||
|
|
||||||
|
class parameter:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._tableName = 'parameters'
|
||||||
|
dbManage.createTable(self._tableName)
|
||||||
|
|
||||||
|
def get(self,user_id:str):
|
||||||
|
records = dbManage.query(self._tableName,user_id = user_id)
|
||||||
|
data = {}
|
||||||
|
for record in records:
|
||||||
|
key = record['name']
|
||||||
|
value = record['value']
|
||||||
|
data[key] = value
|
||||||
|
return data
|
||||||
|
|
||||||
|
def set(self,user_id:str):
|
||||||
|
dbManage.addRecord(self._tableName,{})
|
||||||
|
|
||||||
|
def delete(self,user_id:str):
|
||||||
|
dbManage.delete(self._tableName,user_id = user_id)
|
||||||
|
|
||||||
|
class message:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._tableName = 'messages'
|
||||||
|
dbManage.createTable(self._tableName)
|
||||||
|
|
||||||
|
def gets(self,user_id:str,conversation_id:str):
|
||||||
|
records = dbManage.query(self._tableName,user_id = user_id,conversation_id = conversation_id)
|
||||||
|
datas = []
|
||||||
|
for record in records:
|
||||||
|
datas.append(record)
|
||||||
|
return datas
|
||||||
|
|
||||||
|
def add(self,user_id:str,conversation_id:str,query:str,answer:str):
|
||||||
|
template = BaseConfig.MessageCfg
|
||||||
|
template['id'] = str(uuid.uuid4())
|
||||||
|
template['user_id'] = user_id
|
||||||
|
template['conversation_id'] = conversation_id
|
||||||
|
template['query'] = query
|
||||||
|
template['answer'] = answer
|
||||||
|
dbManage.addRecord(self._tableName,template)
|
||||||
|
|
||||||
|
def delete(self,user_id:str):
|
||||||
|
dbManage.delete(self._tableName,user_id = user_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
|
||||||
|
class BaseConfig:
|
||||||
|
ParamterCfg = {
|
||||||
|
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
|
||||||
|
"suggested_questions": [],
|
||||||
|
"suggested_questions_after_answer": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"speech_to_text": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"text_to_speech": {
|
||||||
|
"enabled": False,
|
||||||
|
"language": "",
|
||||||
|
"voice": ""
|
||||||
|
},
|
||||||
|
"retriever_resource": {
|
||||||
|
"enabled": True
|
||||||
|
},
|
||||||
|
"annotation_reply": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"more_like_this": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"user_input_form": [],
|
||||||
|
"sensitive_word_avoidance": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"file_upload": {
|
||||||
|
"image": {
|
||||||
|
"enabled": False,
|
||||||
|
"number_limits": 3,
|
||||||
|
"transfer_methods": [
|
||||||
|
"remote_url"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"system_parameters": {
|
||||||
|
"image_file_size_limit": "10"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationCfg = {
|
||||||
|
"id": "",
|
||||||
|
'user_id':'',
|
||||||
|
"name": "",
|
||||||
|
"inputs": {},
|
||||||
|
"status": "normal",
|
||||||
|
"introduction": ParamterCfg['opening_statement'],
|
||||||
|
"created_at":''
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MessageCfg = {
|
||||||
|
"id": "",
|
||||||
|
'user_id':'',
|
||||||
|
"conversation_id": "",
|
||||||
|
"inputs": {},
|
||||||
|
"query": "",
|
||||||
|
"answer": ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,207 @@
|
|||||||
|
import os
|
||||||
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import create_engine, Column, String, Integer, JSON
|
||||||
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
#orm类
|
||||||
|
class ConversationOrm(Base):
|
||||||
|
__tablename__ = "conversations"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True)
|
||||||
|
user_id = Column(String)
|
||||||
|
name = Column(String)
|
||||||
|
inputs = Column(JSON)
|
||||||
|
status = Column(String)
|
||||||
|
introduction = Column(String)
|
||||||
|
created_at = Column(Integer)
|
||||||
|
|
||||||
|
def update(self,data:Dict[str,Any]):
|
||||||
|
if 'name' in data:
|
||||||
|
self.name = data['name']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class UserOrm(Base):
|
||||||
|
__tablename__ = "user"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True)
|
||||||
|
createtime = Column(String)
|
||||||
|
|
||||||
|
class ParametersOrm(Base):
|
||||||
|
__tablename__ = "parameters"
|
||||||
|
|
||||||
|
user_id = Column(String,primary_key=True)
|
||||||
|
name = Column(String)
|
||||||
|
value = Column(JSON)
|
||||||
|
|
||||||
|
class MessagesOrm(Base):
|
||||||
|
__tablename__ = "messages"
|
||||||
|
|
||||||
|
id = Column(String,primary_key=True)
|
||||||
|
user_id = Column(String)
|
||||||
|
conversation_id = Column(String)
|
||||||
|
inputs = Column(JSON)
|
||||||
|
query = Column(String)
|
||||||
|
answer = Column(String)
|
||||||
|
|
||||||
|
#数据结构
|
||||||
|
class ConversationModel(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
inputs: Dict[str, Any]
|
||||||
|
status: str
|
||||||
|
introduction: str
|
||||||
|
created_at: int
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
#orm_mode = True
|
||||||
|
from_attributes=True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def orm(cls):
|
||||||
|
return ConversationOrm
|
||||||
|
|
||||||
|
class UserModel(BaseModel):
|
||||||
|
id: str
|
||||||
|
createtime: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
#orm_mode = True
|
||||||
|
from_attributes=True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def orm(cls):
|
||||||
|
return UserOrm
|
||||||
|
|
||||||
|
class ParametersModel(BaseModel):
|
||||||
|
user_id : str
|
||||||
|
name : str
|
||||||
|
value : Dict[str, Any]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
#orm_mode = True
|
||||||
|
from_attributes=True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def orm(cls):
|
||||||
|
return ParametersOrm
|
||||||
|
|
||||||
|
class MessagesModel(BaseModel):
|
||||||
|
id :str
|
||||||
|
conversation_id :str
|
||||||
|
inputs : Dict[str, Any]
|
||||||
|
query : str
|
||||||
|
answer : str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
#orm_mode = True
|
||||||
|
from_attributes=True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def orm(cls):
|
||||||
|
return MessagesOrm
|
||||||
|
|
||||||
|
class DBManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
|
||||||
|
self._engine = create_engine(DATABASE_URL)
|
||||||
|
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
|
||||||
|
|
||||||
|
def createTable(self,tableName:str):
|
||||||
|
if self._engine is None:
|
||||||
|
return
|
||||||
|
if not self.exist(tableName):
|
||||||
|
Base.metadata.tables[tableName].create(self._engine)
|
||||||
|
|
||||||
|
def addRecord(self,tableName:str,record:Dict[str,Any]):
|
||||||
|
ormCls = self._get_orm(tableName)
|
||||||
|
if ormCls is None:
|
||||||
|
return
|
||||||
|
session = self.SessionLocal()
|
||||||
|
data = ormCls(**record)
|
||||||
|
session.add(data)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def addRecords(self,tableName:str,records:List[Dict[str,Any]]):
|
||||||
|
ormCls = self._get_orm(tableName)
|
||||||
|
if ormCls is None:
|
||||||
|
return
|
||||||
|
datas = []
|
||||||
|
session = self.SessionLocal()
|
||||||
|
for record in records:
|
||||||
|
datas.append(ormCls(**record))
|
||||||
|
session.add(datas)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def delete(self,tableName:str,**filter):
|
||||||
|
session = self.SessionLocal()
|
||||||
|
ormCls = self._get_orm(tableName)
|
||||||
|
if ormCls is None:
|
||||||
|
return
|
||||||
|
records = session.query(ormCls).filter_by(**filter).all()
|
||||||
|
if records is not None:
|
||||||
|
session.delete(records)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def update(self,tableName:str,data:Dict[str,Any],**filter):
|
||||||
|
if not self.exist(tableName):
|
||||||
|
return
|
||||||
|
session = self.SessionLocal()
|
||||||
|
ormCls = self._get_orm(tableName)
|
||||||
|
if ormCls is None:
|
||||||
|
return
|
||||||
|
if len(filter) > 0:
|
||||||
|
records = session.query(ormCls).filter_by(**filter).all()
|
||||||
|
else:
|
||||||
|
records = session.query(ormCls).all()
|
||||||
|
for record in records:
|
||||||
|
if record is not None:
|
||||||
|
record.update(data)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def query(self,tableName:str,**filter):
|
||||||
|
session = self.SessionLocal()
|
||||||
|
ormCls = self._get_orm(tableName)
|
||||||
|
if ormCls is None:
|
||||||
|
return
|
||||||
|
modelCls = self._get_model(ormCls)
|
||||||
|
if modelCls is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if filter is not None:
|
||||||
|
records = session.query(ormCls).filter_by(**filter).all()
|
||||||
|
else:
|
||||||
|
records = session.query(ormCls).all()
|
||||||
|
|
||||||
|
datas = []
|
||||||
|
for record in records:
|
||||||
|
datas.append(modelCls.from_orm(record))
|
||||||
|
return datas
|
||||||
|
|
||||||
|
def exist(self,tableName:str)->bool:
|
||||||
|
if self._engine is None:
|
||||||
|
return
|
||||||
|
inspector = Inspector.from_engine(self._engine)
|
||||||
|
return inspector.has_table(tableName)
|
||||||
|
|
||||||
|
def _get_orm(self,tableName:str):
|
||||||
|
subClss = Base.__subclasses__()
|
||||||
|
for sunCls in subClss:
|
||||||
|
if sunCls.__tablename__ == tableName:
|
||||||
|
return sunCls
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_model(self,orm:Any):
|
||||||
|
subClss = BaseModel.__subclasses__()
|
||||||
|
for sunCls in subClss:
|
||||||
|
if 'orm' in sunCls.__dict__ and sunCls.orm() == orm:
|
||||||
|
return sunCls
|
||||||
|
return None
|
||||||
|
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
|
||||||
|
from typing import Dict, Any
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequestData(BaseModel):
|
||||||
|
inputs: Dict[str,Any]
|
||||||
|
query: str
|
||||||
|
user: str
|
||||||
|
response_mode: str
|
||||||
|
files: Any
|
||||||
|
conversation_id: str = None
|
||||||
|
|
||||||
|
class ChatFileUploadRequest(BaseModel):
|
||||||
|
base64: str
|
||||||
@@ -31,13 +31,19 @@ def get_chat_engine(filters=None, params=None):
|
|||||||
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
||||||
description="适用于任何需要进行全面总结、概括的要求。",
|
description="适用于任何需要进行全面总结、概括的要求。",
|
||||||
)
|
)
|
||||||
query_engine = create_query_engine(index,top_k,use_reranker,filters)
|
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "COMPACT")
|
||||||
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
|
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
|
||||||
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
|
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "TREE_SUMMARIZE")
|
||||||
|
query_engine_tool_1 = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool_1",
|
||||||
|
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后,且在询问工程中单位的具体数值,例如用量,费率,合计,金额等的时候建议使用“zj_query_tool_1”工具。",
|
||||||
|
)
|
||||||
|
|
||||||
tools.append(summary_query_tool)
|
tools.append(summary_query_tool)
|
||||||
tools.append(query_engine_tool)
|
tools.append(query_engine_tool)
|
||||||
|
tools.append(query_engine_tool_1)
|
||||||
|
|
||||||
# Add additional tools
|
# Add additional tools
|
||||||
tools += ToolFactory.from_env()
|
tools += ToolFactory.from_env()
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None
|
|||||||
return summary_query_engine
|
return summary_query_engine
|
||||||
|
|
||||||
# Create a query engine
|
# Create a query engine
|
||||||
def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
def create_query_engine(index, top_k=3, use_reranker=False, filters=None, response_mode=None):
|
||||||
# 创建向量检索查询工具
|
# 创建向量检索查询工具
|
||||||
postprocess = None
|
postprocess = None
|
||||||
if use_reranker:
|
if use_reranker:
|
||||||
@@ -103,6 +103,7 @@ def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
|||||||
node_postprocessors=postprocess,
|
node_postprocessors=postprocess,
|
||||||
use_async=True,
|
use_async=True,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
|
ResponseMode = response_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
return query_engine
|
return query_engine
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||||
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||||
@@ -9,7 +8,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def load_configs():
|
def load_configs():
|
||||||
with open("config/loaders.yaml") as f:
|
with open("config/loaders.yaml",encoding='UTF-8') as f:
|
||||||
configs = yaml.safe_load(f)
|
configs = yaml.safe_load(f)
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
@@ -17,24 +16,26 @@ def load_configs():
|
|||||||
def get_documents():
|
def get_documents():
|
||||||
documents = []
|
documents = []
|
||||||
config = load_configs()
|
config = load_configs()
|
||||||
|
|
||||||
if config is None or len(config.items()) == 0:
|
if config is None or len(config.items()) == 0:
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
for loader_type, loader_config in config.items():
|
for loader_type, loader_config in config.items():
|
||||||
logger.info(
|
if loader_config.get('enable', True): # 检查 enable 字段
|
||||||
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
logger.info(
|
||||||
)
|
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
||||||
|
)
|
||||||
|
|
||||||
loader_config = loader_config or []
|
loader_config = loader_config or []
|
||||||
match loader_type:
|
match loader_type:
|
||||||
case "file":
|
case "file":
|
||||||
document = get_file_documents(FileLoaderConfig(**loader_config))
|
document = get_file_documents(FileLoaderConfig(**loader_config))
|
||||||
case "web":
|
case "web":
|
||||||
document = get_web_documents(WebLoaderConfig(**loader_config))
|
document = get_web_documents(WebLoaderConfig(**loader_config))
|
||||||
case "db":
|
case "db":
|
||||||
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
|
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Invalid loader type: {loader_type}")
|
raise ValueError(f"Invalid loader type: {loader_type}")
|
||||||
documents.extend(document)
|
documents.extend(document)
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
@@ -2,17 +2,14 @@ import logging
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from llama_index.core import SQLDatabase, Document
|
from llama_index.core import SQLDatabase, Document
|
||||||
from llama_index.core.objects import SQLTableSchema
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
from llama_index.readers.database import DatabaseReader
|
from llama_index.readers.database import DatabaseReader
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, text
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class CustomDatabaseReader(BaseReader):
|
class CustomDatabaseReader(DatabaseReader):
|
||||||
"""Simple Database reader.
|
"""Simple Database reader.
|
||||||
|
|
||||||
Concatenates each row into Document used by LlamaIndex.
|
Concatenates each row into Document used by LlamaIndex.
|
||||||
@@ -85,19 +82,20 @@ class CustomDatabaseReader(BaseReader):
|
|||||||
Returns:
|
Returns:
|
||||||
List[Document]: A list of Document objects.
|
List[Document]: A list of Document objects.
|
||||||
"""
|
"""
|
||||||
dco_str = ""
|
dco_str = ""
|
||||||
|
|
||||||
with self.sql_database.engine.connect() as connection:
|
with self.sql_database.engine.connect() as connection:
|
||||||
if query is None:
|
if query is None:
|
||||||
raise ValueError("A query parameter is necessary to filter the data")
|
raise ValueError("A query parameter is necessary to filter the data")
|
||||||
else:
|
else:
|
||||||
result = connection.execute(text(query))
|
result = connection.execute(text(query))
|
||||||
|
|
||||||
dco_str = ", ".join(
|
dco_str += ", ".join(
|
||||||
[f"{entry}" for entry in result.keys()]
|
[f"{entry}" for entry in result.keys()]
|
||||||
)
|
) + "\n"
|
||||||
|
|
||||||
for item in result.fetchall():
|
for item in result.fetchall():
|
||||||
# fetch each item
|
# Fetch each item
|
||||||
record_str = ", ".join(
|
record_str = ", ".join(
|
||||||
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
||||||
)
|
)
|
||||||
@@ -111,45 +109,36 @@ class CustomDatabaseReader(BaseReader):
|
|||||||
|
|
||||||
class DBLoaderConfig(BaseModel):
|
class DBLoaderConfig(BaseModel):
|
||||||
uri: str
|
uri: str
|
||||||
queries: List[str]
|
queries: List[dict]
|
||||||
|
|
||||||
def get_db_documents(configs: list[DBLoaderConfig]):
|
def get_db_documents(configs: List[DBLoaderConfig]) -> List[Document]:
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
if len(configs) == 0 or configs[0].uri == "":
|
if not configs or not configs[0].uri:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to load database, error message: uri is empty. Return as empty document list."
|
f"Failed to load database, error message: uri is empty. Return as empty document list."
|
||||||
)
|
)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
#'file_name':'',
|
'file_type': 'application/booway.document.zj',
|
||||||
'file_type':'application/booway.document.zj',
|
|
||||||
#'file_path':'',
|
|
||||||
#'file_size':'',
|
|
||||||
#'creation_date':'',
|
|
||||||
#'last_modified_date':'',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#from llama_index.readers.database import DatabaseReader
|
|
||||||
for entry in configs:
|
for entry in configs:
|
||||||
engine = create_engine(entry.uri)
|
engine = create_engine(entry.uri)
|
||||||
sql_database = SQLDatabase(engine)
|
sql_database = SQLDatabase(engine)
|
||||||
|
|
||||||
# table_schema_objs = makeDescriptionByEngine(sql_database)
|
|
||||||
# table_node_mapping = SQLTableNodeMapping(sql_database)
|
|
||||||
#
|
|
||||||
# nodes = table_node_mapping.to_nodes(table_schema_objs)
|
|
||||||
# for node in nodes:
|
|
||||||
# node.metadata.update(metadata)
|
|
||||||
#
|
|
||||||
# docs.extend(nodes)
|
|
||||||
|
|
||||||
queries = entry.queries or []
|
|
||||||
loader = CustomDatabaseReader(sql_database)
|
loader = CustomDatabaseReader(sql_database)
|
||||||
for query in queries:
|
for query_dict in entry.queries:
|
||||||
|
query = query_dict.get("sql", "")
|
||||||
|
explanation = query_dict.get("explanation", "")
|
||||||
logger.info(f"Loading data from database with query: {query}")
|
logger.info(f"Loading data from database with query: {query}")
|
||||||
documents = loader.load_data(query=query)
|
documents = loader.load_data(query=query)
|
||||||
|
|
||||||
docs.extend(documents)
|
# 添加解释到元数据中
|
||||||
return docs
|
for doc in documents:
|
||||||
|
doc.metadata["explanation"] = explanation
|
||||||
|
doc.metadata.update(metadata) # 更新或添加额外的元数据
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
return docs
|
||||||
@@ -5,6 +5,8 @@ text_qa_template_str = (
|
|||||||
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
|
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
|
||||||
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
|
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
|
||||||
"如同直接从文件中提取的内容。\n"
|
"如同直接从文件中提取的内容。\n"
|
||||||
|
"知识库中已经导入一个工程的全部数据,请你站在当前工程的角度回答用户关于工程文件的问题。\n"
|
||||||
|
"例如:询问“此工程”指当前导入的工程。询问“此工程名称”指当前导入的工程的工程名称。\n"
|
||||||
|
|
||||||
"## 技能\n"
|
"## 技能\n"
|
||||||
"### 技能 1: 数据查询与提供\n"
|
"### 技能 1: 数据查询与提供\n"
|
||||||
@@ -39,15 +41,19 @@ refine_template_str = (
|
|||||||
"这是原本的问题: {query_str}\n"
|
"这是原本的问题: {query_str}\n"
|
||||||
"我们已经提供了回答: {existing_answer}\n"
|
"我们已经提供了回答: {existing_answer}\n"
|
||||||
"现在我们有机会改进这个回答 "
|
"现在我们有机会改进这个回答 "
|
||||||
"使用以下更多上下文(仅当需要用时)\n"
|
"使用以下更多上下文(仅当有助于改进回答时使用)\n"
|
||||||
|
"你需要仔细的判断新的上下文的信息与原本问题必须一个字都不差,如果有一点差别,那就不能改变我现有的回答。\n"
|
||||||
|
"在判断回答是否正确的时候,你应该仔细对比新的上下文中包含的信息是否与原本的问题一字不差,如果一字不差,才能当作新的正确回答。\n"
|
||||||
|
"如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n"
|
||||||
|
"判断一下如果原回答正确,且在新的上下文仍然包含正确的回答,请将新的回答与原回答一起返回。\n"
|
||||||
"------------\n"
|
"------------\n"
|
||||||
"{context_msg}\n"
|
"{context_msg}\n"
|
||||||
"------------\n"
|
"------------\n"
|
||||||
"根据新的上下文, 请改进原来的回答。"
|
"如果回答中已经包含有正确答案,不要返回多余的解释等信息,只返回正确答案\n"
|
||||||
"如果新的上下文没有用, 直接返回原本的回答。\n"
|
"如果是表结构或者是数据库的相关内容,仅用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||||
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
|
||||||
"改进的回答: "
|
"改进的回答: "
|
||||||
)
|
)
|
||||||
|
|
||||||
refine_template = PromptTemplate(refine_template_str)
|
refine_template = PromptTemplate(refine_template_str)
|
||||||
|
|
||||||
summary_template_str = (
|
summary_template_str = (
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import os
|
|
||||||
import yaml
|
|
||||||
import json
|
|
||||||
import importlib
|
import importlib
|
||||||
from cachetools import cached, LRUCache
|
import os
|
||||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
|
||||||
|
import yaml
|
||||||
from llama_index.core.tools.function_tool import FunctionTool
|
from llama_index.core.tools.function_tool import FunctionTool
|
||||||
|
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||||
|
|
||||||
|
|
||||||
class ToolType:
|
class ToolType:
|
||||||
@@ -46,7 +45,7 @@ class ToolFactory:
|
|||||||
def from_env() -> list[FunctionTool]:
|
def from_env() -> list[FunctionTool]:
|
||||||
tools = []
|
tools = []
|
||||||
if os.path.exists("config/tools.yaml"):
|
if os.path.exists("config/tools.yaml"):
|
||||||
with open("config/tools.yaml", "r") as f:
|
with open("config/tools.yaml", "r", encoding='UTF-8') as f:
|
||||||
tool_configs = yaml.safe_load(f)
|
tool_configs = yaml.safe_load(f)
|
||||||
if tool_configs != None and len(tool_configs.items()) != 0:
|
if tool_configs != None and len(tool_configs.items()) != 0:
|
||||||
for tool_type, config_entries in tool_configs.items():
|
for tool_type, config_entries in tool_configs.items():
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
file:
|
file:
|
||||||
|
enable: true # 添加 enable 字段
|
||||||
# use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable
|
# use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable
|
||||||
use_llama_parse: false
|
use_llama_parse: false
|
||||||
|
|
||||||
@@ -7,27 +8,41 @@ db:
|
|||||||
# uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
|
# uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
|
||||||
# query: The query to fetch data from the database. E.g.: SELECT * FROM table
|
# query: The query to fetch data from the database. E.g.: SELECT * FROM table
|
||||||
- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||||
#- uri: mysql+pymysql://zjinfo:Y6EAjEEdSYmskA8B@110.42.234.166:3306/zjinfo
|
enable: true # 添加 enable 字段
|
||||||
# - uri: mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
|
||||||
queries:
|
queries:
|
||||||
- sql: select * from ProjectProperties limit 30;
|
- sql: select * from ProjectProperties;
|
||||||
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
|
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
|
||||||
|
|
||||||
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
|
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
|
||||||
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
|
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
|
||||||
|
|
||||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '线路' limit 50;
|
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
|
||||||
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||||
|
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
|
||||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '余物清理' limit 50;
|
|
||||||
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||||
|
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
|
||||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '拆除线路' limit 50;
|
|
||||||
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||||
|
|
||||||
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
|
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
|
||||||
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
|
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
|
||||||
|
|
||||||
|
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
|
||||||
|
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||||
|
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
|
||||||
|
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||||
|
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
|
||||||
|
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||||
|
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
|
||||||
|
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||||
|
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
|
||||||
|
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||||
|
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
|
||||||
|
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||||
|
|
||||||
|
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
|
||||||
|
explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||||
|
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
|
||||||
|
explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||||
#web:
|
#web:
|
||||||
# driver_arguments:
|
# driver_arguments:
|
||||||
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode
|
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from app.api.routers.chat import chat_router
|
from app.api.routers.chat import chat_router
|
||||||
from app.api.routers.upload import file_upload_router
|
from app.api.routers.upload import file_upload_router
|
||||||
|
from app.api.routers.app import v1_router
|
||||||
from app.settings import init_settings
|
from app.settings import init_settings
|
||||||
from app.observability import init_observability
|
from app.observability import init_observability
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
@@ -56,6 +57,8 @@ mount_static_files("data_output", "/api/files/output")
|
|||||||
app.include_router(chat_router, prefix="/api/chat")
|
app.include_router(chat_router, prefix="/api/chat")
|
||||||
app.include_router(file_upload_router, prefix="/api/chat/upload")
|
app.include_router(file_upload_router, prefix="/api/chat/upload")
|
||||||
|
|
||||||
|
app.include_router(v1_router, prefix="/v1")
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def redirect_to_docs():
|
async def redirect_to_docs():
|
||||||
return RedirectResponse(url="/docs")
|
return RedirectResponse(url="/docs")
|
||||||
|
|||||||
Submodule
+1
Submodule webapp added at 77dbc14a64
Reference in New Issue
Block a user