Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 72ddf46fc7 | |||
| f57c0c84ef | |||
| 9ee24627c2 | |||
| 88761a5d10 |
@@ -1,8 +1,3 @@
|
|||||||
JIEBA_DATA=./nltk_data
|
|
||||||
NLTK_DATA=./nltk_data
|
|
||||||
SQLITE_DATABASE_URL=sqlite:///./source.db
|
|
||||||
DATA_SOURCE_CACHE=./restapi
|
|
||||||
|
|
||||||
# The Llama Cloud API key.
|
# The Llama Cloud API key.
|
||||||
# LLAMA_CLOUD_API_KEY=
|
# LLAMA_CLOUD_API_KEY=
|
||||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||||
@@ -85,4 +80,3 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
|
|||||||
- You can install any pip package (if it exists) by running a cell with pip install.
|
- You can install any pip package (if it exists) by running a cell with pip install.
|
||||||
"
|
"
|
||||||
|
|
||||||
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
|
|
||||||
@@ -1,8 +1,3 @@
|
|||||||
JIEBA_DATA=./nltk_data
|
|
||||||
NLTK_DATA=./nltk_data
|
|
||||||
SQLITE_DATABASE_URL=sqlite:///./source.db
|
|
||||||
DATA_SOURCE_CACHE=./restapi
|
|
||||||
|
|
||||||
# The Llama Cloud API key.
|
# The Llama Cloud API key.
|
||||||
# LLAMA_CLOUD_API_KEY=
|
# LLAMA_CLOUD_API_KEY=
|
||||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||||
@@ -116,4 +111,3 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
|
|||||||
- You can install any pip package (if it exists) by running a cell with pip install.
|
- You can install any pip package (if it exists) by running a cell with pip install.
|
||||||
"
|
"
|
||||||
|
|
||||||
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
|
|
||||||
+230
-233
@@ -3,7 +3,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
from aiostream import stream
|
from aiostream import stream
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
@@ -14,8 +13,7 @@ from llama_index.core.callbacks import CBEventType
|
|||||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||||
from llama_index.core.tools import ToolOutput
|
from llama_index.core.tools import ToolOutput
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from app.api.routers.request.base import userMng, conversations,message,parameter,feedback
|
from app.api.routers.request.base import userMng, conversations,message,parameter
|
||||||
from app.api.routers.request.baseConfig import *
|
|
||||||
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
|
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
|
||||||
from app.engine import get_chat_engine
|
from app.engine import get_chat_engine
|
||||||
import uuid
|
import uuid
|
||||||
@@ -26,138 +24,78 @@ api_router = r = APIRouter()
|
|||||||
v1_router = v = APIRouter()
|
v1_router = v = APIRouter()
|
||||||
|
|
||||||
class ChatCallbackEvent(BaseModel):
|
class ChatCallbackEvent(BaseModel):
|
||||||
event_type: ChatEventType
|
event_type: CBEventType
|
||||||
payload: Optional[Dict[str, Any]] = None
|
payload: Optional[Dict[str, Any]] = None
|
||||||
|
event_id: str = ""
|
||||||
|
|
||||||
def get_common_param(self)-> dict:
|
def get_retrieval_message(self) -> dict | None:
|
||||||
return {
|
if self.payload:
|
||||||
'event': self.event_type.name,
|
nodes = self.payload.get("nodes")
|
||||||
'conversation_id':self.payload.get("conversation_id"),
|
if nodes:
|
||||||
'message_id': self.payload.get("message_id"),
|
msg = f"根据查询检索到 {len(nodes)} 源文件"
|
||||||
'created_at': int(time.time()),
|
else:
|
||||||
'task_id': self.payload.get("task_id")
|
msg = f"查询检索中: '{self.payload.get('query_str')}'"
|
||||||
}
|
return {
|
||||||
|
"type": "events",
|
||||||
|
"data": {"title": msg},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_WorkflowStart_param(self) -> dict:
|
def get_tool_message(self) -> dict | None:
|
||||||
params = self.get_common_param()
|
func_call_args = self.payload.get("function_call")
|
||||||
params.update({
|
if func_call_args is not None and "tool" in self.payload:
|
||||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
tool = self.payload.get("tool")
|
||||||
'data':{
|
return {
|
||||||
"id": self.payload.get('workflow_run_id'),
|
"type": "events",
|
||||||
"workflow_id": self.payload.get('workflow_id'),
|
"data": {
|
||||||
"sequence_number": 1709,
|
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
|
||||||
"inputs": {
|
|
||||||
"sys.query": self.payload.get('query'),
|
|
||||||
"sys.files": [],
|
|
||||||
"sys.conversation_id": self.payload.get('conversation_id'),
|
|
||||||
"sys.user_id": self.payload.get('use_id')
|
|
||||||
},
|
},
|
||||||
"created_at": int(time.time())
|
|
||||||
}
|
}
|
||||||
})
|
|
||||||
return params
|
|
||||||
|
|
||||||
def get_WorkflowFinished_param(self) -> dict:
|
def _is_output_serializable(self, output: Any) -> bool:
|
||||||
params = self.get_common_param()
|
try:
|
||||||
params.update({
|
json.dumps(output)
|
||||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
return True
|
||||||
'data':{
|
except TypeError:
|
||||||
"id": self.payload.get('workflow_run_id'),
|
return False
|
||||||
"workflow_id": self.payload.get('workflow_id'),
|
|
||||||
"sequence_number": 1709,
|
|
||||||
"status": "succeeded",
|
|
||||||
"outputs": {
|
|
||||||
"answer": self.payload.get('response')
|
|
||||||
},
|
|
||||||
"error": '',
|
|
||||||
"elapsed_time": 36.03764106379822,
|
|
||||||
"total_tokens": 11707,
|
|
||||||
"total_steps": 10,
|
|
||||||
"created_by": {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"user": self.payload.get('use_id')
|
|
||||||
},
|
|
||||||
"created_at": int(time.time()),
|
|
||||||
"finished_at": int(time.time()),
|
|
||||||
"files": []
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return params
|
|
||||||
|
|
||||||
def get_NodeStart_param(self) -> dict:
|
|
||||||
params = self.get_common_param()
|
|
||||||
params.update({
|
|
||||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
|
||||||
'data':{
|
|
||||||
"id": self.payload.get('nodeid'),
|
|
||||||
"node_id": self.payload.get('nodeid'),
|
|
||||||
"node_type": "http-request",
|
|
||||||
"title": self.payload.get('title'),
|
|
||||||
"index": self.payload.get('index'),
|
|
||||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
|
||||||
"inputs": '',
|
|
||||||
"created_at": 1724398751,
|
|
||||||
"extras": {}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return params
|
|
||||||
|
|
||||||
def get_NodeFinished_param(self) -> dict:
|
def get_agent_tool_response(self) -> dict | None:
|
||||||
params = self.get_common_param()
|
response = self.payload.get("response")
|
||||||
params.update({
|
if response is not None:
|
||||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
sources = response.sources
|
||||||
'data':{
|
for source in sources:
|
||||||
"id": self.payload.get('nodeid'),
|
# Return the tool response here to include the toolCall information
|
||||||
"node_id": self.payload.get('nodeid'),
|
if isinstance(source, ToolOutput):
|
||||||
"node_type": "http-request",
|
if self._is_output_serializable(source.raw_output):
|
||||||
"title": self.payload.get('title'),
|
output = source.raw_output
|
||||||
"index": self.payload.get('index'),
|
else:
|
||||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
output = source.content
|
||||||
"inputs": '',
|
|
||||||
"process_data": '',
|
|
||||||
"outputs": '',
|
|
||||||
"status": "succeeded",
|
|
||||||
"error": '',
|
|
||||||
"elapsed_time": 0.10402441816404462,
|
|
||||||
"execution_metadata": '',
|
|
||||||
"created_at": 1724398751,
|
|
||||||
"finished_at": 1724398751,
|
|
||||||
"files": []
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return params
|
|
||||||
|
|
||||||
def get_Message_param(self) -> dict:
|
return {
|
||||||
params = self.get_common_param()
|
"type": "tools",
|
||||||
params.update({
|
"data": {
|
||||||
'id':self.payload.get('message_id'),
|
"toolOutput": {
|
||||||
'answer':self.payload.get('answer')
|
"output": output,
|
||||||
})
|
"isError": source.is_error,
|
||||||
return params
|
},
|
||||||
|
"toolCall": {
|
||||||
def get_MessageEnd_param(self) -> dict:
|
"id": None, # There is no tool id in the ToolOutput
|
||||||
params = self.get_common_param()
|
"name": source.tool_name,
|
||||||
params.update({
|
"input": source.raw_input,
|
||||||
'id':self.payload.get('message_id'),
|
},
|
||||||
'metadata':self.payload.get('metadata')
|
},
|
||||||
})
|
}
|
||||||
return params
|
|
||||||
|
|
||||||
def to_response(self)-> dict|None:
|
def to_response(self):
|
||||||
try:
|
try:
|
||||||
match self.event_type:
|
match self.event_type:
|
||||||
case "workflow_started":
|
case "retrieve":
|
||||||
return self.get_WorkflowStart_param()
|
return self.get_retrieval_message()
|
||||||
case "workflow_finished":
|
case "function_call":
|
||||||
return self.get_WorkflowFinished_param()
|
return self.get_tool_message()
|
||||||
case "node_started":
|
case "agent_step":
|
||||||
return self.get_NodeStart_param()
|
return self.get_agent_tool_response()
|
||||||
case 'node_finished':
|
|
||||||
return self.get_NodeFinished_param()
|
|
||||||
case 'message':
|
|
||||||
return self.get_Message_param()
|
|
||||||
case 'message_end':
|
|
||||||
return self.get_MessageEnd_param()
|
|
||||||
case _:
|
case _:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -168,7 +106,9 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
_aqueue: asyncio.Queue
|
_aqueue: asyncio.Queue
|
||||||
is_done: bool = False
|
is_done: bool = False
|
||||||
|
|
||||||
def __init__(self,**params):
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
"""Initialize the base callback handler."""
|
"""Initialize the base callback handler."""
|
||||||
ignored_events = [
|
ignored_events = [
|
||||||
# CBEventType.CHUNKING,
|
# CBEventType.CHUNKING,
|
||||||
@@ -179,23 +119,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
]
|
]
|
||||||
super().__init__(ignored_events, ignored_events)
|
super().__init__(ignored_events, ignored_events)
|
||||||
self._aqueue = asyncio.Queue()
|
self._aqueue = asyncio.Queue()
|
||||||
self._response:str = ''
|
|
||||||
self._params:Dict[str,Any] = params
|
|
||||||
self._nodeStack:deque = deque()
|
|
||||||
|
|
||||||
#添加工作流开始事件
|
|
||||||
data:ChatRequestData = self._params['data']
|
|
||||||
args:Dict[str,Any] = self._params['ids']
|
|
||||||
args.update(
|
|
||||||
{
|
|
||||||
'use_id': data.user,
|
|
||||||
'query': data.query,
|
|
||||||
'conversation_id': data.conversation_id
|
|
||||||
}
|
|
||||||
)
|
|
||||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
|
||||||
if wf_event.to_response() is not None:
|
|
||||||
self._aqueue.put_nowait(wf_event)
|
|
||||||
|
|
||||||
def on_event_start(
|
def on_event_start(
|
||||||
self,
|
self,
|
||||||
@@ -206,21 +129,9 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
) -> str:
|
) -> str:
|
||||||
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||||
|
|
||||||
self._nodeStack.append(event_id)
|
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||||
nindex = self._nodeStack.count() - 1
|
if event.to_response() is not None:
|
||||||
args:Dict[str,Any] = self._params['ids']
|
self._aqueue.put_nowait(event)
|
||||||
args.update(
|
|
||||||
{
|
|
||||||
'nodeid':event_id,
|
|
||||||
'title':event_type.name,
|
|
||||||
'index':nindex + 1,
|
|
||||||
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
|
|
||||||
}
|
|
||||||
)
|
|
||||||
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args)
|
|
||||||
if nd_event.to_response() is not None:
|
|
||||||
self._aqueue.put_nowait(nd_event)
|
|
||||||
|
|
||||||
|
|
||||||
def on_event_end(
|
def on_event_end(
|
||||||
self,
|
self,
|
||||||
@@ -230,25 +141,9 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||||
|
event = ChatCallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||||
#self.response = payload.get("response","")
|
if event.to_response() is not None:
|
||||||
args:Dict[str,Any] = self._params['ids']
|
self._aqueue.put_nowait(event)
|
||||||
nodeID = self._nodeStack[-1]
|
|
||||||
if nodeID == event_id:
|
|
||||||
nindex = self._nodeStack.count() - 1
|
|
||||||
args.update(
|
|
||||||
{
|
|
||||||
'nodeid':event_id,
|
|
||||||
'title':event_type.name,
|
|
||||||
'index':nindex + 1,
|
|
||||||
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
|
|
||||||
}
|
|
||||||
)
|
|
||||||
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args)
|
|
||||||
if nd_event.to_response() is not None:
|
|
||||||
self._aqueue.put_nowait(nd_event)
|
|
||||||
self._nodeStack.pop()
|
|
||||||
|
|
||||||
|
|
||||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||||
"""No-op."""
|
"""No-op."""
|
||||||
@@ -261,23 +156,6 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""No-op."""
|
"""No-op."""
|
||||||
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||||
data:ChatRequestData = self._params['data']
|
|
||||||
args:Dict[str,Any] = self._params['ids']
|
|
||||||
args.update(
|
|
||||||
{
|
|
||||||
'response':self._response,
|
|
||||||
'conversation_id': data.conversation_id
|
|
||||||
}
|
|
||||||
)
|
|
||||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
|
||||||
if wf_event.to_response() is not None:
|
|
||||||
self._aqueue.put_nowait(wf_event)
|
|
||||||
|
|
||||||
|
|
||||||
args:Dict[str,Any] = self._params['ids']
|
|
||||||
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
|
||||||
if msgEnt_event.to_response() is not None:
|
|
||||||
self._aqueue.put_nowait(msgEnt_event)
|
|
||||||
|
|
||||||
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
||||||
while not self._aqueue.empty() or not self.is_done:
|
while not self._aqueue.empty() or not self.is_done:
|
||||||
@@ -295,26 +173,95 @@ class IDManager:
|
|||||||
"workflow_id": str(uuid.uuid4())
|
"workflow_id": str(uuid.uuid4())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class DifyChatResponseEvent(BaseModel):
|
||||||
|
event: str
|
||||||
|
conversation_id: str
|
||||||
|
message_id: str
|
||||||
|
created_at: int = int(time.time())
|
||||||
|
task_id: str
|
||||||
|
|
||||||
|
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||||
|
workflow_run_id:str
|
||||||
|
data:Dict[str,Any]
|
||||||
|
def __init__(self,**args):
|
||||||
|
args['data'] = {
|
||||||
|
"id": args['workflow_run_id'],
|
||||||
|
"workflow_id": args['workflow_id'],
|
||||||
|
"sequence_number": 1709,
|
||||||
|
"inputs": {
|
||||||
|
"sys.query": args['query'],
|
||||||
|
"sys.files": [],
|
||||||
|
"sys.conversation_id": args['conversation_id'],
|
||||||
|
"sys.user_id": args['use_id']
|
||||||
|
},
|
||||||
|
"created_at": int(time.time())
|
||||||
|
}
|
||||||
|
args['event'] = 'workflow_started'
|
||||||
|
super().__init__(**args)
|
||||||
|
|
||||||
|
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||||
|
workflow_run_id:str
|
||||||
|
data:Dict[str,Any]
|
||||||
|
def __init__(self,**args):
|
||||||
|
args['event'] = 'workflow_finished'
|
||||||
|
args['data'] = {
|
||||||
|
"id": args['workflow_run_id'],
|
||||||
|
"workflow_id": args['workflow_id'],
|
||||||
|
"sequence_number": 1709,
|
||||||
|
"status": "succeeded",
|
||||||
|
"outputs": {
|
||||||
|
"answer": args['response']
|
||||||
|
},
|
||||||
|
"error": '',
|
||||||
|
"elapsed_time": 36.03764106379822,
|
||||||
|
"total_tokens": 11707,
|
||||||
|
"total_steps": 10,
|
||||||
|
"created_by": {
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"user": args['use_id']
|
||||||
|
},
|
||||||
|
"created_at": int(time.time()),
|
||||||
|
"finished_at": int(time.time()),
|
||||||
|
"files": []
|
||||||
|
}
|
||||||
|
super().__init__(**args)
|
||||||
|
|
||||||
|
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||||
|
id:str
|
||||||
|
answer:str
|
||||||
|
def __init__(self,**args):
|
||||||
|
args['id'] = args['message_id']
|
||||||
|
args['event'] = 'message'
|
||||||
|
super().__init__(**args)
|
||||||
|
|
||||||
|
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
|
||||||
|
id:str
|
||||||
|
metadata:Dict[str,Any] = {}
|
||||||
|
def __init__(self,**args):
|
||||||
|
args['id'] = args['message_id']
|
||||||
|
args['event'] = 'message_end'
|
||||||
|
super().__init__(**args)
|
||||||
|
|
||||||
class ChatStreamResponse(StreamingResponse):
|
class ChatStreamResponse(StreamingResponse):
|
||||||
TEXT_PREFIX = "data: "
|
TEXT_PREFIX = "data: "
|
||||||
DATA_PREFIX = "data: "
|
DATA_PREFIX = "data: "
|
||||||
ids:Dict[str,Any] = {}
|
|
||||||
data:ChatRequestData = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_Message(cls, token: str):
|
def convert_text(cls, token: str):
|
||||||
params = cls.ids
|
# Escape newlines and double quotes to avoid breaking the stream
|
||||||
params.update({
|
#token = json.dumps(token)
|
||||||
'answer':token,
|
|
||||||
'conversation_id':cls.data.conversation_id
|
#return f"data: {{"event": "message", "conversation_id": "80d85523-de92-4b9d-aca0-c48a5eacb068", "message_id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "created_at": 1724406492, "task_id": "802f3064-030d-42ac-a882-0e1293712d04", "id": "16a06b1b-a89b-49c0-bc15-123bd999f6d6", "answer": "{token}"}}"
|
||||||
})
|
return "\n"
|
||||||
event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params)
|
|
||||||
data_str = json.dumps(event.to_response())
|
@classmethod
|
||||||
|
def convert_data(cls, data: dict):
|
||||||
|
data_str = json.dumps(data)
|
||||||
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_Event(cls, data: dict):
|
def convert_event(cls, event: DifyChatResponseEvent):
|
||||||
data_str = json.dumps(data)
|
data_str = json.dumps(event.dict())
|
||||||
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -322,11 +269,8 @@ class ChatStreamResponse(StreamingResponse):
|
|||||||
request: Request,
|
request: Request,
|
||||||
event_handler: ChatEventCallbackHandler,
|
event_handler: ChatEventCallbackHandler,
|
||||||
response: StreamingAgentChatResponse,
|
response: StreamingAgentChatResponse,
|
||||||
data: ChatRequestData,
|
data: ChatRequestData
|
||||||
ids:Dict[str,Any]
|
|
||||||
):
|
):
|
||||||
ChatStreamResponse.ids = ids
|
|
||||||
ChatStreamResponse.data = data
|
|
||||||
content = ChatStreamResponse.content_generator(
|
content = ChatStreamResponse.content_generator(
|
||||||
request, event_handler, response, data
|
request, event_handler, response, data
|
||||||
)
|
)
|
||||||
@@ -340,26 +284,41 @@ class ChatStreamResponse(StreamingResponse):
|
|||||||
response: StreamingAgentChatResponse,
|
response: StreamingAgentChatResponse,
|
||||||
data: ChatRequestData
|
data: ChatRequestData
|
||||||
):
|
):
|
||||||
|
ids = IDManager().createID()
|
||||||
# Yield the text response
|
# Yield the text response
|
||||||
async def _chat_response_generator():
|
async def _chat_response_generator():
|
||||||
final_response = ""
|
final_response = ""
|
||||||
async for token in response.async_response_gen():
|
async for token in response.async_response_gen():
|
||||||
final_response += token
|
final_response += token
|
||||||
yield ChatStreamResponse.convert_Message(token)
|
args = ids
|
||||||
|
args['answer'] = token
|
||||||
|
args['conversation_id'] = data.conversation_id
|
||||||
|
event = Message_DifyChatResponseEvent(**args)
|
||||||
|
yield ChatStreamResponse.convert_event(event)
|
||||||
|
#yield ChatStreamResponse.convert_text(token)
|
||||||
|
|
||||||
# 存储消息历史
|
# 存储消息历史
|
||||||
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
|
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
|
||||||
|
|
||||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||||
event_handler.is_done = True
|
event_handler.is_done = True
|
||||||
|
# 发送工作流结束事件
|
||||||
|
args = ids
|
||||||
|
args['response'] = final_response
|
||||||
|
args['conversation_id'] = data.conversation_id
|
||||||
|
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
|
||||||
|
yield ChatStreamResponse.convert_event(wf_event)
|
||||||
|
|
||||||
|
msgEnt_event = MessageEnd_DifyChatResponseEvent(**ids)
|
||||||
|
yield ChatStreamResponse.convert_event(msgEnt_event)
|
||||||
|
|
||||||
|
|
||||||
# Yield the events from the event handler
|
# Yield the events from the event handler
|
||||||
async def _event_generator():
|
async def _event_generator():
|
||||||
async for event in event_handler.async_event_gen():
|
async for event in event_handler.async_event_gen():
|
||||||
event_response = event.to_response()
|
event_response = event.to_response()
|
||||||
if event_response is not None:
|
if event_response is not None:
|
||||||
yield ChatStreamResponse.convert_Event(event_response)
|
yield ChatStreamResponse.convert_text("")
|
||||||
|
|
||||||
combine = stream.merge(_chat_response_generator(), _event_generator())
|
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||||
is_stream_started = False
|
is_stream_started = False
|
||||||
@@ -368,11 +327,25 @@ class ChatStreamResponse(StreamingResponse):
|
|||||||
if not is_stream_started:
|
if not is_stream_started:
|
||||||
is_stream_started = True
|
is_stream_started = True
|
||||||
|
|
||||||
|
# 发送工作流开始事件
|
||||||
|
args = ids
|
||||||
|
args['use_id'] = data.user
|
||||||
|
args['query'] = data.query
|
||||||
|
args['conversation_id'] = data.conversation_id
|
||||||
|
wf_event = Workflow_started_DifyChatResponseEvent(**args)
|
||||||
|
yield ChatStreamResponse.convert_event(wf_event)
|
||||||
|
|
||||||
|
# Stream a blank message to start the stream
|
||||||
|
# 发送一个空消息事件
|
||||||
|
#yield ChatStreamResponse.convert_text("")
|
||||||
|
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@v.post("/chat-messages")
|
@v.post("/chat-messages")
|
||||||
async def post_conversations(request: Request, data: ChatRequestData):
|
async def post_conversations(request: Request, data: ChatRequestData):
|
||||||
userMng.findNoExistCreate(data.user)
|
userMng.findNoExistCreate(data.user)
|
||||||
@@ -392,15 +365,14 @@ async def post_conversations(request: Request, data: ChatRequestData):
|
|||||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||||
|
|
||||||
# 启动聊天事件监听
|
# 启动聊天事件监听
|
||||||
ids = IDManager().createID()
|
event_handler = ChatEventCallbackHandler()
|
||||||
event_handler = ChatEventCallbackHandler(ids = ids,data = data)
|
|
||||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||||
|
|
||||||
# 执行异步聊天
|
# 执行异步聊天
|
||||||
response = await chat_engine.astream_chat(data.query)
|
response = await chat_engine.astream_chat(data.query)
|
||||||
|
|
||||||
# 返回异步消息回应
|
# 返回异步消息回应
|
||||||
return ChatStreamResponse(request, event_handler, response, data,ids)
|
return ChatStreamResponse(request, event_handler, response, data)
|
||||||
|
|
||||||
@v.get("/messages")
|
@v.get("/messages")
|
||||||
async def query_messages(user:str, conversation_id:str):
|
async def query_messages(user:str, conversation_id:str):
|
||||||
@@ -416,9 +388,8 @@ async def query_messages(user:str, conversation_id:str):
|
|||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
res = record.dict()
|
res = record.dict()
|
||||||
feeds = feedback().query(res['id'])
|
|
||||||
res["message_files"] = []
|
res["message_files"] = []
|
||||||
res["feedback"] = {'rating':feeds['rating'] } if feeds != None else ''
|
res["feedback"] = ''
|
||||||
res["retriever_resources"] = []
|
res["retriever_resources"] = []
|
||||||
res["created_at"] = 1723444905
|
res["created_at"] = 1723444905
|
||||||
res["agent_thoughts"] = []
|
res["agent_thoughts"] = []
|
||||||
@@ -469,22 +440,48 @@ async def query_conversations(user:str, first_id:str = None, limit:str = None, p
|
|||||||
async def query_parameters(user:str):
|
async def query_parameters(user:str):
|
||||||
params = parameter().get(user)
|
params = parameter().get(user)
|
||||||
if len(params) == 0:
|
if len(params) == 0:
|
||||||
params = BaseConfig().ParamterCfg()
|
params = {
|
||||||
|
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
|
||||||
|
"suggested_questions": [],
|
||||||
|
"suggested_questions_after_answer": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"speech_to_text": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"text_to_speech": {
|
||||||
|
"enabled": False,
|
||||||
|
"language": "",
|
||||||
|
"voice": ""
|
||||||
|
},
|
||||||
|
"retriever_resource": {
|
||||||
|
"enabled": True
|
||||||
|
},
|
||||||
|
"annotation_reply": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"more_like_this": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"user_input_form": [],
|
||||||
|
"sensitive_word_avoidance": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"file_upload": {
|
||||||
|
"image": {
|
||||||
|
"enabled": False,
|
||||||
|
"number_limits": 3,
|
||||||
|
"transfer_methods": [
|
||||||
|
"remote_url"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"system_parameters": {
|
||||||
|
"image_file_size_limit": "10"
|
||||||
|
}
|
||||||
|
}
|
||||||
return params
|
return params
|
||||||
|
|
||||||
@v.post("/messages/{message_id}/feedbacks")
|
|
||||||
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
|
|
||||||
if params['rating'] =='null':
|
|
||||||
feedback().delete(message_id)
|
|
||||||
else:
|
|
||||||
condition = {'id':message_id}
|
|
||||||
results = message().query(**condition)
|
|
||||||
if len(results) > 0:
|
|
||||||
result = results[0]
|
|
||||||
feedback().add(message_id=message_id,query=result['query'],
|
|
||||||
answer=result['answer'],rating=params['rating'])
|
|
||||||
|
|
||||||
@r.post("")
|
@r.post("")
|
||||||
def upload_file(request: ChatFileUploadRequest) -> List[str]:
|
def upload_file(request: ChatFileUploadRequest) -> List[str]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ class conversations:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def add(self,id:str, user_id:str, name:str):
|
def add(self,id:str, user_id:str, name:str):
|
||||||
template = BaseConfig().ConversationCfg()
|
template = BaseConfig.ConversationCfg
|
||||||
template['id'] = id
|
template['id'] = id
|
||||||
template['user_id'] = user_id
|
template['user_id'] = user_id
|
||||||
template['name'] = name
|
template['name'] = name
|
||||||
@@ -111,7 +111,7 @@ class message:
|
|||||||
return datas
|
return datas
|
||||||
|
|
||||||
def add(self,user_id:str,conversation_id:str,query:str,answer:str):
|
def add(self,user_id:str,conversation_id:str,query:str,answer:str):
|
||||||
template = BaseConfig.MessageCfg()
|
template = BaseConfig.MessageCfg
|
||||||
template['id'] = str(uuid.uuid4())
|
template['id'] = str(uuid.uuid4())
|
||||||
template['user_id'] = user_id
|
template['user_id'] = user_id
|
||||||
template['conversation_id'] = conversation_id
|
template['conversation_id'] = conversation_id
|
||||||
@@ -122,34 +122,4 @@ class message:
|
|||||||
def delete(self,user_id:str):
|
def delete(self,user_id:str):
|
||||||
dbManage.delete(self._tableName,user_id = user_id)
|
dbManage.delete(self._tableName,user_id = user_id)
|
||||||
|
|
||||||
def query(self,**condition):
|
|
||||||
results = []
|
|
||||||
records = dbManage.query(self._tableName,**condition)
|
|
||||||
for record in records:
|
|
||||||
results.append(record.dict())
|
|
||||||
return results
|
|
||||||
|
|
||||||
class feedback:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._tableName = 'feedbacks'
|
|
||||||
dbManage.createTable(self._tableName)
|
|
||||||
|
|
||||||
def add(self,message_id:str,query:str,answer:str,rating:str):
|
|
||||||
record = {
|
|
||||||
'message_id': message_id,
|
|
||||||
'query': query,
|
|
||||||
'answer': answer,
|
|
||||||
'rating': rating,
|
|
||||||
}
|
|
||||||
dbManage.addRecord(self._tableName,record)
|
|
||||||
|
|
||||||
def delete(self,message_id:str):
|
|
||||||
cond = {'message_id':message_id}
|
|
||||||
dbManage.delete(self._tableName,**cond)
|
|
||||||
|
|
||||||
def query(self,message_id:str):
|
|
||||||
cond = {'message_id':message_id}
|
|
||||||
records = dbManage.query(self._tableName,**cond)
|
|
||||||
if len(records) > 0:
|
|
||||||
return records[0].dict()
|
|
||||||
return None
|
|
||||||
@@ -1,80 +1,62 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
import os
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
class BaseConfig(BaseModel):
|
class BaseConfig:
|
||||||
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
|
ParamterCfg = {
|
||||||
|
"opening_statement": "您好,我是配网D3造价软件小助手,您可以问我有关配网造价软件的相关问题!",
|
||||||
def ParamterCfg(self):
|
"suggested_questions": [],
|
||||||
questions = os.getenv("CONVERSATION_STARTERS", "dev")
|
"suggested_questions_after_answer": {
|
||||||
return{
|
"enabled": False
|
||||||
"opening_statement": self.projectInfo,
|
},
|
||||||
"suggested_questions": questions.split('\n'),
|
"speech_to_text": {
|
||||||
"suggested_questions_after_answer": {
|
"enabled": False
|
||||||
"enabled": False
|
},
|
||||||
},
|
"text_to_speech": {
|
||||||
"speech_to_text": {
|
"enabled": False,
|
||||||
"enabled": False
|
"language": "",
|
||||||
},
|
"voice": ""
|
||||||
"text_to_speech": {
|
},
|
||||||
|
"retriever_resource": {
|
||||||
|
"enabled": True
|
||||||
|
},
|
||||||
|
"annotation_reply": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"more_like_this": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"user_input_form": [],
|
||||||
|
"sensitive_word_avoidance": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"file_upload": {
|
||||||
|
"image": {
|
||||||
"enabled": False,
|
"enabled": False,
|
||||||
"language": "",
|
"number_limits": 3,
|
||||||
"voice": ""
|
"transfer_methods": [
|
||||||
},
|
"remote_url"
|
||||||
"retriever_resource": {
|
]
|
||||||
"enabled": True
|
|
||||||
},
|
|
||||||
"annotation_reply": {
|
|
||||||
"enabled": False
|
|
||||||
},
|
|
||||||
"more_like_this": {
|
|
||||||
"enabled": False
|
|
||||||
},
|
|
||||||
"user_input_form": [],
|
|
||||||
"sensitive_word_avoidance": {
|
|
||||||
"enabled": False
|
|
||||||
},
|
|
||||||
"file_upload": {
|
|
||||||
"image": {
|
|
||||||
"enabled": False,
|
|
||||||
"number_limits": 3,
|
|
||||||
"transfer_methods": [
|
|
||||||
"remote_url"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"system_parameters": {
|
|
||||||
"image_file_size_limit": "10"
|
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"system_parameters": {
|
||||||
|
"image_file_size_limit": "10"
|
||||||
}
|
}
|
||||||
|
}
|
||||||
def ConversationCfg(self):
|
|
||||||
return{
|
|
||||||
"id": "",
|
|
||||||
'user_id':'',
|
|
||||||
"name": "",
|
|
||||||
"inputs": {},
|
|
||||||
"status": "normal",
|
|
||||||
"introduction": self.projectInfo,
|
|
||||||
"created_at":''
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
ConversationCfg = {
|
||||||
def MessageCfg(cls):
|
"id": "",
|
||||||
return {
|
'user_id':'',
|
||||||
|
"name": "",
|
||||||
|
"inputs": {},
|
||||||
|
"status": "normal",
|
||||||
|
"introduction": ParamterCfg['opening_statement'],
|
||||||
|
"created_at":''
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MessageCfg = {
|
||||||
"id": "",
|
"id": "",
|
||||||
'user_id':'',
|
'user_id':'',
|
||||||
"conversation_id": "",
|
"conversation_id": "",
|
||||||
"inputs": {},
|
"inputs": {},
|
||||||
"query": "",
|
"query": "",
|
||||||
"answer": ""
|
"answer": ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ChatEventType(str, Enum):
|
|
||||||
WORKFLOW_START = "workflow_started"
|
|
||||||
WORKFLOW_FINISHED = "workflow_finished"
|
|
||||||
NODE_START = "node_started"
|
|
||||||
NODE_FINISHED = "node_finished"
|
|
||||||
MESSAGE = "message"
|
|
||||||
MESSAGE_END = "message_end"
|
|
||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import create_engine, Column, String, Integer, JSON,Float
|
from sqlalchemy import create_engine, Column, String, Integer, JSON
|
||||||
from sqlalchemy.engine.reflection import Inspector
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||||
|
|
||||||
@@ -24,6 +24,10 @@ class ConversationOrm(Base):
|
|||||||
if 'name' in data:
|
if 'name' in data:
|
||||||
self.name = data['name']
|
self.name = data['name']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class UserOrm(Base):
|
class UserOrm(Base):
|
||||||
__tablename__ = "user"
|
__tablename__ = "user"
|
||||||
|
|
||||||
@@ -47,14 +51,6 @@ class MessagesOrm(Base):
|
|||||||
query = Column(String)
|
query = Column(String)
|
||||||
answer = Column(String)
|
answer = Column(String)
|
||||||
|
|
||||||
class FeedBackOrm(Base):
|
|
||||||
__tablename__ = "feedbacks"
|
|
||||||
|
|
||||||
message_id = Column(String,primary_key=True)
|
|
||||||
query = Column(String)
|
|
||||||
answer = Column(String)
|
|
||||||
rating = Column(String)
|
|
||||||
|
|
||||||
#数据结构
|
#数据结构
|
||||||
class ConversationModel(BaseModel):
|
class ConversationModel(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
@@ -65,6 +61,7 @@ class ConversationModel(BaseModel):
|
|||||||
created_at: int
|
created_at: int
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
#orm_mode = True
|
||||||
from_attributes=True
|
from_attributes=True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -76,6 +73,7 @@ class UserModel(BaseModel):
|
|||||||
createtime: str
|
createtime: str
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
#orm_mode = True
|
||||||
from_attributes=True
|
from_attributes=True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -88,6 +86,7 @@ class ParametersModel(BaseModel):
|
|||||||
value : Dict[str, Any]
|
value : Dict[str, Any]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
#orm_mode = True
|
||||||
from_attributes=True
|
from_attributes=True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -102,25 +101,13 @@ class MessagesModel(BaseModel):
|
|||||||
answer : str
|
answer : str
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
#orm_mode = True
|
||||||
from_attributes=True
|
from_attributes=True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def orm(cls):
|
def orm(cls):
|
||||||
return MessagesOrm
|
return MessagesOrm
|
||||||
|
|
||||||
class FeedBackModel(BaseModel):
|
|
||||||
message_id :str
|
|
||||||
query :str
|
|
||||||
answer :str
|
|
||||||
rating :str
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
from_attributes=True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def orm(cls):
|
|
||||||
return FeedBackOrm
|
|
||||||
|
|
||||||
class DBManager:
|
class DBManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
|
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
class ChatRequestData(BaseModel):
|
class ChatRequestData(BaseModel):
|
||||||
inputs: Dict[str,Any]
|
inputs: Dict[str,Any]
|
||||||
@@ -12,6 +12,4 @@ class ChatRequestData(BaseModel):
|
|||||||
conversation_id: str = None
|
conversation_id: str = None
|
||||||
|
|
||||||
class ChatFileUploadRequest(BaseModel):
|
class ChatFileUploadRequest(BaseModel):
|
||||||
base64: str
|
base64: str
|
||||||
|
|
||||||
|
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
from typing import Any, Dict, List, Union, Callable, NamedTuple
|
from typing import Any, Dict, List, Union, Callable, NamedTuple
|
||||||
from bm25s.tokenization import *
|
from bm25s.tokenization import *
|
||||||
|
|
||||||
@@ -9,12 +8,9 @@ except ImportError:
|
|||||||
def tqdm(iterable, *args, **kwargs):
|
def tqdm(iterable, *args, **kwargs):
|
||||||
return iterable
|
return iterable
|
||||||
|
|
||||||
import jieba
|
|
||||||
jiebapath = os.environ.get("JIEBA_DATA", "")
|
|
||||||
jieba.set_dictionary(os.path.join(jiebapath, 'dict.txt')) #设置字典
|
|
||||||
jieba.initialize() #初始化jeiba
|
|
||||||
|
|
||||||
def chinese_tokenizer(text: str) -> List[str]:
|
def chinese_tokenizer(text: str) -> List[str]:
|
||||||
|
import jieba
|
||||||
from nltk.corpus import stopwords
|
from nltk.corpus import stopwords
|
||||||
tokens = jieba.lcut(text)
|
tokens = jieba.lcut(text)
|
||||||
return [token for token in tokens if token not in stopwords.words('chinese')]
|
return [token for token in tokens if token not in stopwords.words('chinese')]
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ from typing import Dict
|
|||||||
|
|
||||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
from llama_index.core.settings import Settings
|
from llama_index.core.settings import Settings
|
||||||
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
|
||||||
from llama_index.llms.xinference import Xinference
|
from llama_index.llms.xinference import Xinference
|
||||||
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
||||||
|
|
||||||
|
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||||
|
|
||||||
|
|
||||||
def get_node_postprocessors():
|
def get_node_postprocessors():
|
||||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||||
@@ -231,4 +232,4 @@ def init_mistral():
|
|||||||
#
|
#
|
||||||
# Settings.llm = MistralAI(model=os.getenv("MODEL"))
|
# Settings.llm = MistralAI(model=os.getenv("MODEL"))
|
||||||
# Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
# Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||||
pass
|
pass
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
-349046
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
@@ -17,7 +17,7 @@ aiostream = "^0.6.2"
|
|||||||
llama-index = "0.10.63"
|
llama-index = "0.10.63"
|
||||||
cachetools = "^5.3.3"
|
cachetools = "^5.3.3"
|
||||||
protobuf = "4.25.4"
|
protobuf = "4.25.4"
|
||||||
nltk = "^3.9.1"
|
nltk = "^3.8.2"
|
||||||
jieba = "^0.42.1"
|
jieba = "^0.42.1"
|
||||||
|
|
||||||
#arize-phoenix = "^4.12.0"
|
#arize-phoenix = "^4.12.0"
|
||||||
@@ -35,7 +35,6 @@ chroma="^0.2.0"
|
|||||||
llama-index-vector-stores-chroma = "^0.1.10"
|
llama-index-vector-stores-chroma = "^0.1.10"
|
||||||
llama-index-readers-json = "^0.1.5"
|
llama-index-readers-json = "^0.1.5"
|
||||||
llama-index-retrievers-bm25 = "^0.2.2"
|
llama-index-retrievers-bm25 = "^0.2.2"
|
||||||
llama-index-experimental = "^0.2.0"
|
|
||||||
|
|
||||||
duckduckgo_search = "^6.2.6"
|
duckduckgo_search = "^6.2.6"
|
||||||
|
|
||||||
@@ -63,12 +62,6 @@ version = "^0.8"
|
|||||||
version = "0.0.7"
|
version = "0.0.7"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[[tool.poetry.source]]
|
|
||||||
name = "mirrors"
|
|
||||||
url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
|
|
||||||
priority = "default"
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = [ "poetry-core" ]
|
requires = [ "poetry-core" ]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
import nest_asyncio
|
|
||||||
nest_asyncio.apply()
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
from llama_index.core import VectorStoreIndex
|
|
||||||
from llama_index.core.evaluation import (
|
|
||||||
FaithfulnessEvaluator,
|
|
||||||
DatasetGenerator,
|
|
||||||
CorrectnessEvaluator,
|
|
||||||
SemanticSimilarityEvaluator,
|
|
||||||
)
|
|
||||||
from llama_index.experimental.param_tuner import ParamTuner
|
|
||||||
from llama_index.experimental.param_tuner.base import RunResult
|
|
||||||
from llama_index.llms.openai import OpenAI
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# 初始化环境
|
|
||||||
from app.observability import init_observability
|
|
||||||
from app.settings import init_settings
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
init_settings()
|
|
||||||
init_observability()
|
|
||||||
|
|
||||||
# 读取文档
|
|
||||||
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
|
||||||
|
|
||||||
# 参数字典
|
|
||||||
param_dict = {
|
|
||||||
"chunk_size": [512, 1024],
|
|
||||||
"top_k": [1, 5],
|
|
||||||
"temperature": [0.1, 1.0]
|
|
||||||
}
|
|
||||||
|
|
||||||
# 辅助函数
|
|
||||||
def _build_index(chunk_size, documents):
|
|
||||||
# 构建索引
|
|
||||||
splitter = SentenceSplitter(chunk_size=chunk_size)
|
|
||||||
vector_index = VectorStoreIndex.from_documents(
|
|
||||||
documents, transformations=[splitter],
|
|
||||||
)
|
|
||||||
return vector_index
|
|
||||||
|
|
||||||
# 评估函数
|
|
||||||
def evaluate_query_engine(query_engine, questions):
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
correct, total = loop.run_until_complete(_evaluate_query_engine_async(query_engine, questions))
|
|
||||||
return correct, total
|
|
||||||
|
|
||||||
async def _evaluate_query_engine_async(query_engine, questions):
|
|
||||||
c = [query_engine.aquery(q) for q in questions]
|
|
||||||
gathering_future = asyncio.gather(*c)
|
|
||||||
results = await gathering_future
|
|
||||||
|
|
||||||
total_correct = 0
|
|
||||||
for r in results:
|
|
||||||
eval_result = (
|
|
||||||
1 if FaithfulnessEvaluator().evaluate_response(response=r).passing else 0
|
|
||||||
)
|
|
||||||
total_correct += eval_result
|
|
||||||
|
|
||||||
return total_correct, len(results)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 生成问题
|
|
||||||
question_generator = DatasetGenerator.from_documents(documents)
|
|
||||||
eval_questions = question_generator.generate_questions_from_nodes(1) # 假设生成10个问题
|
|
||||||
|
|
||||||
# 打印生成的问题
|
|
||||||
for i, q in enumerate(eval_questions, start=1):
|
|
||||||
print(f"问题 {i}: {q}")
|
|
||||||
|
|
||||||
# 目标函数
|
|
||||||
def objective_function(params_dict, documents, questions):
|
|
||||||
chunk_size = params_dict["chunk_size"]
|
|
||||||
top_k = params_dict["top_k"]
|
|
||||||
temperature = params_dict["temperature"]
|
|
||||||
|
|
||||||
# 构建索引
|
|
||||||
vector_index = _build_index(chunk_size, documents)
|
|
||||||
|
|
||||||
# 查询引擎
|
|
||||||
query_engine = vector_index.as_query_engine(
|
|
||||||
similarity_top_k=top_k, temperature=temperature
|
|
||||||
)
|
|
||||||
|
|
||||||
# 评估查询引擎
|
|
||||||
correct, total = 0, len(questions)
|
|
||||||
question_answers = [] # 添加列表来收集问题和答案
|
|
||||||
|
|
||||||
for question in questions:
|
|
||||||
response = query_engine.query(question)
|
|
||||||
if response is not None:
|
|
||||||
question_answers.append((question, response.response))
|
|
||||||
eval_result = FaithfulnessEvaluator().evaluate_response(response=response, query_str=question)
|
|
||||||
if eval_result.passing:
|
|
||||||
correct += 1
|
|
||||||
|
|
||||||
# 计算分数
|
|
||||||
score = correct / total if total > 0 else 0
|
|
||||||
return RunResult(score=score, params=params_dict, question_answers=question_answers)
|
|
||||||
|
|
||||||
# 创建 ParamTuner 实例
|
|
||||||
param_tuner = ParamTuner(
|
|
||||||
param_fn=lambda params_dict: objective_function(params_dict, documents, eval_questions),
|
|
||||||
param_dict=param_dict,
|
|
||||||
show_progress=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 调用 tune 方法
|
|
||||||
results = param_tuner.tune()
|
|
||||||
best_result = results.best_run_result
|
|
||||||
best_top_k = best_result.params["top_k"]
|
|
||||||
best_chunk_size = best_result.params["chunk_size"]
|
|
||||||
best_temperature = best_result.params["temperature"]
|
|
||||||
print(f"得分: {best_result.score}")
|
|
||||||
print(f"Top-k: {best_top_k}")
|
|
||||||
print(f"文本块大小: {best_chunk_size}")
|
|
||||||
print(f"温度: {best_temperature}")
|
|
||||||
|
|
||||||
# 使用最佳参数再次运行查询引擎,并打印问题与答案
|
|
||||||
best_vector_index = _build_index(best_chunk_size, documents)
|
|
||||||
best_query_engine = best_vector_index.as_query_engine(
|
|
||||||
similarity_top_k=best_top_k, temperature=best_temperature
|
|
||||||
)
|
|
||||||
|
|
||||||
best_question_answers = []
|
|
||||||
for question in eval_questions:
|
|
||||||
response = best_query_engine.query(question)
|
|
||||||
if response is not None:
|
|
||||||
best_question_answers.append((question, response.response))
|
|
||||||
|
|
||||||
# 打印最佳参数下的问题与答案
|
|
||||||
for i, (question, answer) in enumerate(best_question_answers, start=1):
|
|
||||||
print(f"最佳参数 - 问题 {i}: {question}\n答案: {answer}\n")
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
from app.observability import init_observability
|
|
||||||
from app.settings import init_settings
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
import nest_asyncio
|
|
||||||
nest_asyncio.apply()
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
from llama_index.core import (
|
|
||||||
VectorStoreIndex,
|
|
||||||
SimpleDirectoryReader,
|
|
||||||
Response,
|
|
||||||
)
|
|
||||||
from llama_index.core.evaluation import (
|
|
||||||
FaithfulnessEvaluator,
|
|
||||||
DatasetGenerator,
|
|
||||||
CorrectnessEvaluator,
|
|
||||||
SemanticSimilarityEvaluator,)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
init_settings()
|
|
||||||
init_observability()
|
|
||||||
|
|
||||||
faith_evaluator_qwen = FaithfulnessEvaluator() #诚实度评测
|
|
||||||
corr_evaluator_qwen = CorrectnessEvaluator() #准确率评测
|
|
||||||
Seman_evaluator_qwen = SemanticSimilarityEvaluator()#嵌入相似度评估
|
|
||||||
|
|
||||||
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
|
||||||
|
|
||||||
splitter = SentenceSplitter(chunk_size=512)
|
|
||||||
|
|
||||||
|
|
||||||
vector_index = VectorStoreIndex.from_documents(
|
|
||||||
documents, transformations=[splitter],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# # 运行评估
|
|
||||||
# query_engine = vector_index.as_query_engine()
|
|
||||||
# response_vector = query_engine.query("工程监理费的金额是多少?")
|
|
||||||
# eval_result = evaluator_qwen.evaluate_response(response=response_vector)
|
|
||||||
|
|
||||||
# print(response_vector)
|
|
||||||
# print(eval_result)
|
|
||||||
|
|
||||||
|
|
||||||
question_generator = DatasetGenerator.from_documents(documents)
|
|
||||||
eval_questions = question_generator.generate_questions_from_nodes(5)
|
|
||||||
print(eval_questions)
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
async def evaluate_query_engine_async(query_engine, questions):
|
|
||||||
c = [query_engine.aquery(q) for q in questions]
|
|
||||||
gathering_future = asyncio.gather(*c)
|
|
||||||
results = await gathering_future
|
|
||||||
#print(results)
|
|
||||||
|
|
||||||
total_correct = 0
|
|
||||||
for r in results:
|
|
||||||
eval_result = (
|
|
||||||
1 if faith_evaluator_qwen.evaluate_response(response=r).passing else 0
|
|
||||||
)
|
|
||||||
total_correct += eval_result
|
|
||||||
|
|
||||||
return total_correct, len(results)
|
|
||||||
|
|
||||||
def evaluate_query_engine(query_engine, questions):
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
correct, total = loop.run_until_complete(evaluate_query_engine_async(query_engine, questions))
|
|
||||||
return correct, total
|
|
||||||
|
|
||||||
# 使用 evaluate_query_engine 函数
|
|
||||||
vector_query_engine = vector_index.as_query_engine()
|
|
||||||
correct, total = evaluate_query_engine(vector_query_engine, eval_questions[:5])
|
|
||||||
|
|
||||||
print(f"score: {correct}/{total}")
|
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
from dotenv import load_dotenv
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
from llama_index.core.evaluation import CorrectnessEvaluator
|
|
||||||
from app.engine import get_chat_engine
|
|
||||||
from app.engine.index import get_index
|
|
||||||
from app.observability import init_observability
|
|
||||||
from app.settings import init_settings
|
|
||||||
|
|
||||||
init_settings()
|
|
||||||
init_observability()
|
|
||||||
|
|
||||||
index = get_index()
|
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import asyncio
|
|
||||||
import nest_asyncio
|
|
||||||
nest_asyncio.apply()
|
|
||||||
from llama_index.core.prompts import (
|
|
||||||
ChatMessage,
|
|
||||||
ChatPromptTemplate,
|
|
||||||
MessageRole
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_SYSTEM_TEMPLATE = """
|
|
||||||
您是一个问答聊天机器人的专业评估系统。
|
|
||||||
|
|
||||||
您将获得以下信息:
|
|
||||||
|
|
||||||
- 用户查询,
|
|
||||||
- 生成的回答,
|
|
||||||
|
|
||||||
也可能提供一个参考答案作为评估的依据。
|
|
||||||
|
|
||||||
您的任务是判断生成回答的相关性和正确性。
|
|
||||||
输出一个代表全面评估的单一分数。
|
|
||||||
您必须在一行中仅返回该分数。
|
|
||||||
不要以其他任何格式返回答案。
|
|
||||||
在单独的一行提供给定分数的理由。
|
|
||||||
|
|
||||||
请遵循以下评分指南:
|
|
||||||
|
|
||||||
- 您的分数必须在1到5之间,其中1是最差,5是最好的。
|
|
||||||
-如果生成的回答与用户查询不相关,您应该给出1分。
|
|
||||||
-如果生成的回答相关但包含错误,您应该给出2到3分之间的分数。
|
|
||||||
-如果生成的回答相关且完全正确,您应该给出4到5分之间的分数。
|
|
||||||
示例响应:
|
|
||||||
4.0
|
|
||||||
生成的回答与参考答案的指标完全相同,但不够精炼。
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
DEFAULT_USER_TEMPLATE = """
|
|
||||||
## User Query
|
|
||||||
{query}
|
|
||||||
|
|
||||||
## Reference Answer
|
|
||||||
{reference_answer}
|
|
||||||
|
|
||||||
## Generated Answer
|
|
||||||
{generated_answer}
|
|
||||||
"""
|
|
||||||
|
|
||||||
DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
|
|
||||||
message_templates=[
|
|
||||||
ChatMessage(role=MessageRole.SYSTEM, content=DEFAULT_SYSTEM_TEMPLATE),
|
|
||||||
ChatMessage(role=MessageRole.USER, content=DEFAULT_USER_TEMPLATE),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# 初始化聊天引擎和评估器
|
|
||||||
chat_engine = get_chat_engine()
|
|
||||||
corr_evaluator_qwen = CorrectnessEvaluator()
|
|
||||||
|
|
||||||
# 加载本地问题回答文件
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
file_path = os.path.join(script_dir, 'questions_and_answers.json')
|
|
||||||
output_file_path = file_path.replace('.json', '_test.json')
|
|
||||||
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
# 异步函数用于评估查询
|
|
||||||
async def evaluate_query(question, answer, index, output_file):
|
|
||||||
response = await chat_engine.astream_chat(question)
|
|
||||||
|
|
||||||
# 检查sources是否为空
|
|
||||||
if response.sources:
|
|
||||||
content_str = str(response.sources[0])
|
|
||||||
else:
|
|
||||||
content_str = "<无回答>"
|
|
||||||
|
|
||||||
result = corr_evaluator_qwen.evaluate(
|
|
||||||
query=question,
|
|
||||||
response=content_str,
|
|
||||||
reference=answer,
|
|
||||||
)
|
|
||||||
|
|
||||||
result_dict = {
|
|
||||||
"编号": index,
|
|
||||||
"问题": question,
|
|
||||||
"答案": answer,
|
|
||||||
"回答": result.response,
|
|
||||||
"得分(1~5)": result.score,
|
|
||||||
"评价": result.feedback
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(output_file, 'a', encoding='utf-8') as f:
|
|
||||||
f.write(json.dumps(result_dict, ensure_ascii=False, indent=4))
|
|
||||||
f.write(',\n')
|
|
||||||
|
|
||||||
# 主异步函数
|
|
||||||
async def main():
|
|
||||||
for index, item in enumerate(data, start=1):
|
|
||||||
await evaluate_query(item['question'], item['answer'], index, output_file_path)
|
|
||||||
|
|
||||||
# 运行主协程
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
Attribute_Prompt = (
|
|
||||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
|
||||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
|
||||||
"现在需要你针对数据中属性一列进行提问和回答。"
|
|
||||||
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元),工程总投资(万元)是77469835.590045万元','尖峰及施工基面土石方量,尖峰及施工基面土石方量是8377.6','截止阀的编码,截止阀的编码是F01010203',"
|
|
||||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
|
||||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
|
||||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
Amount_Prompt = (
|
|
||||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
|
||||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
|
||||||
"现在需要你针对上下文信息中的金额或者合价进行提问和回答。"
|
|
||||||
"问题和回答的示例应该是这种类型的,示例:'项目建设技术服务费的金额,项目建设技术服务费的金额是16855957065.4302','项目后评价费的费率,项目后评价费的费率是0.5','架空输电线路本体工程的金额,架空输电线路本体工程的金额是55105688268.5176','工程静态投资的金额,工程静态投资的金额是715035853336.391'"
|
|
||||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
|
||||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
|
||||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Units_Prompt = (
|
|
||||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
|
||||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
|
||||||
"现在需要你针对上下文信息来进行单位转化问题提问和回答。"
|
|
||||||
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元)结果用元表示,工程总投资(万元)是774698355900.45元','本体工程(元)结果用万元表示,本体工程(元)是5490494.261046万元'"
|
|
||||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
|
||||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
|
||||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
|
||||||
)
|
|
||||||
|
|
||||||
Name_Prompt = (
|
|
||||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
|
||||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
|
||||||
"现在需要你针对上下文信息中的重名问题进行提问和回答。"
|
|
||||||
"问题和回答的示例应该是这种类型的,示例:'专业类型为线路的杆塔工程项目划分的合价,专业类型为线路的杆塔工程项目划分的合价是220969744.905856','专业类型为线路清理的杆塔工程项目划分的合价,电缆工程的合价是0'"
|
|
||||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
|
||||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
|
||||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
All_Amount_Prompt = (
|
|
||||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
|
||||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
|
||||||
"现在需要你针对上下文信息中的总体金额进行提问和回答。"
|
|
||||||
"问题和回答的示例应该是这种类型的,示例:'架空输电线路本体工程的总体金额,架空输电线路本体工程的总体金额是7.706703','工程静态投资的总体金额,工程静态投资的总体金额是100'"
|
|
||||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
|
||||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
|
||||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
from app.observability import init_observability
|
|
||||||
from app.settings import init_settings
|
|
||||||
|
|
||||||
import nest_asyncio
|
|
||||||
nest_asyncio.apply()
|
|
||||||
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
|
||||||
from llama_index.core.evaluation import DatasetGenerator
|
|
||||||
|
|
||||||
import prompts
|
|
||||||
|
|
||||||
init_settings()
|
|
||||||
init_observability()
|
|
||||||
|
|
||||||
# 读取所有文档(即所有表格)
|
|
||||||
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
|
||||||
|
|
||||||
# 定义表格名称和索引的对应关系
|
|
||||||
table_names = {
|
|
||||||
"工程信息表": 0,
|
|
||||||
"其他费用表": 1,
|
|
||||||
"取费表": 2,
|
|
||||||
"项目划分表": 3,
|
|
||||||
"项目划分_费用预览表": 4,
|
|
||||||
"总算表": 5,
|
|
||||||
"工程量表": 6
|
|
||||||
}
|
|
||||||
|
|
||||||
# 定义中文提示词和Python代码中提示词名称的映射
|
|
||||||
prompt_mapping = {
|
|
||||||
"普通属性": "Attribute_Prompt",
|
|
||||||
"金额查询": "Amount_Prompt",
|
|
||||||
"单位换算": "Units_Prompt",
|
|
||||||
"重名项目划分": "Name_Prompt",
|
|
||||||
"总体金额查询": "All_Amount_Prompt"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 定义表格与其对应的查询类别
|
|
||||||
table_prompt_mapping = {
|
|
||||||
"工程信息表": ["普通属性", "单位换算"],
|
|
||||||
"其他费用表": ["金额查询", "单位换算"],
|
|
||||||
"取费表": ["金额查询"],
|
|
||||||
"总算表": ["金额查询", "总体金额查询"],
|
|
||||||
"工程量表": ["普通属性", "重名项目划分"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# 根据表格名称选择特定的表格
|
|
||||||
def select_document(documents, table_name):
|
|
||||||
if table_name not in table_names:
|
|
||||||
raise ValueError(f"未找到名为 '{table_name}' 的表格")
|
|
||||||
index = table_names[table_name]
|
|
||||||
return [documents[index]] # 返回一个包含所选表格的列表
|
|
||||||
|
|
||||||
# 选择提示词
|
|
||||||
def select_prompt(prompt_category):
|
|
||||||
prompt_name = prompt_mapping.get(prompt_category)
|
|
||||||
if not prompt_name:
|
|
||||||
raise ValueError(f"未找到名为 '{prompt_category}' 的提示词")
|
|
||||||
try:
|
|
||||||
return getattr(prompts, prompt_name)
|
|
||||||
except AttributeError:
|
|
||||||
raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数")
|
|
||||||
|
|
||||||
# 生成问题和答案
|
|
||||||
def generate_questions_from_document(document, quest_prompt, num_questions):
|
|
||||||
question_generator = DatasetGenerator.from_documents(
|
|
||||||
documents=document,
|
|
||||||
question_gen_query=quest_prompt,
|
|
||||||
num_questions_per_chunk=num_questions
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_questions = question_generator.generate_questions_from_nodes(num_questions)
|
|
||||||
print(eval_questions)
|
|
||||||
|
|
||||||
qa_pairs = []
|
|
||||||
for qa in eval_questions:
|
|
||||||
if ',' in qa:
|
|
||||||
question, answer = qa.split(",", 1)
|
|
||||||
qa_pairs.append({
|
|
||||||
"question": question.strip(),
|
|
||||||
"answer": answer.strip()
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
print(f"无法处理的问题和答案: {qa}")
|
|
||||||
|
|
||||||
return qa_pairs
|
|
||||||
|
|
||||||
# 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中
|
|
||||||
def main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt):
|
|
||||||
if table_names_input == "all":
|
|
||||||
selected_tables = list(table_prompt_mapping.keys())
|
|
||||||
else:
|
|
||||||
selected_tables = table_names_input.strip('[]').split(',')
|
|
||||||
|
|
||||||
all_results = {}
|
|
||||||
|
|
||||||
for table_name in selected_tables:
|
|
||||||
table_name = table_name.strip() # 去掉前后空格
|
|
||||||
document = select_document(documents, table_name)
|
|
||||||
|
|
||||||
if prompt_categories_input == "all":
|
|
||||||
selected_prompts = table_prompt_mapping[table_name]
|
|
||||||
else:
|
|
||||||
selected_prompts = prompt_categories_input.strip('[]').split(',')
|
|
||||||
selected_prompts = [p.strip() for p in selected_prompts] # 去掉前后空格
|
|
||||||
|
|
||||||
for prompt_category in selected_prompts:
|
|
||||||
if prompt_category not in table_prompt_mapping[table_name]:
|
|
||||||
print(f"跳过表格 '{table_name}' 的提示词 '{prompt_category}',因为该表中不包含该类别的信息")
|
|
||||||
continue
|
|
||||||
|
|
||||||
quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt)
|
|
||||||
qa_pairs = generate_questions_from_document(document, quest_prompt, num_questions_per_prompt)
|
|
||||||
|
|
||||||
label = f"test:{table_name}_{prompt_category}"
|
|
||||||
all_results[label] = qa_pairs
|
|
||||||
|
|
||||||
# 自动生成输出文件名
|
|
||||||
output_file = "combined_test.json"
|
|
||||||
|
|
||||||
with open(output_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(all_results, f, ensure_ascii=False, indent=4)
|
|
||||||
|
|
||||||
print(f"All questions and answers have been saved to '{output_file}'")
|
|
||||||
|
|
||||||
# 获取命令行参数
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if len(sys.argv) != 4:
|
|
||||||
print("Usage: python script.py <table_names_input> <prompt_categories_input> <num_questions_per_prompt>")
|
|
||||||
else:
|
|
||||||
table_names_input = sys.argv[1]
|
|
||||||
prompt_categories_input = sys.argv[2]
|
|
||||||
num_questions_per_prompt = int(sys.argv[3])
|
|
||||||
|
|
||||||
main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt)
|
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
import phoenix as px
|
import phoenix as px
|
||||||
|
|
||||||
|
|
||||||
|
os.environ['PHOENIX_HOST'] = "0.0.0.0"
|
||||||
|
|
||||||
session = px.launch_app(use_temp_dir=False)
|
session = px.launch_app(use_temp_dir=False)
|
||||||
|
|
||||||
import msvcrt
|
import msvcrt
|
||||||
|
|||||||
Reference in New Issue
Block a user