#!/usr/bin/env python # -*- coding: utf-8 -*- """ File: extract_wikijs_nouns.py Author: oyyz Description: 从 Wikijs 文档中提取专业名词 """ import os from typing import List from dotenv import load_dotenv from langchain.output_parsers import PydanticOutputParser from rag2_0.tool.WikijsTool import WikijsTool from rag2_0.intent_recognition.DataModels import Term, TermList from rag2_0.tool.html_to_md import convert_html_to_md from rag2_0.tool.ModelTool import OpenAiLLM import json import datetime import logging import threading import concurrent.futures from threading import Semaphore # 加载环境变量 load_dotenv() extract_wiki_nouns_prompt=""" 我在完善我的专业词库,请从提供的电力行业造价软件相关文本中提取关键词,要求如下: 一、提取范围 1. 核心功能模块 (例:多工程批量计价、材机数据反算、变电工程智能组价、架空线路地形系数计算) 2、软件功能及界面名称(包括:界面页签、功能按钮、功能名称等) (例:新建工程量清单、导出工程量清单等) 3. 业务专用术语 (例:装置性材料、甲供材保管费、施工降效补偿、电缆头试验配套费) 4. 计价标准体系 (例:预规2020版、电网检修定额2015版、配网工程概算定额) 二、提取规则 1. 识别核心功能名称(如"多工程批量设置工程量、工程设置密码") 2. 提取业务专用名词(如"主材卸车保管费") 3. 标注关联术语的对应关系(如"市场价"与"市场价格"互为同义词) 4. 包含定额标准相关术语(如"预规2020版") 5. 复合型术语需保持完整 √ 正确:"地形增加系数批量设置" × 错误:"地形"、"系数"、"设置" 6. 总结生成关键词解释 关键词:编制依据 描述:造价文件编制基准规范 7. 软件的特定版本号不作为关键词 三、输出格式: {output_format} 四、输入内容: {content} """ class WikijsNounsExtractor: """从 Wikijs 文档中提取专业名词""" def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo"): """ 初始化专业名词提取器 Args: api_key: API密钥,如果为None则从环境变量获取 base_url: API基础URL,如果为None则使用默认URL model_name: 要使用的模型名称 """ # 保存参数 self.api_key = api_key self.base_url = base_url self.model_name = model_name # 初始化LLM llm_params = { "temperature": 0.6, "model": model_name } if api_key: llm_params["api_key"] = api_key if base_url: llm_params["base_url"] = base_url self.llm = OpenAiLLM(**llm_params) # 准备术语列表解析器 self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList) # 信号量,限制并发请求数量 self.semaphore = None # 线程锁,用于保护共享资源 self.lock = threading.Lock() def _convert_html_to_md(self, content, title): """HTML转Markdown""" options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True} new_content = (content.replace("h6>", "h7>") .replace("h5>", "h6>") .replace("h4>", "h5>") .replace("h3>", "h4>") .replace("h2>", "h3>") .replace("h1>", "h2>")) # 将HTML内容转换为Markdown markdown_content = convert_html_to_md(new_content, "", **options) markdown_content = f"# {title}\n\n{markdown_content}" return markdown_content def extract_from_document(self, doc_info: dict) -> List[Term]: """从单个文档中提取专业名词""" try: # 使用LLM调用处理文档 content = doc_info['content'] title = doc_info["title"] # 转换HTML到Markdown markdown_content = self._convert_html_to_md(content, title) # 准备提示词 formatted_prompt = extract_wiki_nouns_prompt.replace("{content}", markdown_content) formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions()) try: # 调用LLM response = self.llm.invoke(formatted_prompt) # 使用Pydantic解析器解析结果 parsed_output = self.terms_list_parser.parse(response.content) return parsed_output.terms except Exception as e: logging.error(f"解析LLM响应时出错: {str(e)}", exc_info=True) return [] except Exception as e: logging.error(f"提取专业名词时出错: {str(e)}", exc_info=True) return [] def _process_document(self, doc, path_terms): """处理单个文档""" try: # 获取信号量 with self.semaphore: # 检查文档路径是否在我们要处理的路径中 path_prefix = None for prefix in path_terms.keys(): if doc['path'].startswith(prefix): path_prefix = prefix break # 如果不在要处理的路径中,则跳过 if not path_prefix: return None # 获取文档详细信息 doc_info = WikijsTool.query_doc_info(doc['id']) if not doc_info or not doc_info.get('content'): return None # 提取专业名词 terms = self.extract_from_document(doc_info) # 将提取的术语添加到对应路径的结果列表中 terms_dicts = [{"name": term.name, "synonymous": term.synonymous, "description": term.description} for term in terms] with self.lock: path_terms[path_prefix].extend(terms_dicts) logging.info(f"文档 {doc['path']} 处理完成,提取了 {len(terms)} 个专业名词") # 每处理10个文档保存一次中间结果 current_count = len(path_terms[path_prefix]) if current_count % 10 == 0: # 使用锁保护文件IO self._save_terms_to_file(path_terms[path_prefix], os.path.join(self.output_dir, f"{path_prefix.split('(')[0]}_nouns.json")) logging.info(f"已处理 {path_prefix} 的文档数达到 {current_count//10*10} 个,已保存中间结果") return path_prefix except Exception as e: logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}", exc_info=True) return None def process_all_documents(self, output_dir: str = "extracted_nouns", max_concurrency: int = 5): """使用线程池处理所有文档""" # 保存输出目录 self.output_dir = output_dir # 创建输出目录 if not os.path.exists(output_dir): os.makedirs(output_dir) # 初始化信号量,限制并发请求数 self.semaphore = Semaphore(max_concurrency) # 获取所有文档 all_docs = WikijsTool.get_all_documents() # 要处理的路径前缀 # path_prefixes = [ # "技改检修计价通(2020)", # "西藏造价软件(2023)", # "新型储能电站建设计价通C1(2024)", # "配网造价软件(2022)", # ] path_prefixes = [ "主网电力建设计价通(2018)", ] # 为每个路径创建单独的结果列表 path_terms = {prefix: [] for prefix in path_prefixes} # 过滤出符合路径前缀的文档 filtered_docs = [] for doc in all_docs: for prefix in path_prefixes: if doc['path'].startswith(prefix): filtered_docs.append(doc) break logging.info(f"开始使用线程池处理 {len(filtered_docs)} 个文档...") # 使用线程池处理所有文档 with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor: futures = [] for doc in filtered_docs: future = executor.submit(self._process_document, doc, path_terms) futures.append(future) # 等待所有任务完成 for i, future in enumerate(concurrent.futures.as_completed(futures)): try: prefix = future.result() if i % 10 == 0: logging.info(f"已完成 {i+1}/{len(futures)} 个文档的处理") except Exception as e: logging.error(f"处理文档时出错: {str(e)}", exc_info=True) # 保存最终结果 for prefix, terms in path_terms.items(): # 为每个路径保存单独的文件 output_file = os.path.join(output_dir, f"{prefix.split('(')[0]}_nouns.json") self._save_terms_to_file(terms, output_file) logging.info(f"{prefix} 处理完成,共提取 {len(terms)} 个专业名词,已保存到 {output_file}") def _save_terms_to_file(self, terms, output_file): """保存术语列表到文件""" with open(output_file, 'w', encoding='utf-8') as f: json.dump(terms, f, ensure_ascii=False, indent=2) def main(): # 从环境变量获取配置 api_key = os.getenv("OPENAI_API_KEY") base_url = os.getenv("OPENAI_API_BASE") # os.environ["LLM_MODEL_NAME"] = "Qwen/Qwen2.5-72B-Instruct-128K" extractor = WikijsNounsExtractor(api_key=api_key, base_url=base_url, model_name=os.getenv("LLM_MODEL_NAME")) current_dir = os.path.dirname(os.path.abspath(__file__)) output_dir = os.path.join(current_dir, "..", "..", "data", "wiki_extracted_nouns") extractor.process_all_documents(output_dir=output_dir, max_concurrency=2) if __name__ == "__main__": # 配置日志输出到文件,并设置格式 current_dir = os.path.dirname(os.path.abspath(__file__)) log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' date_format = '%Y-%m-%d %H:%M:%S' # 创建一个控制台处理器 console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_handler.setFormatter(logging.Formatter(log_format, date_format)) # 获取根日志记录器并添加处理器 root_logger = logging.getLogger() root_logger.setLevel(logging.INFO) root_logger.addHandler(console_handler) main()