14 Commits

Author SHA1 Message Date
chentianrui e634746a52 修改了单元测试的问题生成代码 2024-09-06 18:43:57 +08:00
chentianrui d12800e14e 修改了单元测试的问题生成代码 2024-09-06 18:40:54 +08:00
chentianrui c1df0d1bba Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-09-05 17:03:29 +08:00
chentianrui 0664952ecd 增加了问题生成脚本 2024-09-05 17:02:42 +08:00
ly 7023b54246 解决xinference内嵌模型类使用问题。由于目前xinference组件的版本和llamaindex最新版有冲突,所以未更新支持xinference的内嵌模型的版本 2024-09-05 12:12:31 +08:00
ly aee6aa3c04 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-09-05 11:36:53 +08:00
chentianrui 680e24c516 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-30 18:40:32 +08:00
chentianrui 6663ee8976 新增加了单元测试 2024-08-30 18:40:21 +08:00
ly 0a5f335981 调整NLTK数据目录和JIEBA字典位置到本项目中,避免重新安装时需要从网上下载 2024-08-30 01:20:29 +08:00
ly 2901bd9eaf 优化导入,解决初始化LLAMAINDEX过程中环境变量没起作用问题 2024-08-30 01:16:35 +08:00
ly 453b3ca55c Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-30 00:00:57 +08:00
wanyaokun 03c4eb1af1 优化ChatCallbackEvent事件代码 2024-08-29 19:52:53 +08:00
ly f0afd1a4bb Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-29 12:03:28 +08:00
ly eb572eff27 增加加载环境变量功能 2024-08-29 11:54:20 +08:00
16 changed files with 349527 additions and 199 deletions
+5
View File
@@ -1,3 +1,8 @@
JIEBA_DATA=./nltk_data
NLTK_DATA=./nltk_data
SQLITE_DATABASE_URL=sqlite:///./source.db
DATA_SOURCE_CACHE=./restapi
# The Llama Cloud API key.
# LLAMA_CLOUD_API_KEY=
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
+5
View File
@@ -1,3 +1,8 @@
JIEBA_DATA=./nltk_data
NLTK_DATA=./nltk_data
SQLITE_DATABASE_URL=sqlite:///./source.db
DATA_SOURCE_CACHE=./restapi
# The Llama Cloud API key.
# LLAMA_CLOUD_API_KEY=
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
+102 -161
View File
@@ -26,125 +26,48 @@ api_router = r = APIRouter()
v1_router = v = APIRouter()
class ChatCallbackEvent(BaseModel):
event_type: CBEventType
event_type: ChatEventType
payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_retrieval_message(self) -> dict | None:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"根据查询检索到 {len(nodes)} 源文件"
else:
msg = f"查询检索中: '{self.payload.get('query_str')}'"
def get_common_param(self)-> dict:
return {
"type": "events",
"data": {"title": msg},
}
else:
return None
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
},
'event': self.event_type.name,
'conversation_id':self.payload.get("conversation_id"),
'message_id': self.payload.get("message_id"),
'created_at': int(time.time()),
'task_id': self.payload.get("task_id")
}
def _is_output_serializable(self, output: Any) -> bool:
try:
json.dumps(output)
return True
except TypeError:
return False
def get_agent_tool_response(self) -> dict | None:
response = self.payload.get("response")
if response is not None:
sources = response.sources
for source in sources:
# Return the tool response here to include the toolCall information
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
},
}
def to_response(self):
try:
match self.event_type:
case "retrieve":
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case _:
return None
except Exception as e:
logger.error(f"转换回应时间时发生错误,原因: {e}")
return None
class DifyChatResponseEvent(BaseModel):
event: str
conversation_id: str
message_id: str
created_at: int = int(time.time())
task_id: str
def to_response(self):
return self.dict()
class Workflow_started_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'workflow_started'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
def get_WorkflowStart_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('workflow_run_id'),
"workflow_id": self.payload.get('workflow_id'),
"sequence_number": 1709,
"inputs": {
"sys.query": args['query'],
"sys.query": self.payload.get('query'),
"sys.files": [],
"sys.conversation_id": args['conversation_id'],
"sys.user_id": args['use_id']
"sys.conversation_id": self.payload.get('conversation_id'),
"sys.user_id": self.payload.get('use_id')
},
"created_at": int(time.time())
}
super().__init__(**args)
})
return params
class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'workflow_finished'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['workflow_run_id'],
"workflow_id": args['workflow_id'],
def get_WorkflowFinished_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('workflow_run_id'),
"workflow_id": self.payload.get('workflow_id'),
"sequence_number": 1709,
"status": "succeeded",
"outputs": {
"answer": args['response']
"answer": self.payload.get('response')
},
"error": '',
"elapsed_time": 36.03764106379822,
@@ -152,60 +75,44 @@ class Workflow_finished_DifyChatResponseEvent(DifyChatResponseEvent):
"total_steps": 10,
"created_by": {
"id": str(uuid.uuid4()),
"user": args['use_id']
"user": self.payload.get('use_id')
},
"created_at": int(time.time()),
"finished_at": int(time.time()),
"files": []
}
super().__init__(**args)
})
return params
class Message_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'message'
id:str
answer:str
def __init__(self,**args):
args['id'] = args['message_id']
super().__init__(**args)
class MessageEnd_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'message_end'
id:str
metadata:Dict[str,Any] = {}
def __init__(self,**args):
args['id'] = args['message_id']
super().__init__(**args)
class Node_started_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'node_started'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['nodeid'],
"node_id": args['nodeid'],
def get_NodeStart_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": args['title'],
"index": args['index'],
"predecessor_node_id": args['predecessor_node_id'],
"title": self.payload.get('title'),
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
"created_at": 1724398751,
"extras": {}
}
super().__init__(**args)
})
return params
class Node_finished_DifyChatResponseEvent(DifyChatResponseEvent):
event: str = 'node_finished'
workflow_run_id:str
data:Dict[str,Any]
def __init__(self,**args):
args['data'] = {
"id": args['nodeid'],
"node_id": args['nodeid'],
def get_NodeFinished_param(self) -> dict:
params = self.get_common_param()
params.update({
'workflow_run_id':self.payload.get('workflow_run_id'),
'data':{
"id": self.payload.get('nodeid'),
"node_id": self.payload.get('nodeid'),
"node_type": "http-request",
"title": args['title'],
"index": args['index'],
"predecessor_node_id": args['predecessor_node_id'],
"title": self.payload.get('title'),
"index": self.payload.get('index'),
"predecessor_node_id": self.payload.get('predecessor_node_id'),
"inputs": '',
"process_data": '',
"outputs": '',
@@ -217,7 +124,45 @@ class Node_finished_DifyChatResponseEvent(DifyChatResponseEvent):
"finished_at": 1724398751,
"files": []
}
super().__init__(**args)
})
return params
def get_Message_param(self) -> dict:
params = self.get_common_param()
params.update({
'id':self.payload.get('message_id'),
'answer':self.payload.get('answer')
})
return params
def get_MessageEnd_param(self) -> dict:
params = self.get_common_param()
params.update({
'id':self.payload.get('message_id'),
'metadata':self.payload.get('metadata')
})
return params
def to_response(self)-> dict|None:
try:
match self.event_type:
case "workflow_started":
return self.get_WorkflowStart_param()
case "workflow_finished":
return self.get_WorkflowFinished_param()
case "node_started":
return self.get_NodeStart_param()
case 'node_finished':
return self.get_NodeFinished_param()
case 'message':
return self.get_Message_param()
case 'message_end':
return self.get_MessageEnd_param()
case _:
return None
except Exception as e:
logger.error(f"转换回应时间时发生错误,原因: {e}")
return None
class ChatEventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
@@ -239,9 +184,8 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
self._nodeStack:deque = deque()
#添加工作流开始事件
ids:Dict[str,Any] = self._params['ids']
data:ChatRequestData = self._params['data']
args = ids
args:Dict[str,Any] = self._params['ids']
args.update(
{
'use_id': data.user,
@@ -249,7 +193,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
'conversation_id': data.conversation_id
}
)
wf_event = Workflow_started_DifyChatResponseEvent(**args)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
@@ -264,9 +208,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
self._nodeStack.append(event_id)
nindex = self._nodeStack.count() - 1
ids:Dict[str,Any] = self._params['ids']
args = ids
args:Dict[str,Any] = self._params['ids']
args.update(
{
'nodeid':event_id,
@@ -275,7 +217,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
}
)
nd_event = Node_started_DifyChatResponseEvent(**args)
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args)
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
@@ -302,7 +244,7 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
}
)
nd_event = Node_finished_DifyChatResponseEvent(**args)
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args)
if nd_event.to_response() is not None:
self._aqueue.put_nowait(nd_event)
self._nodeStack.pop()
@@ -319,22 +261,21 @@ class ChatEventCallbackHandler(BaseCallbackHandler):
) -> None:
"""No-op."""
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
ids:Dict[str,Any] = self._params['ids']
data:ChatRequestData = self._params['data']
args = ids
args:Dict[str,Any] = self._params['ids']
args.update(
{
'response':self._response,
'conversation_id': data.conversation_id
}
)
wf_event = Workflow_finished_DifyChatResponseEvent(**args)
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
if wf_event.to_response() is not None:
self._aqueue.put_nowait(wf_event)
args = ids
msgEnt_event = MessageEnd_DifyChatResponseEvent(**args)
args:Dict[str,Any] = self._params['ids']
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
if msgEnt_event.to_response() is not None:
self._aqueue.put_nowait(msgEnt_event)
@@ -367,8 +308,8 @@ class ChatStreamResponse(StreamingResponse):
'answer':token,
'conversation_id':cls.data.conversation_id
})
event = Message_DifyChatResponseEvent(**params)
data_str = json.dumps(event.dict())
event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params)
data_str = json.dumps(event.to_response())
return f"{cls.DATA_PREFIX}{data_str}\n\n"
@classmethod
@@ -1,5 +1,7 @@
from pydantic import BaseModel
import os
from enum import Enum
class BaseConfig(BaseModel):
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
@@ -69,3 +71,10 @@ class BaseConfig(BaseModel):
}
class ChatEventType(str, Enum):
WORKFLOW_START = "workflow_started"
WORKFLOW_FINISHED = "workflow_finished"
NODE_START = "node_started"
NODE_FINISHED = "node_finished"
MESSAGE = "message"
MESSAGE_END = "message_end"
+5 -1
View File
@@ -1,3 +1,4 @@
import os
from typing import Any, Dict, List, Union, Callable, NamedTuple
from bm25s.tokenization import *
@@ -8,9 +9,12 @@ except ImportError:
def tqdm(iterable, *args, **kwargs):
return iterable
import jieba
jiebapath = os.environ.get("JIEBA_DATA", "")
jieba.set_dictionary(os.path.join(jiebapath, 'dict.txt')) #设置字典
jieba.initialize() #初始化jeiba
def chinese_tokenizer(text: str) -> List[str]:
import jieba
from nltk.corpus import stopwords
tokens = jieba.lcut(text)
return [token for token in tokens if token not in stopwords.words('chinese')]
+1 -2
View File
@@ -3,11 +3,10 @@ from typing import Dict
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.settings import Settings
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
from llama_index.llms.xinference import Xinference
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED").title()
-2
View File
@@ -1,7 +1,5 @@
from dotenv import load_dotenv
from llama_index.core.node_parser import SentenceSplitter
load_dotenv()
import logging
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large Load Diff
Binary file not shown.
+121
View File
@@ -0,0 +1,121 @@
from dotenv import load_dotenv
load_dotenv()
from llama_index.core.evaluation import CorrectnessEvaluator
from app.engine import get_chat_engine
from app.engine.index import get_index
from app.observability import init_observability
from app.settings import init_settings
init_settings()
init_observability()
index = get_index()
import os
import json
import asyncio
import nest_asyncio
nest_asyncio.apply()
from llama_index.core.prompts import (
ChatMessage,
ChatPromptTemplate,
MessageRole
)
DEFAULT_SYSTEM_TEMPLATE = """
您是一个问答聊天机器人的专业评估系统。
您将获得以下信息:
- 用户查询,
- 生成的回答,
也可能提供一个参考答案作为评估的依据。
您的任务是判断生成回答的相关性和正确性。
输出一个代表全面评估的单一分数。
您必须在一行中仅返回该分数。
不要以其他任何格式返回答案。
在单独的一行提供给定分数的理由。
请遵循以下评分指南:
- 您的分数必须在1到5之间,其中1是最差,5是最好的。
-如果生成的回答与用户查询不相关,您应该给出1分。
-如果生成的回答相关但包含错误,您应该给出2到3分之间的分数。
-如果生成的回答相关且完全正确,您应该给出4到5分之间的分数。
示例响应:
4.0
生成的回答与参考答案的指标完全相同,但不够精炼。
"""
DEFAULT_USER_TEMPLATE = """
## User Query
{query}
## Reference Answer
{reference_answer}
## Generated Answer
{generated_answer}
"""
DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
message_templates=[
ChatMessage(role=MessageRole.SYSTEM, content=DEFAULT_SYSTEM_TEMPLATE),
ChatMessage(role=MessageRole.USER, content=DEFAULT_USER_TEMPLATE),
]
)
# 初始化聊天引擎和评估器
chat_engine = get_chat_engine()
corr_evaluator_qwen = CorrectnessEvaluator()
# 加载本地问题回答文件
script_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(script_dir, 'questions_and_answers.json')
output_file_path = file_path.replace('.json', '_test.json')
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 异步函数用于评估查询
async def evaluate_query(question, answer, index, output_file):
response = await chat_engine.astream_chat(question)
# 检查sources是否为空
if response.sources:
content_str = str(response.sources[0])
else:
content_str = "<无回答>"
result = corr_evaluator_qwen.evaluate(
query=question,
response=content_str,
reference=answer,
)
result_dict = {
"编号": index,
"问题": question,
"答案": answer,
"回答": result.response,
"得分(1~5)": result.score,
"评价": result.feedback
}
with open(output_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(result_dict, ensure_ascii=False, indent=4))
f.write(',\n')
# 主异步函数
async def main():
for index, item in enumerate(data, start=1):
await evaluate_query(item['question'], item['answer'], index, output_file_path)
# 运行主协程
asyncio.run(main())
+55
View File
@@ -0,0 +1,55 @@
Attribute_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对数据中属性一列进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元),工程总投资(万元)是77469835.590045万元','尖峰及施工基面土石方量,尖峰及施工基面土石方量是8377.6','截止阀的编码,截止阀的编码是F01010203',"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
Amount_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息中的金额或者合价进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'项目建设技术服务费的金额,项目建设技术服务费的金额是16855957065.4302','项目后评价费的费率,项目后评价费的费率是0.5','架空输电线路本体工程的金额,架空输电线路本体工程的金额是55105688268.5176','工程静态投资的金额,工程静态投资的金额是715035853336.391'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
Units_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息来进行单位转化问题提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元)结果用元表示,工程总投资(万元)是774698355900.45元','本体工程(元)结果用万元表示,本体工程(元)是5490494.261046万元'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
Name_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息中的重名问题进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'专业类型为线路的杆塔工程项目划分的合价,专业类型为线路的杆塔工程项目划分的合价是220969744.905856','专业类型为线路清理的杆塔工程项目划分的合价,电缆工程的合价是0'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
All_Amount_Prompt = (
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
"现在需要你针对上下文信息中的总体金额进行提问和回答。"
"问题和回答的示例应该是这种类型的,示例:'架空输电线路本体工程的总体金额,架空输电线路本体工程的总体金额是7.706703','工程静态投资的总体金额,工程静态投资的总体金额是100'"
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
)
+144
View File
@@ -0,0 +1,144 @@
from dotenv import load_dotenv
load_dotenv()
import json
import sys
from app.observability import init_observability
from app.settings import init_settings
import nest_asyncio
nest_asyncio.apply()
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import SimpleDirectoryReader
from llama_index.core.evaluation import DatasetGenerator
import prompts
init_settings()
init_observability()
# 读取所有文档(即所有表格)
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
# 定义表格名称和索引的对应关系
table_names = {
"工程信息表": 0,
"其他费用表": 1,
"取费表": 2,
"项目划分表": 3,
"项目划分_费用预览表": 4,
"总算表": 5,
"工程量表": 6
}
# 定义中文提示词和Python代码中提示词名称的映射
prompt_mapping = {
"普通属性": "Attribute_Prompt",
"金额查询": "Amount_Prompt",
"单位换算": "Units_Prompt",
"重名项目划分": "Name_Prompt",
"总体金额查询": "All_Amount_Prompt"
}
# 定义表格与其对应的查询类别
table_prompt_mapping = {
"工程信息表": ["普通属性", "单位换算"],
"其他费用表": ["金额查询", "单位换算"],
"取费表": ["金额查询"],
"总算表": ["金额查询", "总体金额查询"],
"工程量表": ["普通属性", "重名项目划分"]
}
# 根据表格名称选择特定的表格
def select_document(documents, table_name):
if table_name not in table_names:
raise ValueError(f"未找到名为 '{table_name}' 的表格")
index = table_names[table_name]
return [documents[index]] # 返回一个包含所选表格的列表
# 选择提示词
def select_prompt(prompt_category):
prompt_name = prompt_mapping.get(prompt_category)
if not prompt_name:
raise ValueError(f"未找到名为 '{prompt_category}' 的提示词")
try:
return getattr(prompts, prompt_name)
except AttributeError:
raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数")
# 生成问题和答案
def generate_questions_from_document(document, quest_prompt, num_questions):
question_generator = DatasetGenerator.from_documents(
documents=document,
question_gen_query=quest_prompt,
num_questions_per_chunk=num_questions
)
eval_questions = question_generator.generate_questions_from_nodes(num_questions)
print(eval_questions)
qa_pairs = []
for qa in eval_questions:
if ',' in qa:
question, answer = qa.split(",", 1)
qa_pairs.append({
"question": question.strip(),
"answer": answer.strip()
})
else:
print(f"无法处理的问题和答案: {qa}")
return qa_pairs
# 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中
def main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt):
if table_names_input == "all":
selected_tables = list(table_prompt_mapping.keys())
else:
selected_tables = table_names_input.strip('[]').split(',')
all_results = {}
for table_name in selected_tables:
table_name = table_name.strip() # 去掉前后空格
document = select_document(documents, table_name)
if prompt_categories_input == "all":
selected_prompts = table_prompt_mapping[table_name]
else:
selected_prompts = prompt_categories_input.strip('[]').split(',')
selected_prompts = [p.strip() for p in selected_prompts] # 去掉前后空格
for prompt_category in selected_prompts:
if prompt_category not in table_prompt_mapping[table_name]:
print(f"跳过表格 '{table_name}' 的提示词 '{prompt_category}',因为该表中不包含该类别的信息")
continue
quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt)
qa_pairs = generate_questions_from_document(document, quest_prompt, num_questions_per_prompt)
label = f"test:{table_name}_{prompt_category}"
all_results[label] = qa_pairs
# 自动生成输出文件名
output_file = "combined_test.json"
with open(output_file, "w", encoding="utf-8") as f:
json.dump(all_results, f, ensure_ascii=False, indent=4)
print(f"All questions and answers have been saved to '{output_file}'")
# 获取命令行参数
if __name__ == "__main__":
if len(sys.argv) != 4:
print("Usage: python script.py <table_names_input> <prompt_categories_input> <num_questions_per_prompt>")
else:
table_names_input = sys.argv[1]
prompt_categories_input = sys.argv[2]
num_questions_per_prompt = int(sys.argv[3])
main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt)
+3 -2
View File
@@ -1,9 +1,10 @@
import os
from dotenv import load_dotenv
load_dotenv()
import phoenix as px
os.environ['PHOENIX_HOST'] = "0.0.0.0"
session = px.launch_app(use_temp_dir=False)
import msvcrt