24 Commits

Author SHA1 Message Date
chentianrui e634746a52 修改了单元测试的问题生成代码 2024-09-06 18:43:57 +08:00
chentianrui d12800e14e 修改了单元测试的问题生成代码 2024-09-06 18:40:54 +08:00
chentianrui c1df0d1bba Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-09-05 17:03:29 +08:00
chentianrui 0664952ecd 增加了问题生成脚本 2024-09-05 17:02:42 +08:00
ly 7023b54246 解决xinference内嵌模型类使用问题。由于目前xinference组件的版本和llamaindex最新版有冲突,所以未更新支持xinference的内嵌模型的版本 2024-09-05 12:12:31 +08:00
ly aee6aa3c04 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-09-05 11:36:53 +08:00
chentianrui 680e24c516 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-30 18:40:32 +08:00
chentianrui 6663ee8976 新增加了单元测试 2024-08-30 18:40:21 +08:00
ly 0a5f335981 调整NLTK数据目录和JIEBA字典位置到本项目中,避免重新安装时需要从网上下载 2024-08-30 01:20:29 +08:00
ly 2901bd9eaf 优化导入,解决初始化LLAMAINDEX过程中环境变量没起作用问题 2024-08-30 01:16:35 +08:00
ly 453b3ca55c Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-30 00:00:57 +08:00
wanyaokun 03c4eb1af1 优化ChatCallbackEvent事件代码 2024-08-29 19:52:53 +08:00
wanyaokun 480a1f7fdc 新增工程配置信息 2024-08-29 19:03:38 +08:00
wanyaokun cdc9d84a1e Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-29 19:01:40 +08:00
wanyaokun 50f35bb0c9 优化Web事件代码 2024-08-29 19:00:25 +08:00
chentianrui 4a8c79e83d 参数优化针对问题做出了调整 2024-08-29 15:09:55 +08:00
ly f0afd1a4bb Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-29 12:03:28 +08:00
chentianrui de34c3938c 增加了参数评估 2024-08-29 12:02:53 +08:00
ly eb572eff27 增加加载环境变量功能 2024-08-29 11:54:20 +08:00
chentianrui 2706cf9d5a 更新了依赖包 2024-08-29 11:41:42 +08:00
chentianrui 5fa4752d6e Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-29 11:39:06 +08:00
chentianrui aff1793c4e 新增了参数评估脚本和评分脚本 2024-08-29 11:38:45 +08:00
chentianrui 3ee1ba529f Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-28 18:12:37 +08:00
chentianrui 576a2ae737 增加了评估脚本 2024-08-28 18:12:28 +08:00
22 changed files with 349975 additions and 303 deletions
+6
View File
@@ -1,3 +1,8 @@
JIEBA_DATA=./nltk_data
NLTK_DATA=./nltk_data
SQLITE_DATABASE_URL=sqlite:///./source.db
DATA_SOURCE_CACHE=./restapi
# The Llama Cloud API key. # The Llama Cloud API key.
# LLAMA_CLOUD_API_KEY= # LLAMA_CLOUD_API_KEY=
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
@@ -80,3 +85,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
- You can install any pip package (if it exists) by running a cell with pip install. - You can install any pip package (if it exists) by running a cell with pip install.
" "
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
+6
View File
@@ -1,3 +1,8 @@
JIEBA_DATA=./nltk_data
NLTK_DATA=./nltk_data
SQLITE_DATABASE_URL=sqlite:///./source.db
DATA_SOURCE_CACHE=./restapi
# The Llama Cloud API key. # The Llama Cloud API key.
# LLAMA_CLOUD_API_KEY= # LLAMA_CLOUD_API_KEY=
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
@@ -111,3 +116,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
- You can install any pip package (if it exists) by running a cell with pip install. - You can install any pip package (if it exists) by running a cell with pip install.
" "
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
+232 -229
View File
@@ -3,6 +3,7 @@ import json
import logging import logging
import time import time
from typing import Dict, List, Any, Optional, AsyncGenerator from typing import Dict, List, Any, Optional, AsyncGenerator
from collections import deque
from aiostream import stream from aiostream import stream
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
@@ -13,7 +14,8 @@ from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput from llama_index.core.tools import ToolOutput
from pydantic import BaseModel from pydantic import BaseModel
from app.api.routers.request.base import userMng, conversations,message,parameter from app.api.routers.request.base import userMng, conversations,message,parameter,feedback
from app.api.routers.request.baseConfig import *
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine from app.engine import get_chat_engine
import uuid import uuid
@@ -24,78 +26,138 @@ api_router = r = APIRouter()
v1_router = v = APIRouter() v1_router = v = APIRouter()
class ChatCallbackEvent(BaseModel): class ChatCallbackEvent(BaseModel):
event_type: CBEventType event_type: ChatEventType
payload: Optional[Dict[str, Any]] = None payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_retrieval_message(self) -> dict | None: def get_common_param(self)-> dict:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"根据查询检索到 {len(nodes)} 源文件"
else:
msg = f"查询检索中: '{self.payload.get('query_str')}'"
return { return {
"type": "events", 'event': self.event_type.name,
"data": {"title": msg}, 'conversation_id':self.payload.get("conversation_id"),
} 'message_id': self.payload.get("message_id"),
else: 'created_at': int(time.time()),
return None 'task_id': self.payload.get("task_id")
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
},
} }
def _is_output_serializable(self, output: Any) -> bool: def get_WorkflowStart_param(self) -> dict:
try: params = self.get_common_param()
json.dumps(output) params.update({
return True 'workflow_run_id':self.payload.get('workflow_run_id'),
except TypeError: 'data':{
return False "id": self.payload.get('workflow_run_id'),
"workflow_id": self.payload.get('workflow_id'),
def get_agent_tool_response(self) -> dict | None: "sequence_number": 1709,
response = self.payload.get("response") "inputs": {
if response is not None: "sys.query": self.payload.get('query'),
sources = response.sources "sys.files": [],
for source in sources: "sys.conversation_id": self.payload.get('conversation_id'),
# Return the tool response here to include the toolCall information "sys.user_id": self.payload.get('use_id')
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
}, },
"created_at": int(time.time())
} }
})
return params
def to_response(self): def get_WorkflowFinished_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('workflow_run_id'),
"workflow_id": self.payload.get('workflow_id'),
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": self.payload.get('response')
},
"error": '',
"elapsed_time": 36.03764106379822,
"total_tokens": 11707,
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": self.payload.get('use_id')
},
"created_at": int(time.time()),
"finished_at": int(time.time()),
"files": []
}
})
return params
def get_NodeStart_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": self.payload.get('title'),
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
"created_at": 1724398751,
"extras": {}
}
})
return params
def get_NodeFinished_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": self.payload.get('title'),
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
"process_data": '',
"outputs": '',
"status": "succeeded",
"error": '',
"elapsed_time": 0.10402441816404462,
"execution_metadata": '',
"created_at": 1724398751,
"finished_at": 1724398751,
"files": []
}
})
return params
def get_Message_param(self) -> dict:
params = self.get_common_param()
params.update({
'id':self.payload.get('message_id'),
'answer':self.payload.get('answer')
})
return params
def get_MessageEnd_param(self) -> dict:
params = self.get_common_param()
params.update({
'id':self.payload.get('message_id'),
'metadata':self.payload.get('metadata')
})
return params
def to_response(self)-> dict|None:
try: try:
match self.event_type: match self.event_type:
case "retrieve": case "workflow_started":
return self.get_retrieval_message() return self.get_WorkflowStart_param()
case "function_call": case "workflow_finished":
return self.get_tool_message() return self.get_WorkflowFinished_param()
case "agent_step": case "node_started":
return self.get_agent_tool_response() return self.get_NodeStart_param()
case 'node_finished':
return self.get_NodeFinished_param()
case 'message':
return self.get_Message_param()
case 'message_end':
return self.get_MessageEnd_param()
case _: case _:
return None return None
except Exception as e: except Exception as e:
@@ -106,9 +168,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue _aqueue: asyncio.Queue
is_done: bool = False is_done: bool = False
def __init__( def __init__(self,**params):
self,
):
"""Initialize the base callback handler.""" """Initialize the base callback handler."""
ignored_events = [ ignored_events = [
# CBEventType.CHUNKING, # CBEventType.CHUNKING,
@@ -119,6 +179,23 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
] ]
super().__init__(ignored_events, ignored_events) super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue() self._aqueue = asyncio.Queue()
self._response:str = ''
self._params:Dict[str,Any] = params
self._nodeStack:deque = deque()
#添加工作流开始事件
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'use_id': data.user,
'query': data.query,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
def on_event_start( def on_event_start(
self, self,
@@ -129,9 +206,21 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
) -> str: ) -> str:
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload)) logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload) self._nodeStack.append(event_id)
if event.to_response() is not None: nindex = self._nodeStack.count() - 1
self._aqueue.put_nowait(event) args:Dict[str,Any] = self._params['ids']
args.update(
{
'nodeid':event_id,
'title':event_type.name,
'index':nindex + 1,
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
}
)
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args)
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
def on_event_end( def on_event_end(
self, self,
@@ -141,9 +230,25 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload)) logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None: #self.response = payload.get("response","")
self._aqueue.put_nowait(event) args:Dict[str,Any] = self._params['ids']
nodeID = self._nodeStack[-1]
if nodeID == event_id:
nindex = self._nodeStack.count() - 1
args.update(
{
'nodeid':event_id,
'title':event_type.name,
'index':nindex + 1,
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
}
)
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args)
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
self._nodeStack.pop()
def start_trace(self, trace_id: Optional[str] = None) -> None: def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op.""" """No-op."""
@@ -156,6 +261,23 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""No-op.""" """No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map)) logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'response':self._response,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
args:Dict[str,Any] = self._params['ids']
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event)
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]: async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done: while not self._aqueue.empty() or not self.is_done:
@@ -173,95 +295,26 @@ class IDManager:
"workflow_id": str(uuid.uuid4()) "workflow_id": str(uuid.uuid4())
} }
class DifyChatResponseEvent(BaseModel):
event: str
conversation_id: str
message_id: str
created_at: int = int(time.time())
task_id: str
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"inputs": {
"sys.query": args['query'],
"sys.files": [],
"sys.conversation_id": args['conversation_id'],
"sys.user_id": args['use_id']
},
"created_at": int(time.time())
}
args['event'] = 'workflow_started'
super().__init__(**args)
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['event'] = 'workflow_finished'
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": args['response']
},
"error": '',
"elapsed_time": 36.03764106379822,
"total_tokens": 11707,
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": args['use_id']
},
"created_at": int(time.time()),
"finished_at": int(time.time()),
"files": []
}
super().__init__(**args)
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
answer:str
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message'
super().__init__(**args)
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
metadata:Dict[str,Any] = {}
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message_end'
super().__init__(**args)
class ChatStreamResponse(StreamingResponse): class ChatStreamResponse(StreamingResponse):
TEXT_PREFIX = "data: " TEXT_PREFIX = "data: "
DATA_PREFIX = "data: " DATA_PREFIX = "data: "
ids:Dict[str,Any] = {}
data:ChatRequestData = None
@classmethod @classmethod
def convert_text(cls, token: str): def convert_Message(cls, token: str):
# Escape newlines and double quotes to avoid breaking the stream params = cls.ids
#token = json.dumps(token) params.update({
'answer':token,
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}" 'conversation_id':cls.data.conversation_id
return "\n" })
event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params)
@classmethod data_str = json.dumps(event.to_response())
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n\n" return f"{cls.DATA_PREFIX}{data_str}\n\n"
@classmethod @classmethod
def convert_event(cls, event: DifyChatResponseEvent): def convert_Event(cls, data: dict):
data_str = json.dumps(event.dict()) data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n\n" return f"{cls.DATA_PREFIX}{data_str}\n\n"
def __init__( def __init__(
@@ -269,8 +322,11 @@ class ChatStreamResponse(StreamingResponse):
request: Request, request: Request,
event_handler: ChatEventCallbackHandler, event_handler: ChatEventCallbackHandler,
response: StreamingAgentChatResponse, response: StreamingAgentChatResponse,
data: ChatRequestData data: ChatRequestData,
ids:Dict[str,Any]
): ):
ChatStreamResponse.ids = ids
ChatStreamResponse.data = data
content = ChatStreamResponse.content_generator( content = ChatStreamResponse.content_generator(
request, event_handler, response, data request, event_handler, response, data
) )
@@ -284,41 +340,26 @@ class ChatStreamResponse(StreamingResponse):
response: StreamingAgentChatResponse, response: StreamingAgentChatResponse,
data: ChatRequestData data: ChatRequestData
): ):
ids = IDManager().createID()
# Yield the text response # Yield the text response
async def _chat_response_generator(): async def _chat_response_generator():
final_response = "" final_response = ""
async for token in response.async_response_gen(): async for token in response.async_response_gen():
final_response += token final_response += token
args = ids yield ChatStreamResponse.convert_Message(token)
args['answer'] = token
args['conversation_id'] = data.conversation_id
event = Message_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(event)
#yield ChatStreamResponse.convert_text(token)
# 存储消息历史 # 存储消息历史
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response) message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
# the text_generator is the leading stream, once it's finished, also finish the event stream # the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True event_handler.is_done = True
# 发送工作流结束事件
args = ids
args['response'] = final_response
args['conversation_id'] = data.conversation_id
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids)
yield ChatStreamResponse.convert_event(msgEnt_event)
# Yield the events from the event handler # Yield the events from the event handler
async def _event_generator(): async def _event_generator():
async for event in event_handler.async_event_gen(): async for event in event_handler.async_event_gen():
event_response = event.to_response() event_response = event.to_response()
if event_response is not None: if event_response is not None:
yield ChatStreamResponse.convert_text("") yield ChatStreamResponse.convert_Event(event_response)
combine = stream.merge(_chat_response_generator(), _event_generator()) combine = stream.merge(_chat_response_generator(), _event_generator())
is_stream_started = False is_stream_started = False
@@ -327,25 +368,11 @@ class ChatStreamResponse(StreamingResponse):
if not is_stream_started: if not is_stream_started:
is_stream_started = True is_stream_started = True
# 发送工作流开始事件
args = ids
args['use_id'] = data.user
args['query'] = data.query
args['conversation_id'] = data.conversation_id
wf_event = Workflow_started_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
# Stream a blank message to start the stream
# 发送一个空消息事件
#yield ChatStreamResponse.convert_text("")
yield output yield output
if await request.is_disconnected(): if await request.is_disconnected():
break break
@v.post("/chat-messages") @v.post("/chat-messages")
async def post_conversations(request: Request, data: ChatRequestData): async def post_conversations(request: Request, data: ChatRequestData):
userMng.findNoExistCreate(data.user) userMng.findNoExistCreate(data.user)
@@ -365,14 +392,15 @@ async def post_conversations(request: Request, data: ChatRequestData):
chat_engine = get_chat_engine(filters=filters, params=params) chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听 # 启动聊天事件监听
event_handler = ChatEventCallbackHandler() ids = IDManager().createID()
event_handler = ChatEventCallbackHandler(ids = ids,data = data)
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
# 执行异步聊天 # 执行异步聊天
response = await chat_engine.astream_chat(data.query) response = await chat_engine.astream_chat(data.query)
# 返回异步消息回应 # 返回异步消息回应
return ChatStreamResponse(request, event_handler, response, data) return ChatStreamResponse(request, event_handler, response, data,ids)
@v.get("/messages") @v.get("/messages")
async def query_messages(user:str, conversation_id:str): async def query_messages(user:str, conversation_id:str):
@@ -388,8 +416,9 @@ async def query_messages(user:str, conversation_id:str):
for record in records: for record in records:
res = record.dict() res = record.dict()
feeds = feedback().query(res['id'])
res["message_files"] = [] res["message_files"] = []
res["feedback"] = '' res["feedback"] = {'rating':feeds['rating'] } if feeds != None else ''
res["retriever_resources"] = [] res["retriever_resources"] = []
res["created_at"] = 1723444905 res["created_at"] = 1723444905
res["agent_thoughts"] = [] res["agent_thoughts"] = []
@@ -440,48 +469,22 @@ async def query_conversations(user:str, first_id:str = None, limit:str = None, p
async def query_parameters(user:str): async def query_parameters(user:str):
params = parameter().get(user) params = parameter().get(user)
if len(params) == 0: if len(params) == 0:
params = { params = BaseConfig().ParamterCfg()
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [],
"suggested_questions_after_answer": {
"enabled": False
},
"speech_to_text": {
"enabled": False
},
"text_to_speech": {
"enabled": False,
"language": "",
"voice": ""
},
"retriever_resource": {
"enabled": True
},
"annotation_reply": {
"enabled": False
},
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"sensitive_word_avoidance": {
"enabled": False
},
"file_upload": {
"image": {
"enabled": False,
"number_limits": 3,
"transfer_methods": [
"remote_url"
]
}
},
"system_parameters": {
"image_file_size_limit": "10"
}
}
return params return params
@v.post("/messages/{message_id}/feedbacks")
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
if params['rating'] =='null':
feedback().delete(message_id)
else:
condition = {'id':message_id}
results = message().query(**condition)
if len(results) > 0:
result = results[0]
feedback().add(message_id=message_id,query=result['query'],
answer=result['answer'],rating=params['rating'])
@r.post("") @r.post("")
def upload_file(request: ChatFileUploadRequest) -> List[str]: def upload_file(request: ChatFileUploadRequest) -> List[str]:
pass pass
+32 -2
View File
@@ -25,7 +25,7 @@ class conversations:
return None return None
def add(self,id:str, user_id:str, name:str): def add(self,id:str, user_id:str, name:str):
template = BaseConfig.ConversationCfg template = BaseConfig().ConversationCfg()
template['id'] = id template['id'] = id
template['user_id'] = user_id template['user_id'] = user_id
template['name'] = name template['name'] = name
@@ -111,7 +111,7 @@ class message:
return datas return datas
def add(self,user_id:str,conversation_id:str,query:str,answer:str): def add(self,user_id:str,conversation_id:str,query:str,answer:str):
template = BaseConfig.MessageCfg template = BaseConfig.MessageCfg()
template['id'] = str(uuid.uuid4()) template['id'] = str(uuid.uuid4())
template['user_id'] = user_id template['user_id'] = user_id
template['conversation_id'] = conversation_id template['conversation_id'] = conversation_id
@@ -122,4 +122,34 @@ class message:
def delete(self,user_id:str): def delete(self,user_id:str):
dbManage.delete(self._tableName,user_id = user_id) dbManage.delete(self._tableName,user_id = user_id)
def query(self,**condition):
results = []
records = dbManage.query(self._tableName,**condition)
for record in records:
results.append(record.dict())
return results
class feedback:
def __init__(self) -> None:
self._tableName = 'feedbacks'
dbManage.createTable(self._tableName)
def add(self,message_id:str,query:str,answer:str,rating:str):
record = {
'message_id': message_id,
'query': query,
'answer': answer,
'rating': rating,
}
dbManage.addRecord(self._tableName,record)
def delete(self,message_id:str):
cond = {'message_id':message_id}
dbManage.delete(self._tableName,**cond)
def query(self,message_id:str):
cond = {'message_id':message_id}
records = dbManage.query(self._tableName,**cond)
if len(records) > 0:
return records[0].dict()
return None
+26 -8
View File
@@ -1,8 +1,15 @@
from pydantic import BaseModel
import os
from enum import Enum
class BaseConfig: class BaseConfig(BaseModel):
ParamterCfg = { projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [], def ParamterCfg(self):
questions = os.getenv("CONVERSATION_STARTERS", "dev")
return{
"opening_statement": self.projectInfo,
"suggested_questions": questions.split('\n'),
"suggested_questions_after_answer": { "suggested_questions_after_answer": {
"enabled": False "enabled": False
}, },
@@ -41,18 +48,20 @@ class BaseConfig:
} }
} }
ConversationCfg = { def ConversationCfg(self):
return{
"id": "", "id": "",
'user_id':'', 'user_id':'',
"name": "", "name": "",
"inputs": {}, "inputs": {},
"status": "normal", "status": "normal",
"introduction": ParamterCfg['opening_statement'], "introduction": self.projectInfo,
"created_at":'' "created_at":''
} }
@classmethod
MessageCfg = { def MessageCfg(cls):
return {
"id": "", "id": "",
'user_id':'', 'user_id':'',
"conversation_id": "", "conversation_id": "",
@@ -60,3 +69,12 @@ class BaseConfig:
"query": "", "query": "",
"answer": "" "answer": ""
} }
class ChatEventType(str, Enum):
WORKFLOW_START = "workflow_started"
WORKFLOW_FINISHED = "workflow_finished"
NODE_START = "node_started"
NODE_FINISHED = "node_finished"
MESSAGE = "message"
MESSAGE_END = "message_end"
+22 -9
View File
@@ -2,7 +2,7 @@ import os
from typing import Dict, List, Any from typing import Dict, List, Any
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import create_engine, Column, String, Integer, JSON from sqlalchemy import create_engine, Column, String, Integer, JSON,Float
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.orm import sessionmaker, declarative_base
@@ -24,10 +24,6 @@ class ConversationOrm(Base):
if 'name' in data: if 'name' in data:
self.name = data['name'] self.name = data['name']
class UserOrm(Base): class UserOrm(Base):
__tablename__ = "user" __tablename__ = "user"
@@ -51,6 +47,14 @@ class MessagesOrm(Base):
query = Column(String) query = Column(String)
answer = Column(String) answer = Column(String)
class FeedBackOrm(Base):
__tablename__ = "feedbacks"
message_id = Column(String,primary_key=True)
query = Column(String)
answer = Column(String)
rating = Column(String)
#数据结构 #数据结构
class ConversationModel(BaseModel): class ConversationModel(BaseModel):
id: str id: str
@@ -61,7 +65,6 @@ class ConversationModel(BaseModel):
created_at: int created_at: int
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
@@ -73,7 +76,6 @@ class UserModel(BaseModel):
createtime: str createtime: str
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
@@ -86,7 +88,6 @@ class ParametersModel(BaseModel):
value : Dict[str, Any] value : Dict[str, Any]
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
@@ -101,13 +102,25 @@ class MessagesModel(BaseModel):
answer : str answer : str
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
def orm(cls): def orm(cls):
return MessagesOrm return MessagesOrm
class FeedBackModel(BaseModel):
message_id :str
query :str
answer :str
rating :str
class Config:
from_attributes=True
@classmethod
def orm(cls):
return FeedBackOrm
class DBManager: class DBManager:
def __init__(self) -> None: def __init__(self) -> None:
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL") DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
+3 -1
View File
@@ -1,7 +1,7 @@
from typing import Dict, Any from typing import Dict, Any
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional
class ChatRequestData(BaseModel): class ChatRequestData(BaseModel):
inputs: Dict[str,Any] inputs: Dict[str,Any]
@@ -13,3 +13,5 @@ class ChatRequestData(BaseModel):
class ChatFileUploadRequest(BaseModel): class ChatFileUploadRequest(BaseModel):
base64: str base64: str
+5 -1
View File
@@ -1,3 +1,4 @@
import os
from typing import Any, Dict, List, Union, Callable, NamedTuple from typing import Any, Dict, List, Union, Callable, NamedTuple
from bm25s.tokenization import * from bm25s.tokenization import *
@@ -8,9 +9,12 @@ except ImportError:
def tqdm(iterable, *args, **kwargs): def tqdm(iterable, *args, **kwargs):
return iterable return iterable
import jieba
jiebapath = os.environ.get("JIEBA_DATA", "")
jieba.set_dictionary(os.path.join(jiebapath, 'dict.txt')) #设置字典
jieba.initialize() #初始化jeiba
def chinese_tokenizer(text: str) -> List[str]: def chinese_tokenizer(text: str) -> List[str]:
import jieba
from nltk.corpus import stopwords from nltk.corpus import stopwords
tokens = jieba.lcut(text) tokens = jieba.lcut(text)
return [token for token in tokens if token not in stopwords.words('chinese')] return [token for token in tokens if token not in stopwords.words('chinese')]
+1 -2
View File
@@ -3,11 +3,10 @@ from typing import Dict
from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.settings import Settings from llama_index.core.settings import Settings
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
from llama_index.llms.xinference import Xinference from llama_index.llms.xinference import Xinference
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
def get_node_postprocessors(): def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED").title() rerank_enabled = os.getenv("RERANK_ENABLED").title()
-2
View File
@@ -1,7 +1,5 @@
from dotenv import load_dotenv from dotenv import load_dotenv
from llama_index.core.node_parser import SentenceSplitter
load_dotenv() load_dotenv()
import logging import logging
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large Load Diff
Binary file not shown.
+8 -1
View File
@@ -17,7 +17,7 @@ aiostream = "^0.6.2"
llama-index = "0.10.63" llama-index = "0.10.63"
cachetools = "^5.3.3" cachetools = "^5.3.3"
protobuf = "4.25.4" protobuf = "4.25.4"
nltk = "^3.8.2" nltk = "^3.9.1"
jieba = "^0.42.1" jieba = "^0.42.1"
#arize-phoenix = "^4.12.0" #arize-phoenix = "^4.12.0"
@@ -35,6 +35,7 @@ chroma="^0.2.0"
llama-index-vector-stores-chroma = "^0.1.10" llama-index-vector-stores-chroma = "^0.1.10"
llama-index-readers-json = "^0.1.5" llama-index-readers-json = "^0.1.5"
llama-index-retrievers-bm25 = "^0.2.2" llama-index-retrievers-bm25 = "^0.2.2"
llama-index-experimental = "^0.2.0"
duckduckgo_search = "^6.2.6" duckduckgo_search = "^6.2.6"
@@ -62,6 +63,12 @@ version = "^0.8"
version = "0.0.7" version = "0.0.7"
[[tool.poetry.source]]
name = "mirrors"
url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
priority = "default"
[build-system] [build-system]
requires = [ "poetry-core" ] requires = [ "poetry-core" ]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
+138
View File
@@ -0,0 +1,138 @@
import nest_asyncio
nest_asyncio.apply()
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import VectorStoreIndex
from llama_index.core.evaluation import (
FaithfulnessEvaluator,
DatasetGenerator,
CorrectnessEvaluator,
SemanticSimilarityEvaluator,
)
from llama_index.experimental.param_tuner import ParamTuner
from llama_index.experimental.param_tuner.base import RunResult
from llama_index.llms.openai import OpenAI
import asyncio
# 初始化环境
from app.observability import init_observability
from app.settings import init_settings
from dotenv import load_dotenv
load_dotenv()
init_settings()
init_observability()
# 读取文档
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
# 参数字典
param_dict = {
"chunk_size": [512, 1024],
"top_k": [1, 5],
"temperature": [0.1, 1.0]
}
# 辅助函数
def _build_index(chunk_size, documents):
# 构建索引
splitter = SentenceSplitter(chunk_size=chunk_size)
vector_index = VectorStoreIndex.from_documents(
documents, transformations=[splitter],
)
return vector_index
# 评估函数
def evaluate_query_engine(query_engine, questions):
loop = asyncio.get_event_loop()
correct, total = loop.run_until_complete(_evaluate_query_engine_async(query_engine, questions))
return correct, total
async def _evaluate_query_engine_async(query_engine, questions):
c = [query_engine.aquery(q) for q in questions]
gathering_future = asyncio.gather(*c)
results = await gathering_future
total_correct = 0
for r in results:
eval_result = (
1 if FaithfulnessEvaluator().evaluate_response(response=r).passing else 0
)
total_correct += eval_result
return total_correct, len(results)
# 生成问题
question_generator = DatasetGenerator.from_documents(documents)
eval_questions = question_generator.generate_questions_from_nodes(1) # 假设生成10个问题
# 打印生成的问题
for i, q in enumerate(eval_questions, start=1):
print(f"问题 {i}: {q}")
# 目标函数
def objective_function(params_dict, documents, questions):
chunk_size = params_dict["chunk_size"]
top_k = params_dict["top_k"]
temperature = params_dict["temperature"]
# 构建索引
vector_index = _build_index(chunk_size, documents)
# 查询引擎
query_engine = vector_index.as_query_engine(
similarity_top_k=top_k, temperature=temperature
)
# 评估查询引擎
correct, total = 0, len(questions)
question_answers = [] # 添加列表来收集问题和答案
for question in questions:
response = query_engine.query(question)
if response is not None:
question_answers.append((question, response.response))
eval_result = FaithfulnessEvaluator().evaluate_response(response=response, query_str=question)
if eval_result.passing:
correct += 1
# 计算分数
score = correct / total if total > 0 else 0
return RunResult(score=score, params=params_dict, question_answers=question_answers)
# 创建 ParamTuner 实例
param_tuner = ParamTuner(
param_fn=lambda params_dict: objective_function(params_dict, documents, eval_questions),
param_dict=param_dict,
show_progress=True,
)
# 调用 tune 方法
results = param_tuner.tune()
best_result = results.best_run_result
best_top_k = best_result.params["top_k"]
best_chunk_size = best_result.params["chunk_size"]
best_temperature = best_result.params["temperature"]
print(f"得分: {best_result.score}")
print(f"Top-k: {best_top_k}")
print(f"文本块大小: {best_chunk_size}")
print(f"温度: {best_temperature}")
# 使用最佳参数再次运行查询引擎,并打印问题与答案
best_vector_index = _build_index(best_chunk_size, documents)
best_query_engine = best_vector_index.as_query_engine(
similarity_top_k=best_top_k, temperature=best_temperature
)
best_question_answers = []
for question in eval_questions:
response = best_query_engine.query(question)
if response is not None:
best_question_answers.append((question, response.response))
# 打印最佳参数下的问题与答案
for i, (question, answer) in enumerate(best_question_answers, start=1):
print(f"最佳参数 - 问题 {i}: {question}\n答案: {answer}\n")
+81
View File
@@ -0,0 +1,81 @@
from app.observability import init_observability
from app.settings import init_settings
from dotenv import load_dotenv
import nest_asyncio
nest_asyncio.apply()
load_dotenv()
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import (
VectorStoreIndex,
SimpleDirectoryReader,
Response,
)
from llama_index.core.evaluation import (
FaithfulnessEvaluator,
DatasetGenerator,
CorrectnessEvaluator,
SemanticSimilarityEvaluator,)
init_settings()
init_observability()
faith_evaluator_qwen = FaithfulnessEvaluator() #诚实度评测
corr_evaluator_qwen = CorrectnessEvaluator() #准确率评测
Seman_evaluator_qwen = SemanticSimilarityEvaluator()#嵌入相似度评估
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
splitter = SentenceSplitter(chunk_size=512)
vector_index = VectorStoreIndex.from_documents(
documents, transformations=[splitter],
)
# # 运行评估
# query_engine = vector_index.as_query_engine()
# response_vector = query_engine.query("工程监理费的金额是多少?")
# eval_result = evaluator_qwen.evaluate_response(response=response_vector)
# print(response_vector)
# print(eval_result)
question_generator = DatasetGenerator.from_documents(documents)
eval_questions = question_generator.generate_questions_from_nodes(5)
print(eval_questions)
import asyncio
async def evaluate_query_engine_async(query_engine, questions):
c = [query_engine.aquery(q) for q in questions]
gathering_future = asyncio.gather(*c)
results = await gathering_future
#print(results)
total_correct = 0
for r in results:
eval_result = (
1 if faith_evaluator_qwen.evaluate_response(response=r).passing else 0
)
total_correct += eval_result
return total_correct, len(results)
def evaluate_query_engine(query_engine, questions):
loop = asyncio.get_event_loop()
correct, total = loop.run_until_complete(evaluate_query_engine_async(query_engine, questions))
return correct, total
# 使用 evaluate_query_engine 函数
vector_query_engine = vector_index.as_query_engine()
correct, total = evaluate_query_engine(vector_query_engine, eval_questions[:5])
print(f"score: {correct}/{total}")
+121
View File
@@ -0,0 +1,121 @@
from dotenv import load_dotenv
load_dotenv()
from llama_index.core.evaluation import CorrectnessEvaluator
from app.engine import get_chat_engine
from app.engine.index import get_index
from app.observability import init_observability
from app.settings import init_settings
init_settings()
init_observability()
index = get_index()
import os
import json
import asyncio
import nest_asyncio
nest_asyncio.apply()
from llama_index.core.prompts import (
ChatMessage,
ChatPromptTemplate,
MessageRole
)
DEFAULT_SYSTEM_TEMPLATE = """
您是一个问答聊天机器人的专业评估系统。
您将获得以下信息:
- 用户查询,
- 生成的回答,
也可能提供一个参考答案作为评估的依据。
您的任务是判断生成回答的相关性和正确性。
输出一个代表全面评估的单一分数。
您必须在一行中仅返回该分数。
不要以其他任何格式返回答案。
在单独的一行提供给定分数的理由。
请遵循以下评分指南:
- 您的分数必须在1到5之间,其中1是最差,5是最好的。
-如果生成的回答与用户查询不相关,您应该给出1分。
-如果生成的回答相关但包含错误,您应该给出2到3分之间的分数。
-如果生成的回答相关且完全正确,您应该给出4到5分之间的分数。
示例响应:
4.0
生成的回答与参考答案的指标完全相同,但不够精炼。
"""
DEFAULT_USER_TEMPLATE = """
## User Query
{query}
## Reference Answer
{reference_answer}
## Generated Answer
{generated_answer}
"""
DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
message_templates=[
ChatMessage(role=MessageRole.SYSTEM, content=DEFAULT_SYSTEM_TEMPLATE),
ChatMessage(role=MessageRole.USER, content=DEFAULT_USER_TEMPLATE),
]
)
# 初始化聊天引擎和评估器
chat_engine = get_chat_engine()
corr_evaluator_qwen = CorrectnessEvaluator()
# 加载本地问题回答文件
script_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(script_dir, 'questions_and_answers.json')
output_file_path = file_path.replace('.json', '_test.json')
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 异步函数用于评估查询
async def evaluate_query(question, answer, index, output_file):
response = await chat_engine.astream_chat(question)
# 检查sources是否为空
if response.sources:
content_str = str(response.sources[0])
else:
content_str = "<无回答>"
result = corr_evaluator_qwen.evaluate(
query=question,
response=content_str,
reference=answer,
)
result_dict = {
"编号": index,
"问题": question,
"答案": answer,
"回答": result.response,
"得分(1~5)": result.score,
"评价": result.feedback
}
with open(output_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(result_dict, ensure_ascii=False, indent=4))
f.write(',\n')
# 主异步函数
async def main():
for index, item in enumerate(data, start=1):
await evaluate_query(item['question'], item['answer'], index, output_file_path)
# 运行主协程
asyncio.run(main())
+55
View File
@@ -0,0 +1,55 @@
Attribute_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对数据中属性一列进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元),工程总投资(万元)是77469835.590045万元','尖峰及施工基面土石方量,尖峰及施工基面土石方量是8377.6','截止阀的编码,截止阀的编码是F01010203',"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
Amount_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息中的金额或者合价进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'项目建设技术服务费的金额,项目建设技术服务费的金额是16855957065.4302','项目后评价费的费率,项目后评价费的费率是0.5','架空输电线路本体工程的金额,架空输电线路本体工程的金额是55105688268.5176','工程静态投资的金额,工程静态投资的金额是715035853336.391'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
Units_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息来进行单位转化问题提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元)结果用元表示,工程总投资(万元)是774698355900.45元','本体工程(元)结果用万元表示,本体工程(元)是5490494.261046万元'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
Name_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息中的重名问题进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'专业类型为线路的杆塔工程项目划分的合价,专业类型为线路的杆塔工程项目划分的合价是220969744.905856','专业类型为线路清理的杆塔工程项目划分的合价,电缆工程的合价是0'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
All_Amount_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息中的总体金额进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'架空输电线路本体工程的总体金额,架空输电线路本体工程的总体金额是7.706703','工程静态投资的总体金额,工程静态投资的总体金额是100'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
+144
View File
@@ -0,0 +1,144 @@
from dotenv import load_dotenv
load_dotenv()
import json
import sys
from app.observability import init_observability
from app.settings import init_settings
import nest_asyncio
nest_asyncio.apply()
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import SimpleDirectoryReader
from llama_index.core.evaluation import DatasetGenerator
import prompts
init_settings()
init_observability()
# 读取所有文档(即所有表格)
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
# 定义表格名称和索引的对应关系
table_names = {
"工程信息表": 0,
"其他费用表": 1,
"取费表": 2,
"项目划分表": 3,
"项目划分_费用预览表": 4,
"总算表": 5,
"工程量表": 6
}
# 定义中文提示词和Python代码中提示词名称的映射
prompt_mapping = {
"普通属性": "Attribute_Prompt",
"金额查询": "Amount_Prompt",
"单位换算": "Units_Prompt",
"重名项目划分": "Name_Prompt",
"总体金额查询": "All_Amount_Prompt"
}
# 定义表格与其对应的查询类别
table_prompt_mapping = {
"工程信息表": ["普通属性", "单位换算"],
"其他费用表": ["金额查询", "单位换算"],
"取费表": ["金额查询"],
"总算表": ["金额查询", "总体金额查询"],
"工程量表": ["普通属性", "重名项目划分"]
}
# 根据表格名称选择特定的表格
def select_document(documents, table_name):
if table_name not in table_names:
raise ValueError(f"未找到名为 '{table_name}' 的表格")
index = table_names[table_name]
return [documents[index]] # 返回一个包含所选表格的列表
# 选择提示词
def select_prompt(prompt_category):
prompt_name = prompt_mapping.get(prompt_category)
if not prompt_name:
raise ValueError(f"未找到名为 '{prompt_category}' 的提示词")
try:
return getattr(prompts, prompt_name)
except AttributeError:
raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数")
# 生成问题和答案
def generate_questions_from_document(document, quest_prompt, num_questions):
question_generator = DatasetGenerator.from_documents(
documents=document,
question_gen_query=quest_prompt,
num_questions_per_chunk=num_questions
)
eval_questions = question_generator.generate_questions_from_nodes(num_questions)
print(eval_questions)
qa_pairs = []
for qa in eval_questions:
if ',' in qa:
question, answer = qa.split(",", 1)
qa_pairs.append({
"question": question.strip(),
"answer": answer.strip()
})
else:
print(f"无法处理的问题和答案: {qa}")
return qa_pairs
# 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中
def main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt):
if table_names_input == "all":
selected_tables = list(table_prompt_mapping.keys())
else:
selected_tables = table_names_input.strip('[]').split(',')
all_results = {}
for table_name in selected_tables:
table_name = table_name.strip() # 去掉前后空格
document = select_document(documents, table_name)
if prompt_categories_input == "all":
selected_prompts = table_prompt_mapping[table_name]
else:
selected_prompts = prompt_categories_input.strip('[]').split(',')
selected_prompts = [p.strip() for p in selected_prompts] # 去掉前后空格
for prompt_category in selected_prompts:
if prompt_category not in table_prompt_mapping[table_name]:
print(f"跳过表格 '{table_name}' 的提示词 '{prompt_category}',因为该表中不包含该类别的信息")
continue
quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt)
qa_pairs = generate_questions_from_document(document, quest_prompt, num_questions_per_prompt)
label = f"test:{table_name}_{prompt_category}"
all_results[label] = qa_pairs
# 自动生成输出文件名
output_file = "combined_test.json"
with open(output_file, "w", encoding="utf-8") as f:
json.dump(all_results, f, ensure_ascii=False, indent=4)
print(f"All questions and answers have been saved to '{output_file}'")
# 获取命令行参数
if __name__ == "__main__":
if len(sys.argv) != 4:
print("Usage: python script.py <table_names_input> <prompt_categories_input> <num_questions_per_prompt>")
else:
table_names_input = sys.argv[1]
prompt_categories_input = sys.argv[2]
num_questions_per_prompt = int(sys.argv[3])
main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt)
+3 -2
View File
@@ -1,9 +1,10 @@
import os import os
from dotenv import load_dotenv
load_dotenv()
import phoenix as px import phoenix as px
os.environ['PHOENIX_HOST'] = "0.0.0.0"
session = px.launch_app(use_temp_dir=False) session = px.launch_app(use_temp_dir=False)
import msvcrt import msvcrt