更新意图识别示例,修改文档相关性判断逻辑以支持多个相关文档,增强结果处理和临时保存功能
This commit is contained in:
@@ -106,10 +106,10 @@ class QueryRewriteProcessor:
|
||||
relevance_score: int = Field(description="相关性评分,0-100分")
|
||||
explanation: str = Field(description="解释文档是否能解决(回答)提问")
|
||||
|
||||
class most_relevant_document(BaseModel):
|
||||
most_relevant_document: TempModel = Field(description="最相关的文档的判断结果")
|
||||
class all_relevant_document(BaseModel):
|
||||
most_relevant_document: list[TempModel] = Field(description="最相关的文档的判断结果")
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=most_relevant_document)
|
||||
parser = PydanticOutputParser(pydantic_object=all_relevant_document)
|
||||
# 构建提示词
|
||||
prompt = f"""请判断以下检索文档列表中是否与用户提问相关,能够解决用户的问题,并给出相关性评分(0-100分)。输出最相关的文档的判断结果。
|
||||
|
||||
@@ -119,7 +119,16 @@ class QueryRewriteProcessor:
|
||||
{doc_text_list}
|
||||
|
||||
请按照以下JSON格式返回结果:
|
||||
{parser.get_format_instructions()}
|
||||
json```
|
||||
{{
|
||||
"most_relevant_document":[{{
|
||||
"can_solve_problem": true,
|
||||
"relevance_score": 60,
|
||||
"explanation":"xxxx"
|
||||
}}]
|
||||
}}
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -127,12 +136,23 @@ class QueryRewriteProcessor:
|
||||
llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"})
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
result = parser.parse(response.content).most_relevant_document
|
||||
result_list = parser.parse(response.content).most_relevant_document
|
||||
|
||||
# 如果列表为空,返回默认的不相关结果
|
||||
if not result_list:
|
||||
return {
|
||||
"is_relevant": False,
|
||||
"explanation": "无法解析文档相关性结果",
|
||||
"relevance_score": 0.0
|
||||
}
|
||||
|
||||
# 找出分数最高的文档
|
||||
max_score_doc = max(result_list, key=lambda x: x.relevance_score)
|
||||
|
||||
return {
|
||||
"is_relevant": result.can_solve_problem,
|
||||
"relevance_score": result.relevance_score,
|
||||
"explanation": result.explanation
|
||||
"is_relevant": max_score_doc.can_solve_problem,
|
||||
"relevance_score": max_score_doc.relevance_score,
|
||||
"explanation": max_score_doc.explanation
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"判断文档相关性时出错: {str(e)}", exc_info=True)
|
||||
@@ -311,7 +331,7 @@ class QueryRewriteProcessor:
|
||||
|
||||
logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}")
|
||||
|
||||
def process_batch(self, questions: List[str], max_workers: int = 2, enable_retrieval: bool = False, output_file: str = None):
|
||||
def process_batch(self, questions: List[str], max_workers: int = 2, enable_retrieval: bool = False, output_file: str = None, save_interval: int = 100):
|
||||
"""
|
||||
批量处理多个问题
|
||||
|
||||
@@ -320,12 +340,14 @@ class QueryRewriteProcessor:
|
||||
max_workers: 并发处理的最大线程数,默认为2
|
||||
enable_retrieval: 是否启用文档检索功能,默认为False
|
||||
output_file: 输出文件路径,如果为None则不保存结果
|
||||
save_interval: 临时保存的间隔,每处理这么多问题就临时保存一次结果,默认为100
|
||||
|
||||
Returns:
|
||||
处理结果列表
|
||||
"""
|
||||
logging.info(f"共有 {len(questions)} 个问题需要处理,使用 {max_workers} 个并发线程")
|
||||
logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}")
|
||||
logging.info(f"每处理 {save_interval} 个问题将临时保存一次结果")
|
||||
|
||||
# 创建一个与输入顺序相同的结果列表
|
||||
results = [None] * len(questions)
|
||||
@@ -340,6 +362,9 @@ class QueryRewriteProcessor:
|
||||
future = executor.submit(self.process_query, query, enable_retrieval=enable_retrieval)
|
||||
future_to_index[future] = idx
|
||||
|
||||
# 用于跟踪已完成的问题数量
|
||||
completed_count = 0
|
||||
|
||||
# 使用tqdm显示进度条
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(future_to_index), desc="处理进度"):
|
||||
idx = future_to_index[future]
|
||||
@@ -347,7 +372,16 @@ class QueryRewriteProcessor:
|
||||
# 将结果放在与输入相同的位置
|
||||
results[idx] = result
|
||||
|
||||
# 如果提供了输出文件路径,则保存结果
|
||||
# 增加已完成的问题计数
|
||||
completed_count += 1
|
||||
|
||||
# 检查是否需要临时保存
|
||||
if output_file and completed_count % save_interval == 0:
|
||||
# 临时保存当前结果
|
||||
self.save_results_to_excel(results, output_file, is_final=False)
|
||||
logging.info(f"已临时保存 {completed_count} 个问题的处理结果")
|
||||
|
||||
# 如果提供了输出文件路径,则保存最终结果
|
||||
if output_file:
|
||||
self.save_results_to_excel(results, output_file, is_final=True)
|
||||
|
||||
@@ -407,14 +441,14 @@ def main():
|
||||
results = processor.process_batch(questions=examples, max_workers=args.max_workers, enable_retrieval=enable_retrieval, output_file=output_file)
|
||||
logging.info(f"所有处理完成,最终结果已保存至: {output_file}")
|
||||
else:
|
||||
logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}")
|
||||
logging.info(f"文档检索功能状态: 已启用")
|
||||
for idx, query in enumerate(examples):
|
||||
if query.strip() == "":
|
||||
continue
|
||||
# 在调试模式下使用完整的参数
|
||||
print(json.dumps(processor.process_query(
|
||||
query,
|
||||
enable_retrieval=enable_retrieval
|
||||
enable_retrieval=True
|
||||
), ensure_ascii=False, indent=2))
|
||||
|
||||
def setup_logging():
|
||||
@@ -431,5 +465,4 @@ def setup_logging():
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logging()
|
||||
logging.info("意图识别示例程序开始运行...")
|
||||
main()
|
||||
Reference in New Issue
Block a user