增加了问题生成脚本
This commit is contained in:
@@ -1,5 +1,20 @@
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from llama_index.core.evaluation import CorrectnessEvaluator
|
||||
from app.engine import get_chat_engine
|
||||
from app.engine.index import get_index
|
||||
from app.observability import init_observability
|
||||
from app.settings import init_settings
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
index = get_index()
|
||||
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
@@ -55,31 +70,14 @@ DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
|
||||
]
|
||||
)
|
||||
|
||||
from app.api.routers.models import ChatData, Message
|
||||
from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore
|
||||
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
|
||||
from llama_index.core.evaluation import CorrectnessEvaluator
|
||||
from app.engine import get_chat_engine
|
||||
from app.api.routers.chat import generate_filters
|
||||
from app.engine.index import get_index
|
||||
from app.observability import init_observability
|
||||
from app.settings import init_settings
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
index = get_index()
|
||||
|
||||
# 初始化聊天引擎和评估器
|
||||
chat_engine = get_chat_engine()
|
||||
corr_evaluator_qwen = CorrectnessEvaluator()
|
||||
|
||||
# 加载本地问题回答文件
|
||||
file_path = 'D:/LLM_model/text2sql/zjdataai-app-test/backend/unit_test/test.json'
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(script_dir, 'questions_and_answers.json')
|
||||
output_file_path = file_path.replace('.json', '_test.json')
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
@@ -88,8 +86,13 @@ with open(file_path, 'r', encoding='utf-8') as f:
|
||||
# 异步函数用于评估查询
|
||||
async def evaluate_query(question, answer, index, output_file):
|
||||
response = await chat_engine.astream_chat(question)
|
||||
content_str = str(response.sources[0])
|
||||
|
||||
|
||||
# 检查sources是否为空
|
||||
if response.sources:
|
||||
content_str = str(response.sources[0])
|
||||
else:
|
||||
content_str = "<无回答>"
|
||||
|
||||
result = corr_evaluator_qwen.evaluate(
|
||||
query=question,
|
||||
response=content_str,
|
||||
@@ -101,13 +104,13 @@ async def evaluate_query(question, answer, index, output_file):
|
||||
"问题": question,
|
||||
"答案": answer,
|
||||
"回答": result.response,
|
||||
"得分(0~5)": result.score,
|
||||
"得分(1~5)": result.score,
|
||||
"评价": result.feedback
|
||||
}
|
||||
|
||||
with open(output_file, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(result_dict, ensure_ascii=False, indent=4))
|
||||
f.write(',')
|
||||
f.write(',\n')
|
||||
|
||||
# 主异步函数
|
||||
async def main():
|
||||
|
||||
Reference in New Issue
Block a user