更新意图识别示例,修改文档相关性判断逻辑以支持多个相关文档,增强结果处理和临时保存功能

This commit is contained in:
2025-06-25 10:07:46 +08:00
parent 33bc91f0fe
commit f9174fdbc9
+47 -14
View File
@@ -106,10 +106,10 @@ class QueryRewriteProcessor:
relevance_score: int = Field(description="相关性评分,0-100分")
explanation: str = Field(description="解释文档是否能解决(回答)提问")
class most_relevant_document(BaseModel):
most_relevant_document: TempModel = Field(description="最相关的文档的判断结果")
class all_relevant_document(BaseModel):
most_relevant_document: list[TempModel] = Field(description="最相关的文档的判断结果")
parser = PydanticOutputParser(pydantic_object=most_relevant_document)
parser = PydanticOutputParser(pydantic_object=all_relevant_document)
# 构建提示词
prompt = f"""请判断以下检索文档列表中是否与用户提问相关,能够解决用户的问题,并给出相关性评分(0-100分)。输出最相关的文档的判断结果。
@@ -119,7 +119,16 @@ class QueryRewriteProcessor:
{doc_text_list}
请按照以下JSON格式返回结果:
{parser.get_format_instructions()}
json```
{{
"most_relevant_document":[{{
"can_solve_problem": true,
"relevance_score": 60,
"explanation":"xxxx"
}}]
}}
```
"""
try:
@@ -127,12 +136,23 @@ class QueryRewriteProcessor:
llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"})
response = llm.invoke(prompt)
result = parser.parse(response.content).most_relevant_document
result_list = parser.parse(response.content).most_relevant_document
# 如果列表为空,返回默认的不相关结果
if not result_list:
return {
"is_relevant": False,
"explanation": "无法解析文档相关性结果",
"relevance_score": 0.0
}
# 找出分数最高的文档
max_score_doc = max(result_list, key=lambda x: x.relevance_score)
return {
"is_relevant": result.can_solve_problem,
"relevance_score": result.relevance_score,
"explanation": result.explanation
"is_relevant": max_score_doc.can_solve_problem,
"relevance_score": max_score_doc.relevance_score,
"explanation": max_score_doc.explanation
}
except Exception as e:
logging.error(f"判断文档相关性时出错: {str(e)}", exc_info=True)
@@ -311,7 +331,7 @@ class QueryRewriteProcessor:
logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}")
def process_batch(self, questions: List[str], max_workers: int = 2, enable_retrieval: bool = False, output_file: str = None):
def process_batch(self, questions: List[str], max_workers: int = 2, enable_retrieval: bool = False, output_file: str = None, save_interval: int = 100):
"""
批量处理多个问题
@@ -320,12 +340,14 @@ class QueryRewriteProcessor:
max_workers: 并发处理的最大线程数,默认为2
enable_retrieval: 是否启用文档检索功能,默认为False
output_file: 输出文件路径,如果为None则不保存结果
save_interval: 临时保存的间隔,每处理这么多问题就临时保存一次结果,默认为100
Returns:
处理结果列表
"""
logging.info(f"共有 {len(questions)} 个问题需要处理,使用 {max_workers} 个并发线程")
logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}")
logging.info(f"每处理 {save_interval} 个问题将临时保存一次结果")
# 创建一个与输入顺序相同的结果列表
results = [None] * len(questions)
@@ -340,14 +362,26 @@ class QueryRewriteProcessor:
future = executor.submit(self.process_query, query, enable_retrieval=enable_retrieval)
future_to_index[future] = idx
# 用于跟踪已完成的问题数量
completed_count = 0
# 使用tqdm显示进度条
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(future_to_index), desc="处理进度"):
idx = future_to_index[future]
result = future.result()
# 将结果放在与输入相同的位置
results[idx] = result
# 增加已完成的问题计数
completed_count += 1
# 检查是否需要临时保存
if output_file and completed_count % save_interval == 0:
# 临时保存当前结果
self.save_results_to_excel(results, output_file, is_final=False)
logging.info(f"已临时保存 {completed_count} 个问题的处理结果")
# 如果提供了输出文件路径,则保存结果
# 如果提供了输出文件路径,则保存最终结果
if output_file:
self.save_results_to_excel(results, output_file, is_final=True)
@@ -407,14 +441,14 @@ def main():
results = processor.process_batch(questions=examples, max_workers=args.max_workers, enable_retrieval=enable_retrieval, output_file=output_file)
logging.info(f"所有处理完成,最终结果已保存至: {output_file}")
else:
logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}")
logging.info(f"文档检索功能状态: 已启用")
for idx, query in enumerate(examples):
if query.strip() == "":
continue
# 在调试模式下使用完整的参数
print(json.dumps(processor.process_query(
query,
enable_retrieval=enable_retrieval
enable_retrieval=True
), ensure_ascii=False, indent=2))
def setup_logging():
@@ -431,5 +465,4 @@ def setup_logging():
if __name__ == "__main__":
setup_logging()
logging.info("意图识别示例程序开始运行...")
main()