From 33bc91f0fe67dc070e773f07809057553061ef33 Mon Sep 17 00:00:00 2001 From: ouyangyouzhang Date: Wed, 25 Jun 2025 09:10:28 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=84=8F=E5=9B=BE=E8=AF=86?= =?UTF-8?q?=E5=88=AB=E7=A4=BA=E4=BE=8B=EF=BC=8C=E6=9B=B4=E6=96=B0=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E7=9B=B8=E5=85=B3=E6=80=A7=E5=88=A4=E6=96=AD=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E5=A2=9E=E5=BC=BAExcel=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=AA=8C=E8=AF=81=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=94=B9=E8=BF=9B?= =?UTF-8?q?=E6=97=A5=E5=BF=97=E8=AE=B0=E5=BD=95=EF=BC=8C=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E4=BB=A5=E6=8F=90=E5=8D=87=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E5=8F=AF=E8=AF=BB=E6=80=A7=E5=92=8C=E7=81=B5=E6=B4=BB=E6=80=A7?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rag2_0/demo/intent_recognition_example.py | 33 ++++++-------- rag2_0/demo/merge_nouns_with_llm.py | 2 +- rag2_0/demo/validate_excel_data_batch.py | 53 +++++++++-------------- rag2_0/tool/ModelTool.py | 2 +- 4 files changed, 35 insertions(+), 55 deletions(-) diff --git a/rag2_0/demo/intent_recognition_example.py b/rag2_0/demo/intent_recognition_example.py index 1cab6d9..759245b 100755 --- a/rag2_0/demo/intent_recognition_example.py +++ b/rag2_0/demo/intent_recognition_example.py @@ -28,7 +28,7 @@ from rag2_0.tool.ModelTool import OpenAiLLM load_dotenv() # 示例查询 -examples_query = """主网电力建设计价通软件, 35kV的软件 土质比例不能一起设置吗""" +examples_query = """ PE2211PK0801是什么软件""" conversation_context="" chat_history=[ { @@ -100,27 +100,23 @@ class QueryRewriteProcessor: "relevance_score": 0.0 } - # 构建文档内容 - doc_contents = [] - for i, doc in enumerate(retrieved_doc[:3]): # 只取前3个文档进行判断 - content = doc.get("content", "") - title = doc.get("title", "") - doc_contents.append(f"文档{i+1}标题: {title}\n文档{i+1}内容: {content}") - - doc_text = "\n\n".join(doc_contents) + doc_text_list = json.dumps(retrieved_doc, ensure_ascii=False, indent=2) class TempModel(BaseModel): - is_relevant: bool = Field(description="是否与用户提问相关") + can_solve_problem: bool = Field(description="是否能解决用户问题") relevance_score: int = Field(description="相关性评分,0-100分") - explanation: str = Field(description="解释各个文档与提问的相关性或不相关性") + explanation: str = Field(description="解释文档是否能解决(回答)提问") - parser = PydanticOutputParser(pydantic_object=TempModel) + class most_relevant_document(BaseModel): + most_relevant_document: TempModel = Field(description="最相关的文档的判断结果") + + parser = PydanticOutputParser(pydantic_object=most_relevant_document) # 构建提示词 - prompt = f"""请判断以下检索文档是否与用户提问相关,并给出相关性评分(0-100分)。 + prompt = f"""请判断以下检索文档列表中是否与用户提问相关,能够解决用户的问题,并给出相关性评分(0-100分)。输出最相关的文档的判断结果。 用户提问: {query} -检索文档: -{doc_text} +检索文档列表: +{doc_text_list} 请按照以下JSON格式返回结果: {parser.get_format_instructions()} @@ -131,10 +127,10 @@ class QueryRewriteProcessor: llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"}) response = llm.invoke(prompt) - result = parser.parse(response.content) + result = parser.parse(response.content).most_relevant_document return { - "is_relevant": result.is_relevant, + "is_relevant": result.can_solve_problem, "relevance_score": result.relevance_score, "explanation": result.explanation } @@ -418,9 +414,6 @@ def main(): # 在调试模式下使用完整的参数 print(json.dumps(processor.process_query( query, - conversation_context=conversation_context, - chat_history=chat_history, - previous_slots=previous_slots, enable_retrieval=enable_retrieval ), ensure_ascii=False, indent=2)) diff --git a/rag2_0/demo/merge_nouns_with_llm.py b/rag2_0/demo/merge_nouns_with_llm.py index 52485f8..84d9d6a 100755 --- a/rag2_0/demo/merge_nouns_with_llm.py +++ b/rag2_0/demo/merge_nouns_with_llm.py @@ -121,7 +121,7 @@ class TermMerger: else: return term_list[0] except Exception as e: - logging.error(f"处理词条 {name} 时出错: {e}") + logging.error(f"处理词条 {name} 时出错: {e}", exc_info=True) return term_list[0] def merge(self): diff --git a/rag2_0/demo/validate_excel_data_batch.py b/rag2_0/demo/validate_excel_data_batch.py index a3868ef..96860ed 100755 --- a/rag2_0/demo/validate_excel_data_batch.py +++ b/rag2_0/demo/validate_excel_data_batch.py @@ -33,7 +33,7 @@ class ValidationResult(BaseModel): class ExcelDataValidator: """Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写""" - def __init__(self, input_file=None, output_file=None, workers=4, batch_size=10, debug=False): + def __init__(self, input_file=None, output_file=None, workers=4, debug=False): """ 初始化验证器 @@ -41,7 +41,6 @@ class ExcelDataValidator: input_file: 输入Excel文件路径 output_file: 输出结果Excel文件路径 workers: 并行工作线程数 - batch_size: 每批处理的行数 debug: 是否启用调试模式(串行处理) """ # 加载环境变量 @@ -50,7 +49,6 @@ class ExcelDataValidator: self.input_file = input_file self.output_file = output_file self.workers = workers - self.batch_size = batch_size self.debug = debug self.df = None @@ -86,7 +84,7 @@ class ExcelDataValidator: try: df = pd.read_excel(file_path) - required_columns = ["问题", "问题分类", "问题改写", "槽点信息"] + required_columns = ["问题", "问题分类", "问题改写", "槽位信息", "检索的内容"] for col in required_columns: if col not in df.columns: logging.error(f"缺少必要的列: {col}", exc_info=True) @@ -320,7 +318,7 @@ class ExcelDataValidator: query = row["问题"] query_class = row.get("问题分类", "") rewrite = row.get("问题改写", "") - slot_info = row.get("槽点信息", "") + slot_info = row.get("槽位信息", "") retrieve_content = row.get("检索的内容", "") if self.debug: @@ -359,15 +357,16 @@ class ExcelDataValidator: if len(query_class_list) >= 2: result = self.validate_classification(llm, rewrite, query_class_list[0], query_class_list[1]) if isinstance(result, tuple) and len(result) >= 3: - is_correct, error_reason, confidence_score = result[:3] + is_correct, error_reason, classification_confidence = result[:3] + confidence_score = max(confidence_score, classification_confidence) if self.debug: - logging.info(f" 问题分类验证结果: {'通过' if is_correct else '不通过'}, 置信度: {confidence_score:.2f}") + logging.info(f" 问题分类验证结果: {'通过' if is_correct else '不通过'}, 置信度: {classification_confidence:.2f}") if not is_correct: logging.info(f" 错误原因: {error_reason}") if not is_correct: - return index, False, "问题分类", error_reason, confidence_score + return index, False, "问题分类", error_reason, classification_confidence @@ -416,13 +415,6 @@ class ExcelDataValidator: logging.error(error_msg, exc_info=True) return index, False, "处理错误", error_msg, 0.0 - def process_batch(self, llm, batch_data): - """处理一批数据""" - results = [] - for row_data in batch_data: - results.append(self.validate_row(llm, row_data)) - return results - def create_llm_instances(self, count): """创建多个LLM实例""" api_key = os.getenv("OPENAI_API_KEY") @@ -437,7 +429,7 @@ class ExcelDataValidator: return [OpenAiLLM(**llm_params) for _ in range(count)] - def validate(self, input_file=None, output_file=None, workers=None, batch_size=None, debug=None): + def validate(self, input_file=None, output_file=None, workers=None, debug=None): """ 执行验证过程 @@ -445,7 +437,7 @@ class ExcelDataValidator: input_file: 输入Excel文件路径 output_file: 输出结果Excel文件路径 workers: 并行工作线程数 - batch_size: 每批处理的行数 + batch_size: 每批处理的行数(已弃用,保留参数保持兼容) debug: 是否启用调试模式(串行处理) Returns: @@ -454,7 +446,6 @@ class ExcelDataValidator: input_file = input_file or self.input_file output_file = output_file or self.output_file workers = workers or self.workers - batch_size = batch_size or self.batch_size debug = debug if debug is not None else self.debug # 读取数据 @@ -492,21 +483,20 @@ class ExcelDataValidator: # 输出当前结果 logging.info(f"行 {index} 验证结果: {'通过' if is_correct else '不通过'}, 错误环节: {error_phase}, 错误原因: {error_reason}, 置信度: {confidence_score:.2f}") else: - # 正常模式:并行处理 - batches = [all_rows[i:i+batch_size] for i in range(0, len(all_rows), batch_size)] - llm_instances = self.create_llm_instances(min(workers, len(batches))) + # 正常模式:并行处理,每行单独处理 + llm_instances = self.create_llm_instances(min(workers, len(all_rows))) with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: - # 为每个批次分配一个LLM实例 - future_to_batch = { - executor.submit(self.process_batch, llm_instances[i % len(llm_instances)], batch): - i for i, batch in enumerate(batches) + # 为每行分配一个LLM实例 + future_to_row = { + executor.submit(self.validate_row, llm_instances[i % len(llm_instances)], row_data): + i for i, row_data in enumerate(all_rows) } # 使用tqdm显示进度条 - for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches), desc="批次处理进度"): - batch_results = future.result() - all_results.extend(batch_results) + for future in tqdm(concurrent.futures.as_completed(future_to_row), total=len(all_rows), desc="处理进度"): + result = future.result() + all_results.append(result) # 按行索引排序结果,确保与原始数据顺序一致 all_results.sort(key=lambda x: x[0]) @@ -558,16 +548,14 @@ class ExcelDataValidator: def main(): """主函数""" # 解析命令行参数 - input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "1500条点踩软件问题测试_检索结果.xlsx") + input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "1500条点踩软件问题测试_意图分类.xlsx") output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx") parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写") parser.add_argument("--input", "-i", type=str, help="输入Excel文件路径", default=input_excel) parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel) parser.add_argument("--workers", "-w", type=int, default=20, help="并行工作线程数") - parser.add_argument("--batch-size", "-b", type=int, default=5, help="每批处理的行数") - parser.add_argument("--debug", "-d", action="store_true", help="启用调试模式(串行处理)") - + logging.info(f"输入文件路径: {args.input}, 输出文件路径: {args.output}, 并行工作线程数: {args.workers}") args = parser.parse_args() is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None @@ -576,7 +564,6 @@ def main(): input_file=args.input, output_file=args.output, workers=args.workers, - batch_size=args.batch_size, debug=is_debug ) validator.validate() diff --git a/rag2_0/tool/ModelTool.py b/rag2_0/tool/ModelTool.py index 3af2004..698f434 100755 --- a/rag2_0/tool/ModelTool.py +++ b/rag2_0/tool/ModelTool.py @@ -82,7 +82,7 @@ class SiliconFlowReRankerModel: results = response.json() return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]] except requests.exceptions.RequestException as e: - logging.error(f"重排序请求失败: {str(e)}") + logging.error(f"重排序请求失败: {str(e)}", exc_info=True) return [] class XinferenceReRankerModel: