10 Commits

Author SHA1 Message Date
wanyaokun 480a1f7fdc 新增工程配置信息 2024-08-29 19:03:38 +08:00
wanyaokun cdc9d84a1e Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-29 19:01:40 +08:00
wanyaokun 50f35bb0c9 优化Web事件代码 2024-08-29 19:00:25 +08:00
chentianrui 4a8c79e83d 参数优化针对问题做出了调整 2024-08-29 15:09:55 +08:00
chentianrui de34c3938c 增加了参数评估 2024-08-29 12:02:53 +08:00
chentianrui 2706cf9d5a 更新了依赖包 2024-08-29 11:41:42 +08:00
chentianrui 5fa4752d6e Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-29 11:39:06 +08:00
chentianrui aff1793c4e 新增了参数评估脚本和评分脚本 2024-08-29 11:38:45 +08:00
chentianrui 3ee1ba529f Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-28 18:12:37 +08:00
chentianrui 576a2ae737 增加了评估脚本 2024-08-28 18:12:28 +08:00
10 changed files with 576 additions and 232 deletions
+1
View File
@@ -80,3 +80,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
- You can install any pip package (if it exists) by running a cell with pip install. - You can install any pip package (if it exists) by running a cell with pip install.
" "
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
+1
View File
@@ -111,3 +111,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
- You can install any pip package (if it exists) by running a cell with pip install. - You can install any pip package (if it exists) by running a cell with pip install.
" "
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
+226 -164
View File
@@ -3,6 +3,7 @@ import json
import logging import logging
import time import time
from typing import Dict, List, Any, Optional, AsyncGenerator from typing import Dict, List, Any, Optional, AsyncGenerator
from collections import deque
from aiostream import stream from aiostream import stream
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
@@ -13,7 +14,8 @@ from llama_index.core.callbacks import CBEventType
from llama_index.core.chat_engine.types import StreamingAgentChatResponse from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.tools import ToolOutput from llama_index.core.tools import ToolOutput
from pydantic import BaseModel from pydantic import BaseModel
from app.api.routers.request.base import userMng, conversations,message,parameter from app.api.routers.request.base import userMng, conversations,message,parameter,feedback
from app.api.routers.request.baseConfig import *
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
from app.engine import get_chat_engine from app.engine import get_chat_engine
import uuid import uuid
@@ -102,77 +104,6 @@ class ChatCallbackEvent(BaseModel):
logger.error(f"转换回应时间时发生错误,原因: {e}") logger.error(f"转换回应时间时发生错误,原因: {e}")
return None return None
class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
def __init__(
self,
):
"""Initialize the base callback handler."""
ignored_events = [
# CBEventType.CHUNKING,
# CBEventType.NODE_PARSING,
# CBEventType.EMBEDDING,
# CBEventType.LLM,
# CBEventType.TEMPLATING,
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> str:
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
logger.info("trace_start:{}\n".format(trace_id))
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
try:
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
except asyncio.TimeoutError:
pass
class IDManager:
def createID(self):
return {
"message_id" : str(uuid.uuid4()),
'task_id':str(uuid.uuid4()),
'workflow_run_id': str(uuid.uuid4()),
"workflow_id": str(uuid.uuid4())
}
class DifyChatResponseEvent(BaseModel): class DifyChatResponseEvent(BaseModel):
event: str event: str
conversation_id: str conversation_id: str
@@ -180,7 +111,11 @@ class DifyChatResponseEvent(BaseModel):
created_at: int = int(time.time()) created_at: int = int(time.time())
task_id: str task_id: str
def to_response(self):
return self.dict()
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent): class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'workflow_started'
workflow_run_id:str workflow_run_id:str
data:Dict[str,Any] data:Dict[str,Any]
def __init__(self,**args): def __init__(self,**args):
@@ -196,14 +131,13 @@ class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
}, },
"created_at": int(time.time()) "created_at": int(time.time())
} }
args['event'] = 'workflow_started'
super().__init__(**args) super().__init__(**args)
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent): class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'workflow_finished'
workflow_run_id:str workflow_run_id:str
data:Dict[str,Any] data:Dict[str,Any]
def __init__(self,**args): def __init__(self,**args):
args['event'] = 'workflow_finished'
args['data'] = { args['data'] = {
"id": args['workflow_run_id'], "id": args['workflow_run_id'],
"workflow_id": args['workflow_id'], "workflow_id": args['workflow_id'],
@@ -227,41 +161,219 @@ class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
super().__init__(**args) super().__init__(**args)
class Message_DifyChatResponseEvent(DifyChatResponseEvent): class Message_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'message'
id:str id:str
answer:str answer:str
def __init__(self,**args): def __init__(self,**args):
args['id'] = args['message_id'] args['id'] = args['message_id']
args['event'] = 'message'
super().__init__(**args) super().__init__(**args)
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent): class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'message_end'
id:str id:str
metadata:Dict[str,Any] = {} metadata:Dict[str,Any] = {}
def __init__(self,**args): def __init__(self,**args):
args['id'] = args['message_id'] args['id'] = args['message_id']
args['event'] = 'message_end'
super().__init__(**args) super().__init__(**args)
class Node_started_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'node_started'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['nodeid'],
"node_id": args['nodeid'],
"node_type": "http-request",
"title": args['title'],
"index": args['index'],
"predecessor_node_id": args['predecessor_node_id'],
"inputs": '',
"created_at": 1724398751,
"extras": {}
}
super().__init__(**args)
class Node_finished_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'node_finished'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['nodeid'],
"node_id": args['nodeid'],
"node_type": "http-request",
"title": args['title'],
"index": args['index'],
"predecessor_node_id": args['predecessor_node_id'],
"inputs": '',
"process_data": '',
"outputs": '',
"status": "succeeded",
"error": '',
"elapsed_time": 0.10402441816404462,
"execution_metadata": '',
"created_at": 1724398751,
"finished_at": 1724398751,
"files": []
}
super().__init__(**args)
class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
def __init__(self,**params):
"""Initialize the base callback handler."""
ignored_events = [
# CBEventType.CHUNKING,
# CBEventType.NODE_PARSING,
# CBEventType.EMBEDDING,
# CBEventType.LLM,
# CBEventType.TEMPLATING,
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
self._response:str = ''
self._params:Dict[str,Any] = params
self._nodeStack:deque = deque()
#添加工作流开始事件
ids:Dict[str,Any] = self._params['ids']
data:ChatRequestData = self._params['data']
args = ids
args.update(
{
'use_id': data.user,
'query': data.query,
'conversation_id': data.conversation_id
}
)
wf_event = Workflow_started_DifyChatResponseEvent(**args)
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> str:
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
self._nodeStack.append(event_id)
nindex = self._nodeStack.count() - 1
ids:Dict[str,Any] = self._params['ids']
args = ids
args.update(
{
'nodeid':event_id,
'title':event_type.name,
'index':nindex + 1,
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
}
)
nd_event = Node_started_DifyChatResponseEvent(**args)
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
#self.response = payload.get("response","")
args:Dict[str,Any] = self._params['ids']
nodeID = self._nodeStack[-1]
if nodeID == event_id:
nindex = self._nodeStack.count() - 1
args.update(
{
'nodeid':event_id,
'title':event_type.name,
'index':nindex + 1,
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
}
)
nd_event = Node_finished_DifyChatResponseEvent(**args)
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
self._nodeStack.pop()
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
logger.info("trace_start:{}\n".format(trace_id))
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
ids:Dict[str,Any] = self._params['ids']
data:ChatRequestData = self._params['data']
args = ids
args.update(
{
'response':self._response,
'conversation_id': data.conversation_id
}
)
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
args = ids
msgEnt_event = MessageEnd_DifyChatResponseEvent(**args)
if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event)
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
try:
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
except asyncio.TimeoutError:
pass
class IDManager:
def createID(self):
return {
"message_id" : str(uuid.uuid4()),
'task_id':str(uuid.uuid4()),
'workflow_run_id': str(uuid.uuid4()),
"workflow_id": str(uuid.uuid4())
}
class ChatStreamResponse(StreamingResponse): class ChatStreamResponse(StreamingResponse):
TEXT_PREFIX = "data: " TEXT_PREFIX = "data: "
DATA_PREFIX = "data: " DATA_PREFIX = "data: "
ids:Dict[str,Any] = {}
data:ChatRequestData = None
@classmethod @classmethod
def convert_text(cls, token: str): def convert_Message(cls, token: str):
# Escape newlines and double quotes to avoid breaking the stream params = cls.ids
#token = json.dumps(token) params.update({
'answer':token,
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}" 'conversation_id':cls.data.conversation_id
return "\n" })
event = Message_DifyChatResponseEvent(**params)
@classmethod data_str = json.dumps(event.dict())
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n\n" return f"{cls.DATA_PREFIX}{data_str}\n\n"
@classmethod @classmethod
def convert_event(cls, event: DifyChatResponseEvent): def convert_Event(cls, data: dict):
data_str = json.dumps(event.dict()) data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}{data_str}\n\n" return f"{cls.DATA_PREFIX}{data_str}\n\n"
def __init__( def __init__(
@@ -269,8 +381,11 @@ class ChatStreamResponse(StreamingResponse):
request: Request, request: Request,
event_handler: ChatEventCallbackHandler, event_handler: ChatEventCallbackHandler,
response: StreamingAgentChatResponse, response: StreamingAgentChatResponse,
data: ChatRequestData data: ChatRequestData,
ids:Dict[str,Any]
): ):
ChatStreamResponse.ids = ids
ChatStreamResponse.data = data
content = ChatStreamResponse.content_generator( content = ChatStreamResponse.content_generator(
request, event_handler, response, data request, event_handler, response, data
) )
@@ -284,41 +399,26 @@ class ChatStreamResponse(StreamingResponse):
response: StreamingAgentChatResponse, response: StreamingAgentChatResponse,
data: ChatRequestData data: ChatRequestData
): ):
ids = IDManager().createID()
# Yield the text response # Yield the text response
async def _chat_response_generator(): async def _chat_response_generator():
final_response = "" final_response = ""
async for token in response.async_response_gen(): async for token in response.async_response_gen():
final_response += token final_response += token
args = ids yield ChatStreamResponse.convert_Message(token)
args['answer'] = token
args['conversation_id'] = data.conversation_id
event = Message_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(event)
#yield ChatStreamResponse.convert_text(token)
# 存储消息历史 # 存储消息历史
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response) message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
# the text_generator is the leading stream, once it's finished, also finish the event stream # the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True event_handler.is_done = True
# 发送工作流结束事件
args = ids
args['response'] = final_response
args['conversation_id'] = data.conversation_id
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids)
yield ChatStreamResponse.convert_event(msgEnt_event)
# Yield the events from the event handler # Yield the events from the event handler
async def _event_generator(): async def _event_generator():
async for event in event_handler.async_event_gen(): async for event in event_handler.async_event_gen():
event_response = event.to_response() event_response = event.to_response()
if event_response is not None: if event_response is not None:
yield ChatStreamResponse.convert_text("") yield ChatStreamResponse.convert_Event(event_response)
combine = stream.merge(_chat_response_generator(), _event_generator()) combine = stream.merge(_chat_response_generator(), _event_generator())
is_stream_started = False is_stream_started = False
@@ -327,25 +427,11 @@ class ChatStreamResponse(StreamingResponse):
if not is_stream_started: if not is_stream_started:
is_stream_started = True is_stream_started = True
# 发送工作流开始事件
args = ids
args['use_id'] = data.user
args['query'] = data.query
args['conversation_id'] = data.conversation_id
wf_event = Workflow_started_DifyChatResponseEvent(**args)
yield ChatStreamResponse.convert_event(wf_event)
# Stream a blank message to start the stream
# 发送一个空消息事件
#yield ChatStreamResponse.convert_text("")
yield output yield output
if await request.is_disconnected(): if await request.is_disconnected():
break break
@v.post("/chat-messages") @v.post("/chat-messages")
async def post_conversations(request: Request, data: ChatRequestData): async def post_conversations(request: Request, data: ChatRequestData):
userMng.findNoExistCreate(data.user) userMng.findNoExistCreate(data.user)
@@ -365,14 +451,15 @@ async def post_conversations(request: Request, data: ChatRequestData):
chat_engine = get_chat_engine(filters=filters, params=params) chat_engine = get_chat_engine(filters=filters, params=params)
# 启动聊天事件监听 # 启动聊天事件监听
event_handler = ChatEventCallbackHandler() ids = IDManager().createID()
event_handler = ChatEventCallbackHandler(ids = ids,data = data)
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
# 执行异步聊天 # 执行异步聊天
response = await chat_engine.astream_chat(data.query) response = await chat_engine.astream_chat(data.query)
# 返回异步消息回应 # 返回异步消息回应
return ChatStreamResponse(request, event_handler, response, data) return ChatStreamResponse(request, event_handler, response, data,ids)
@v.get("/messages") @v.get("/messages")
async def query_messages(user:str, conversation_id:str): async def query_messages(user:str, conversation_id:str):
@@ -388,8 +475,9 @@ async def query_messages(user:str, conversation_id:str):
for record in records: for record in records:
res = record.dict() res = record.dict()
feeds = feedback().query(res['id'])
res["message_files"] = [] res["message_files"] = []
res["feedback"] = '' res["feedback"] = {'rating':feeds['rating'] } if feeds != None else ''
res["retriever_resources"] = [] res["retriever_resources"] = []
res["created_at"] = 1723444905 res["created_at"] = 1723444905
res["agent_thoughts"] = [] res["agent_thoughts"] = []
@@ -440,48 +528,22 @@ async def query_conversations(user:str, first_id:str = None, limit:str = None, p
async def query_parameters(user:str): async def query_parameters(user:str):
params = parameter().get(user) params = parameter().get(user)
if len(params) == 0: if len(params) == 0:
params = { params = BaseConfig().ParamterCfg()
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
"suggested_questions": [],
"suggested_questions_after_answer": {
"enabled": False
},
"speech_to_text": {
"enabled": False
},
"text_to_speech": {
"enabled": False,
"language": "",
"voice": ""
},
"retriever_resource": {
"enabled": True
},
"annotation_reply": {
"enabled": False
},
"more_like_this": {
"enabled": False
},
"user_input_form": [],
"sensitive_word_avoidance": {
"enabled": False
},
"file_upload": {
"image": {
"enabled": False,
"number_limits": 3,
"transfer_methods": [
"remote_url"
]
}
},
"system_parameters": {
"image_file_size_limit": "10"
}
}
return params return params
@v.post("/messages/{message_id}/feedbacks")
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
if params['rating'] =='null':
feedback().delete(message_id)
else:
condition = {'id':message_id}
results = message().query(**condition)
if len(results) > 0:
result = results[0]
feedback().add(message_id=message_id,query=result['query'],
answer=result['answer'],rating=params['rating'])
@r.post("") @r.post("")
def upload_file(request: ChatFileUploadRequest) -> List[str]: def upload_file(request: ChatFileUploadRequest) -> List[str]:
pass pass
+32 -2
View File
@@ -25,7 +25,7 @@ class conversations:
return None return None
def add(self,id:str, user_id:str, name:str): def add(self,id:str, user_id:str, name:str):
template = BaseConfig.ConversationCfg template = BaseConfig().ConversationCfg()
template['id'] = id template['id'] = id
template['user_id'] = user_id template['user_id'] = user_id
template['name'] = name template['name'] = name
@@ -111,7 +111,7 @@ class message:
return datas return datas
def add(self,user_id:str,conversation_id:str,query:str,answer:str): def add(self,user_id:str,conversation_id:str,query:str,answer:str):
template = BaseConfig.MessageCfg template = BaseConfig.MessageCfg()
template['id'] = str(uuid.uuid4()) template['id'] = str(uuid.uuid4())
template['user_id'] = user_id template['user_id'] = user_id
template['conversation_id'] = conversation_id template['conversation_id'] = conversation_id
@@ -122,4 +122,34 @@ class message:
def delete(self,user_id:str): def delete(self,user_id:str):
dbManage.delete(self._tableName,user_id = user_id) dbManage.delete(self._tableName,user_id = user_id)
def query(self,**condition):
results = []
records = dbManage.query(self._tableName,**condition)
for record in records:
results.append(record.dict())
return results
class feedback:
def __init__(self) -> None:
self._tableName = 'feedbacks'
dbManage.createTable(self._tableName)
def add(self,message_id:str,query:str,answer:str,rating:str):
record = {
'message_id': message_id,
'query': query,
'answer': answer,
'rating': rating,
}
dbManage.addRecord(self._tableName,record)
def delete(self,message_id:str):
cond = {'message_id':message_id}
dbManage.delete(self._tableName,**cond)
def query(self,message_id:str):
cond = {'message_id':message_id}
records = dbManage.query(self._tableName,**cond)
if len(records) > 0:
return records[0].dict()
return None
+17 -8
View File
@@ -1,8 +1,13 @@
from pydantic import BaseModel
import os
class BaseConfig(BaseModel):
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
class BaseConfig: def ParamterCfg(self):
ParamterCfg = { questions = os.getenv("CONVERSATION_STARTERS", "dev")
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!", return{
"suggested_questions": [], "opening_statement": self.projectInfo,
"suggested_questions": questions.split('\n'),
"suggested_questions_after_answer": { "suggested_questions_after_answer": {
"enabled": False "enabled": False
}, },
@@ -41,18 +46,20 @@ class BaseConfig:
} }
} }
ConversationCfg = { def ConversationCfg(self):
return{
"id": "", "id": "",
'user_id':'', 'user_id':'',
"name": "", "name": "",
"inputs": {}, "inputs": {},
"status": "normal", "status": "normal",
"introduction": ParamterCfg['opening_statement'], "introduction": self.projectInfo,
"created_at":'' "created_at":''
} }
@classmethod
MessageCfg = { def MessageCfg(cls):
return {
"id": "", "id": "",
'user_id':'', 'user_id':'',
"conversation_id": "", "conversation_id": "",
@@ -60,3 +67,5 @@ class BaseConfig:
"query": "", "query": "",
"answer": "" "answer": ""
} }
+22 -9
View File
@@ -2,7 +2,7 @@ import os
from typing import Dict, List, Any from typing import Dict, List, Any
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import create_engine, Column, String, Integer, JSON from sqlalchemy import create_engine, Column, String, Integer, JSON,Float
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.orm import sessionmaker, declarative_base
@@ -24,10 +24,6 @@ class ConversationOrm(Base):
if 'name' in data: if 'name' in data:
self.name = data['name'] self.name = data['name']
class UserOrm(Base): class UserOrm(Base):
__tablename__ = "user" __tablename__ = "user"
@@ -51,6 +47,14 @@ class MessagesOrm(Base):
query = Column(String) query = Column(String)
answer = Column(String) answer = Column(String)
class FeedBackOrm(Base):
__tablename__ = "feedbacks"
message_id = Column(String,primary_key=True)
query = Column(String)
answer = Column(String)
rating = Column(String)
#数据结构 #数据结构
class ConversationModel(BaseModel): class ConversationModel(BaseModel):
id: str id: str
@@ -61,7 +65,6 @@ class ConversationModel(BaseModel):
created_at: int created_at: int
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
@@ -73,7 +76,6 @@ class UserModel(BaseModel):
createtime: str createtime: str
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
@@ -86,7 +88,6 @@ class ParametersModel(BaseModel):
value : Dict[str, Any] value : Dict[str, Any]
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
@@ -101,13 +102,25 @@ class MessagesModel(BaseModel):
answer : str answer : str
class Config: class Config:
#orm_mode = True
from_attributes=True from_attributes=True
@classmethod @classmethod
def orm(cls): def orm(cls):
return MessagesOrm return MessagesOrm
class FeedBackModel(BaseModel):
message_id :str
query :str
answer :str
rating :str
class Config:
from_attributes=True
@classmethod
def orm(cls):
return FeedBackOrm
class DBManager: class DBManager:
def __init__(self) -> None: def __init__(self) -> None:
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL") DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
+3 -1
View File
@@ -1,7 +1,7 @@
from typing import Dict, Any from typing import Dict, Any
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional
class ChatRequestData(BaseModel): class ChatRequestData(BaseModel):
inputs: Dict[str,Any] inputs: Dict[str,Any]
@@ -13,3 +13,5 @@ class ChatRequestData(BaseModel):
class ChatFileUploadRequest(BaseModel): class ChatFileUploadRequest(BaseModel):
base64: str base64: str
+8 -1
View File
@@ -17,7 +17,7 @@ aiostream = "^0.6.2"
llama-index = "0.10.63" llama-index = "0.10.63"
cachetools = "^5.3.3" cachetools = "^5.3.3"
protobuf = "4.25.4" protobuf = "4.25.4"
nltk = "^3.8.2" nltk = "^3.9.1"
jieba = "^0.42.1" jieba = "^0.42.1"
#arize-phoenix = "^4.12.0" #arize-phoenix = "^4.12.0"
@@ -35,6 +35,7 @@ chroma="^0.2.0"
llama-index-vector-stores-chroma = "^0.1.10" llama-index-vector-stores-chroma = "^0.1.10"
llama-index-readers-json = "^0.1.5" llama-index-readers-json = "^0.1.5"
llama-index-retrievers-bm25 = "^0.2.2" llama-index-retrievers-bm25 = "^0.2.2"
llama-index-experimental = "^0.2.0"
duckduckgo_search = "^6.2.6" duckduckgo_search = "^6.2.6"
@@ -62,6 +63,12 @@ version = "^0.8"
version = "0.0.7" version = "0.0.7"
[[tool.poetry.source]]
name = "mirrors"
url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
priority = "default"
[build-system] [build-system]
requires = [ "poetry-core" ] requires = [ "poetry-core" ]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
+138
View File
@@ -0,0 +1,138 @@
import nest_asyncio
nest_asyncio.apply()
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import VectorStoreIndex
from llama_index.core.evaluation import (
FaithfulnessEvaluator,
DatasetGenerator,
CorrectnessEvaluator,
SemanticSimilarityEvaluator,
)
from llama_index.experimental.param_tuner import ParamTuner
from llama_index.experimental.param_tuner.base import RunResult
from llama_index.llms.openai import OpenAI
import asyncio
# 初始化环境
from app.observability import init_observability
from app.settings import init_settings
from dotenv import load_dotenv
load_dotenv()
init_settings()
init_observability()
# 读取文档
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
# 参数字典
param_dict = {
"chunk_size": [512, 1024],
"top_k": [1, 5],
"temperature": [0.1, 1.0]
}
# 辅助函数
def _build_index(chunk_size, documents):
# 构建索引
splitter = SentenceSplitter(chunk_size=chunk_size)
vector_index = VectorStoreIndex.from_documents(
documents, transformations=[splitter],
)
return vector_index
# 评估函数
def evaluate_query_engine(query_engine, questions):
loop = asyncio.get_event_loop()
correct, total = loop.run_until_complete(_evaluate_query_engine_async(query_engine, questions))
return correct, total
async def _evaluate_query_engine_async(query_engine, questions):
c = [query_engine.aquery(q) for q in questions]
gathering_future = asyncio.gather(*c)
results = await gathering_future
total_correct = 0
for r in results:
eval_result = (
1 if FaithfulnessEvaluator().evaluate_response(response=r).passing else 0
)
total_correct += eval_result
return total_correct, len(results)
# 生成问题
question_generator = DatasetGenerator.from_documents(documents)
eval_questions = question_generator.generate_questions_from_nodes(1) # 假设生成10个问题
# 打印生成的问题
for i, q in enumerate(eval_questions, start=1):
print(f"问题 {i}: {q}")
# 目标函数
def objective_function(params_dict, documents, questions):
chunk_size = params_dict["chunk_size"]
top_k = params_dict["top_k"]
temperature = params_dict["temperature"]
# 构建索引
vector_index = _build_index(chunk_size, documents)
# 查询引擎
query_engine = vector_index.as_query_engine(
similarity_top_k=top_k, temperature=temperature
)
# 评估查询引擎
correct, total = 0, len(questions)
question_answers = [] # 添加列表来收集问题和答案
for question in questions:
response = query_engine.query(question)
if response is not None:
question_answers.append((question, response.response))
eval_result = FaithfulnessEvaluator().evaluate_response(response=response, query_str=question)
if eval_result.passing:
correct += 1
# 计算分数
score = correct / total if total > 0 else 0
return RunResult(score=score, params=params_dict, question_answers=question_answers)
# 创建 ParamTuner 实例
param_tuner = ParamTuner(
param_fn=lambda params_dict: objective_function(params_dict, documents, eval_questions),
param_dict=param_dict,
show_progress=True,
)
# 调用 tune 方法
results = param_tuner.tune()
best_result = results.best_run_result
best_top_k = best_result.params["top_k"]
best_chunk_size = best_result.params["chunk_size"]
best_temperature = best_result.params["temperature"]
print(f"得分: {best_result.score}")
print(f"Top-k: {best_top_k}")
print(f"文本块大小: {best_chunk_size}")
print(f"温度: {best_temperature}")
# 使用最佳参数再次运行查询引擎,并打印问题与答案
best_vector_index = _build_index(best_chunk_size, documents)
best_query_engine = best_vector_index.as_query_engine(
similarity_top_k=best_top_k, temperature=best_temperature
)
best_question_answers = []
for question in eval_questions:
response = best_query_engine.query(question)
if response is not None:
best_question_answers.append((question, response.response))
# 打印最佳参数下的问题与答案
for i, (question, answer) in enumerate(best_question_answers, start=1):
print(f"最佳参数 - 问题 {i}: {question}\n答案: {answer}\n")
+81
View File
@@ -0,0 +1,81 @@
from app.observability import init_observability
from app.settings import init_settings
from dotenv import load_dotenv
import nest_asyncio
nest_asyncio.apply()
load_dotenv()
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import (
VectorStoreIndex,
SimpleDirectoryReader,
Response,
)
from llama_index.core.evaluation import (
FaithfulnessEvaluator,
DatasetGenerator,
CorrectnessEvaluator,
SemanticSimilarityEvaluator,)
init_settings()
init_observability()
faith_evaluator_qwen = FaithfulnessEvaluator() #诚实度评测
corr_evaluator_qwen = CorrectnessEvaluator() #准确率评测
Seman_evaluator_qwen = SemanticSimilarityEvaluator()#嵌入相似度评估
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
splitter = SentenceSplitter(chunk_size=512)
vector_index = VectorStoreIndex.from_documents(
documents, transformations=[splitter],
)
# # 运行评估
# query_engine = vector_index.as_query_engine()
# response_vector = query_engine.query("工程监理费的金额是多少?")
# eval_result = evaluator_qwen.evaluate_response(response=response_vector)
# print(response_vector)
# print(eval_result)
question_generator = DatasetGenerator.from_documents(documents)
eval_questions = question_generator.generate_questions_from_nodes(5)
print(eval_questions)
import asyncio
async def evaluate_query_engine_async(query_engine, questions):
c = [query_engine.aquery(q) for q in questions]
gathering_future = asyncio.gather(*c)
results = await gathering_future
#print(results)
total_correct = 0
for r in results:
eval_result = (
1 if faith_evaluator_qwen.evaluate_response(response=r).passing else 0
)
total_correct += eval_result
return total_correct, len(results)
def evaluate_query_engine(query_engine, questions):
loop = asyncio.get_event_loop()
correct, total = loop.run_until_complete(evaluate_query_engine_async(query_engine, questions))
return correct, total
# 使用 evaluate_query_engine 函数
vector_query_engine = vector_index.as_query_engine()
correct, total = evaluate_query_engine(vector_query_engine, eval_questions[:5])
print(f"score: {correct}/{total}")