diff --git a/rag2_0/demo/intent_recognition_example.py b/rag2_0/demo/intent_recognition_example.py index 759245b..40ee2b1 100755 --- a/rag2_0/demo/intent_recognition_example.py +++ b/rag2_0/demo/intent_recognition_example.py @@ -106,10 +106,10 @@ class QueryRewriteProcessor: relevance_score: int = Field(description="相关性评分,0-100分") explanation: str = Field(description="解释文档是否能解决(回答)提问") - class most_relevant_document(BaseModel): - most_relevant_document: TempModel = Field(description="最相关的文档的判断结果") + class all_relevant_document(BaseModel): + most_relevant_document: list[TempModel] = Field(description="最相关的文档的判断结果") - parser = PydanticOutputParser(pydantic_object=most_relevant_document) + parser = PydanticOutputParser(pydantic_object=all_relevant_document) # 构建提示词 prompt = f"""请判断以下检索文档列表中是否与用户提问相关,能够解决用户的问题,并给出相关性评分(0-100分)。输出最相关的文档的判断结果。 @@ -119,7 +119,16 @@ class QueryRewriteProcessor: {doc_text_list} 请按照以下JSON格式返回结果: -{parser.get_format_instructions()} +json``` +{{ + "most_relevant_document":[{{ + "can_solve_problem": true, + "relevance_score": 60, + "explanation":"xxxx" + }}] +}} +``` + """ try: @@ -127,12 +136,23 @@ class QueryRewriteProcessor: llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"}) response = llm.invoke(prompt) - result = parser.parse(response.content).most_relevant_document - + result_list = parser.parse(response.content).most_relevant_document + + # 如果列表为空,返回默认的不相关结果 + if not result_list: + return { + "is_relevant": False, + "explanation": "无法解析文档相关性结果", + "relevance_score": 0.0 + } + + # 找出分数最高的文档 + max_score_doc = max(result_list, key=lambda x: x.relevance_score) + return { - "is_relevant": result.can_solve_problem, - "relevance_score": result.relevance_score, - "explanation": result.explanation + "is_relevant": max_score_doc.can_solve_problem, + "relevance_score": max_score_doc.relevance_score, + "explanation": max_score_doc.explanation } except Exception as e: logging.error(f"判断文档相关性时出错: {str(e)}", exc_info=True) @@ -311,7 +331,7 @@ class QueryRewriteProcessor: logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}") - def process_batch(self, questions: List[str], max_workers: int = 2, enable_retrieval: bool = False, output_file: str = None): + def process_batch(self, questions: List[str], max_workers: int = 2, enable_retrieval: bool = False, output_file: str = None, save_interval: int = 100): """ 批量处理多个问题 @@ -320,12 +340,14 @@ class QueryRewriteProcessor: max_workers: 并发处理的最大线程数,默认为2 enable_retrieval: 是否启用文档检索功能,默认为False output_file: 输出文件路径,如果为None则不保存结果 + save_interval: 临时保存的间隔,每处理这么多问题就临时保存一次结果,默认为100 Returns: 处理结果列表 """ logging.info(f"共有 {len(questions)} 个问题需要处理,使用 {max_workers} 个并发线程") logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}") + logging.info(f"每处理 {save_interval} 个问题将临时保存一次结果") # 创建一个与输入顺序相同的结果列表 results = [None] * len(questions) @@ -340,14 +362,26 @@ class QueryRewriteProcessor: future = executor.submit(self.process_query, query, enable_retrieval=enable_retrieval) future_to_index[future] = idx + # 用于跟踪已完成的问题数量 + completed_count = 0 + # 使用tqdm显示进度条 for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(future_to_index), desc="处理进度"): idx = future_to_index[future] result = future.result() # 将结果放在与输入相同的位置 results[idx] = result + + # 增加已完成的问题计数 + completed_count += 1 + + # 检查是否需要临时保存 + if output_file and completed_count % save_interval == 0: + # 临时保存当前结果 + self.save_results_to_excel(results, output_file, is_final=False) + logging.info(f"已临时保存 {completed_count} 个问题的处理结果") - # 如果提供了输出文件路径,则保存结果 + # 如果提供了输出文件路径,则保存最终结果 if output_file: self.save_results_to_excel(results, output_file, is_final=True) @@ -407,14 +441,14 @@ def main(): results = processor.process_batch(questions=examples, max_workers=args.max_workers, enable_retrieval=enable_retrieval, output_file=output_file) logging.info(f"所有处理完成,最终结果已保存至: {output_file}") else: - logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}") + logging.info(f"文档检索功能状态: 已启用") for idx, query in enumerate(examples): if query.strip() == "": continue # 在调试模式下使用完整的参数 print(json.dumps(processor.process_query( query, - enable_retrieval=enable_retrieval + enable_retrieval=True ), ensure_ascii=False, indent=2)) def setup_logging(): @@ -431,5 +465,4 @@ def setup_logging(): if __name__ == "__main__": setup_logging() - logging.info("意图识别示例程序开始运行...") main() \ No newline at end of file