diff --git a/.vscode/settings.json b/.vscode/settings.json index 242c7c8..a475a1f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,5 @@ { "python.analysis.typeCheckingMode": "off", - "python.analysis.autoImportCompletions": true + "python.analysis.autoImportCompletions": true, + "cursorpyright.analysis.typeCheckingMode": "off" } \ No newline at end of file diff --git a/api_key.txt b/api_key.txt index 7b3d78d..55e1a73 100644 --- a/api_key.txt +++ b/api_key.txt @@ -1,5 +1,3 @@ -sk-qmrkfvvbbfssuoreyvwqawoveyowuvxviqzqknotyweqmuog -sk-jrdzerhmvrtvzawkksowbgkggkubwfquplmrxbdhespqgtis sk-jnnmltwtqwuoyagoogzzeraczmyfxhoairiddgayksqdfnbr sk-eghuepxnbcollzrjwbzqvbnhiiwagkejaclyhvaodeqgwrog sk-poszkbjdmamimconjustnrxxqusuzlryxkrzkpronlenrmen diff --git a/rag2_0/demo/deduplicate_nouns_json.py b/rag2_0/demo/deduplicate_nouns_json.py index 8bcb0f1..e68f2ef 100755 --- a/rag2_0/demo/deduplicate_nouns_json.py +++ b/rag2_0/demo/deduplicate_nouns_json.py @@ -69,7 +69,7 @@ class JsonDeduplicator: logging.info(f"从{self.INPUT_PATH}加载了{len(data)}条记录") return data except Exception as e: - logging.error(f"读取{self.INPUT_PATH}失败: {e}") + logging.error(f"读取{self.INPUT_PATH}失败: {e}", exc_info=True) return [] def group_items_by_key(self, items): @@ -118,7 +118,7 @@ class JsonDeduplicator: # 如果合并失败,返回第一个项目 return item_list[0] except Exception as e: - logging.error(f"处理键 {key} 时出错: {e}") + logging.error(f"处理键 {key} 时出错: {e}", exc_info=True) return item_list[0] def deduplicate(self): diff --git a/rag2_0/demo/extract_wikijs_nouns.py b/rag2_0/demo/extract_wikijs_nouns.py index 1acc9b3..9591b98 100755 --- a/rag2_0/demo/extract_wikijs_nouns.py +++ b/rag2_0/demo/extract_wikijs_nouns.py @@ -135,11 +135,10 @@ class WikijsNounsExtractor: parsed_output = self.terms_list_parser.parse(response.content) return parsed_output.terms except Exception as e: - logging.error(f"解析LLM响应时出错: {str(e)}") - logging.error(f"原始响应: {response.content}") + logging.error(f"解析LLM响应时出错: {str(e)}", exc_info=True) return [] except Exception as e: - logging.error(f"提取专业名词时出错: {str(e)}") + logging.error(f"提取专业名词时出错: {str(e)}", exc_info=True) return [] def _process_document(self, doc, path_terms): @@ -182,7 +181,7 @@ class WikijsNounsExtractor: return path_prefix except Exception as e: - logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}") + logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}", exc_info=True) return None def process_all_documents(self, output_dir: str = "extracted_nouns", max_concurrency: int = 5): @@ -237,7 +236,7 @@ class WikijsNounsExtractor: if i % 10 == 0: logging.info(f"已完成 {i+1}/{len(futures)} 个文档的处理") except Exception as e: - logging.error(f"处理文档时出错: {str(e)}") + logging.error(f"处理文档时出错: {str(e)}", exc_info=True) # 保存最终结果 for prefix, terms in path_terms.items(): diff --git a/rag2_0/demo/intent_recognition_example.py b/rag2_0/demo/intent_recognition_example.py index 6895c70..1cab6d9 100755 --- a/rag2_0/demo/intent_recognition_example.py +++ b/rag2_0/demo/intent_recognition_example.py @@ -27,249 +27,6 @@ from rag2_0.tool.ModelTool import OpenAiLLM # 加载环境变量 load_dotenv() -dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://172.20.0.145/v1") - -def is_retrieved_doc_relevant(query: str, retrieved_doc: List[Dict[str, Any]], api_key: str = None, base_url: str = None, model_name: str = None) -> Dict[str, Any]: - """ - 使用LLM判断检索出的文档是否与用户提问相关 - - Args: - query: 用户提问 - retrieved_doc: 检索出的文档列表 - api_key: API密钥,默认使用环境变量 - base_url: API基础URL,默认使用环境变量 - model_name: 模型名称,默认使用环境变量或默认模型 - - Returns: - 包含相关性判断结果的字典,包括is_relevant(布尔值)和explanation(解释) - """ - # 使用环境变量或参数值 - api_key = api_key or os.getenv("OPENAI_API_KEY") - base_url = base_url or os.getenv("OPENAI_API_BASE") - model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo") - - # 如果没有检索到文档,直接返回不相关 - if not retrieved_doc or len(retrieved_doc) == 0: - return { - "is_relevant": False, - "explanation": "没有检索到任何文档", - "relevance_score": 0.0 - } - - # 构建文档内容 - doc_contents = [] - for i, doc in enumerate(retrieved_doc[:3]): # 只取前3个文档进行判断 - content = doc.get("content", "") - title = doc.get("title", "") - doc_contents.append(f"文档{i+1}标题: {title}\n文档{i+1}内容: {content}") - - doc_text = "\n\n".join(doc_contents) - class TempModel(BaseModel): - is_relevant: bool = Field(description="是否与用户提问相关") - relevance_score: int = Field(description="相关性评分,0-100分") - explanation: str = Field(description="解释各个文档与提问的相关性或不相关性") - - parser = PydanticOutputParser(pydantic_object=TempModel) - # 构建提示词 - prompt = f"""请判断以下检索文档是否与用户提问相关,并给出相关性评分(0-100分)。 - -用户提问: {query} - -检索文档: -{doc_text} - -请按照以下JSON格式返回结果: -{parser.get_format_instructions()} -""" - - try: - # 初始化LLM并调用 - llm = OpenAiLLM(api_key=api_key, base_url=base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"}) - response = llm.invoke(prompt) - - result = parser.parse(response.content) - - return { - "is_relevant": result.is_relevant, - "relevance_score": result.relevance_score, - "explanation": result.explanation - } - except Exception as e: - logging.error(f"判断文档相关性时出错: {str(e)}") - return { - "is_relevant": False, - "explanation": f"判断过程出错: {str(e)}", - "relevance_score": 0.0 - } - -# 读取Excel文件中的提问数据 -def load_questions_from_excel(file_path=None): - """ - 从Excel文件中读取提问数据 - - Args: - file_path: Excel文件路径,如果为None则使用默认路径 - - Returns: - 提问列表 - """ - - try: - # 读取Excel文件的第一列数据 - df = pd.read_excel(file_path) - questions = df.iloc[:, 0].tolist() # 获取第一列数据 - logging.info(f"成功从{file_path}读取了{len(questions)}条提问") - return questions - except Exception as e: - logging.error(f"读取Excel文件时出错: {e}") - return [] - -def process_query(recognizer: IntentRecognizer, query: str, conversation_context: str = "", chat_history: List[Dict[str, str]] = None, previous_slots: Dict[str, str] = None): - """ - 处理单个查询,支持重试机制,并包含槽位填充 - - Args: - recognizer: 意图识别器实例 - query: 查询字符串 - - Returns: - 处理结果字典 - """ - max_retries = 3 - retry_count = 0 - - while retry_count <= max_retries: - try: - # 使用新的process_query_with_slots方法处理查询 - # result = recognizer.process_query_with_slots(query) - result = recognizer.process_query(query, - conversation_context=conversation_context, - chat_history=chat_history, - previous_slots=previous_slots, - enable_query_expansion=True) - # 提取分类信息 - classification = result["classification"] - original_query = result["rewrite"]["rewrite"] - query_list = result["query_expand"]["all"] - soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","") - # 将字典转换为Classification对象 - classification_obj = Classification(**classification) - retrieved_doc=dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name) - - # 判断检索文档是否相关 - relevance_result = {} - if retrieved_doc: - # 获取API密钥和基础URL - api_key = os.getenv("OPENAI_API_KEY") - base_url = os.getenv("OPENAI_API_BASE") - model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo") - # 判断文档相关性 - relevance_result = is_retrieved_doc_relevant(query, retrieved_doc, api_key, base_url, model_name) - else: - retrieved_doc_str = [] - relevance_result = { - "is_relevant": False, - "explanation": "没有检索到文档", - "relevance_score": 0.0 - } - - retrieved_doc_titles=[] - if retrieved_doc: - retrieved_doc_titles=[doc["title"].split("/")[-1] for doc in retrieved_doc] - # 提取槽位填充信息 - slot_filling = result.get("slot_filling", {}) - slot_filling_str = "" - if slot_filling and "filled_data" in slot_filling: - # 格式化槽位填充结果 - slot_filling_str = json.dumps({ - "是否完整": slot_filling.get("is_complete", False), - "缺失槽位": slot_filling.get("missing_slots", {}), - "填充数据": slot_filling.get("filled_data", {}) - }, ensure_ascii=False, indent=2) - - # 处理成功,返回结果 - return { - "提问": query, - "问题分类": f"{classification['vertical_classification']} - {classification['sub_classification']}", - "问题改写": result["rewrite"]["rewrite"], - "槽位填充": slot_filling_str, - "检索的文档": "\n".join(retrieved_doc_titles), - "文档是否相关": "相关" if relevance_result["is_relevant"] else "不相关", - "文档相关性解释": relevance_result["explanation"] - } - except Exception as e: - logging.error(f"处理问题 '{query}' 时出错: ",exc_info=True) - retry_count += 1 - - # 如果已经重试了最大次数,则记录错误并返回错误结果 - if retry_count > max_retries: - return { - "提问": query, - "问题分类": "处理出错", - "问题改写": "处理出错", - "槽位填充": "处理出错", - "检索的文档": f"重试 {max_retries} 次后失败: {str(e)}", - "文档是否相关": "处理出错", - "文档相关性解释": "处理出错" - } - else: - # 可以在这里添加延迟,避免过快重试 - time.sleep(10) - -def save_results_to_excel(results, output_file, is_final=False): - """ - 将结果保存到Excel文件 - - Args: - results: 结果列表 - output_file: 输出文件路径 - is_final: 是否为最终保存,如果是则使用完整文件名,否则添加临时标记 - - Returns: - None - """ - # 过滤掉None值 - valid_results = [r for r in results if r is not None] - - if not valid_results: - logging.warning("没有有效结果可保存") - return - - # 创建DataFrame - results_df = pd.DataFrame(valid_results) - - # 根据是否为最终保存确定文件名 - if not is_final: - file_name, file_ext = os.path.splitext(output_file) - temp_output_file = f"{file_name}_temp{file_ext}" - else: - temp_output_file = output_file - - # 使用ExcelWriter设置格式 - with pd.ExcelWriter(temp_output_file, engine='xlsxwriter') as writer: - results_df.to_excel(writer, index=False, sheet_name='Sheet1') - - # 获取工作簿和工作表对象 - workbook = writer.book - worksheet = writer.sheets['Sheet1'] - - # 设置列宽(单位:像素) - # 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位) - worksheet.set_column('A:A', 60) # 提问列 60个Excel单位 - worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位 - worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位 - worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位 - worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位 - worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位 - worksheet.set_column('G:G', 80) # 槽位填充 80个Excel单位 - worksheet.set_column('H:H', 60) # 文档相关性 60个Excel单位 - - # 设置所有行高为20磅 - for i in range(len(results_df) + 1): # +1 是为了包括表头 - worksheet.set_row(i, 20) - - logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}") - # 示例查询 examples_query = """主网电力建设计价通软件, 35kV的软件 土质比例不能一起设置吗""" conversation_context="" @@ -296,29 +53,327 @@ previous_slots={ "operation_steps": None } +class QueryRewriteProcessor: + """ + 查询改写处理器,用于意图识别、问题改写和文档检索 + """ + def __init__(self, + api_key: str = None, + base_url: str = None, + model_name: str = None, + dify_api_key: str = "dataset-skLjmPVonjHo119OWNf3kAmY", + dify_base_url: str = "http://172.20.0.145/v1"): + """ + 初始化查询改写处理器 + + Args: + api_key: API密钥,默认使用环境变量 + base_url: API基础URL,默认使用环境变量 + model_name: 模型名称,默认使用环境变量或默认模型 + dify_api_key: Dify API密钥 + dify_base_url: Dify API基础URL + """ + # 初始化意图识别器 + self.api_key = api_key or os.getenv("OPENAI_API_KEY") + self.base_url = base_url or os.getenv("OPENAI_API_BASE") + self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo") + + self.recognizer = IntentRecognizer(api_key=self.api_key, base_url=self.base_url, model_name=self.model_name) + self.dify_query_retrieval = DifyQueryRetrieval(api_key=dify_api_key, base_url=dify_base_url) + + def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + 使用LLM判断检索出的文档是否与用户提问相关 + + Args: + query: 用户提问 + retrieved_doc: 检索出的文档列表 + + Returns: + 包含相关性判断结果的字典,包括is_relevant(布尔值)和explanation(解释) + """ + # 如果没有检索到文档,直接返回不相关 + if not retrieved_doc or len(retrieved_doc) == 0: + return { + "is_relevant": False, + "explanation": "没有检索到任何文档", + "relevance_score": 0.0 + } + + # 构建文档内容 + doc_contents = [] + for i, doc in enumerate(retrieved_doc[:3]): # 只取前3个文档进行判断 + content = doc.get("content", "") + title = doc.get("title", "") + doc_contents.append(f"文档{i+1}标题: {title}\n文档{i+1}内容: {content}") + + doc_text = "\n\n".join(doc_contents) + class TempModel(BaseModel): + is_relevant: bool = Field(description="是否与用户提问相关") + relevance_score: int = Field(description="相关性评分,0-100分") + explanation: str = Field(description="解释各个文档与提问的相关性或不相关性") + + parser = PydanticOutputParser(pydantic_object=TempModel) + # 构建提示词 + prompt = f"""请判断以下检索文档是否与用户提问相关,并给出相关性评分(0-100分)。 + +用户提问: {query} + +检索文档: +{doc_text} + +请按照以下JSON格式返回结果: +{parser.get_format_instructions()} +""" + + try: + # 初始化LLM并调用 + llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"}) + response = llm.invoke(prompt) + + result = parser.parse(response.content) + + return { + "is_relevant": result.is_relevant, + "relevance_score": result.relevance_score, + "explanation": result.explanation + } + except Exception as e: + logging.error(f"判断文档相关性时出错: {str(e)}", exc_info=True) + return { + "is_relevant": False, + "explanation": f"判断过程出错: {str(e)}", + "relevance_score": 0.0 + } + + def load_questions_from_excel(self, file_path=None): + """ + 从Excel文件中读取提问数据 + + Args: + file_path: Excel文件路径,如果为None则使用默认路径 + + Returns: + 提问列表 + """ + try: + # 读取Excel文件的第一列数据 + df = pd.read_excel(file_path) + questions = df.iloc[:, 0].tolist() # 获取第一列数据 + logging.info(f"成功从{file_path}读取了{len(questions)}条提问") + return questions + except Exception as e: + logging.error(f"读取Excel文件时出错: {e}", exc_info=True) + return [] + + def process_query(self, query: str, conversation_context: str = "", chat_history: List[Dict[str, str]] = None, previous_slots: Dict[str, str] = None, enable_retrieval: bool = False): + """ + 处理单个查询,支持重试机制,并包含槽位填充 + + Args: + query: 查询字符串 + conversation_context: 对话上下文 + chat_history: 聊天历史记录 + previous_slots: 之前识别的槽位信息 + enable_retrieval: 是否启用文档检索功能,默认为False + + Returns: + 处理结果字典 + """ + max_retries = 3 + retry_count = 0 + + while retry_count <= max_retries: + try: + # 使用process_query方法处理查询 + result = self.recognizer.process_query(query, + conversation_context=conversation_context, + chat_history=chat_history, + previous_slots=previous_slots, + enable_query_expansion=True) + # 提取分类信息 + classification = result["classification"] + original_query = result["rewrite"]["rewrite"] + query_list = result["query_expand"]["all"] + soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","") + # 将字典转换为Classification对象 + classification_obj = Classification(**classification) + + # 根据enable_retrieval参数决定是否进行文档检索 + retrieved_doc = None + if enable_retrieval: + retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name) + + # 判断检索文档是否相关 + relevance_result = {} + if retrieved_doc: + # 判断文档相关性 + relevance_result = self.is_retrieved_doc_relevant(query, retrieved_doc) + else: + relevance_result = { + "is_relevant": False, + "explanation": "没有检索到文档" if enable_retrieval else "文档检索功能未启用", + "relevance_score": 0.0 + } + + retrieved_doc_titles=[] + if retrieved_doc: + retrieved_doc_titles=[doc["title"].split("/")[-1] for doc in retrieved_doc] + # 提取槽位填充信息 + slot_filling = result.get("slot_filling", {}) + slot_filling_str = "" + if slot_filling and "filled_data" in slot_filling: + # 格式化槽位填充结果 + slot_filling_str = json.dumps({ + "是否完整": slot_filling.get("is_complete", False), + "缺失槽位": slot_filling.get("missing_slots", {}), + "填充数据": slot_filling.get("filled_data", {}) + }, ensure_ascii=False, indent=2) + + # 处理成功,返回结果 + return { + "问题": query, + "问题分类": f"{classification['vertical_classification']} - {classification['sub_classification']}", + "问题改写": result["rewrite"]["rewrite"], + "槽位信息": slot_filling_str, + "检索的文档": "\n".join(retrieved_doc_titles), + "检索的内容": json.dumps(retrieved_doc, ensure_ascii=False, indent=2) if retrieved_doc else "", + "文档是否相关": "相关" if relevance_result["is_relevant"] else "不相关", + "文档相关性解释": relevance_result["explanation"] + } + except Exception as e: + logging.error(f"处理问题 '{query}' 时出错: ",exc_info=True) + retry_count += 1 + + # 如果已经重试了最大次数,则记录错误并返回错误结果 + if retry_count > max_retries: + return { + "问题": query, + "问题分类": "处理出错", + "问题改写": "处理出错", + "槽位信息": "处理出错", + "检索的文档": f"重试 {max_retries} 次后失败: {str(e)}", + "检索的内容":"", + "文档是否相关": "处理出错", + "文档相关性解释": "处理出错" + } + else: + # 可以在这里添加延迟,避免过快重试 + time.sleep(10) + + def save_results_to_excel(self, results, output_file, is_final=False): + """ + 将结果保存到Excel文件 + + Args: + results: 结果列表 + output_file: 输出文件路径 + is_final: 是否为最终保存,如果是则使用完整文件名,否则添加临时标记 + + Returns: + None + """ + # 过滤掉None值 + valid_results = [r for r in results if r is not None] + + if not valid_results: + logging.warning("没有有效结果可保存") + return + + # 创建DataFrame + results_df = pd.DataFrame(valid_results) + + # 根据是否为最终保存确定文件名 + if not is_final: + file_name, file_ext = os.path.splitext(output_file) + temp_output_file = f"{file_name}_temp{file_ext}" + else: + temp_output_file = output_file + + # 使用ExcelWriter设置格式 + with pd.ExcelWriter(temp_output_file, engine='xlsxwriter') as writer: + results_df.to_excel(writer, index=False, sheet_name='Sheet1') + + # 获取工作簿和工作表对象 + workbook = writer.book + worksheet = writer.sheets['Sheet1'] + + # 设置列宽(单位:像素) + # 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位) + worksheet.set_column('A:A', 60) # 提问列 60个Excel单位 + worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位 + worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位 + worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位 + worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位 + worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位 + worksheet.set_column('G:G', 80) # 槽位填充 80个Excel单位 + worksheet.set_column('H:H', 60) # 文档相关性 60个Excel单位 + + # 设置所有行高为20磅 + for i in range(len(results_df) + 1): # +1 是为了包括表头 + worksheet.set_row(i, 20) + + logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}") + + def process_batch(self, questions: List[str], max_workers: int = 2, enable_retrieval: bool = False, output_file: str = None): + """ + 批量处理多个问题 + + Args: + questions: 问题列表 + max_workers: 并发处理的最大线程数,默认为2 + enable_retrieval: 是否启用文档检索功能,默认为False + output_file: 输出文件路径,如果为None则不保存结果 + + Returns: + 处理结果列表 + """ + logging.info(f"共有 {len(questions)} 个问题需要处理,使用 {max_workers} 个并发线程") + logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}") + + # 创建一个与输入顺序相同的结果列表 + results = [None] * len(questions) + + # 使用线程池进行并发处理 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交所有任务并记录它们的索引 + future_to_index = {} + for idx, query in enumerate(questions): + if not query or query.strip() == "": + continue + future = executor.submit(self.process_query, query, enable_retrieval=enable_retrieval) + future_to_index[future] = idx + + # 使用tqdm显示进度条 + for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(future_to_index), desc="处理进度"): + idx = future_to_index[future] + result = future.result() + # 将结果放在与输入相同的位置 + results[idx] = result + + # 如果提供了输出文件路径,则保存结果 + if output_file: + self.save_results_to_excel(results, output_file, is_final=True) + + return results + def parse_arguments(): """解析命令行参数""" parser = argparse.ArgumentParser(description='意图识别和问题改写工具') - + input_file="data/excel/1500条点踩软件问题测试.xlsx" + ouput_file="data/excel/1500条点踩软件问题测试_意图分类.xlsx" # 添加数据文件路径参数 - parser.add_argument('--input', '-i', type=str, + parser.add_argument('--input', '-i', type=str, default=input_file, help='输入Excel文件路径,包含待处理的提问数据(第一列)') - parser.add_argument('--output', '-o', type=str, + parser.add_argument('--output', '-o', type=str,default=ouput_file, help='输出Excel文件路径,用于保存处理结果') - # 添加LLM相关参数 - parser.add_argument('--model', '-m', type=str, - help='LLM模型名称,默认使用环境变量中的配置') - parser.add_argument('--api_base', '-a', type=str, - help='API基础URL,默认使用环境变量中的配置') - # 添加处理相关参数 parser.add_argument('--max_workers', '-w', type=int, default=2, help='并发处理的最大线程数,默认为20') - parser.add_argument('--debug', '-d', action='store_true', - help='启用调试模式,使用示例查询而非从文件读取') - parser.add_argument('--query', '-q', type=str, - help='在调试模式下使用的查询字符串') + + parser.add_argument('--enable_retrieval', '-r', action='store_true', + help='是否启用文档检索功能,默认不启用') return parser.parse_args() @@ -332,11 +387,12 @@ def main(): # 从环境变量中获取配置,命令行参数优先 api_key = os.getenv("OPENAI_API_KEY") - base_url = args.api_base if args.api_base else os.getenv("OPENAI_API_BASE") - model_name = args.model if args.model else os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo") + base_url = os.getenv("OPENAI_API_BASE") + model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo") + enable_retrieval = args.enable_retrieval - # 初始化意图识别器 - recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name) + # 初始化查询改写处理器 + processor = QueryRewriteProcessor(api_key=api_key, base_url=base_url, model_name=model_name) # 读取提问数据 current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -344,51 +400,29 @@ def main(): output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "1500条点踩软件问题_槽位(分类)填充结果.xlsx") # 检测是否为调试模式 - is_debug = args.debug or (hasattr(sys, 'gettrace') and sys.gettrace() is not None) - is_debug = False + is_debug =hasattr(sys, 'gettrace') and sys.gettrace() is not None if is_debug: - # 如果提供了查询参数,使用它;否则使用默认示例 - if args.query: - examples = [args.query] - else: - examples = examples_query.strip().split("\n") + examples = examples_query.strip().split("\n") else: - examples = load_questions_from_excel(data_file) + examples = processor.load_questions_from_excel(data_file) if not is_debug: - max_workers = args.max_workers - logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程") - - # 创建一个与输入顺序相同的结果列表 - results = [None] * len(examples) - batch_size = 100 # 每100条保存一次 - - # 使用线程池进行并发处理 - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - # 提交所有任务并记录它们的索引 - future_to_index = {} - for idx, query in enumerate(examples): - future = executor.submit(process_query, recognizer, query) - future_to_index[future] = idx - - # 使用tqdm显示进度条 - completed = 0 - for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(examples), desc="处理进度"): - idx = future_to_index[future] - result = future.result() - # 将结果放在与输入相同的位置 - results[idx] = result - completed += 1 - - # 处理完所有数据后,保存最终结果 - save_results_to_excel(results, output_file, is_final=True) + # 批量处理问题 + results = processor.process_batch(questions=examples, max_workers=args.max_workers, enable_retrieval=enable_retrieval, output_file=output_file) logging.info(f"所有处理完成,最终结果已保存至: {output_file}") else: + logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}") for idx, query in enumerate(examples): if query.strip() == "": continue - # process_query(recognizer, query, conversation_context, chat_history, previous_slots) - print(json.dumps(process_query(recognizer, query), ensure_ascii=False, indent=2)) + # 在调试模式下使用完整的参数 + print(json.dumps(processor.process_query( + query, + conversation_context=conversation_context, + chat_history=chat_history, + previous_slots=previous_slots, + enable_retrieval=enable_retrieval + ), ensure_ascii=False, indent=2)) def setup_logging(): # 配置日志输出到控制台 diff --git a/rag2_0/demo/validate_excel_data_batch.py b/rag2_0/demo/validate_excel_data_batch.py index a6d4bbd..a3868ef 100755 --- a/rag2_0/demo/validate_excel_data_batch.py +++ b/rag2_0/demo/validate_excel_data_batch.py @@ -81,21 +81,21 @@ class ExcelDataValidator: """ file_path = file_path or self.input_file if not file_path: - logging.error("未指定输入文件路径") + logging.error("未指定输入文件路径", exc_info=True) return None try: df = pd.read_excel(file_path) - required_columns = ["问题", "问题分类", "问题改写", "槽点信息", "检索的内容"] + required_columns = ["问题", "问题分类", "问题改写", "槽点信息"] for col in required_columns: if col not in df.columns: - logging.error(f"缺少必要的列: {col}") + logging.error(f"缺少必要的列: {col}", exc_info=True) return None logging.info(f"成功从{file_path}读取了{len(df)}条数据") self.df = df return df except Exception as e: - logging.error(f"读取Excel文件时出错: {e}") + logging.error(f"读取Excel文件时出错: {e}", exc_info=True) return None def validate_classification(self, llm:OpenAiLLM , query:str, vertical_class:str, sub_class:str): @@ -413,10 +413,7 @@ class ExcelDataValidator: return index, True, "", "", confidence_score except Exception as e: error_msg = f"处理行 {index} 时发生错误: {str(e)}" - logging.error(error_msg) - if self.debug: - import traceback - logging.error(traceback.format_exc()) + logging.error(error_msg, exc_info=True) return index, False, "处理错误", error_msg, 0.0 def process_batch(self, llm, batch_data): diff --git a/rag2_0/intent_recognition/IntentRecognition.py b/rag2_0/intent_recognition/IntentRecognition.py index 861a3b6..860297b 100755 --- a/rag2_0/intent_recognition/IntentRecognition.py +++ b/rag2_0/intent_recognition/IntentRecognition.py @@ -635,7 +635,7 @@ class IntentRecognizer: return parsed_output except Exception as e: # 如果解析失败,返回原始查询作为后退提示 - logging.error(f"后退提示生成失败: {e}") + logging.error(f"后退提示生成失败: {e}", exc_info=True) return StepBackPrompt(original_query=query, step_back_query=query) def _generate_follow_up_questions(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> FollowUpQuestions: @@ -672,7 +672,7 @@ class IntentRecognizer: return parsed_output except Exception as e: # 如果解析失败,返回原始查询作为后续问题 - logging.error(f"后续问题生成失败: {e}") + logging.error(f"后续问题生成失败: {e}", exc_info=True) return FollowUpQuestions(original_query=query, follow_up_query=query) def _generate_hypothetical_document(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> HypotheticalDocument: @@ -709,7 +709,7 @@ class IntentRecognizer: return parsed_output except Exception as e: # 如果解析失败,返回空的假设性回答 - logging.error(f"假设性文档生成失败: {e}") + logging.error(f"假设性文档生成失败: {e}", exc_info=True) return HypotheticalDocument(original_query=query, hypothetical_answer="") def _generate_multi_questions(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> MultiQuestions: @@ -746,7 +746,7 @@ class IntentRecognizer: return parsed_output except Exception as e: # 如果解析失败,返回原始查询作为唯一子问题 - logging.error(f"多角度问题生成失败: {e},LLM返回内容:{response.content}") + logging.error(f"多角度问题生成失败: {e}",exc_info=True) return MultiQuestions(original_query=query, sub_questions=[query]) def _run_in_thread(self, func, args=(), kwargs={}): @@ -768,7 +768,7 @@ class IntentRecognizer: result = func(*args, **kwargs) result_container.append(result) except Exception as e: - logging.error(f"线程执行函数 {func.__name__} 时出错: {e}") + logging.error(f"线程执行函数 {func.__name__} 时出错: {e}", exc_info=True) result_container.append(None) thread = threading.Thread(target=thread_target) @@ -866,5 +866,4 @@ class IntentRecognizer: return result except Exception as e: - logging.error(f"process_intent_and_slot error:{e}") raise RuntimeError(f"process_intent_and_slot error:{e}") from e \ No newline at end of file diff --git a/rag2_0/intent_recognition/ProfessionalNounVector.py b/rag2_0/intent_recognition/ProfessionalNounVector.py index 9921db0..31b3d2b 100755 --- a/rag2_0/intent_recognition/ProfessionalNounVector.py +++ b/rag2_0/intent_recognition/ProfessionalNounVector.py @@ -93,7 +93,7 @@ class ProfessionalNounVectorizer: logging.info(f"总共加载了 {len(merged_terms)} 条专业名词") return merged_terms except Exception as e: - logging.error(f"加载多个文件失败: {e}") + logging.error(f"加载多个文件失败: {e}", exc_info=True) return [] def vectorize_files_and_save(self, file_paths: List[str]) -> bool: @@ -139,7 +139,7 @@ class ProfessionalNounVectorizer: logging.info("完成多文件专业名词向量化和索引创建") return True except Exception as e: - logging.error(f"多文件向量化处理失败: {e}") + logging.error(f"多文件向量化处理失败: {e}", exc_info=True) return False def _updata_suffix_item(self)->Tuple[List[str], List[Dict]] : @@ -246,7 +246,7 @@ class ProfessionalNounVectorizer: faiss_index.save_local(folder_path=self.index_path) logging.info(f"FAISS索引已保存至 {self.index_path}") except Exception as e: - logging.error(f"保存FAISS索引失败: {e}") + logging.error(f"保存FAISS索引失败: {e}", exc_info=True) raise e @@ -349,5 +349,5 @@ class ProfessionalNounRetriever: return [json.loads(item) for item in intersection] except Exception as e: - logging.error(f"查询FAISS索引失败: {e}") + logging.error(f"查询FAISS索引失败: {e}", exc_info=True) return [] \ No newline at end of file