Compare commits
65 Commits
88761a5d10
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| e634746a52 | |||
| d12800e14e | |||
| c1df0d1bba | |||
| 0664952ecd | |||
| 7023b54246 | |||
| aee6aa3c04 | |||
| 680e24c516 | |||
| 6663ee8976 | |||
| 0a5f335981 | |||
| 2901bd9eaf | |||
| 453b3ca55c | |||
| 03c4eb1af1 | |||
| 480a1f7fdc | |||
| cdc9d84a1e | |||
| 50f35bb0c9 | |||
| 4a8c79e83d | |||
| f0afd1a4bb | |||
| de34c3938c | |||
| eb572eff27 | |||
| 2706cf9d5a | |||
| 5fa4752d6e | |||
| aff1793c4e | |||
| 0db159ac89 | |||
| 131d6ef1d1 | |||
| 3ee1ba529f | |||
| 576a2ae737 | |||
| 9b47e1a6e1 | |||
| 20510a937b | |||
| a7c79df339 | |||
| 327bba75d5 | |||
| d1242d2080 | |||
| 0f09551f5d | |||
| 8a5facb5b6 | |||
| 0f7c900c1e | |||
| b008ad9766 | |||
| 56459c164e | |||
| 07a3b2a147 | |||
| b4c571cddb | |||
| 7068b058e8 | |||
| 33b2281b7b | |||
| 1704b61609 | |||
| afccaf6eb5 | |||
| b052d373f1 | |||
| 7462244f01 | |||
| a200e8adfc | |||
| 2b64aca26b | |||
| 7691b22274 | |||
| d1117c73c4 | |||
| 5fc8375a06 | |||
| cf1ed4e71d | |||
| 8050551a53 | |||
| 513ce73190 | |||
| 48d10fd1f3 | |||
| 9cbe414a0c | |||
| 4c1c67aa50 | |||
| 59ef831a41 | |||
| 3ceb30c375 | |||
| e71da586e3 | |||
| b3a575d158 | |||
| db006985d7 | |||
| 870af69189 | |||
| 3460b8410e | |||
| 586bb76c9c | |||
| 8d7190d0b6 | |||
| 043aea6cca |
@@ -0,0 +1,3 @@
|
||||
[submodule "webapp"]
|
||||
path = webapp
|
||||
url = https://git.97id.com/ly/webapp.git
|
||||
@@ -1,7 +1,13 @@
|
||||
JIEBA_DATA=./nltk_data
|
||||
NLTK_DATA=./nltk_data
|
||||
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||
DATA_SOURCE_CACHE=./restapi
|
||||
|
||||
# The Llama Cloud API key.
|
||||
# LLAMA_CLOUD_API_KEY=
|
||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||
|
||||
DASHSCOPE_API_KEY=sk-02c8540e86d84b7ca0e6f4f51bac6e60
|
||||
# The provider for the AI models to use.
|
||||
@@ -49,6 +55,7 @@ VECTOR_STORE_COLLECTION=default
|
||||
# Specify this if you are using a local vector database.
|
||||
# Otherwise, use VECTOR_STORE__HOST and VECTOR_STORE__PORT config above
|
||||
VECTOR_STORE_PATH=./storage_vector
|
||||
BM_RETRIEVER_PATH =./storage_bm
|
||||
|
||||
|
||||
|
||||
@@ -78,3 +85,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
|
||||
- You can install any pip package (if it exists) by running a cell with pip install.
|
||||
"
|
||||
|
||||
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
|
||||
+13
-5
@@ -1,17 +1,24 @@
|
||||
JIEBA_DATA=./nltk_data
|
||||
NLTK_DATA=./nltk_data
|
||||
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||
DATA_SOURCE_CACHE=./restapi
|
||||
|
||||
# The Llama Cloud API key.
|
||||
# LLAMA_CLOUD_API_KEY=
|
||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||
|
||||
# The number of similar embeddings to return when retrieving documents.
|
||||
TOP_K=10
|
||||
#--------------------------
|
||||
# 是否启用混合检索
|
||||
HYBRID_ENABLED = true
|
||||
# 混合检索阈值
|
||||
HYBRID_ALPHA = 0.6
|
||||
#--------------------------
|
||||
# 是否启用检索重排功能
|
||||
RERANK_ENABLED=true
|
||||
# 是否启用混合检索
|
||||
HYBRID_ENABLED = false
|
||||
# 混合检索阈值
|
||||
HYBRID_ALPHA = 0.5
|
||||
# Rerank model
|
||||
RERANK_MODEL=bge-reranker-v2-m3
|
||||
RERANK_BASE_URL=http://10.1.16.39:9995
|
||||
@@ -80,7 +87,7 @@ VECTOR_STORE_COLLECTION=default
|
||||
# Specify this if you are using a local vector database.
|
||||
# Otherwise, use VECTOR_STORE__HOST and VECTOR_STORE__PORT config above
|
||||
VECTOR_STORE_PATH=./storage_vector
|
||||
|
||||
BM_RETRIEVER_PATH =./storage_bm
|
||||
|
||||
|
||||
PHOENIX_API_KEY=123456
|
||||
@@ -109,3 +116,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
|
||||
- You can install any pip package (if it exists) by running a cell with pip install.
|
||||
"
|
||||
|
||||
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
|
||||
@@ -0,0 +1,490 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||
from collections import deque
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core import BaseCallbackHandler
|
||||
from llama_index.core.base.llms.types import ChatMessage
|
||||
from llama_index.core.callbacks import CBEventType
|
||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||
from llama_index.core.tools import ToolOutput
|
||||
from pydantic import BaseModel
|
||||
from app.api.routers.request.base import userMng, conversations,message,parameter,feedback
|
||||
from app.api.routers.request.baseConfig import *
|
||||
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
|
||||
from app.engine import get_chat_engine
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
api_router = r = APIRouter()
|
||||
v1_router = v = APIRouter()
|
||||
|
||||
class ChatCallbackEvent(BaseModel):
|
||||
event_type: ChatEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
|
||||
def get_common_param(self)-> dict:
|
||||
return {
|
||||
'event': self.event_type.name,
|
||||
'conversation_id':self.payload.get("conversation_id"),
|
||||
'message_id': self.payload.get("message_id"),
|
||||
'created_at': int(time.time()),
|
||||
'task_id': self.payload.get("task_id")
|
||||
}
|
||||
|
||||
def get_WorkflowStart_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('workflow_run_id'),
|
||||
"workflow_id": self.payload.get('workflow_id'),
|
||||
"sequence_number": 1709,
|
||||
"inputs": {
|
||||
"sys.query": self.payload.get('query'),
|
||||
"sys.files": [],
|
||||
"sys.conversation_id": self.payload.get('conversation_id'),
|
||||
"sys.user_id": self.payload.get('use_id')
|
||||
},
|
||||
"created_at": int(time.time())
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def get_WorkflowFinished_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('workflow_run_id'),
|
||||
"workflow_id": self.payload.get('workflow_id'),
|
||||
"sequence_number": 1709,
|
||||
"status": "succeeded",
|
||||
"outputs": {
|
||||
"answer": self.payload.get('response')
|
||||
},
|
||||
"error": '',
|
||||
"elapsed_time": 36.03764106379822,
|
||||
"total_tokens": 11707,
|
||||
"total_steps": 10,
|
||||
"created_by": {
|
||||
"id": str(uuid.uuid4()),
|
||||
"user": self.payload.get('use_id')
|
||||
},
|
||||
"created_at": int(time.time()),
|
||||
"finished_at": int(time.time()),
|
||||
"files": []
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def get_NodeStart_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('nodeid'),
|
||||
"node_id": self.payload.get('nodeid'),
|
||||
"node_type": "http-request",
|
||||
"title": self.payload.get('title'),
|
||||
"index": self.payload.get('index'),
|
||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
||||
"inputs": '',
|
||||
"created_at": 1724398751,
|
||||
"extras": {}
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def get_NodeFinished_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('nodeid'),
|
||||
"node_id": self.payload.get('nodeid'),
|
||||
"node_type": "http-request",
|
||||
"title": self.payload.get('title'),
|
||||
"index": self.payload.get('index'),
|
||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
||||
"inputs": '',
|
||||
"process_data": '',
|
||||
"outputs": '',
|
||||
"status": "succeeded",
|
||||
"error": '',
|
||||
"elapsed_time": 0.10402441816404462,
|
||||
"execution_metadata": '',
|
||||
"created_at": 1724398751,
|
||||
"finished_at": 1724398751,
|
||||
"files": []
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def get_Message_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'id':self.payload.get('message_id'),
|
||||
'answer':self.payload.get('answer')
|
||||
})
|
||||
return params
|
||||
|
||||
def get_MessageEnd_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'id':self.payload.get('message_id'),
|
||||
'metadata':self.payload.get('metadata')
|
||||
})
|
||||
return params
|
||||
|
||||
def to_response(self)-> dict|None:
|
||||
try:
|
||||
match self.event_type:
|
||||
case "workflow_started":
|
||||
return self.get_WorkflowStart_param()
|
||||
case "workflow_finished":
|
||||
return self.get_WorkflowFinished_param()
|
||||
case "node_started":
|
||||
return self.get_NodeStart_param()
|
||||
case 'node_finished':
|
||||
return self.get_NodeFinished_param()
|
||||
case 'message':
|
||||
return self.get_Message_param()
|
||||
case 'message_end':
|
||||
return self.get_MessageEnd_param()
|
||||
case _:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
||||
return None
|
||||
|
||||
class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
_aqueue: asyncio.Queue
|
||||
is_done: bool = False
|
||||
|
||||
def __init__(self,**params):
|
||||
"""Initialize the base callback handler."""
|
||||
ignored_events = [
|
||||
# CBEventType.CHUNKING,
|
||||
# CBEventType.NODE_PARSING,
|
||||
# CBEventType.EMBEDDING,
|
||||
# CBEventType.LLM,
|
||||
# CBEventType.TEMPLATING,
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
self._response:str = ''
|
||||
self._params:Dict[str,Any] = params
|
||||
self._nodeStack:deque = deque()
|
||||
|
||||
#添加工作流开始事件
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'use_id': data.user,
|
||||
'query': data.query,
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||
|
||||
self._nodeStack.append(event_id)
|
||||
nindex = self._nodeStack.count() - 1
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'nodeid':event_id,
|
||||
'title':event_type.name,
|
||||
'index':nindex + 1,
|
||||
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
|
||||
}
|
||||
)
|
||||
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args)
|
||||
if nd_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(nd_event)
|
||||
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||
|
||||
#self.response = payload.get("response","")
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
nodeID = self._nodeStack[-1]
|
||||
if nodeID == event_id:
|
||||
nindex = self._nodeStack.count() - 1
|
||||
args.update(
|
||||
{
|
||||
'nodeid':event_id,
|
||||
'title':event_type.name,
|
||||
'index':nindex + 1,
|
||||
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
|
||||
}
|
||||
)
|
||||
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args)
|
||||
if nd_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(nd_event)
|
||||
self._nodeStack.pop()
|
||||
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_start:{}\n".format(trace_id))
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'response':self._response,
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
||||
if msgEnt_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(msgEnt_event)
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
try:
|
||||
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
class IDManager:
|
||||
def createID(self):
|
||||
return {
|
||||
"message_id" : str(uuid.uuid4()),
|
||||
'task_id':str(uuid.uuid4()),
|
||||
'workflow_run_id': str(uuid.uuid4()),
|
||||
"workflow_id": str(uuid.uuid4())
|
||||
}
|
||||
|
||||
class ChatStreamResponse(StreamingResponse):
|
||||
TEXT_PREFIX = "data: "
|
||||
DATA_PREFIX = "data: "
|
||||
ids:Dict[str,Any] = {}
|
||||
data:ChatRequestData = None
|
||||
|
||||
@classmethod
|
||||
def convert_Message(cls, token: str):
|
||||
params = cls.ids
|
||||
params.update({
|
||||
'answer':token,
|
||||
'conversation_id':cls.data.conversation_id
|
||||
})
|
||||
event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params)
|
||||
data_str = json.dumps(event.to_response())
|
||||
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||
|
||||
@classmethod
|
||||
def convert_Event(cls, data: dict):
|
||||
data_str = json.dumps(data)
|
||||
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request: Request,
|
||||
event_handler: ChatEventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
data: ChatRequestData,
|
||||
ids:Dict[str,Any]
|
||||
):
|
||||
ChatStreamResponse.ids = ids
|
||||
ChatStreamResponse.data = data
|
||||
content = ChatStreamResponse.content_generator(
|
||||
request, event_handler, response, data
|
||||
)
|
||||
super().__init__(content=content)
|
||||
|
||||
@classmethod
|
||||
async def content_generator(
|
||||
cls,
|
||||
request: Request,
|
||||
event_handler: ChatEventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
data: ChatRequestData
|
||||
):
|
||||
|
||||
# Yield the text response
|
||||
async def _chat_response_generator():
|
||||
final_response = ""
|
||||
async for token in response.async_response_gen():
|
||||
final_response += token
|
||||
yield ChatStreamResponse.convert_Message(token)
|
||||
|
||||
# 存储消息历史
|
||||
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
|
||||
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
event_response = event.to_response()
|
||||
if event_response is not None:
|
||||
yield ChatStreamResponse.convert_Event(event_response)
|
||||
|
||||
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||
is_stream_started = False
|
||||
async with combine.stream() as streamer:
|
||||
async for output in streamer:
|
||||
if not is_stream_started:
|
||||
is_stream_started = True
|
||||
|
||||
yield output
|
||||
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
@v.post("/chat-messages")
|
||||
async def post_conversations(request: Request, data: ChatRequestData):
|
||||
userMng.findNoExistCreate(data.user)
|
||||
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
|
||||
|
||||
conversaObj = conversations()
|
||||
conversationinfo = conversaObj.get(data.conversation_id)
|
||||
if conversationinfo is None:
|
||||
conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话")
|
||||
|
||||
# 生成聊天参数
|
||||
last_message_content = ChatMessage.from_str(data.query)
|
||||
filters = None
|
||||
params = data.inputs or {}
|
||||
|
||||
# 获取聊天引擎对象
|
||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||
|
||||
# 启动聊天事件监听
|
||||
ids = IDManager().createID()
|
||||
event_handler = ChatEventCallbackHandler(ids = ids,data = data)
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
|
||||
# 执行异步聊天
|
||||
response = await chat_engine.astream_chat(data.query)
|
||||
|
||||
# 返回异步消息回应
|
||||
return ChatStreamResponse(request, event_handler, response, data,ids)
|
||||
|
||||
@v.get("/messages")
|
||||
async def query_messages(user:str, conversation_id:str):
|
||||
#conversation_id = default_conversation_id if conversation_id is None else conversation_id
|
||||
datas = []
|
||||
records = message().gets(user,conversation_id)
|
||||
if records is None:
|
||||
return {
|
||||
"limit": 20,
|
||||
"has_more": False,
|
||||
"data": []
|
||||
}
|
||||
|
||||
for record in records:
|
||||
res = record.dict()
|
||||
feeds = feedback().query(res['id'])
|
||||
res["message_files"] = []
|
||||
res["feedback"] = {'rating':feeds['rating'] } if feeds != None else ''
|
||||
res["retriever_resources"] = []
|
||||
res["created_at"] = 1723444905
|
||||
res["agent_thoughts"] = []
|
||||
res["status"] = "normal"
|
||||
res["error"] = ''
|
||||
datas.append(res)
|
||||
|
||||
return {
|
||||
"limit": 20,
|
||||
"has_more": False,
|
||||
"data": datas
|
||||
}
|
||||
|
||||
@v.post("/conversations/{itemid}/name")
|
||||
async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
|
||||
consaObj = conversations()
|
||||
consaObj.rename(itemid,'知识问答')
|
||||
cond = {
|
||||
'id':itemid,
|
||||
'user_id':params['user']
|
||||
}
|
||||
results = consaObj.query(**cond)
|
||||
if len(results) > 0:
|
||||
res = results[0]
|
||||
return {
|
||||
"id": res['id'],
|
||||
"name": res['name'],
|
||||
"inputs": res['inputs'],
|
||||
"status": res['status'],
|
||||
"introduction": res['introduction'],
|
||||
"created_at": res['created_at'],
|
||||
#"工程位置"
|
||||
}
|
||||
return 'null'
|
||||
|
||||
@v.get("/conversations")
|
||||
async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None):
|
||||
user_id = '' if user is None else user
|
||||
userMng.findNoExistCreate(user_id)
|
||||
|
||||
return {
|
||||
"limit": 20,
|
||||
"has_more": False,
|
||||
"data": conversations().gets(user_id)
|
||||
}
|
||||
|
||||
@v.get("/parameters")
|
||||
async def query_parameters(user:str):
|
||||
params = parameter().get(user)
|
||||
if len(params) == 0:
|
||||
params = BaseConfig().ParamterCfg()
|
||||
return params
|
||||
|
||||
@v.post("/messages/{message_id}/feedbacks")
|
||||
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
|
||||
if params['rating'] =='null':
|
||||
feedback().delete(message_id)
|
||||
else:
|
||||
condition = {'id':message_id}
|
||||
results = message().query(**condition)
|
||||
if len(results) > 0:
|
||||
result = results[0]
|
||||
feedback().add(message_id=message_id,query=result['query'],
|
||||
answer=result['answer'],rating=params['rating'])
|
||||
|
||||
@r.post("")
|
||||
def upload_file(request: ChatFileUploadRequest) -> List[str]:
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from app.api.routers.request.baseConfig import BaseConfig
|
||||
from app.api.routers.request.dbOrm import DBManager
|
||||
|
||||
dbManage = DBManager()
|
||||
|
||||
class conversations:
|
||||
def __init__(self) -> None:
|
||||
self._tableName = 'conversations'
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def gets(self,user_id:str):
|
||||
records = dbManage.query(self._tableName,user_id = user_id)
|
||||
datas = []
|
||||
for record in records:
|
||||
datas.append(record)
|
||||
|
||||
return datas
|
||||
|
||||
def get(self, id:str):
|
||||
records = dbManage.query(self._tableName, id=id)
|
||||
if len(records) >0:
|
||||
return records[0]
|
||||
return None
|
||||
|
||||
def add(self,id:str, user_id:str, name:str):
|
||||
template = BaseConfig().ConversationCfg()
|
||||
template['id'] = id
|
||||
template['user_id'] = user_id
|
||||
template['name'] = name
|
||||
template['created_at'] = 1724399038
|
||||
dbManage.addRecord(self._tableName,template)
|
||||
|
||||
def delete(self,id:str):
|
||||
dbManage.delete(self._tableName,id=id)
|
||||
|
||||
def rename(self,id:str,name:str):
|
||||
data = {'name':name}
|
||||
dbManage.update(self._tableName,data,id=id)
|
||||
|
||||
def query(self,**condition):
|
||||
results = []
|
||||
records = dbManage.query(self._tableName,**condition)
|
||||
for record in records:
|
||||
results.append(record.dict())
|
||||
return results
|
||||
|
||||
class user:
|
||||
def __init__(self) -> None:
|
||||
self._tableName = 'user'
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def gets(self):
|
||||
return dbManage.query(self._tableName)
|
||||
|
||||
def get(self,id:str):
|
||||
return dbManage.query(self._tableName,id = id)
|
||||
|
||||
def add(self,id:str):
|
||||
info = {
|
||||
'id':id,
|
||||
'createtime': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
dbManage.addRecord(self._tableName,info)
|
||||
|
||||
def delete(self,id:str):
|
||||
dbManage.delete(self._tableName,id = id)
|
||||
|
||||
class userMng:
|
||||
userObj = user()
|
||||
@classmethod
|
||||
def findNoExistCreate(cls,user_id:str):
|
||||
userInfo = cls.userObj.get(user_id)
|
||||
if len(userInfo) == 0:
|
||||
cls.userObj.add(user_id)
|
||||
|
||||
def remove(cls,user_id:str):
|
||||
cls.userObj.delete(user_id)
|
||||
|
||||
class parameter:
|
||||
def __init__(self) -> None:
|
||||
self._tableName = 'parameters'
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def get(self,user_id:str):
|
||||
records = dbManage.query(self._tableName,user_id = user_id)
|
||||
data = {}
|
||||
for record in records:
|
||||
key = record['name']
|
||||
value = record['value']
|
||||
data[key] = value
|
||||
return data
|
||||
|
||||
def set(self,user_id:str):
|
||||
dbManage.addRecord(self._tableName,{})
|
||||
|
||||
def delete(self,user_id:str):
|
||||
dbManage.delete(self._tableName,user_id = user_id)
|
||||
|
||||
class message:
|
||||
def __init__(self) -> None:
|
||||
self._tableName = 'messages'
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def gets(self,user_id:str,conversation_id:str):
|
||||
records = dbManage.query(self._tableName,user_id = user_id,conversation_id = conversation_id)
|
||||
datas = []
|
||||
for record in records:
|
||||
datas.append(record)
|
||||
return datas
|
||||
|
||||
def add(self,user_id:str,conversation_id:str,query:str,answer:str):
|
||||
template = BaseConfig.MessageCfg()
|
||||
template['id'] = str(uuid.uuid4())
|
||||
template['user_id'] = user_id
|
||||
template['conversation_id'] = conversation_id
|
||||
template['query'] = query
|
||||
template['answer'] = answer
|
||||
dbManage.addRecord(self._tableName,template)
|
||||
|
||||
def delete(self,user_id:str):
|
||||
dbManage.delete(self._tableName,user_id = user_id)
|
||||
|
||||
def query(self,**condition):
|
||||
results = []
|
||||
records = dbManage.query(self._tableName,**condition)
|
||||
for record in records:
|
||||
results.append(record.dict())
|
||||
return results
|
||||
|
||||
class feedback:
|
||||
def __init__(self) -> None:
|
||||
self._tableName = 'feedbacks'
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def add(self,message_id:str,query:str,answer:str,rating:str):
|
||||
record = {
|
||||
'message_id': message_id,
|
||||
'query': query,
|
||||
'answer': answer,
|
||||
'rating': rating,
|
||||
}
|
||||
dbManage.addRecord(self._tableName,record)
|
||||
|
||||
def delete(self,message_id:str):
|
||||
cond = {'message_id':message_id}
|
||||
dbManage.delete(self._tableName,**cond)
|
||||
|
||||
def query(self,message_id:str):
|
||||
cond = {'message_id':message_id}
|
||||
records = dbManage.query(self._tableName,**cond)
|
||||
if len(records) > 0:
|
||||
return records[0].dict()
|
||||
return None
|
||||
@@ -0,0 +1,80 @@
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
|
||||
|
||||
def ParamterCfg(self):
|
||||
questions = os.getenv("CONVERSATION_STARTERS", "dev")
|
||||
return{
|
||||
"opening_statement": self.projectInfo,
|
||||
"suggested_questions": questions.split('\n'),
|
||||
"suggested_questions_after_answer": {
|
||||
"enabled": False
|
||||
},
|
||||
"speech_to_text": {
|
||||
"enabled": False
|
||||
},
|
||||
"text_to_speech": {
|
||||
"enabled": False,
|
||||
"language": "",
|
||||
"voice": ""
|
||||
},
|
||||
"retriever_resource": {
|
||||
"enabled": True
|
||||
},
|
||||
"annotation_reply": {
|
||||
"enabled": False
|
||||
},
|
||||
"more_like_this": {
|
||||
"enabled": False
|
||||
},
|
||||
"user_input_form": [],
|
||||
"sensitive_word_avoidance": {
|
||||
"enabled": False
|
||||
},
|
||||
"file_upload": {
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"transfer_methods": [
|
||||
"remote_url"
|
||||
]
|
||||
}
|
||||
},
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": "10"
|
||||
}
|
||||
}
|
||||
|
||||
def ConversationCfg(self):
|
||||
return{
|
||||
"id": "",
|
||||
'user_id':'',
|
||||
"name": "",
|
||||
"inputs": {},
|
||||
"status": "normal",
|
||||
"introduction": self.projectInfo,
|
||||
"created_at":''
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def MessageCfg(cls):
|
||||
return {
|
||||
"id": "",
|
||||
'user_id':'',
|
||||
"conversation_id": "",
|
||||
"inputs": {},
|
||||
"query": "",
|
||||
"answer": ""
|
||||
}
|
||||
|
||||
|
||||
class ChatEventType(str, Enum):
|
||||
WORKFLOW_START = "workflow_started"
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_START = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
MESSAGE = "message"
|
||||
MESSAGE_END = "message_end"
|
||||
@@ -0,0 +1,220 @@
|
||||
import os
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import create_engine, Column, String, Integer, JSON,Float
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
#orm类
|
||||
class ConversationOrm(Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
name = Column(String)
|
||||
inputs = Column(JSON)
|
||||
status = Column(String)
|
||||
introduction = Column(String)
|
||||
created_at = Column(Integer)
|
||||
|
||||
def update(self,data:Dict[str,Any]):
|
||||
if 'name' in data:
|
||||
self.name = data['name']
|
||||
|
||||
class UserOrm(Base):
|
||||
__tablename__ = "user"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
createtime = Column(String)
|
||||
|
||||
class ParametersOrm(Base):
|
||||
__tablename__ = "parameters"
|
||||
|
||||
user_id = Column(String,primary_key=True)
|
||||
name = Column(String)
|
||||
value = Column(JSON)
|
||||
|
||||
class MessagesOrm(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id = Column(String,primary_key=True)
|
||||
user_id = Column(String)
|
||||
conversation_id = Column(String)
|
||||
inputs = Column(JSON)
|
||||
query = Column(String)
|
||||
answer = Column(String)
|
||||
|
||||
class FeedBackOrm(Base):
|
||||
__tablename__ = "feedbacks"
|
||||
|
||||
message_id = Column(String,primary_key=True)
|
||||
query = Column(String)
|
||||
answer = Column(String)
|
||||
rating = Column(String)
|
||||
|
||||
#数据结构
|
||||
class ConversationModel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
inputs: Dict[str, Any]
|
||||
status: str
|
||||
introduction: str
|
||||
created_at: int
|
||||
|
||||
class Config:
|
||||
from_attributes=True
|
||||
|
||||
@classmethod
|
||||
def orm(cls):
|
||||
return ConversationOrm
|
||||
|
||||
class UserModel(BaseModel):
|
||||
id: str
|
||||
createtime: str
|
||||
|
||||
class Config:
|
||||
from_attributes=True
|
||||
|
||||
@classmethod
|
||||
def orm(cls):
|
||||
return UserOrm
|
||||
|
||||
class ParametersModel(BaseModel):
|
||||
user_id : str
|
||||
name : str
|
||||
value : Dict[str, Any]
|
||||
|
||||
class Config:
|
||||
from_attributes=True
|
||||
|
||||
@classmethod
|
||||
def orm(cls):
|
||||
return ParametersOrm
|
||||
|
||||
class MessagesModel(BaseModel):
|
||||
id :str
|
||||
conversation_id :str
|
||||
inputs : Dict[str, Any]
|
||||
query : str
|
||||
answer : str
|
||||
|
||||
class Config:
|
||||
from_attributes=True
|
||||
|
||||
@classmethod
|
||||
def orm(cls):
|
||||
return MessagesOrm
|
||||
|
||||
class FeedBackModel(BaseModel):
|
||||
message_id :str
|
||||
query :str
|
||||
answer :str
|
||||
rating :str
|
||||
|
||||
class Config:
|
||||
from_attributes=True
|
||||
|
||||
@classmethod
|
||||
def orm(cls):
|
||||
return FeedBackOrm
|
||||
|
||||
class DBManager:
|
||||
def __init__(self) -> None:
|
||||
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
|
||||
self._engine = create_engine(DATABASE_URL)
|
||||
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
|
||||
|
||||
def createTable(self,tableName:str):
|
||||
if self._engine is None:
|
||||
return
|
||||
if not self.exist(tableName):
|
||||
Base.metadata.tables[tableName].create(self._engine)
|
||||
|
||||
def addRecord(self,tableName:str,record:Dict[str,Any]):
|
||||
ormCls = self._get_orm(tableName)
|
||||
if ormCls is None:
|
||||
return
|
||||
session = self.SessionLocal()
|
||||
data = ormCls(**record)
|
||||
session.add(data)
|
||||
session.commit()
|
||||
|
||||
def addRecords(self,tableName:str,records:List[Dict[str,Any]]):
|
||||
ormCls = self._get_orm(tableName)
|
||||
if ormCls is None:
|
||||
return
|
||||
datas = []
|
||||
session = self.SessionLocal()
|
||||
for record in records:
|
||||
datas.append(ormCls(**record))
|
||||
session.add(datas)
|
||||
session.commit()
|
||||
|
||||
def delete(self,tableName:str,**filter):
|
||||
session = self.SessionLocal()
|
||||
ormCls = self._get_orm(tableName)
|
||||
if ormCls is None:
|
||||
return
|
||||
records = session.query(ormCls).filter_by(**filter).all()
|
||||
if records is not None:
|
||||
session.delete(records)
|
||||
session.commit()
|
||||
|
||||
def update(self,tableName:str,data:Dict[str,Any],**filter):
|
||||
if not self.exist(tableName):
|
||||
return
|
||||
session = self.SessionLocal()
|
||||
ormCls = self._get_orm(tableName)
|
||||
if ormCls is None:
|
||||
return
|
||||
if len(filter) > 0:
|
||||
records = session.query(ormCls).filter_by(**filter).all()
|
||||
else:
|
||||
records = session.query(ormCls).all()
|
||||
for record in records:
|
||||
if record is not None:
|
||||
record.update(data)
|
||||
session.commit()
|
||||
|
||||
def query(self,tableName:str,**filter):
|
||||
session = self.SessionLocal()
|
||||
ormCls = self._get_orm(tableName)
|
||||
if ormCls is None:
|
||||
return
|
||||
modelCls = self._get_model(ormCls)
|
||||
if modelCls is None:
|
||||
return
|
||||
|
||||
if filter is not None:
|
||||
records = session.query(ormCls).filter_by(**filter).all()
|
||||
else:
|
||||
records = session.query(ormCls).all()
|
||||
|
||||
datas = []
|
||||
for record in records:
|
||||
datas.append(modelCls.from_orm(record))
|
||||
return datas
|
||||
|
||||
def exist(self,tableName:str)->bool:
|
||||
if self._engine is None:
|
||||
return
|
||||
inspector = Inspector.from_engine(self._engine)
|
||||
return inspector.has_table(tableName)
|
||||
|
||||
def _get_orm(self,tableName:str):
|
||||
subClss = Base.__subclasses__()
|
||||
for sunCls in subClss:
|
||||
if sunCls.__tablename__ == tableName:
|
||||
return sunCls
|
||||
return None
|
||||
|
||||
def _get_model(self,orm:Any):
|
||||
subClss = BaseModel.__subclasses__()
|
||||
for sunCls in subClss:
|
||||
if 'orm' in sunCls.__dict__ and sunCls.orm() == orm:
|
||||
return sunCls
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class ChatRequestData(BaseModel):
|
||||
inputs: Dict[str,Any]
|
||||
query: str
|
||||
user: str
|
||||
response_mode: str
|
||||
files: Any
|
||||
conversation_id: str = None
|
||||
|
||||
class ChatFileUploadRequest(BaseModel):
|
||||
base64: str
|
||||
|
||||
|
||||
+25
-122
@@ -1,154 +1,63 @@
|
||||
import os
|
||||
|
||||
from llama_index.core import SQLDatabase, SummaryIndex, VectorStoreIndex
|
||||
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
|
||||
from llama_index.core.agent import AgentRunner, ReActChatFormatter
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.agent import AgentRunner, StructuredPlannerAgent, FunctionCallingAgentWorker
|
||||
from llama_index.core.tools.query_engine import QueryEngineTool
|
||||
from sqlalchemy import create_engine, Engine
|
||||
|
||||
from app.engine.loaders.db import makeDescriptionByEngine
|
||||
from app.engine.tools import ToolFactory
|
||||
from app.engine.engine import create_query_engine, create_summary_query_engine
|
||||
from app.engine.index import get_index
|
||||
from app.settings import get_node_postprocessors
|
||||
#from app.engine.loaders.db import makeDescriptionByEngine
|
||||
from app.engine.tools import ToolFactory
|
||||
|
||||
from llama_index.core.retrievers import BaseRetriever
|
||||
from llama_index.core import QueryBundle
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from typing import List, Any, Optional,Dict
|
||||
from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine
|
||||
|
||||
class HybridRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
vector_index,
|
||||
similarity_top_k: int = 2,
|
||||
out_top_k: Optional[int] = None,
|
||||
alpha: float = 0.5,
|
||||
filters = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
from llama_index.retrievers.bm25 import BM25Retriever
|
||||
from nltk.corpus import stopwords
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._vector_index = vector_index
|
||||
self._embed_model = vector_index._embed_model
|
||||
self._out_top_k = out_top_k or similarity_top_k
|
||||
self._vecRetriever = vector_index.as_retriever(
|
||||
similarity_top_k=similarity_top_k,filters = filters
|
||||
)
|
||||
self._bm25Retriever = BM25Retriever.from_defaults(similarity_top_k=similarity_top_k,
|
||||
nodes=self._vector_index.vector_store.get_nodes(None),
|
||||
language=stopwords.words('chinese'))
|
||||
self._alpha = alpha
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str)
|
||||
bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str)
|
||||
|
||||
bmDic:Dict[str,NodeWithScore] = {}
|
||||
for node in bmNodes:
|
||||
bmDic[node.node_id] = node
|
||||
|
||||
result_tups = []
|
||||
for i in range(len(vecNodes)):
|
||||
node = vecNodes[i]
|
||||
bmScore = 0.0
|
||||
if node.node_id in bmDic:
|
||||
bmScore = bmDic[node.node_id].score
|
||||
bmDic.pop(node.node_id)
|
||||
else:
|
||||
bmScore = 0.0
|
||||
full_similarity = (self._alpha * node.score) + (
|
||||
(1 - self._alpha) * bmScore
|
||||
)
|
||||
result_tups.append((full_similarity, node))
|
||||
|
||||
for _,node in bmDic.items():
|
||||
full_similarity = (1 - self._alpha) * node.score
|
||||
result_tups.append((full_similarity, node))
|
||||
|
||||
result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True)
|
||||
for full_score, node in result_tups:
|
||||
node.score = full_score
|
||||
return [n for _, n in result_tups][:self._out_top_k]
|
||||
|
||||
def get_Retriever(index,**kwargs):
|
||||
bEnableHybrid = True if os.getenv("HYBRID_ENABLED",False).title() == 'True' else False
|
||||
if bEnableHybrid:
|
||||
alpha = float(os.getenv("HYBRID_ALPHA", "0.5"))
|
||||
retriever = HybridRetriever(index,alpha = alpha,**kwargs)
|
||||
else:
|
||||
retriever = index.as_retriever(**kwargs)
|
||||
return retriever
|
||||
|
||||
sql_database = None
|
||||
sql_obj_index = None
|
||||
|
||||
def get_chat_engine(filters=None, params=None):
|
||||
system_prompt = os.getenv("SYSTEM_PROMPT")
|
||||
top_k = int(os.getenv("TOP_K", "3"))
|
||||
use_reranker = os.getenv("RERANK_ENABLED")
|
||||
tools = []
|
||||
|
||||
global sql_obj_index
|
||||
global sql_database
|
||||
if sql_obj_index is None:
|
||||
sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||
sql_database = SQLDatabase(sqlengine)
|
||||
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
|
||||
sql_obj_index = ObjectIndex.from_objects(
|
||||
table_schema_objs,
|
||||
table_node_mapping,
|
||||
index_cls=VectorStoreIndex,
|
||||
)
|
||||
|
||||
# 创建SQL查询工具
|
||||
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
||||
sql_obj_index.as_retriever(similarity_top_k=top_k),
|
||||
verbose=True,)
|
||||
sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine,
|
||||
name="zjdata_query_tool",
|
||||
description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具")
|
||||
# sql_query_engine = create_summary_query_engine(index)
|
||||
# sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine,
|
||||
# name="zjdata_query_tool",
|
||||
# description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具"
|
||||
# )
|
||||
#tools.append(sql_query_tool)
|
||||
|
||||
# Add query tool if index exists
|
||||
index = get_index()
|
||||
if index is not None:
|
||||
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None))
|
||||
summary_query_engine = summary_index.as_query_engine()
|
||||
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
|
||||
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
||||
description="适用于任何需要进行全面总结、概括的要求。",
|
||||
#description="适用于任何需要对所有内容进行全面总结的请求。有关电力造价领域更具体部分的问题,请使用zj_query_engine_tool",
|
||||
)
|
||||
|
||||
# 创建向量检索查询工具
|
||||
postprocess = get_node_postprocessors()
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
get_Retriever(index,similarity_top_k=top_k,
|
||||
filters=filters),
|
||||
node_postprocessors=postprocess,
|
||||
)
|
||||
|
||||
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "COMPACT")
|
||||
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
|
||||
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。如果你不知道答案,就说你不知道,不要编造答案。",
|
||||
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
|
||||
)
|
||||
|
||||
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "TREE_SUMMARIZE")
|
||||
query_engine_tool_1 = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool_1",
|
||||
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后,且在询问工程中单位的具体数值,例如用量,费率,合计,金额等的时候建议使用“zj_query_tool_1”工具。",
|
||||
)
|
||||
|
||||
tools.append(summary_query_tool)
|
||||
tools.append(query_engine_tool)
|
||||
#tools.append(sql_query_tool)
|
||||
tools.append(query_engine_tool_1)
|
||||
|
||||
# Add additional tools
|
||||
tools += ToolFactory.from_env()
|
||||
|
||||
return AgentRunner.from_llm(
|
||||
prefix_messages = ("""您的设计旨在帮助完成各种任务,从回答问题到提供其他类型分析的摘要。\n\n##工具\n\n你可以访问各种工具。你有责任按照你认为合适的顺序使用这些工具来完成当前的任务。\n这可能需要将任务分解为子任务,并使用不同的工具来完成每个子任务。\n\n你可以访问以下工具:\n{tool_desc}\n\n\n##输出格式\n\n请用与问题相同的语言回答,并使用以下格式:\n\n \nThought: 用户当前的语言是:(user's language)。我需要使用工具来帮助我回答问题。\nAction: 如果使用工具,则为工具名称(one of {tool_names})。\nAction Input: 输入给工具的内容,使用JSON格式表示kwargs(例如{{\"input\": \"hello world\", \"num_beams\": 5}})\n \n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n切勿用Markdown代码标记包围你的响应。如果需要,可以在响应中使用代码标记。\n\n请为Action Input使用有效的JSON格式。不要这样做{{\'input\': \'hello world\', \'num_beams\': 5}}。\n\n如果使用此格式,用户将以下面的格式进行回应:\n\n \nObservation: 工具响应\n \n\n你应该继续重复上述格式,直到你有足够的信息来回答问题而无需使用更多工具。此时,你必须使用以下两种格式之一进行回答:\n\n \nThought: 我可以不用任何工具来回答。我将使用用户的语言来回答。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n \nThought: 我无法使用提供的工具回答问题。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n##如果从工具中得到的回应是Empty Response,那么只需要回答“我不知道”,不需要额外回答别的内容。## 当前对话\n\n以下是当前对话,由人类和助手的消息交替组成。\n""")
|
||||
react_chat_formatter = ReActChatFormatter.from_defaults(prefix_messages)
|
||||
agentrunner = AgentRunner.from_llm(
|
||||
llm=Settings.llm,
|
||||
tools=tools,
|
||||
react_chat_formatter=react_chat_formatter,
|
||||
system_prompt=system_prompt,
|
||||
verbose=True,
|
||||
)
|
||||
return agentrunner
|
||||
# create the function calling worker for reasoning
|
||||
# worker = FunctionCallingAgentWorker.from_tools(
|
||||
# tools, verbose=True
|
||||
@@ -156,9 +65,3 @@ def get_chat_engine(filters=None, params=None):
|
||||
#
|
||||
# # wrap the worker in the top-level planner
|
||||
# return StructuredPlannerAgent(worker, tools)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
import os
|
||||
|
||||
from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex
|
||||
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import ResponseMode
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||
from app.engine.retriever.HybridRetriever import HybridRetriever
|
||||
from app.settings import get_node_postprocessors
|
||||
|
||||
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
||||
reader = DatabaseReader(sql_database)
|
||||
|
||||
table_names = sql_database.get_usable_table_names()
|
||||
table_schema_objs = []
|
||||
for table_name in table_names:
|
||||
columns = sql_database.get_table_columns(table_name)
|
||||
if len(columns) > 150:
|
||||
continue
|
||||
stats_txt = ""
|
||||
|
||||
if table_name == 'gongchengshuxing':
|
||||
stats_txt = '该表中有以下属性:'
|
||||
documents = reader.load_data(query='select name from gongchengshuxing')
|
||||
for index in range(len(documents) if len(documents) < 30 else 30):
|
||||
if index == 0:
|
||||
continue
|
||||
elif index > 1:
|
||||
stats_txt += ','
|
||||
stats_txt += documents[index].text.split(':')[1]
|
||||
|
||||
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
|
||||
table_schema_objs.append(tbSchema)
|
||||
|
||||
return table_schema_objs
|
||||
|
||||
def get_Retriever(index,**kwargs):
|
||||
strEnableHybrid = os.getenv("HYBRID_ENABLED",'False')
|
||||
bEnableHybrid = True if strEnableHybrid is not None and strEnableHybrid.title() == 'True' else False
|
||||
if bEnableHybrid:
|
||||
alpha = float(os.getenv("HYBRID_ALPHA", "0.5"))
|
||||
retriever = HybridRetriever(index,alpha = alpha,**kwargs)
|
||||
else:
|
||||
retriever = index.as_retriever(**kwargs)
|
||||
return retriever
|
||||
|
||||
|
||||
sql_database = None
|
||||
sql_obj_index = None
|
||||
|
||||
# Create a summary query engine
|
||||
def create_summary_query_engine(top_k=3, use_reranker=False, filters=None):
|
||||
global sql_obj_index
|
||||
global sql_database
|
||||
if sql_obj_index is None or sql_database is None:
|
||||
sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||
sql_database = SQLDatabase(sqlengine)
|
||||
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
|
||||
sql_obj_index = ObjectIndex.from_objects(
|
||||
table_schema_objs,
|
||||
table_node_mapping,
|
||||
index_cls=VectorStoreIndex,
|
||||
)
|
||||
|
||||
# 创建SQL查询工具
|
||||
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
||||
sql_obj_index.as_retriever(similarity_top_k=top_k),
|
||||
verbose=True,
|
||||
)
|
||||
return sql_query_engine
|
||||
|
||||
# Create a summary query engine
|
||||
def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
||||
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None))
|
||||
summary_query_engine = summary_index.as_query_engine(
|
||||
response_mode=ResponseMode.TREE_SUMMARIZE,
|
||||
use_async=True,
|
||||
streaming=True,
|
||||
)
|
||||
return summary_query_engine
|
||||
|
||||
# Create a query engine
|
||||
def create_query_engine(index, top_k=3, use_reranker=False, filters=None, response_mode=None):
|
||||
# 创建向量检索查询工具
|
||||
postprocess = None
|
||||
if use_reranker:
|
||||
postprocess = get_node_postprocessors()
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
get_Retriever(index,
|
||||
similarity_top_k=top_k,
|
||||
filters=filters),
|
||||
text_qa_template=text_qa_template,
|
||||
refine_template=refine_template,
|
||||
summary_template = summary_template,
|
||||
simple_template = simple_template,
|
||||
node_postprocessors=postprocess,
|
||||
use_async=True,
|
||||
streaming=True,
|
||||
ResponseMode = response_mode
|
||||
)
|
||||
|
||||
return query_engine
|
||||
@@ -8,6 +8,7 @@ import os
|
||||
from app.engine.loaders import get_documents
|
||||
from app.engine.vectordb import get_vector_store
|
||||
from app.settings import init_settings
|
||||
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.settings import Settings
|
||||
@@ -58,6 +59,13 @@ def persist_storage(docstore, vector_store):
|
||||
storage_context.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def persist_BMRetriever(vector_store):
|
||||
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||
top_k = int(os.getenv("TOP_K", "3"))
|
||||
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=top_k,nodes=vector_store.get_nodes([]))
|
||||
bmRetriver.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Generate index for the provided data")
|
||||
@@ -75,6 +83,7 @@ def generate_datasource():
|
||||
|
||||
# Build the index and persist storage
|
||||
persist_storage(docstore, vector_store)
|
||||
persist_BMRetriever(vector_store)
|
||||
|
||||
logger.info("Finished generating the index")
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||
@@ -9,7 +8,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_configs():
|
||||
with open("config/loaders.yaml") as f:
|
||||
with open("config/loaders.yaml",encoding='UTF-8') as f:
|
||||
configs = yaml.safe_load(f)
|
||||
return configs
|
||||
|
||||
@@ -17,10 +16,12 @@ def load_configs():
|
||||
def get_documents():
|
||||
documents = []
|
||||
config = load_configs()
|
||||
|
||||
if config is None or len(config.items()) == 0:
|
||||
return documents
|
||||
|
||||
for loader_type, loader_config in config.items():
|
||||
if loader_config.get('enable', True): # 检查 enable 字段
|
||||
logger.info(
|
||||
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
||||
)
|
||||
|
||||
@@ -1,24 +1,15 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.schema import Document
|
||||
from llama_index.core.utilities.sql_wrapper import SQLDatabase
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Engine
|
||||
from llama_index.core import SQLDatabase, Document
|
||||
from llama_index.core.objects import SQLTableSchema, SQLTableNodeMapping
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from pydantic import BaseModel, validator
|
||||
from llama_index.core.indices.vector_store import VectorStoreIndex
|
||||
from sqlalchemy import create_engine
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CustomDatabaseReader(BaseReader):
|
||||
class CustomDatabaseReader(DatabaseReader):
|
||||
"""Simple Database reader.
|
||||
|
||||
Concatenates each row into Document used by LlamaIndex.
|
||||
@@ -92,18 +83,19 @@ class CustomDatabaseReader(BaseReader):
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
dco_str = ""
|
||||
|
||||
with self.sql_database.engine.connect() as connection:
|
||||
if query is None:
|
||||
raise ValueError("A query parameter is necessary to filter the data")
|
||||
else:
|
||||
result = connection.execute(text(query))
|
||||
|
||||
dco_str = ", ".join(
|
||||
dco_str += ", ".join(
|
||||
[f"{entry}" for entry in result.keys()]
|
||||
)
|
||||
) + "\n"
|
||||
|
||||
for item in result.fetchall():
|
||||
# fetch each item
|
||||
# Fetch each item
|
||||
record_str = ", ".join(
|
||||
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
||||
)
|
||||
@@ -117,71 +109,36 @@ class CustomDatabaseReader(BaseReader):
|
||||
|
||||
class DBLoaderConfig(BaseModel):
|
||||
uri: str
|
||||
queries: List[str]
|
||||
queries: List[dict]
|
||||
|
||||
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
||||
reader = DatabaseReader(sql_database)
|
||||
|
||||
table_names = sql_database.get_usable_table_names()
|
||||
table_schema_objs = []
|
||||
for table_name in table_names:
|
||||
columns = sql_database.get_table_columns(table_name)
|
||||
if len(columns) > 150:
|
||||
continue
|
||||
stats_txt = ""
|
||||
|
||||
if table_name == 'gongchengshuxing':
|
||||
stats_txt = '该表中有以下属性:'
|
||||
documents = reader.load_data(query='select name from gongchengshuxing')
|
||||
for index in range(len(documents) if len(documents) < 30 else 30):
|
||||
if index == 0:
|
||||
continue
|
||||
elif index > 1:
|
||||
stats_txt += ','
|
||||
stats_txt += documents[index].text.split(':')[1]
|
||||
|
||||
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
|
||||
table_schema_objs.append(tbSchema)
|
||||
|
||||
return table_schema_objs
|
||||
|
||||
def get_db_documents(configs: list[DBLoaderConfig]):
|
||||
def get_db_documents(configs: List[DBLoaderConfig]) -> List[Document]:
|
||||
docs = []
|
||||
|
||||
if len(configs) == 0 or configs[0].uri == "":
|
||||
if not configs or not configs[0].uri:
|
||||
logger.warning(
|
||||
f"Failed to load database, error message: uri is empty. Return as empty document list."
|
||||
)
|
||||
return docs
|
||||
|
||||
metadata = {
|
||||
#'file_name':'',
|
||||
'file_type':'application/booway.document.zj',
|
||||
#'file_path':'',
|
||||
#'file_size':'',
|
||||
#'creation_date':'',
|
||||
#'last_modified_date':'',
|
||||
'file_type': 'application/booway.document.zj',
|
||||
}
|
||||
|
||||
#from llama_index.readers.database import DatabaseReader
|
||||
for entry in configs:
|
||||
engine = create_engine(entry.uri)
|
||||
sql_database = SQLDatabase(engine)
|
||||
|
||||
# table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
# table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
#
|
||||
# nodes = table_node_mapping.to_nodes(table_schema_objs)
|
||||
# for node in nodes:
|
||||
# node.metadata.update(metadata)
|
||||
#
|
||||
# docs.extend(nodes)
|
||||
|
||||
queries = entry.queries or []
|
||||
loader = CustomDatabaseReader(sql_database)
|
||||
for query in queries:
|
||||
for query_dict in entry.queries:
|
||||
query = query_dict.get("sql", "")
|
||||
explanation = query_dict.get("explanation", "")
|
||||
logger.info(f"Loading data from database with query: {query}")
|
||||
documents = loader.load_data(query=query)
|
||||
|
||||
docs.extend(documents)
|
||||
# 添加解释到元数据中
|
||||
for doc in documents:
|
||||
doc.metadata["explanation"] = explanation
|
||||
doc.metadata.update(metadata) # 更新或添加额外的元数据
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@@ -0,0 +1,95 @@
|
||||
from llama_index.core import PromptTemplate
|
||||
|
||||
text_qa_template_str = (
|
||||
"# 角色\n"
|
||||
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
|
||||
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
|
||||
"如同直接从文件中提取的内容。\n"
|
||||
"知识库中已经导入一个工程的全部数据,请你站在当前工程的角度回答用户关于工程文件的问题。\n"
|
||||
"例如:询问“此工程”指当前导入的工程。询问“此工程名称”指当前导入的工程的工程名称。\n"
|
||||
|
||||
"## 技能\n"
|
||||
"### 技能 1: 数据查询与提供\n"
|
||||
"- 准确回答所有关于电力工程造价的相关问题。\n"
|
||||
"- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n"
|
||||
"- 确保提供的信息严格基于工程文档中的记录。\n"
|
||||
|
||||
"### 技能 2: 技术性解释\n"
|
||||
"- 解释造价工程中的技术术语和概念。\n"
|
||||
"- 为复杂的工程细节提供清晰易懂的说明。\n"
|
||||
|
||||
"## 约束\n"
|
||||
"- 仅回答与电力工程造价文件相关的具体问题。\n"
|
||||
"- 不进行任何超出文件内容的猜测或假设。\n"
|
||||
"- 所有回答均基于文件内容,采用客观和技术性的语言。\n"
|
||||
"- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n"
|
||||
"以下为上下文信息\n"
|
||||
"---------------------\n"
|
||||
"{context_str}\n"
|
||||
"---------------------\n"
|
||||
"请根据上下文信息而非先前知识回答我的问题或回复我的指令。前面的上下文信息可能有用,也可能没用,你需要从我给出的上下文信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。我的问题或指令是什么语种,你就用什么语种回复。\n"
|
||||
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||
|
||||
"问题:{query_str}\n"
|
||||
"你的回复: "
|
||||
)
|
||||
|
||||
|
||||
text_qa_template = PromptTemplate(text_qa_template_str)
|
||||
|
||||
refine_template_str = (
|
||||
"这是原本的问题: {query_str}\n"
|
||||
"我们已经提供了回答: {existing_answer}\n"
|
||||
"现在我们有机会改进这个回答 "
|
||||
"使用以下更多上下文(仅当有助于改进回答时使用)\n"
|
||||
"你需要仔细的判断新的上下文的信息与原本问题必须一个字都不差,如果有一点差别,那就不能改变我现有的回答。\n"
|
||||
"在判断回答是否正确的时候,你应该仔细对比新的上下文中包含的信息是否与原本的问题一字不差,如果一字不差,才能当作新的正确回答。\n"
|
||||
"如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n"
|
||||
"判断一下如果原回答正确,且在新的上下文仍然包含正确的回答,请将新的回答与原回答一起返回。\n"
|
||||
"------------\n"
|
||||
"{context_msg}\n"
|
||||
"------------\n"
|
||||
"如果回答中已经包含有正确答案,不要返回多余的解释等信息,只返回正确答案\n"
|
||||
"如果是表结构或者是数据库的相关内容,仅用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||
"改进的回答: "
|
||||
)
|
||||
|
||||
refine_template = PromptTemplate(refine_template_str)
|
||||
|
||||
summary_template_str = (
|
||||
"# 角色\n"
|
||||
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
|
||||
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
|
||||
"如同直接从文件中提取的内容。\n"
|
||||
|
||||
"## 技能\n"
|
||||
"### 技能 1: 数据查询与提供\n"
|
||||
"- 准确回答所有关于电力工程造价的相关问题。\n"
|
||||
"- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n"
|
||||
"- 确保提供的信息严格基于工程文档中的记录。\n"
|
||||
|
||||
"### 技能 2: 技术性解释\n"
|
||||
"- 解释造价工程中的技术术语和概念。\n"
|
||||
"- 为复杂的工程细节提供清晰易懂的说明。\n"
|
||||
|
||||
"## 约束\n"
|
||||
"- 仅回答与电力工程造价文件相关的具体问题。\n"
|
||||
"- 不进行任何超出文件内容的猜测或假设。\n"
|
||||
"- 所有回答均基于文件内容,采用客观和技术性的语言。\n"
|
||||
"- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n"
|
||||
"来自多个来源的上下文信息如下。\n"
|
||||
"---------------------\n"
|
||||
"{context_str}\n"
|
||||
"---------------------\n"
|
||||
"鉴于来自多个来源的信息而非先验知识, "
|
||||
"回答查询。\n"
|
||||
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||
"Query: {query_str}\n"
|
||||
"Answer: "
|
||||
)
|
||||
summary_template = PromptTemplate(summary_template_str)
|
||||
|
||||
simple_template_str = (
|
||||
"{query_str}"
|
||||
)
|
||||
simple_template = PromptTemplate(simple_template_str)
|
||||
@@ -0,0 +1,133 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
|
||||
from llama_index.core.indices.vector_store.base import VectorStoreIndex
|
||||
from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
|
||||
from llama_index.core.storage.docstore.types import BaseDocumentStore
|
||||
from llama_index.core.vector_stores.utils import (
|
||||
node_to_metadata_dict,
|
||||
metadata_dict_to_node,
|
||||
)
|
||||
|
||||
import bm25s
|
||||
from app.engine.retriever.CHTokener import chTokenize
|
||||
|
||||
CHDEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"}
|
||||
|
||||
CHDEFAULT_PERSIST_FILENAME = "retriever.json"
|
||||
|
||||
class CHBM25Retriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
nodes: Optional[List[BaseNode]] = None,
|
||||
existing_bm25: Optional[bm25s.BM25] = None,
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
objects: Optional[List[IndexNode]] = None,
|
||||
object_map: Optional[dict] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.similarity_top_k = similarity_top_k
|
||||
if existing_bm25 is not None:
|
||||
self.bm25 = existing_bm25
|
||||
self.corpus = existing_bm25.corpus
|
||||
else:
|
||||
from nltk.corpus import stopwords
|
||||
if nodes is None:
|
||||
raise ValueError("Please pass nodes or an existing BM25 object.")
|
||||
|
||||
self.corpus = [node_to_metadata_dict(node) for node in nodes]
|
||||
|
||||
corpus_tokens = chTokenize(
|
||||
[node.get_content() for node in nodes],
|
||||
show_progress=verbose,
|
||||
)
|
||||
self.bm25 = bm25s.BM25()
|
||||
self.bm25.index(corpus_tokens, show_progress=verbose)
|
||||
super().__init__(
|
||||
callback_manager=callback_manager,
|
||||
object_map=object_map,
|
||||
objects=objects,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_defaults(
|
||||
cls,
|
||||
index: Optional[VectorStoreIndex] = None,
|
||||
nodes: Optional[List[BaseNode]] = None,
|
||||
docstore: Optional[BaseDocumentStore] = None,
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
verbose: bool = False,
|
||||
) -> "CHBM25Retriever":
|
||||
if sum(bool(val) for val in [index, nodes, docstore]) != 1:
|
||||
raise ValueError("Please pass exactly one of index, nodes, or docstore.")
|
||||
|
||||
if index is not None:
|
||||
docstore = index.docstore
|
||||
|
||||
if docstore is not None:
|
||||
nodes = cast(List[BaseNode], list(docstore.docs.values()))
|
||||
|
||||
assert (
|
||||
nodes is not None
|
||||
), "Please pass exactly one of index, nodes, or docstore."
|
||||
|
||||
return cls(
|
||||
nodes=nodes,
|
||||
similarity_top_k=similarity_top_k,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
def get_persist_args(self) -> Dict[str, Any]:
|
||||
"""Get Persist Args Dict to Save."""
|
||||
return {
|
||||
CHDEFAULT_PERSIST_ARGS[key]: getattr(self, key)
|
||||
for key in CHDEFAULT_PERSIST_ARGS
|
||||
if hasattr(self, key)
|
||||
}
|
||||
|
||||
def persist(self, path: str, **kwargs: Any) -> None:
|
||||
"""Persist the retriever to a directory."""
|
||||
self.bm25.save(path, corpus=self.corpus, **kwargs)
|
||||
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME), "w") as f:
|
||||
json.dump(self.get_persist_args(), f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_persist_dir(cls, path: str, **kwargs: Any) -> "CHBM25Retriever":
|
||||
"""Load the retriever from a directory."""
|
||||
bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
|
||||
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME)) as f:
|
||||
retriever_data = json.load(f)
|
||||
return cls(existing_bm25=bm25, **retriever_data)
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
query = query_bundle.query_str
|
||||
tokenized_query = chTokenize(
|
||||
query,show_progress=self._verbose
|
||||
)
|
||||
indexes, scores = self.bm25.retrieve(
|
||||
tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
|
||||
)
|
||||
|
||||
# batched, but only one query
|
||||
indexes = indexes[0]
|
||||
scores = scores[0]
|
||||
|
||||
nodes: List[NodeWithScore] = []
|
||||
for idx, score in zip(indexes, scores):
|
||||
# idx can be an int or a dict of the node
|
||||
if isinstance(idx, dict):
|
||||
node = metadata_dict_to_node(idx)
|
||||
else:
|
||||
node_dict = self.corpus[int(idx)]
|
||||
node = metadata_dict_to_node(node_dict)
|
||||
nodes.append(NodeWithScore(node=node, score=float(score)))
|
||||
|
||||
return nodes
|
||||
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Union, Callable, NamedTuple
|
||||
from bm25s.tokenization import *
|
||||
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
except ImportError:
|
||||
|
||||
def tqdm(iterable, *args, **kwargs):
|
||||
return iterable
|
||||
|
||||
import jieba
|
||||
jiebapath = os.environ.get("JIEBA_DATA", "")
|
||||
jieba.set_dictionary(os.path.join(jiebapath, 'dict.txt')) #设置字典
|
||||
jieba.initialize() #初始化jeiba
|
||||
|
||||
def chinese_tokenizer(text: str) -> List[str]:
|
||||
from nltk.corpus import stopwords
|
||||
tokens = jieba.lcut(text)
|
||||
return [token for token in tokens if token not in stopwords.words('chinese')]
|
||||
|
||||
def chTokenize(
|
||||
texts,
|
||||
show_progress: bool = True,
|
||||
leave: bool = False,
|
||||
) -> Union[List[List[str]], Tokenized]:
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
corpus_ids = []
|
||||
token_to_index = {}
|
||||
|
||||
for text in tqdm(
|
||||
texts, desc="Split strings", leave=leave, disable=not show_progress
|
||||
):
|
||||
|
||||
splitted = chinese_tokenizer(text)
|
||||
doc_ids = []
|
||||
|
||||
for token in splitted:
|
||||
if token not in token_to_index:
|
||||
token_to_index[token] = len(token_to_index)
|
||||
|
||||
token_id = token_to_index[token]
|
||||
doc_ids.append(token_id)
|
||||
|
||||
corpus_ids.append(doc_ids)
|
||||
|
||||
return Tokenized(ids=corpus_ids, vocab=token_to_index)
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
import os
|
||||
from typing import Optional, Any, Dict, List
|
||||
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
|
||||
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||
|
||||
|
||||
class HybridRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
vector_index,
|
||||
similarity_top_k: int = 2,
|
||||
out_top_k: Optional[int] = None,
|
||||
alpha: float = 0.5,
|
||||
filters = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._vector_index = vector_index
|
||||
self._embed_model = vector_index._embed_model
|
||||
self._out_top_k = out_top_k or similarity_top_k
|
||||
self._vecRetriever = vector_index.as_retriever(
|
||||
similarity_top_k=similarity_top_k,filters = filters
|
||||
)
|
||||
|
||||
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
|
||||
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
|
||||
else:
|
||||
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None))
|
||||
bmRetriver.persist(STORAGE_DIR)
|
||||
self._alpha = alpha
|
||||
|
||||
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str)
|
||||
bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str)
|
||||
|
||||
bmDic:Dict[str,NodeWithScore] = {}
|
||||
for node in bmNodes:
|
||||
bmDic[node.node_id] = node
|
||||
|
||||
result_tups = []
|
||||
for i in range(len(vecNodes)):
|
||||
node = vecNodes[i]
|
||||
bmScore = 0.0
|
||||
if node.node_id in bmDic:
|
||||
bmScore = bmDic[node.node_id].score
|
||||
bmDic.pop(node.node_id)
|
||||
else:
|
||||
bmScore = 0.0
|
||||
full_similarity = (self._alpha * node.score) + (
|
||||
(1 - self._alpha) * bmScore
|
||||
)
|
||||
result_tups.append((full_similarity, node))
|
||||
|
||||
for _,node in bmDic.items():
|
||||
full_similarity = (1 - self._alpha) * node.score
|
||||
result_tups.append((full_similarity, node))
|
||||
|
||||
result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True)
|
||||
for full_score, node in result_tups:
|
||||
node.score = full_score
|
||||
return [n for _, n in result_tups][:self._out_top_k]
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import importlib
|
||||
from cachetools import cached, LRUCache
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
import os
|
||||
|
||||
import yaml
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
|
||||
|
||||
class ToolType:
|
||||
@@ -46,7 +45,7 @@ class ToolFactory:
|
||||
def from_env() -> list[FunctionTool]:
|
||||
tools = []
|
||||
if os.path.exists("config/tools.yaml"):
|
||||
with open("config/tools.yaml", "r") as f:
|
||||
with open("config/tools.yaml", "r", encoding='UTF-8') as f:
|
||||
tool_configs = yaml.safe_load(f)
|
||||
if tool_configs != None and len(tool_configs.items()) != 0:
|
||||
for tool_type, config_entries in tool_configs.items():
|
||||
|
||||
@@ -3,11 +3,10 @@ from typing import Dict
|
||||
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.core.settings import Settings
|
||||
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||
from llama_index.llms.xinference import Xinference
|
||||
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
||||
|
||||
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
file:
|
||||
enable: true # 添加 enable 字段
|
||||
# use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable
|
||||
use_llama_parse: false
|
||||
|
||||
@@ -7,14 +8,41 @@ db:
|
||||
# uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
|
||||
# query: The query to fetch data from the database. E.g.: SELECT * FROM table
|
||||
- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
#- uri: mysql+pymysql://zjinfo:Y6EAjEEdSYmskA8B@110.42.234.166:3306/zjinfo
|
||||
# - uri: mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||
enable: true # 添加 enable 字段
|
||||
queries:
|
||||
- select * from ProjectProperties limit 30;
|
||||
- select Name, Code, Amount, Amount_Total from TotalCalculateTable
|
||||
- select SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 limit 50;
|
||||
- select Name, Code, Rate, Amount from OtherFee
|
||||
- sql: select * from ProjectProperties;
|
||||
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
|
||||
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
|
||||
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
|
||||
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
|
||||
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
|
||||
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
|
||||
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
|
||||
explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
|
||||
explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
#web:
|
||||
# driver_arguments:
|
||||
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
+7
-5
@@ -1,6 +1,5 @@
|
||||
from dotenv import load_dotenv
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
@@ -11,16 +10,20 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from app.api.routers.chat import chat_router
|
||||
from app.api.routers.upload import file_upload_router
|
||||
from app.api.routers.app import v1_router
|
||||
from app.settings import init_settings
|
||||
from app.observability import init_observability
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from phoenix.trace import using_project
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
usPrj = using_project(os.getenv("PHOENIX_PROJECT_NAME"))
|
||||
usPrj.__enter__()
|
||||
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
@@ -52,12 +55,12 @@ mount_static_files("data_output", "/api/files/output")
|
||||
app.include_router(chat_router, prefix="/api/chat")
|
||||
app.include_router(file_upload_router, prefix="/api/chat/upload")
|
||||
|
||||
# Redirect to documentation page when accessing base URL
|
||||
app.include_router(v1_router, prefix="/v1")
|
||||
|
||||
@app.get("/")
|
||||
async def redirect_to_docs():
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
SentenceSplitter
|
||||
if __name__ == "__main__":
|
||||
app_host = os.getenv("APP_HOST", "0.0.0.0")
|
||||
app_port = int(os.getenv("APP_PORT", "8000"))
|
||||
@@ -65,4 +68,3 @@ if __name__ == "__main__":
|
||||
reload = False
|
||||
uvicorn.run(app="main:app", host=app_host, port=app_port, reload=reload)
|
||||
|
||||
#usPrj.__exit__()
|
||||
|
||||
Binary file not shown.
Binary file not shown.
+349046
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
@@ -17,7 +17,8 @@ aiostream = "^0.6.2"
|
||||
llama-index = "0.10.63"
|
||||
cachetools = "^5.3.3"
|
||||
protobuf = "4.25.4"
|
||||
nltk = "^3.8.2"
|
||||
nltk = "^3.9.1"
|
||||
jieba = "^0.42.1"
|
||||
|
||||
#arize-phoenix = "^4.12.0"
|
||||
openinference-instrumentation-llama-index="2.2.3"
|
||||
@@ -34,6 +35,7 @@ chroma="^0.2.0"
|
||||
llama-index-vector-stores-chroma = "^0.1.10"
|
||||
llama-index-readers-json = "^0.1.5"
|
||||
llama-index-retrievers-bm25 = "^0.2.2"
|
||||
llama-index-experimental = "^0.2.0"
|
||||
|
||||
duckduckgo_search = "^6.2.6"
|
||||
|
||||
@@ -61,6 +63,12 @@ version = "^0.8"
|
||||
version = "0.0.7"
|
||||
|
||||
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "mirrors"
|
||||
url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
|
||||
priority = "default"
|
||||
|
||||
[build-system]
|
||||
requires = [ "poetry-core" ]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,138 @@
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.evaluation import (
|
||||
FaithfulnessEvaluator,
|
||||
DatasetGenerator,
|
||||
CorrectnessEvaluator,
|
||||
SemanticSimilarityEvaluator,
|
||||
)
|
||||
from llama_index.experimental.param_tuner import ParamTuner
|
||||
from llama_index.experimental.param_tuner.base import RunResult
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
import asyncio
|
||||
|
||||
# 初始化环境
|
||||
from app.observability import init_observability
|
||||
from app.settings import init_settings
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
# 读取文档
|
||||
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
||||
|
||||
# 参数字典
|
||||
param_dict = {
|
||||
"chunk_size": [512, 1024],
|
||||
"top_k": [1, 5],
|
||||
"temperature": [0.1, 1.0]
|
||||
}
|
||||
|
||||
# 辅助函数
|
||||
def _build_index(chunk_size, documents):
|
||||
# 构建索引
|
||||
splitter = SentenceSplitter(chunk_size=chunk_size)
|
||||
vector_index = VectorStoreIndex.from_documents(
|
||||
documents, transformations=[splitter],
|
||||
)
|
||||
return vector_index
|
||||
|
||||
# 评估函数
|
||||
def evaluate_query_engine(query_engine, questions):
|
||||
loop = asyncio.get_event_loop()
|
||||
correct, total = loop.run_until_complete(_evaluate_query_engine_async(query_engine, questions))
|
||||
return correct, total
|
||||
|
||||
async def _evaluate_query_engine_async(query_engine, questions):
|
||||
c = [query_engine.aquery(q) for q in questions]
|
||||
gathering_future = asyncio.gather(*c)
|
||||
results = await gathering_future
|
||||
|
||||
total_correct = 0
|
||||
for r in results:
|
||||
eval_result = (
|
||||
1 if FaithfulnessEvaluator().evaluate_response(response=r).passing else 0
|
||||
)
|
||||
total_correct += eval_result
|
||||
|
||||
return total_correct, len(results)
|
||||
|
||||
|
||||
|
||||
# 生成问题
|
||||
question_generator = DatasetGenerator.from_documents(documents)
|
||||
eval_questions = question_generator.generate_questions_from_nodes(1) # 假设生成10个问题
|
||||
|
||||
# 打印生成的问题
|
||||
for i, q in enumerate(eval_questions, start=1):
|
||||
print(f"问题 {i}: {q}")
|
||||
|
||||
# 目标函数
|
||||
def objective_function(params_dict, documents, questions):
|
||||
chunk_size = params_dict["chunk_size"]
|
||||
top_k = params_dict["top_k"]
|
||||
temperature = params_dict["temperature"]
|
||||
|
||||
# 构建索引
|
||||
vector_index = _build_index(chunk_size, documents)
|
||||
|
||||
# 查询引擎
|
||||
query_engine = vector_index.as_query_engine(
|
||||
similarity_top_k=top_k, temperature=temperature
|
||||
)
|
||||
|
||||
# 评估查询引擎
|
||||
correct, total = 0, len(questions)
|
||||
question_answers = [] # 添加列表来收集问题和答案
|
||||
|
||||
for question in questions:
|
||||
response = query_engine.query(question)
|
||||
if response is not None:
|
||||
question_answers.append((question, response.response))
|
||||
eval_result = FaithfulnessEvaluator().evaluate_response(response=response, query_str=question)
|
||||
if eval_result.passing:
|
||||
correct += 1
|
||||
|
||||
# 计算分数
|
||||
score = correct / total if total > 0 else 0
|
||||
return RunResult(score=score, params=params_dict, question_answers=question_answers)
|
||||
|
||||
# 创建 ParamTuner 实例
|
||||
param_tuner = ParamTuner(
|
||||
param_fn=lambda params_dict: objective_function(params_dict, documents, eval_questions),
|
||||
param_dict=param_dict,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
# 调用 tune 方法
|
||||
results = param_tuner.tune()
|
||||
best_result = results.best_run_result
|
||||
best_top_k = best_result.params["top_k"]
|
||||
best_chunk_size = best_result.params["chunk_size"]
|
||||
best_temperature = best_result.params["temperature"]
|
||||
print(f"得分: {best_result.score}")
|
||||
print(f"Top-k: {best_top_k}")
|
||||
print(f"文本块大小: {best_chunk_size}")
|
||||
print(f"温度: {best_temperature}")
|
||||
|
||||
# 使用最佳参数再次运行查询引擎,并打印问题与答案
|
||||
best_vector_index = _build_index(best_chunk_size, documents)
|
||||
best_query_engine = best_vector_index.as_query_engine(
|
||||
similarity_top_k=best_top_k, temperature=best_temperature
|
||||
)
|
||||
|
||||
best_question_answers = []
|
||||
for question in eval_questions:
|
||||
response = best_query_engine.query(question)
|
||||
if response is not None:
|
||||
best_question_answers.append((question, response.response))
|
||||
|
||||
# 打印最佳参数下的问题与答案
|
||||
for i, (question, answer) in enumerate(best_question_answers, start=1):
|
||||
print(f"最佳参数 - 问题 {i}: {question}\n答案: {answer}\n")
|
||||
@@ -0,0 +1,81 @@
|
||||
from app.observability import init_observability
|
||||
from app.settings import init_settings
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core import (
|
||||
VectorStoreIndex,
|
||||
SimpleDirectoryReader,
|
||||
Response,
|
||||
)
|
||||
from llama_index.core.evaluation import (
|
||||
FaithfulnessEvaluator,
|
||||
DatasetGenerator,
|
||||
CorrectnessEvaluator,
|
||||
SemanticSimilarityEvaluator,)
|
||||
|
||||
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
faith_evaluator_qwen = FaithfulnessEvaluator() #诚实度评测
|
||||
corr_evaluator_qwen = CorrectnessEvaluator() #准确率评测
|
||||
Seman_evaluator_qwen = SemanticSimilarityEvaluator()#嵌入相似度评估
|
||||
|
||||
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
||||
|
||||
splitter = SentenceSplitter(chunk_size=512)
|
||||
|
||||
|
||||
vector_index = VectorStoreIndex.from_documents(
|
||||
documents, transformations=[splitter],
|
||||
)
|
||||
|
||||
|
||||
# # 运行评估
|
||||
# query_engine = vector_index.as_query_engine()
|
||||
# response_vector = query_engine.query("工程监理费的金额是多少?")
|
||||
# eval_result = evaluator_qwen.evaluate_response(response=response_vector)
|
||||
|
||||
# print(response_vector)
|
||||
# print(eval_result)
|
||||
|
||||
|
||||
question_generator = DatasetGenerator.from_documents(documents)
|
||||
eval_questions = question_generator.generate_questions_from_nodes(5)
|
||||
print(eval_questions)
|
||||
|
||||
import asyncio
|
||||
|
||||
async def evaluate_query_engine_async(query_engine, questions):
|
||||
c = [query_engine.aquery(q) for q in questions]
|
||||
gathering_future = asyncio.gather(*c)
|
||||
results = await gathering_future
|
||||
#print(results)
|
||||
|
||||
total_correct = 0
|
||||
for r in results:
|
||||
eval_result = (
|
||||
1 if faith_evaluator_qwen.evaluate_response(response=r).passing else 0
|
||||
)
|
||||
total_correct += eval_result
|
||||
|
||||
return total_correct, len(results)
|
||||
|
||||
def evaluate_query_engine(query_engine, questions):
|
||||
loop = asyncio.get_event_loop()
|
||||
correct, total = loop.run_until_complete(evaluate_query_engine_async(query_engine, questions))
|
||||
return correct, total
|
||||
|
||||
# 使用 evaluate_query_engine 函数
|
||||
vector_query_engine = vector_index.as_query_engine()
|
||||
correct, total = evaluate_query_engine(vector_query_engine, eval_questions[:5])
|
||||
|
||||
print(f"score: {correct}/{total}")
|
||||
@@ -0,0 +1,121 @@
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from llama_index.core.evaluation import CorrectnessEvaluator
|
||||
from app.engine import get_chat_engine
|
||||
from app.engine.index import get_index
|
||||
from app.observability import init_observability
|
||||
from app.settings import init_settings
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
index = get_index()
|
||||
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
from llama_index.core.prompts import (
|
||||
ChatMessage,
|
||||
ChatPromptTemplate,
|
||||
MessageRole
|
||||
)
|
||||
|
||||
DEFAULT_SYSTEM_TEMPLATE = """
|
||||
您是一个问答聊天机器人的专业评估系统。
|
||||
|
||||
您将获得以下信息:
|
||||
|
||||
- 用户查询,
|
||||
- 生成的回答,
|
||||
|
||||
也可能提供一个参考答案作为评估的依据。
|
||||
|
||||
您的任务是判断生成回答的相关性和正确性。
|
||||
输出一个代表全面评估的单一分数。
|
||||
您必须在一行中仅返回该分数。
|
||||
不要以其他任何格式返回答案。
|
||||
在单独的一行提供给定分数的理由。
|
||||
|
||||
请遵循以下评分指南:
|
||||
|
||||
- 您的分数必须在1到5之间,其中1是最差,5是最好的。
|
||||
-如果生成的回答与用户查询不相关,您应该给出1分。
|
||||
-如果生成的回答相关但包含错误,您应该给出2到3分之间的分数。
|
||||
-如果生成的回答相关且完全正确,您应该给出4到5分之间的分数。
|
||||
示例响应:
|
||||
4.0
|
||||
生成的回答与参考答案的指标完全相同,但不够精炼。
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_USER_TEMPLATE = """
|
||||
## User Query
|
||||
{query}
|
||||
|
||||
## Reference Answer
|
||||
{reference_answer}
|
||||
|
||||
## Generated Answer
|
||||
{generated_answer}
|
||||
"""
|
||||
|
||||
DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
|
||||
message_templates=[
|
||||
ChatMessage(role=MessageRole.SYSTEM, content=DEFAULT_SYSTEM_TEMPLATE),
|
||||
ChatMessage(role=MessageRole.USER, content=DEFAULT_USER_TEMPLATE),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# 初始化聊天引擎和评估器
|
||||
chat_engine = get_chat_engine()
|
||||
corr_evaluator_qwen = CorrectnessEvaluator()
|
||||
|
||||
# 加载本地问题回答文件
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(script_dir, 'questions_and_answers.json')
|
||||
output_file_path = file_path.replace('.json', '_test.json')
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 异步函数用于评估查询
|
||||
async def evaluate_query(question, answer, index, output_file):
|
||||
response = await chat_engine.astream_chat(question)
|
||||
|
||||
# 检查sources是否为空
|
||||
if response.sources:
|
||||
content_str = str(response.sources[0])
|
||||
else:
|
||||
content_str = "<无回答>"
|
||||
|
||||
result = corr_evaluator_qwen.evaluate(
|
||||
query=question,
|
||||
response=content_str,
|
||||
reference=answer,
|
||||
)
|
||||
|
||||
result_dict = {
|
||||
"编号": index,
|
||||
"问题": question,
|
||||
"答案": answer,
|
||||
"回答": result.response,
|
||||
"得分(1~5)": result.score,
|
||||
"评价": result.feedback
|
||||
}
|
||||
|
||||
with open(output_file, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(result_dict, ensure_ascii=False, indent=4))
|
||||
f.write(',\n')
|
||||
|
||||
# 主异步函数
|
||||
async def main():
|
||||
for index, item in enumerate(data, start=1):
|
||||
await evaluate_query(item['question'], item['answer'], index, output_file_path)
|
||||
|
||||
# 运行主协程
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,55 @@
|
||||
Attribute_Prompt = (
|
||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||
"现在需要你针对数据中属性一列进行提问和回答。"
|
||||
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元),工程总投资(万元)是77469835.590045万元','尖峰及施工基面土石方量,尖峰及施工基面土石方量是8377.6','截止阀的编码,截止阀的编码是F01010203',"
|
||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||
)
|
||||
|
||||
|
||||
Amount_Prompt = (
|
||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||
"现在需要你针对上下文信息中的金额或者合价进行提问和回答。"
|
||||
"问题和回答的示例应该是这种类型的,示例:'项目建设技术服务费的金额,项目建设技术服务费的金额是16855957065.4302','项目后评价费的费率,项目后评价费的费率是0.5','架空输电线路本体工程的金额,架空输电线路本体工程的金额是55105688268.5176','工程静态投资的金额,工程静态投资的金额是715035853336.391'"
|
||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||
)
|
||||
|
||||
|
||||
|
||||
Units_Prompt = (
|
||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||
"现在需要你针对上下文信息来进行单位转化问题提问和回答。"
|
||||
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元)结果用元表示,工程总投资(万元)是774698355900.45元','本体工程(元)结果用万元表示,本体工程(元)是5490494.261046万元'"
|
||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||
)
|
||||
|
||||
Name_Prompt = (
|
||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||
"现在需要你针对上下文信息中的重名问题进行提问和回答。"
|
||||
"问题和回答的示例应该是这种类型的,示例:'专业类型为线路的杆塔工程项目划分的合价,专业类型为线路的杆塔工程项目划分的合价是220969744.905856','专业类型为线路清理的杆塔工程项目划分的合价,电缆工程的合价是0'"
|
||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||
)
|
||||
|
||||
|
||||
All_Amount_Prompt = (
|
||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||
"现在需要你针对上下文信息中的总体金额进行提问和回答。"
|
||||
"问题和回答的示例应该是这种类型的,示例:'架空输电线路本体工程的总体金额,架空输电线路本体工程的总体金额是7.706703','工程静态投资的总体金额,工程静态投资的总体金额是100'"
|
||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
from app.observability import init_observability
|
||||
from app.settings import init_settings
|
||||
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.evaluation import DatasetGenerator
|
||||
|
||||
import prompts
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
# 读取所有文档(即所有表格)
|
||||
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
||||
|
||||
# 定义表格名称和索引的对应关系
|
||||
table_names = {
|
||||
"工程信息表": 0,
|
||||
"其他费用表": 1,
|
||||
"取费表": 2,
|
||||
"项目划分表": 3,
|
||||
"项目划分_费用预览表": 4,
|
||||
"总算表": 5,
|
||||
"工程量表": 6
|
||||
}
|
||||
|
||||
# 定义中文提示词和Python代码中提示词名称的映射
|
||||
prompt_mapping = {
|
||||
"普通属性": "Attribute_Prompt",
|
||||
"金额查询": "Amount_Prompt",
|
||||
"单位换算": "Units_Prompt",
|
||||
"重名项目划分": "Name_Prompt",
|
||||
"总体金额查询": "All_Amount_Prompt"
|
||||
}
|
||||
|
||||
# 定义表格与其对应的查询类别
|
||||
table_prompt_mapping = {
|
||||
"工程信息表": ["普通属性", "单位换算"],
|
||||
"其他费用表": ["金额查询", "单位换算"],
|
||||
"取费表": ["金额查询"],
|
||||
"总算表": ["金额查询", "总体金额查询"],
|
||||
"工程量表": ["普通属性", "重名项目划分"]
|
||||
}
|
||||
|
||||
# 根据表格名称选择特定的表格
|
||||
def select_document(documents, table_name):
|
||||
if table_name not in table_names:
|
||||
raise ValueError(f"未找到名为 '{table_name}' 的表格")
|
||||
index = table_names[table_name]
|
||||
return [documents[index]] # 返回一个包含所选表格的列表
|
||||
|
||||
# 选择提示词
|
||||
def select_prompt(prompt_category):
|
||||
prompt_name = prompt_mapping.get(prompt_category)
|
||||
if not prompt_name:
|
||||
raise ValueError(f"未找到名为 '{prompt_category}' 的提示词")
|
||||
try:
|
||||
return getattr(prompts, prompt_name)
|
||||
except AttributeError:
|
||||
raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数")
|
||||
|
||||
# 生成问题和答案
|
||||
def generate_questions_from_document(document, quest_prompt, num_questions):
|
||||
question_generator = DatasetGenerator.from_documents(
|
||||
documents=document,
|
||||
question_gen_query=quest_prompt,
|
||||
num_questions_per_chunk=num_questions
|
||||
)
|
||||
|
||||
eval_questions = question_generator.generate_questions_from_nodes(num_questions)
|
||||
print(eval_questions)
|
||||
|
||||
qa_pairs = []
|
||||
for qa in eval_questions:
|
||||
if ',' in qa:
|
||||
question, answer = qa.split(",", 1)
|
||||
qa_pairs.append({
|
||||
"question": question.strip(),
|
||||
"answer": answer.strip()
|
||||
})
|
||||
else:
|
||||
print(f"无法处理的问题和答案: {qa}")
|
||||
|
||||
return qa_pairs
|
||||
|
||||
# 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中
|
||||
def main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt):
|
||||
if table_names_input == "all":
|
||||
selected_tables = list(table_prompt_mapping.keys())
|
||||
else:
|
||||
selected_tables = table_names_input.strip('[]').split(',')
|
||||
|
||||
all_results = {}
|
||||
|
||||
for table_name in selected_tables:
|
||||
table_name = table_name.strip() # 去掉前后空格
|
||||
document = select_document(documents, table_name)
|
||||
|
||||
if prompt_categories_input == "all":
|
||||
selected_prompts = table_prompt_mapping[table_name]
|
||||
else:
|
||||
selected_prompts = prompt_categories_input.strip('[]').split(',')
|
||||
selected_prompts = [p.strip() for p in selected_prompts] # 去掉前后空格
|
||||
|
||||
for prompt_category in selected_prompts:
|
||||
if prompt_category not in table_prompt_mapping[table_name]:
|
||||
print(f"跳过表格 '{table_name}' 的提示词 '{prompt_category}',因为该表中不包含该类别的信息")
|
||||
continue
|
||||
|
||||
quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt)
|
||||
qa_pairs = generate_questions_from_document(document, quest_prompt, num_questions_per_prompt)
|
||||
|
||||
label = f"test:{table_name}_{prompt_category}"
|
||||
all_results[label] = qa_pairs
|
||||
|
||||
# 自动生成输出文件名
|
||||
output_file = "combined_test.json"
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(all_results, f, ensure_ascii=False, indent=4)
|
||||
|
||||
print(f"All questions and answers have been saved to '{output_file}'")
|
||||
|
||||
# 获取命令行参数
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 4:
|
||||
print("Usage: python script.py <table_names_input> <prompt_categories_input> <num_questions_per_prompt>")
|
||||
else:
|
||||
table_names_input = sys.argv[1]
|
||||
prompt_categories_input = sys.argv[2]
|
||||
num_questions_per_prompt = int(sys.argv[3])
|
||||
|
||||
main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt)
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import phoenix as px
|
||||
|
||||
|
||||
os.environ['PHOENIX_HOST'] = "0.0.0.0"
|
||||
|
||||
session = px.launch_app(use_temp_dir=False)
|
||||
|
||||
import msvcrt
|
||||
|
||||
Submodule
+1
Submodule webapp added at 77dbc14a64
Reference in New Issue
Block a user