118 lines
3.2 KiB
Python
118 lines
3.2 KiB
Python
import json
|
|
from dotenv import load_dotenv
|
|
import asyncio
|
|
import nest_asyncio
|
|
nest_asyncio.apply()
|
|
from llama_index.core.prompts import (
|
|
ChatMessage,
|
|
ChatPromptTemplate,
|
|
MessageRole
|
|
)
|
|
|
|
DEFAULT_SYSTEM_TEMPLATE = """
|
|
您是一个问答聊天机器人的专业评估系统。
|
|
|
|
您将获得以下信息:
|
|
|
|
- 用户查询,
|
|
- 生成的回答,
|
|
|
|
也可能提供一个参考答案作为评估的依据。
|
|
|
|
您的任务是判断生成回答的相关性和正确性。
|
|
输出一个代表全面评估的单一分数。
|
|
您必须在一行中仅返回该分数。
|
|
不要以其他任何格式返回答案。
|
|
在单独的一行提供给定分数的理由。
|
|
|
|
请遵循以下评分指南:
|
|
|
|
- 您的分数必须在1到5之间,其中1是最差,5是最好的。
|
|
-如果生成的回答与用户查询不相关,您应该给出1分。
|
|
-如果生成的回答相关但包含错误,您应该给出2到3分之间的分数。
|
|
-如果生成的回答相关且完全正确,您应该给出4到5分之间的分数。
|
|
示例响应:
|
|
4.0
|
|
生成的回答与参考答案的指标完全相同,但不够精炼。
|
|
|
|
"""
|
|
|
|
DEFAULT_USER_TEMPLATE = """
|
|
## User Query
|
|
{query}
|
|
|
|
## Reference Answer
|
|
{reference_answer}
|
|
|
|
## Generated Answer
|
|
{generated_answer}
|
|
"""
|
|
|
|
DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
|
|
message_templates=[
|
|
ChatMessage(role=MessageRole.SYSTEM, content=DEFAULT_SYSTEM_TEMPLATE),
|
|
ChatMessage(role=MessageRole.USER, content=DEFAULT_USER_TEMPLATE),
|
|
]
|
|
)
|
|
|
|
from app.api.routers.models import ChatData, Message
|
|
from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore
|
|
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
|
|
from llama_index.core.evaluation import CorrectnessEvaluator
|
|
from app.engine import get_chat_engine
|
|
from app.api.routers.chat import generate_filters
|
|
from app.engine.index import get_index
|
|
from app.observability import init_observability
|
|
from app.settings import init_settings
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
init_settings()
|
|
init_observability()
|
|
|
|
index = get_index()
|
|
|
|
# 初始化聊天引擎和评估器
|
|
chat_engine = get_chat_engine()
|
|
corr_evaluator_qwen = CorrectnessEvaluator()
|
|
|
|
# 加载本地问题回答文件
|
|
file_path = 'D:/LLM_model/text2sql/zjdataai-app-test/backend/unit_test/test.json'
|
|
output_file_path = file_path.replace('.json', '_test.json')
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
# 异步函数用于评估查询
|
|
async def evaluate_query(question, answer, index, output_file):
|
|
response = await chat_engine.astream_chat(question)
|
|
content_str = str(response.sources[0])
|
|
|
|
result = corr_evaluator_qwen.evaluate(
|
|
query=question,
|
|
response=content_str,
|
|
reference=answer,
|
|
)
|
|
|
|
result_dict = {
|
|
"编号": index,
|
|
"问题": question,
|
|
"答案": answer,
|
|
"回答": result.response,
|
|
"得分(0~5)": result.score,
|
|
"评价": result.feedback
|
|
}
|
|
|
|
with open(output_file, 'a', encoding='utf-8') as f:
|
|
f.write(json.dumps(result_dict, ensure_ascii=False, indent=4))
|
|
f.write(',')
|
|
|
|
# 主异步函数
|
|
async def main():
|
|
for index, item in enumerate(data, start=1):
|
|
await evaluate_query(item['question'], item['answer'], index, output_file_path)
|
|
|
|
# 运行主协程
|
|
asyncio.run(main()) |