29 Commits

Author SHA1 Message Date
chentianrui e634746a52 修改了单元测试的问题生成代码 2024-09-06 18:43:57 +08:00
chentianrui d12800e14e 修改了单元测试的问题生成代码 2024-09-06 18:40:54 +08:00
chentianrui c1df0d1bba Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-09-05 17:03:29 +08:00
chentianrui 0664952ecd 增加了问题生成脚本 2024-09-05 17:02:42 +08:00
ly 7023b54246 解决xinference内嵌模型类使用问题。由于目前xinference组件的版本和llamaindex最新版有冲突,所以未更新支持xinference的内嵌模型的版本 2024-09-05 12:12:31 +08:00
ly aee6aa3c04 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-09-05 11:36:53 +08:00
chentianrui 680e24c516 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-30 18:40:32 +08:00
chentianrui 6663ee8976 新增加了单元测试 2024-08-30 18:40:21 +08:00
ly 0a5f335981 调整NLTK数据目录和JIEBA字典位置到本项目中,避免重新安装时需要从网上下载 2024-08-30 01:20:29 +08:00
ly 2901bd9eaf 优化导入,解决初始化LLAMAINDEX过程中环境变量没起作用问题 2024-08-30 01:16:35 +08:00
ly 453b3ca55c Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-30 00:00:57 +08:00
wanyaokun 03c4eb1af1 优化ChatCallbackEvent事件代码 2024-08-29 19:52:53 +08:00
wanyaokun 480a1f7fdc 新增工程配置信息 2024-08-29 19:03:38 +08:00
wanyaokun cdc9d84a1e Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-29 19:01:40 +08:00
wanyaokun 50f35bb0c9 优化Web事件代码 2024-08-29 19:00:25 +08:00
chentianrui 4a8c79e83d 参数优化针对问题做出了调整 2024-08-29 15:09:55 +08:00
ly f0afd1a4bb Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-29 12:03:28 +08:00
chentianrui de34c3938c 增加了参数评估 2024-08-29 12:02:53 +08:00
ly eb572eff27 增加加载环境变量功能 2024-08-29 11:54:20 +08:00
chentianrui 2706cf9d5a 更新了依赖包 2024-08-29 11:41:42 +08:00
chentianrui 5fa4752d6e Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-29 11:39:06 +08:00
chentianrui aff1793c4e 新增了参数评估脚本和评分脚本 2024-08-29 11:38:45 +08:00
ly 0db159ac89 增加新的前端子模块 2024-08-29 10:48:40 +08:00
ly 131d6ef1d1 完善接口,实现对DIFY前端消息流传输的支持 2024-08-29 08:26:59 +08:00
chentianrui 3ee1ba529f Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-28 18:12:37 +08:00
chentianrui 576a2ae737 增加了评估脚本 2024-08-28 18:12:28 +08:00
ly 9b47e1a6e1 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-28 17:41:52 +08:00
ly 0f09551f5d Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-28 11:49:22 +08:00
ly 56459c164e 配置文件增加UTF8编码格式支持,以免解析中文时出现问题 2024-08-28 08:04:01 +08:00
26 changed files with 350014 additions and 328 deletions
+3
View File
@@ -0,0 +1,3 @@
[submodule "webapp"]
path = webapp
url = https://git.97id.com/ly/webapp.git
+6
View File
@@ -1,3 +1,8 @@
JIEBA_DATA=./nltk_data
NLTK_DATA=./nltk_data
SQLITE_DATABASE_URL=sqlite:///./source.db
DATA_SOURCE_CACHE=./restapi
# The Llama Cloud API key.
# LLAMA_CLOUD_API_KEY=
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
@@ -80,3 +85,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
- You can install any pip package (if it exists) by running a cell with pip install.
"
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
+6
View File
@@ -1,3 +1,8 @@
JIEBA_DATA=./nltk_data
NLTK_DATA=./nltk_data
SQLITE_DATABASE_URL=sqlite:///./source.db
DATA_SOURCE_CACHE=./restapi
# The Llama Cloud API key.
# LLAMA_CLOUD_API_KEY=
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
@@ -111,3 +116,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
- You can install any pip package (if it exists) by running a cell with pip install.
"
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
+259 -245
View File
@@ -1,7 +1,9 @@
import asyncio
import json
import logging
import time
from typing import Dict, List, Any, Optional, AsyncGenerator
from collections import deque
from aiostream import stream
from fastapi import APIRouter, Request
@@ -12,7 +14,8 @@ from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput
from pydantic import BaseModel
from app.api.routers.request.base import userMng, conversations,message,parameter
from app.api.routers.request.base import userMng, conversations,message,parameter,feedback
from app.api.routers.request.baseConfig import *
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine
import uuid
@@ -22,81 +25,139 @@ logger = logging.getLogger("uvicorn")
api_router = r = APIRouter()
v1_router = v = APIRouter()
default_conversation_id = '82e8417f-2c3b-4bb5-ab22-2ad318bbd29a'
class ChatCallbackEvent(BaseModel):
event_type: CBEventType
event_type: ChatEventType
payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_retrieval_message(self) -> dict | None:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"根据查询检索到 {len(nodes)} 源文件"
else:
msg = f"查询检索中: '{self.payload.get('query_str')}'"
return {
"type": "events",
"data": {"title": msg},
}
else:
return None
def get_common_param(self)-> dict:
return {
'event': self.event_type.name,
'conversation_id':self.payload.get("conversation_id"),
'message_id': self.payload.get("message_id"),
'created_at': int(time.time()),
'task_id': self.payload.get("task_id")
}
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
def get_WorkflowStart_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('workflow_run_id'),
"workflow_id": self.payload.get('workflow_id'),
"sequence_number": 1709,
"inputs": {
"sys.query": self.payload.get('query'),
"sys.files": [],
"sys.conversation_id": self.payload.get('conversation_id'),
"sys.user_id": self.payload.get('use_id')
},
"created_at": int(time.time())
}
})
return params
def _is_output_serializable(self, output: Any) -> bool:
try:
json.dumps(output)
return True
except TypeError:
return False
def get_WorkflowFinished_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('workflow_run_id'),
"workflow_id": self.payload.get('workflow_id'),
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": self.payload.get('response')
},
"error": '',
"elapsed_time": 36.03764106379822,
"total_tokens": 11707,
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": self.payload.get('use_id')
},
"created_at": int(time.time()),
"finished_at": int(time.time()),
"files": []
}
})
return params
def get_NodeStart_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": self.payload.get('title'),
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
"created_at": 1724398751,
"extras": {}
}
})
return params
def get_agent_tool_response(self) -> dict | None:
response = self.payload.get("response")
if response is not None:
sources = response.sources
for source in sources:
# Return the tool response here to include the toolCall information
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
def get_NodeFinished_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": self.payload.get('title'),
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
"process_data": '',
"outputs": '',
"status": "succeeded",
"error": '',
"elapsed_time": 0.10402441816404462,
"execution_metadata": '',
"created_at": 1724398751,
"finished_at": 1724398751,
"files": []
}
})
return params
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
},
}
def get_Message_param(self) -> dict:
params = self.get_common_param()
params.update({
'id':self.payload.get('message_id'),
'answer':self.payload.get('answer')
})
return params
def get_MessageEnd_param(self) -> dict:
params = self.get_common_param()
params.update({
'id':self.payload.get('message_id'),
'metadata':self.payload.get('metadata')
})
return params
def to_response(self):
def to_response(self)-> dict|None:
try:
match self.event_type:
case "retrieve":
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case "workflow_started":
return self.get_WorkflowStart_param()
case "workflow_finished":
return self.get_WorkflowFinished_param()
case "node_started":
return self.get_NodeStart_param()
case 'node_finished':
return self.get_NodeFinished_param()
case 'message':
return self.get_Message_param()
case 'message_end':
return self.get_MessageEnd_param()
case _:
return None
except Exception as e:
@@ -107,19 +168,34 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
def __init__(
self,
):
def __init__(self,**params):
"""Initialize the base callback handler."""
ignored_events = [
CBEventType.CHUNKING,
CBEventType.NODE_PARSING,
CBEventType.EMBEDDING,
CBEventType.LLM,
CBEventType.TEMPLATING,
# CBEventType.CHUNKING,
# CBEventType.NODE_PARSING,
# CBEventType.EMBEDDING,
# CBEventType.LLM,
# CBEventType.TEMPLATING,
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
self._response:str = ''
self._params:Dict[str,Any] = params
self._nodeStack:deque = deque()
#添加工作流开始事件
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'use_id': data.user,
'query': data.query,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
def on_event_start(
self,
@@ -128,9 +204,23 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
event_id: str = "",
**kwargs: Any,
) -> str:
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
self._nodeStack.append(event_id)
nindex = self._nodeStack.count() - 1
args:Dict[str,Any] = self._params['ids']
args.update(
{
'nodeid':event_id,
'title':event_type.name,
'index':nindex + 1,
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
}
)
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args)
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
def on_event_end(
self,
@@ -139,12 +229,30 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
event_id: str = "",
**kwargs: Any,
) -> None:
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
#self.response = payload.get("response","")
args:Dict[str,Any] = self._params['ids']
nodeID = self._nodeStack[-1]
if nodeID == event_id:
nindex = self._nodeStack.count() - 1
args.update(
{
'nodeid':event_id,
'title':event_type.name,
'index':nindex + 1,
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
}
)
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args)
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
self._nodeStack.pop()
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
logger.info("trace_start:{}\n".format(trace_id))
def end_trace(
self,
@@ -152,6 +260,24 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
data:ChatRequestData = self._params['data']
args:Dict[str,Any] = self._params['ids']
args.update(
{
'response':self._response,
'conversation_id': data.conversation_id
}
)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
args:Dict[str,Any] = self._params['ids']
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event)
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
@@ -169,104 +295,38 @@ class IDManager:
"workflow_id": str(uuid.uuid4())
}
class DifyChatResponseEvent(BaseModel):
event: str
conversation_id: str
message_id: str
created_at: int = 1724406492
task_id: str
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"inputs": {
"sys.query": args['query'],
"sys.files": [],
"sys.conversation_id": args['conversation_id'],
"sys.user_id": args['use_id']
},
"created_at": 1724406492
}
args['event'] = 'workflow_started'
super().__init__(**args)
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['event'] = 'workflow_finished'
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": args['response']
},
"error": '',
"elapsed_time": 36.03764106379822,
"total_tokens": 11707,
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": args['use_id']
},
"created_at": 1724406492,
"finished_at": 1724406528,
"files": []
}
super().__init__(**args)
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
answer:str
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message'
super().__init__(**args)
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
id:str
metadata:Dict[str,Any] = {}
def __init__(self,**args):
args['id'] = args['message_id']
args['event'] = 'message_end'
super().__init__(**args)
class ChatStreamResponse(StreamingResponse):
TEXT_PREFIX = "data:"
DATA_PREFIX = "data:"
TEXT_PREFIX = "data: "
DATA_PREFIX = "data: "
ids:Dict[str,Any] = {}
data:ChatRequestData = None
@classmethod
def convert_text(cls, token: str):
# Escape newlines and double quotes to avoid breaking the stream
token = json.dumps(token)
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
return ""
def convert_Message(cls, token: str):
params = cls.ids
params.update({
'answer':token,
'conversation_id':cls.data.conversation_id
})
event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params)
data_str = json.dumps(event.to_response())
return f"{cls.DATA_PREFIX}{data_str}\n\n"
@classmethod
def convert_data(cls, data: dict):
def convert_Event(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n"
@classmethod
def convert_event(cls, event: DifyChatResponseEvent):
data_str = json.dumps(event.dict())
return f"{cls.DATA_PREFIX}{data_str}\n"
return f"{cls.DATA_PREFIX}{data_str}\n\n"
def __init__(
self,
request: Request,
event_handler: ChatEventCallbackHandler,
response: StreamingAgentChatResponse,
data: ChatRequestData
data: ChatRequestData,
ids:Dict[str,Any]
):
ChatStreamResponse.ids = ids
ChatStreamResponse.data = data
content = ChatStreamResponse.content_generator(
request, event_handler, response, data
)
@@ -280,41 +340,26 @@ class ChatStreamResponse(StreamingResponse):
response: StreamingAgentChatResponse,
data: ChatRequestData
):
ids = IDManager().createID()
# Yield the text response
async def _chat_response_generator():
final_response = ""
async for token in response.async_response_gen():
final_response += token
args = ids
args['answer'] = token
args['conversation_id'] = data.conversation_id
event = Message_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(event)
#yield ChatStreamResponse.convert_text(token)
yield ChatStreamResponse.convert_Message(token)
# 存储消息历史
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
# 发送工作流结束事件
args = ids
args['response'] = final_response
args['conversation_id'] = data.conversation_id
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids)
yield ChatStreamResponse.convert_event(msgEnt_event)
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
event_response = event.to_response()
if event_response is not None:
yield ChatStreamResponse.convert_data(event_response)
yield ChatStreamResponse.convert_Event(event_response)
combine = stream.merge(_chat_response_generator(), _event_generator())
is_stream_started = False
@@ -323,34 +368,20 @@ class ChatStreamResponse(StreamingResponse):
if not is_stream_started:
is_stream_started = True
# 发送工作流开始事件
args = ids
args['use_id'] = data.user
args['query'] = data.query
args['conversation_id'] = data.conversation_id
wf_event = Workflow_started_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
# Stream a blank message to start the stream
# 发送一个空消息事件
#yield ChatStreamResponse.convert_text("")
yield output
if await request.is_disconnected():
break
@v.post("/chat-messages")
async def post_conversations(request: Request, data: ChatRequestData):
userMng.findNoExistCreate(data.user)
data.conversation_id = default_conversation_id if data.conversation_id is None else data.conversation_id
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
conversaObj = conversations()
conversationinfo = conversaObj.get(data.user, data.conversation_id)
conversationinfo = conversaObj.get(data.conversation_id)
if conversationinfo is None:
conversationinfo = conversaObj.add(data.user, "新建会话", data.conversation_id)
conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话")
# 生成聊天参数
last_message_content = ChatMessage.from_str(data.query)
@@ -361,24 +392,33 @@ async def post_conversations(request: Request, data: ChatRequestData):
chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听
event_handler = ChatEventCallbackHandler()
ids = IDManager().createID()
event_handler = ChatEventCallbackHandler(ids = ids,data = data)
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
# 执行异步聊天
response = await chat_engine.astream_chat(data.query)
# 返回异步消息回应
return ChatStreamResponse(request, event_handler, response, data)
return ChatStreamResponse(request, event_handler, response, data,ids)
@v.get("/messages")
async def query_messages(user:str, conversation_id:str):
conversation_id = default_conversation_id if conversation_id is None else conversation_id
#conversation_id = default_conversation_id if conversation_id is None else conversation_id
datas = []
records = message().gets(user,conversation_id)
if records is None:
return {
"limit": 20,
"has_more": False,
"data": []
}
for record in records:
res = record.dict()
feeds = feedback().query(res['id'])
res["message_files"] = []
res["feedback"] = ''
res["feedback"] = {'rating':feeds['rating'] } if feeds != None else ''
res["retriever_resources"] = []
res["created_at"] = 1723444905
res["agent_thoughts"] = []
@@ -415,7 +455,7 @@ async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
return 'null'
@v.get("/conversations")
async def query_conversations(user:str):
async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None):
user_id = '' if user is None else user
userMng.findNoExistCreate(user_id)
@@ -429,48 +469,22 @@ async def query_conversations(user:str):
async def query_parameters(user:str):
params = parameter().get(user)
if len(params) == 0:
params = {
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [],
"suggested_questions_after_answer": {
"enabled": False
},
"speech_to_text": {
"enabled": False
},
"text_to_speech": {
"enabled": False,
"language": "",
"voice": ""
},
"retriever_resource": {
"enabled": True
},
"annotation_reply": {
"enabled": False
},
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"sensitive_word_avoidance": {
"enabled": False
},
"file_upload": {
"image": {
"enabled": False,
"number_limits": 3,
"transfer_methods": [
"remote_url"
]
}
},
"system_parameters": {
"image_file_size_limit": "10"
}
}
params = BaseConfig().ParamterCfg()
return params
@v.post("/messages/{message_id}/feedbacks")
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
if params['rating'] =='null':
feedback().delete(message_id)
else:
condition = {'id':message_id}
results = message().query(**condition)
if len(results) > 0:
result = results[0]
feedback().add(message_id=message_id,query=result['query'],
answer=result['answer'],rating=params['rating'])
@r.post("")
def upload_file(request: ChatFileUploadRequest) -> List[str]:
pass
pass
+35 -5
View File
@@ -18,14 +18,14 @@ class conversations:
return datas
def get(self,user_id:str,id:str = ''):
records = dbManage.query(self._tableName,user_id = user_id,id=id)
def get(self, id:str):
records = dbManage.query(self._tableName, id=id)
if len(records) >0:
return records[0]
return None
def add(self,user_id:str,name:str,id:str = ''):
template = BaseConfig.ConversationCfg
def add(self,id:str, user_id:str, name:str):
template = BaseConfig().ConversationCfg()
template['id'] = id
template['user_id'] = user_id
template['name'] = name
@@ -111,7 +111,7 @@ class message:
return datas
def add(self,user_id:str,conversation_id:str,query:str,answer:str):
template = BaseConfig.MessageCfg
template = BaseConfig.MessageCfg()
template['id'] = str(uuid.uuid4())
template['user_id'] = user_id
template['conversation_id'] = conversation_id
@@ -122,4 +122,34 @@ class message:
def delete(self,user_id:str):
dbManage.delete(self._tableName,user_id = user_id)
def query(self,**condition):
results = []
records = dbManage.query(self._tableName,**condition)
for record in records:
results.append(record.dict())
return results
class feedback:
def __init__(self) -> None:
self._tableName = 'feedbacks'
dbManage.createTable(self._tableName)
def add(self,message_id:str,query:str,answer:str,rating:str):
record = {
'message_id': message_id,
'query': query,
'answer': answer,
'rating': rating,
}
dbManage.addRecord(self._tableName,record)
def delete(self,message_id:str):
cond = {'message_id':message_id}
dbManage.delete(self._tableName,**cond)
def query(self,message_id:str):
cond = {'message_id':message_id}
records = dbManage.query(self._tableName,**cond)
if len(records) > 0:
return records[0].dict()
return None
+69 -51
View File
@@ -1,62 +1,80 @@
from pydantic import BaseModel
import os
from enum import Enum
class BaseConfig:
ParamterCfg = {
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [],
"suggested_questions_after_answer": {
"enabled": False
},
"speech_to_text": {
"enabled": False
},
"text_to_speech": {
"enabled": False,
"language": "",
"voice": ""
},
"retriever_resource": {
"enabled": True
},
"annotation_reply": {
"enabled": False
},
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"sensitive_word_avoidance": {
"enabled": False
},
"file_upload": {
"image": {
class BaseConfig(BaseModel):
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
def ParamterCfg(self):
questions = os.getenv("CONVERSATION_STARTERS", "dev")
return{
"opening_statement": self.projectInfo,
"suggested_questions": questions.split('\n'),
"suggested_questions_after_answer": {
"enabled": False
},
"speech_to_text": {
"enabled": False
},
"text_to_speech": {
"enabled": False,
"number_limits": 3,
"transfer_methods": [
"remote_url"
]
"language": "",
"voice": ""
},
"retriever_resource": {
"enabled": True
},
"annotation_reply": {
"enabled": False
},
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"sensitive_word_avoidance": {
"enabled": False
},
"file_upload": {
"image": {
"enabled": False,
"number_limits": 3,
"transfer_methods": [
"remote_url"
]
}
},
"system_parameters": {
"image_file_size_limit": "10"
}
},
"system_parameters": {
"image_file_size_limit": "10"
}
}
def ConversationCfg(self):
return{
"id": "",
'user_id':'',
"name": "",
"inputs": {},
"status": "normal",
"introduction": self.projectInfo,
"created_at":''
}
ConversationCfg = {
"id": "",
'user_id':'',
"name": "",
"inputs": {},
"status": "normal",
"introduction": ParamterCfg['opening_statement'],
"created_at":''
}
MessageCfg = {
@classmethod
def MessageCfg(cls):
return {
"id": "",
'user_id':'',
"conversation_id": "",
"inputs": {},
"query": "",
"answer": ""
}
}
class ChatEventType(str, Enum):
WORKFLOW_START = "workflow_started"
WORKFLOW_FINISHED = "workflow_finished"
NODE_START = "node_started"
NODE_FINISHED = "node_finished"
MESSAGE = "message"
MESSAGE_END = "message_end"
+22 -9
View File
@@ -2,7 +2,7 @@ import os
from typing import Dict, List, Any
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, String, Integer, JSON
from sqlalchemy import create_engine, Column, String, Integer, JSON,Float
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import sessionmaker, declarative_base
@@ -24,10 +24,6 @@ class ConversationOrm(Base):
if 'name' in data:
self.name = data['name']
class UserOrm(Base):
__tablename__ = "user"
@@ -51,6 +47,14 @@ class MessagesOrm(Base):
query = Column(String)
answer = Column(String)
class FeedBackOrm(Base):
__tablename__ = "feedbacks"
message_id = Column(String,primary_key=True)
query = Column(String)
answer = Column(String)
rating = Column(String)
#数据结构
class ConversationModel(BaseModel):
id: str
@@ -61,7 +65,6 @@ class ConversationModel(BaseModel):
created_at: int
class Config:
#orm_mode = True
from_attributes=True
@classmethod
@@ -73,7 +76,6 @@ class UserModel(BaseModel):
createtime: str
class Config:
#orm_mode = True
from_attributes=True
@classmethod
@@ -86,7 +88,6 @@ class ParametersModel(BaseModel):
value : Dict[str, Any]
class Config:
#orm_mode = True
from_attributes=True
@classmethod
@@ -101,13 +102,25 @@ class MessagesModel(BaseModel):
answer : str
class Config:
#orm_mode = True
from_attributes=True
@classmethod
def orm(cls):
return MessagesOrm
class FeedBackModel(BaseModel):
message_id :str
query :str
answer :str
rating :str
class Config:
from_attributes=True
@classmethod
def orm(cls):
return FeedBackOrm
class DBManager:
def __init__(self) -> None:
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
+4 -2
View File
@@ -1,7 +1,7 @@
from typing import Dict, Any
from pydantic import BaseModel
from typing import Optional
class ChatRequestData(BaseModel):
inputs: Dict[str,Any]
@@ -12,4 +12,6 @@ class ChatRequestData(BaseModel):
conversation_id: str = None
class ChatFileUploadRequest(BaseModel):
base64: str
base64: str
+1 -1
View File
@@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
def load_configs():
with open("config/loaders.yaml") as f:
with open("config/loaders.yaml",encoding='UTF-8') as f:
configs = yaml.safe_load(f)
return configs
+5 -1
View File
@@ -1,3 +1,4 @@
import os
from typing import Any, Dict, List, Union, Callable, NamedTuple
from bm25s.tokenization import *
@@ -8,9 +9,12 @@ except ImportError:
def tqdm(iterable, *args, **kwargs):
return iterable
import jieba
jiebapath = os.environ.get("JIEBA_DATA", "")
jieba.set_dictionary(os.path.join(jiebapath, 'dict.txt')) #设置字典
jieba.initialize() #初始化jeiba
def chinese_tokenizer(text: str) -> List[str]:
import jieba
from nltk.corpus import stopwords
tokens = jieba.lcut(text)
return [token for token in tokens if token not in stopwords.words('chinese')]
+5 -6
View File
@@ -1,10 +1,9 @@
import os
import yaml
import json
import importlib
from cachetools import cached, LRUCache
from llama_index.core.tools.tool_spec.base import BaseToolSpec
import os
import yaml
from llama_index.core.tools.function_tool import FunctionTool
from llama_index.core.tools.tool_spec.base import BaseToolSpec
class ToolType:
@@ -46,7 +45,7 @@ class ToolFactory:
def from_env() -> list[FunctionTool]:
tools = []
if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f:
with open("config/tools.yaml", "r", encoding='UTF-8') as f:
tool_configs = yaml.safe_load(f)
if tool_configs != None and len(tool_configs.items()) != 0:
for tool_type, config_entries in tool_configs.items():
+2 -3
View File
@@ -3,11 +3,10 @@ from typing import Dict
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.settings import Settings
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
from llama_index.llms.xinference import Xinference
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED").title()
@@ -232,4 +231,4 @@ def init_mistral():
#
# Settings.llm = MistralAI(model=os.getenv("MODEL"))
# Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
pass
pass
-2
View File
@@ -1,7 +1,5 @@
from dotenv import load_dotenv
from llama_index.core.node_parser import SentenceSplitter
load_dotenv()
import logging
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large Load Diff
Binary file not shown.
+8 -1
View File
@@ -17,7 +17,7 @@ aiostream = "^0.6.2"
llama-index = "0.10.63"
cachetools = "^5.3.3"
protobuf = "4.25.4"
nltk = "^3.8.2"
nltk = "^3.9.1"
jieba = "^0.42.1"
#arize-phoenix = "^4.12.0"
@@ -35,6 +35,7 @@ chroma="^0.2.0"
llama-index-vector-stores-chroma = "^0.1.10"
llama-index-readers-json = "^0.1.5"
llama-index-retrievers-bm25 = "^0.2.2"
llama-index-experimental = "^0.2.0"
duckduckgo_search = "^6.2.6"
@@ -62,6 +63,12 @@ version = "^0.8"
version = "0.0.7"
[[tool.poetry.source]]
name = "mirrors"
url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
priority = "default"
[build-system]
requires = [ "poetry-core" ]
build-backend = "poetry.core.masonry.api"
+138
View File
@@ -0,0 +1,138 @@
import nest_asyncio
nest_asyncio.apply()
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import VectorStoreIndex
from llama_index.core.evaluation import (
FaithfulnessEvaluator,
DatasetGenerator,
CorrectnessEvaluator,
SemanticSimilarityEvaluator,
)
from llama_index.experimental.param_tuner import ParamTuner
from llama_index.experimental.param_tuner.base import RunResult
from llama_index.llms.openai import OpenAI
import asyncio
# 初始化环境
from app.observability import init_observability
from app.settings import init_settings
from dotenv import load_dotenv
load_dotenv()
init_settings()
init_observability()
# 读取文档
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
# 参数字典
param_dict = {
"chunk_size": [512, 1024],
"top_k": [1, 5],
"temperature": [0.1, 1.0]
}
# 辅助函数
def _build_index(chunk_size, documents):
# 构建索引
splitter = SentenceSplitter(chunk_size=chunk_size)
vector_index = VectorStoreIndex.from_documents(
documents, transformations=[splitter],
)
return vector_index
# 评估函数
def evaluate_query_engine(query_engine, questions):
loop = asyncio.get_event_loop()
correct, total = loop.run_until_complete(_evaluate_query_engine_async(query_engine, questions))
return correct, total
async def _evaluate_query_engine_async(query_engine, questions):
c = [query_engine.aquery(q) for q in questions]
gathering_future = asyncio.gather(*c)
results = await gathering_future
total_correct = 0
for r in results:
eval_result = (
1 if FaithfulnessEvaluator().evaluate_response(response=r).passing else 0
)
total_correct += eval_result
return total_correct, len(results)
# 生成问题
question_generator = DatasetGenerator.from_documents(documents)
eval_questions = question_generator.generate_questions_from_nodes(1) # 假设生成10个问题
# 打印生成的问题
for i, q in enumerate(eval_questions, start=1):
print(f"问题 {i}: {q}")
# 目标函数
def objective_function(params_dict, documents, questions):
chunk_size = params_dict["chunk_size"]
top_k = params_dict["top_k"]
temperature = params_dict["temperature"]
# 构建索引
vector_index = _build_index(chunk_size, documents)
# 查询引擎
query_engine = vector_index.as_query_engine(
similarity_top_k=top_k, temperature=temperature
)
# 评估查询引擎
correct, total = 0, len(questions)
question_answers = [] # 添加列表来收集问题和答案
for question in questions:
response = query_engine.query(question)
if response is not None:
question_answers.append((question, response.response))
eval_result = FaithfulnessEvaluator().evaluate_response(response=response, query_str=question)
if eval_result.passing:
correct += 1
# 计算分数
score = correct / total if total > 0 else 0
return RunResult(score=score, params=params_dict, question_answers=question_answers)
# 创建 ParamTuner 实例
param_tuner = ParamTuner(
param_fn=lambda params_dict: objective_function(params_dict, documents, eval_questions),
param_dict=param_dict,
show_progress=True,
)
# 调用 tune 方法
results = param_tuner.tune()
best_result = results.best_run_result
best_top_k = best_result.params["top_k"]
best_chunk_size = best_result.params["chunk_size"]
best_temperature = best_result.params["temperature"]
print(f"得分: {best_result.score}")
print(f"Top-k: {best_top_k}")
print(f"文本块大小: {best_chunk_size}")
print(f"温度: {best_temperature}")
# 使用最佳参数再次运行查询引擎,并打印问题与答案
best_vector_index = _build_index(best_chunk_size, documents)
best_query_engine = best_vector_index.as_query_engine(
similarity_top_k=best_top_k, temperature=best_temperature
)
best_question_answers = []
for question in eval_questions:
response = best_query_engine.query(question)
if response is not None:
best_question_answers.append((question, response.response))
# 打印最佳参数下的问题与答案
for i, (question, answer) in enumerate(best_question_answers, start=1):
print(f"最佳参数 - 问题 {i}: {question}\n答案: {answer}\n")
+81
View File
@@ -0,0 +1,81 @@
from app.observability import init_observability
from app.settings import init_settings
from dotenv import load_dotenv
import nest_asyncio
nest_asyncio.apply()
load_dotenv()
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import (
VectorStoreIndex,
SimpleDirectoryReader,
Response,
)
from llama_index.core.evaluation import (
FaithfulnessEvaluator,
DatasetGenerator,
CorrectnessEvaluator,
SemanticSimilarityEvaluator,)
init_settings()
init_observability()
faith_evaluator_qwen = FaithfulnessEvaluator() #诚实度评测
corr_evaluator_qwen = CorrectnessEvaluator() #准确率评测
Seman_evaluator_qwen = SemanticSimilarityEvaluator()#嵌入相似度评估
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
splitter = SentenceSplitter(chunk_size=512)
vector_index = VectorStoreIndex.from_documents(
documents, transformations=[splitter],
)
# # 运行评估
# query_engine = vector_index.as_query_engine()
# response_vector = query_engine.query("工程监理费的金额是多少?")
# eval_result = evaluator_qwen.evaluate_response(response=response_vector)
# print(response_vector)
# print(eval_result)
question_generator = DatasetGenerator.from_documents(documents)
eval_questions = question_generator.generate_questions_from_nodes(5)
print(eval_questions)
import asyncio
async def evaluate_query_engine_async(query_engine, questions):
c = [query_engine.aquery(q) for q in questions]
gathering_future = asyncio.gather(*c)
results = await gathering_future
#print(results)
total_correct = 0
for r in results:
eval_result = (
1 if faith_evaluator_qwen.evaluate_response(response=r).passing else 0
)
total_correct += eval_result
return total_correct, len(results)
def evaluate_query_engine(query_engine, questions):
loop = asyncio.get_event_loop()
correct, total = loop.run_until_complete(evaluate_query_engine_async(query_engine, questions))
return correct, total
# 使用 evaluate_query_engine 函数
vector_query_engine = vector_index.as_query_engine()
correct, total = evaluate_query_engine(vector_query_engine, eval_questions[:5])
print(f"score: {correct}/{total}")
+121
View File
@@ -0,0 +1,121 @@
from dotenv import load_dotenv
load_dotenv()
from llama_index.core.evaluation import CorrectnessEvaluator
from app.engine import get_chat_engine
from app.engine.index import get_index
from app.observability import init_observability
from app.settings import init_settings
init_settings()
init_observability()
index = get_index()
import os
import json
import asyncio
import nest_asyncio
nest_asyncio.apply()
from llama_index.core.prompts import (
ChatMessage,
ChatPromptTemplate,
MessageRole
)
DEFAULT_SYSTEM_TEMPLATE = """
您是一个问答聊天机器人的专业评估系统。
您将获得以下信息:
- 用户查询,
- 生成的回答,
也可能提供一个参考答案作为评估的依据。
您的任务是判断生成回答的相关性和正确性。
输出一个代表全面评估的单一分数。
您必须在一行中仅返回该分数。
不要以其他任何格式返回答案。
在单独的一行提供给定分数的理由。
请遵循以下评分指南:
- 您的分数必须在1到5之间,其中1是最差,5是最好的。
-如果生成的回答与用户查询不相关,您应该给出1分。
-如果生成的回答相关但包含错误,您应该给出2到3分之间的分数。
-如果生成的回答相关且完全正确,您应该给出4到5分之间的分数。
示例响应:
4.0
生成的回答与参考答案的指标完全相同,但不够精炼。
"""
DEFAULT_USER_TEMPLATE = """
## User Query
{query}
## Reference Answer
{reference_answer}
## Generated Answer
{generated_answer}
"""
DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
message_templates=[
ChatMessage(role=MessageRole.SYSTEM, content=DEFAULT_SYSTEM_TEMPLATE),
ChatMessage(role=MessageRole.USER, content=DEFAULT_USER_TEMPLATE),
]
)
# 初始化聊天引擎和评估器
chat_engine = get_chat_engine()
corr_evaluator_qwen = CorrectnessEvaluator()
# 加载本地问题回答文件
script_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(script_dir, 'questions_and_answers.json')
output_file_path = file_path.replace('.json', '_test.json')
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 异步函数用于评估查询
async def evaluate_query(question, answer, index, output_file):
response = await chat_engine.astream_chat(question)
# 检查sources是否为空
if response.sources:
content_str = str(response.sources[0])
else:
content_str = "<无回答>"
result = corr_evaluator_qwen.evaluate(
query=question,
response=content_str,
reference=answer,
)
result_dict = {
"编号": index,
"问题": question,
"答案": answer,
"回答": result.response,
"得分(1~5)": result.score,
"评价": result.feedback
}
with open(output_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(result_dict, ensure_ascii=False, indent=4))
f.write(',\n')
# 主异步函数
async def main():
for index, item in enumerate(data, start=1):
await evaluate_query(item['question'], item['answer'], index, output_file_path)
# 运行主协程
asyncio.run(main())
+55
View File
@@ -0,0 +1,55 @@
Attribute_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对数据中属性一列进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元),工程总投资(万元)是77469835.590045万元','尖峰及施工基面土石方量,尖峰及施工基面土石方量是8377.6','截止阀的编码,截止阀的编码是F01010203',"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
Amount_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息中的金额或者合价进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'项目建设技术服务费的金额,项目建设技术服务费的金额是16855957065.4302','项目后评价费的费率,项目后评价费的费率是0.5','架空输电线路本体工程的金额,架空输电线路本体工程的金额是55105688268.5176','工程静态投资的金额,工程静态投资的金额是715035853336.391'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
Units_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息来进行单位转化问题提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元)结果用元表示,工程总投资(万元)是774698355900.45元','本体工程(元)结果用万元表示,本体工程(元)是5490494.261046万元'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
Name_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息中的重名问题进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'专业类型为线路的杆塔工程项目划分的合价,专业类型为线路的杆塔工程项目划分的合价是220969744.905856','专业类型为线路清理的杆塔工程项目划分的合价,电缆工程的合价是0'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
All_Amount_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息中的总体金额进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'架空输电线路本体工程的总体金额,架空输电线路本体工程的总体金额是7.706703','工程静态投资的总体金额,工程静态投资的总体金额是100'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
+144
View File
@@ -0,0 +1,144 @@
from dotenv import load_dotenv
load_dotenv()
import json
import sys
from app.observability import init_observability
from app.settings import init_settings
import nest_asyncio
nest_asyncio.apply()
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import SimpleDirectoryReader
from llama_index.core.evaluation import DatasetGenerator
import prompts
init_settings()
init_observability()
# 读取所有文档(即所有表格)
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
# 定义表格名称和索引的对应关系
table_names = {
"工程信息表": 0,
"其他费用表": 1,
"取费表": 2,
"项目划分表": 3,
"项目划分_费用预览表": 4,
"总算表": 5,
"工程量表": 6
}
# 定义中文提示词和Python代码中提示词名称的映射
prompt_mapping = {
"普通属性": "Attribute_Prompt",
"金额查询": "Amount_Prompt",
"单位换算": "Units_Prompt",
"重名项目划分": "Name_Prompt",
"总体金额查询": "All_Amount_Prompt"
}
# 定义表格与其对应的查询类别
table_prompt_mapping = {
"工程信息表": ["普通属性", "单位换算"],
"其他费用表": ["金额查询", "单位换算"],
"取费表": ["金额查询"],
"总算表": ["金额查询", "总体金额查询"],
"工程量表": ["普通属性", "重名项目划分"]
}
# 根据表格名称选择特定的表格
def select_document(documents, table_name):
if table_name not in table_names:
raise ValueError(f"未找到名为 '{table_name}' 的表格")
index = table_names[table_name]
return [documents[index]] # 返回一个包含所选表格的列表
# 选择提示词
def select_prompt(prompt_category):
prompt_name = prompt_mapping.get(prompt_category)
if not prompt_name:
raise ValueError(f"未找到名为 '{prompt_category}' 的提示词")
try:
return getattr(prompts, prompt_name)
except AttributeError:
raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数")
# 生成问题和答案
def generate_questions_from_document(document, quest_prompt, num_questions):
question_generator = DatasetGenerator.from_documents(
documents=document,
question_gen_query=quest_prompt,
num_questions_per_chunk=num_questions
)
eval_questions = question_generator.generate_questions_from_nodes(num_questions)
print(eval_questions)
qa_pairs = []
for qa in eval_questions:
if ',' in qa:
question, answer = qa.split(",", 1)
qa_pairs.append({
"question": question.strip(),
"answer": answer.strip()
})
else:
print(f"无法处理的问题和答案: {qa}")
return qa_pairs
# 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中
def main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt):
if table_names_input == "all":
selected_tables = list(table_prompt_mapping.keys())
else:
selected_tables = table_names_input.strip('[]').split(',')
all_results = {}
for table_name in selected_tables:
table_name = table_name.strip() # 去掉前后空格
document = select_document(documents, table_name)
if prompt_categories_input == "all":
selected_prompts = table_prompt_mapping[table_name]
else:
selected_prompts = prompt_categories_input.strip('[]').split(',')
selected_prompts = [p.strip() for p in selected_prompts] # 去掉前后空格
for prompt_category in selected_prompts:
if prompt_category not in table_prompt_mapping[table_name]:
print(f"跳过表格 '{table_name}' 的提示词 '{prompt_category}',因为该表中不包含该类别的信息")
continue
quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt)
qa_pairs = generate_questions_from_document(document, quest_prompt, num_questions_per_prompt)
label = f"test:{table_name}_{prompt_category}"
all_results[label] = qa_pairs
# 自动生成输出文件名
output_file = "combined_test.json"
with open(output_file, "w", encoding="utf-8") as f:
json.dump(all_results, f, ensure_ascii=False, indent=4)
print(f"All questions and answers have been saved to '{output_file}'")
# 获取命令行参数
if __name__ == "__main__":
if len(sys.argv) != 4:
print("Usage: python script.py <table_names_input> <prompt_categories_input> <num_questions_per_prompt>")
else:
table_names_input = sys.argv[1]
prompt_categories_input = sys.argv[2]
num_questions_per_prompt = int(sys.argv[3])
main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt)
+3 -2
View File
@@ -1,9 +1,10 @@
import os
from dotenv import load_dotenv
load_dotenv()
import phoenix as px
os.environ['PHOENIX_HOST'] = "0.0.0.0"
session = px.launch_app(use_temp_dir=False)
import msvcrt
Submodule
+1
Submodule webapp added at 77dbc14a64