282 lines
11 KiB
Python
Executable File
282 lines
11 KiB
Python
Executable File
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
File: extract_wikijs_nouns.py
|
||
Author: oyyz
|
||
Description: 从 Wikijs 文档中提取专业名词
|
||
"""
|
||
|
||
import os
|
||
from typing import List
|
||
from dotenv import load_dotenv
|
||
from langchain.output_parsers import PydanticOutputParser
|
||
from rag2_0.tool.WikijsTool import WikijsTool
|
||
from rag2_0.intent_recognition.DataModels import Term, TermList
|
||
from rag2_0.tool.html_to_md import convert_html_to_md
|
||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||
import json
|
||
import datetime
|
||
import logging
|
||
import threading
|
||
import concurrent.futures
|
||
from threading import Semaphore
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
extract_wiki_nouns_prompt="""
|
||
我在完善我的专业词库,请从提供的电力行业造价软件相关文本中提取关键词,要求如下:
|
||
|
||
一、提取范围
|
||
1. 核心功能模块
|
||
(例:多工程批量计价、材机数据反算、变电工程智能组价、架空线路地形系数计算)
|
||
2、软件功能及界面名称(包括:界面页签、功能按钮、功能名称等)
|
||
(例:新建工程量清单、导出工程量清单等)
|
||
3. 业务专用术语
|
||
(例:装置性材料、甲供材保管费、施工降效补偿、电缆头试验配套费)
|
||
4. 计价标准体系
|
||
(例:预规2020版、电网检修定额2015版、配网工程概算定额)
|
||
|
||
|
||
二、提取规则
|
||
1. 识别核心功能名称(如"多工程批量设置工程量、工程设置密码")
|
||
2. 提取业务专用名词(如"主材卸车保管费")
|
||
3. 标注关联术语的对应关系(如"市场价"与"市场价格"互为同义词)
|
||
4. 包含定额标准相关术语(如"预规2020版")
|
||
5. 复合型术语需保持完整
|
||
√ 正确:"地形增加系数批量设置"
|
||
× 错误:"地形"、"系数"、"设置"
|
||
6. 总结生成关键词解释
|
||
关键词:编制依据
|
||
描述:造价文件编制基准规范
|
||
|
||
7. 软件的特定版本号不作为关键词
|
||
|
||
三、输出格式:
|
||
{output_format}
|
||
|
||
四、输入内容:
|
||
{content}
|
||
"""
|
||
|
||
|
||
class WikijsNounsExtractor:
|
||
"""从 Wikijs 文档中提取专业名词"""
|
||
|
||
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo"):
|
||
"""
|
||
初始化专业名词提取器
|
||
|
||
Args:
|
||
api_key: API密钥,如果为None则从环境变量获取
|
||
base_url: API基础URL,如果为None则使用默认URL
|
||
model_name: 要使用的模型名称
|
||
"""
|
||
# 保存参数
|
||
self.api_key = api_key
|
||
self.base_url = base_url
|
||
self.model_name = model_name
|
||
|
||
# 初始化LLM
|
||
llm_params = {
|
||
"temperature": 0.6,
|
||
"model": model_name
|
||
}
|
||
|
||
if api_key:
|
||
llm_params["api_key"] = api_key
|
||
|
||
if base_url:
|
||
llm_params["base_url"] = base_url
|
||
|
||
self.llm = OpenAiLLM(**llm_params)
|
||
|
||
# 准备术语列表解析器
|
||
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
||
|
||
# 信号量,限制并发请求数量
|
||
self.semaphore = None
|
||
|
||
# 线程锁,用于保护共享资源
|
||
self.lock = threading.Lock()
|
||
|
||
def _convert_html_to_md(self, content, title):
|
||
"""HTML转Markdown"""
|
||
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
|
||
new_content = (content.replace("h6>", "h7>")
|
||
.replace("h5>", "h6>")
|
||
.replace("h4>", "h5>")
|
||
.replace("h3>", "h4>")
|
||
.replace("h2>", "h3>")
|
||
.replace("h1>", "h2>"))
|
||
# 将HTML内容转换为Markdown
|
||
markdown_content = convert_html_to_md(new_content, "", **options)
|
||
markdown_content = f"# {title}\n\n{markdown_content}"
|
||
return markdown_content
|
||
|
||
def extract_from_document(self, doc_info: dict) -> List[Term]:
|
||
"""从单个文档中提取专业名词"""
|
||
try:
|
||
# 使用LLM调用处理文档
|
||
content = doc_info['content']
|
||
title = doc_info["title"]
|
||
|
||
# 转换HTML到Markdown
|
||
markdown_content = self._convert_html_to_md(content, title)
|
||
|
||
# 准备提示词
|
||
formatted_prompt = extract_wiki_nouns_prompt.replace("{content}", markdown_content)
|
||
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
|
||
|
||
try:
|
||
# 调用LLM
|
||
response = self.llm.invoke(formatted_prompt)
|
||
# 使用Pydantic解析器解析结果
|
||
parsed_output = self.terms_list_parser.parse(response.content)
|
||
return parsed_output.terms
|
||
except Exception as e:
|
||
logging.error(f"解析LLM响应时出错: {str(e)}")
|
||
logging.error(f"原始响应: {response.content}")
|
||
return []
|
||
except Exception as e:
|
||
logging.error(f"提取专业名词时出错: {str(e)}")
|
||
return []
|
||
|
||
def _process_document(self, doc, path_terms):
|
||
"""处理单个文档"""
|
||
try:
|
||
# 获取信号量
|
||
with self.semaphore:
|
||
# 检查文档路径是否在我们要处理的路径中
|
||
path_prefix = None
|
||
for prefix in path_terms.keys():
|
||
if doc['path'].startswith(prefix):
|
||
path_prefix = prefix
|
||
break
|
||
|
||
# 如果不在要处理的路径中,则跳过
|
||
if not path_prefix:
|
||
return None
|
||
|
||
# 获取文档详细信息
|
||
doc_info = WikijsTool.query_doc_info(doc['id'])
|
||
if not doc_info or not doc_info.get('content'):
|
||
return None
|
||
|
||
# 提取专业名词
|
||
terms = self.extract_from_document(doc_info)
|
||
|
||
# 将提取的术语添加到对应路径的结果列表中
|
||
terms_dicts = [{"name": term.name, "synonymous": term.synonymous, "description": term.description} for term in terms]
|
||
|
||
with self.lock:
|
||
path_terms[path_prefix].extend(terms_dicts)
|
||
logging.info(f"文档 {doc['path']} 处理完成,提取了 {len(terms)} 个专业名词")
|
||
|
||
# 每处理10个文档保存一次中间结果
|
||
current_count = len(path_terms[path_prefix])
|
||
if current_count % 10 == 0:
|
||
# 使用锁保护文件IO
|
||
self._save_terms_to_file(path_terms[path_prefix], os.path.join(self.output_dir, f"{path_prefix.split('(')[0]}_nouns.json"))
|
||
logging.info(f"已处理 {path_prefix} 的文档数达到 {current_count//10*10} 个,已保存中间结果")
|
||
|
||
return path_prefix
|
||
except Exception as e:
|
||
logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}")
|
||
return None
|
||
|
||
def process_all_documents(self, output_dir: str = "extracted_nouns", max_concurrency: int = 5):
|
||
"""使用线程池处理所有文档"""
|
||
# 保存输出目录
|
||
self.output_dir = output_dir
|
||
|
||
# 创建输出目录
|
||
if not os.path.exists(output_dir):
|
||
os.makedirs(output_dir)
|
||
|
||
# 初始化信号量,限制并发请求数
|
||
self.semaphore = Semaphore(max_concurrency)
|
||
|
||
# 获取所有文档
|
||
all_docs = WikijsTool.get_all_documents()
|
||
|
||
# 要处理的路径前缀
|
||
# path_prefixes = [
|
||
# "技改检修计价通(2020)",
|
||
# "西藏造价软件(2023)",
|
||
# "新型储能电站建设计价通C1(2024)",
|
||
# "配网造价软件(2022)",
|
||
# ]
|
||
path_prefixes = [
|
||
"主网电力建设计价通(2018)",
|
||
]
|
||
# 为每个路径创建单独的结果列表
|
||
path_terms = {prefix: [] for prefix in path_prefixes}
|
||
|
||
# 过滤出符合路径前缀的文档
|
||
filtered_docs = []
|
||
for doc in all_docs:
|
||
for prefix in path_prefixes:
|
||
if doc['path'].startswith(prefix):
|
||
filtered_docs.append(doc)
|
||
break
|
||
|
||
logging.info(f"开始使用线程池处理 {len(filtered_docs)} 个文档...")
|
||
|
||
# 使用线程池处理所有文档
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||
futures = []
|
||
for doc in filtered_docs:
|
||
future = executor.submit(self._process_document, doc, path_terms)
|
||
futures.append(future)
|
||
|
||
# 等待所有任务完成
|
||
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||
try:
|
||
prefix = future.result()
|
||
if i % 10 == 0:
|
||
logging.info(f"已完成 {i+1}/{len(futures)} 个文档的处理")
|
||
except Exception as e:
|
||
logging.error(f"处理文档时出错: {str(e)}")
|
||
|
||
# 保存最终结果
|
||
for prefix, terms in path_terms.items():
|
||
# 为每个路径保存单独的文件
|
||
output_file = os.path.join(output_dir, f"{prefix.split('(')[0]}_nouns.json")
|
||
self._save_terms_to_file(terms, output_file)
|
||
logging.info(f"{prefix} 处理完成,共提取 {len(terms)} 个专业名词,已保存到 {output_file}")
|
||
|
||
def _save_terms_to_file(self, terms, output_file):
|
||
"""保存术语列表到文件"""
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(terms, f, ensure_ascii=False, indent=2)
|
||
|
||
|
||
def main():
|
||
# 从环境变量获取配置
|
||
api_key = os.getenv("OPENAI_API_KEY")
|
||
base_url = os.getenv("OPENAI_API_BASE")
|
||
|
||
# os.environ["LLM_MODEL_NAME"] = "Qwen/Qwen2.5-72B-Instruct-128K"
|
||
|
||
extractor = WikijsNounsExtractor(api_key=api_key, base_url=base_url, model_name=os.getenv("LLM_MODEL_NAME"))
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
output_dir = os.path.join(current_dir, "..", "..", "data", "wiki_extracted_nouns")
|
||
extractor.process_all_documents(output_dir=output_dir, max_concurrency=2)
|
||
|
||
if __name__ == "__main__":
|
||
# 配置日志输出到文件,并设置格式
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
date_format = '%Y-%m-%d %H:%M:%S'
|
||
|
||
# 创建一个控制台处理器
|
||
console_handler = logging.StreamHandler()
|
||
console_handler.setLevel(logging.INFO)
|
||
console_handler.setFormatter(logging.Formatter(log_format, date_format))
|
||
|
||
# 获取根日志记录器并添加处理器
|
||
root_logger = logging.getLogger()
|
||
root_logger.setLevel(logging.INFO)
|
||
root_logger.addHandler(console_handler)
|
||
main() |