上传问题改写、意图识别模块代码

This commit is contained in:
2025-05-27 09:48:03 +08:00
commit 99017f0cb0
66 changed files with 111493 additions and 0 deletions
+282
View File
@@ -0,0 +1,282 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: extract_wikijs_nouns.py
Author: oyyz
Description: 从 Wikijs 文档中提取专业名词
"""
import os
from typing import List
from dotenv import load_dotenv
from langchain.output_parsers import PydanticOutputParser
from rag2_0.tool.WikijsTool import WikijsTool
from rag2_0.intent_recognition.DataModels import Term, TermList
from rag2_0.tool.html_to_md import convert_html_to_md
from rag2_0.tool.ModelTool import OpenAiLLM
import json
import datetime
import logging
import threading
import concurrent.futures
from threading import Semaphore
# 加载环境变量
load_dotenv()
extract_wiki_nouns_prompt="""
我在完善我的专业词库,请从提供的电力行业造价软件相关文本中提取关键词,要求如下:
一、提取范围
1. 核心功能模块
(例:多工程批量计价、材机数据反算、变电工程智能组价、架空线路地形系数计算)
2、软件功能及界面名称(包括:界面页签、功能按钮、功能名称等)
(例:新建工程量清单、导出工程量清单等)
3. 业务专用术语
(例:装置性材料、甲供材保管费、施工降效补偿、电缆头试验配套费)
4. 计价标准体系
(例:预规2020版、电网检修定额2015版、配网工程概算定额)
二、提取规则
1. 识别核心功能名称(如"多工程批量设置工程量、工程设置密码"
2. 提取业务专用名词(如"主材卸车保管费"
3. 标注关联术语的对应关系(如"市场价""市场价格"互为同义词)
4. 包含定额标准相关术语(如"预规2020版"
5. 复合型术语需保持完整
√ 正确:"地形增加系数批量设置"
× 错误:"地形""系数""设置"
6. 总结生成关键词解释
关键词:编制依据
描述:造价文件编制基准规范
7. 软件的特定版本号不作为关键词
三、输出格式:
{output_format}
四、输入内容:
{content}
"""
class WikijsNounsExtractor:
"""从 Wikijs 文档中提取专业名词"""
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo"):
"""
初始化专业名词提取器
Args:
api_key: API密钥,如果为None则从环境变量获取
base_url: API基础URL,如果为None则使用默认URL
model_name: 要使用的模型名称
"""
# 保存参数
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
# 初始化LLM
llm_params = {
"temperature": 0.6,
"model": model_name
}
if api_key:
llm_params["api_key"] = api_key
if base_url:
llm_params["base_url"] = base_url
self.llm = OpenAiLLM(**llm_params)
# 准备术语列表解析器
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
# 信号量,限制并发请求数量
self.semaphore = None
# 线程锁,用于保护共享资源
self.lock = threading.Lock()
def _convert_html_to_md(self, content, title):
"""HTML转Markdown"""
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
new_content = (content.replace("h6>", "h7>")
.replace("h5>", "h6>")
.replace("h4>", "h5>")
.replace("h3>", "h4>")
.replace("h2>", "h3>")
.replace("h1>", "h2>"))
# 将HTML内容转换为Markdown
markdown_content = convert_html_to_md(new_content, "", **options)
markdown_content = f"# {title}\n\n{markdown_content}"
return markdown_content
def extract_from_document(self, doc_info: dict) -> List[Term]:
"""从单个文档中提取专业名词"""
try:
# 使用LLM调用处理文档
content = doc_info['content']
title = doc_info["title"]
# 转换HTML到Markdown
markdown_content = self._convert_html_to_md(content, title)
# 准备提示词
formatted_prompt = extract_wiki_nouns_prompt.replace("{content}", markdown_content)
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
try:
# 调用LLM
response = self.llm.invoke(formatted_prompt)
# 使用Pydantic解析器解析结果
parsed_output = self.terms_list_parser.parse(response.content)
return parsed_output.terms
except Exception as e:
logging.error(f"解析LLM响应时出错: {str(e)}")
logging.error(f"原始响应: {response.content}")
return []
except Exception as e:
logging.error(f"提取专业名词时出错: {str(e)}")
return []
def _process_document(self, doc, path_terms):
"""处理单个文档"""
try:
# 获取信号量
with self.semaphore:
# 检查文档路径是否在我们要处理的路径中
path_prefix = None
for prefix in path_terms.keys():
if doc['path'].startswith(prefix):
path_prefix = prefix
break
# 如果不在要处理的路径中,则跳过
if not path_prefix:
return None
# 获取文档详细信息
doc_info = WikijsTool.query_doc_info(doc['id'])
if not doc_info or not doc_info.get('content'):
return None
# 提取专业名词
terms = self.extract_from_document(doc_info)
# 将提取的术语添加到对应路径的结果列表中
terms_dicts = [{"name": term.name, "synonymous": term.synonymous, "description": term.description} for term in terms]
with self.lock:
path_terms[path_prefix].extend(terms_dicts)
logging.info(f"文档 {doc['path']} 处理完成,提取了 {len(terms)} 个专业名词")
# 每处理10个文档保存一次中间结果
current_count = len(path_terms[path_prefix])
if current_count % 10 == 0:
# 使用锁保护文件IO
self._save_terms_to_file(path_terms[path_prefix], os.path.join(self.output_dir, f"{path_prefix.split('')[0]}_nouns.json"))
logging.info(f"已处理 {path_prefix} 的文档数达到 {current_count//10*10} 个,已保存中间结果")
return path_prefix
except Exception as e:
logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}")
return None
def process_all_documents(self, output_dir: str = "extracted_nouns", max_concurrency: int = 5):
"""使用线程池处理所有文档"""
# 保存输出目录
self.output_dir = output_dir
# 创建输出目录
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 初始化信号量,限制并发请求数
self.semaphore = Semaphore(max_concurrency)
# 获取所有文档
all_docs = WikijsTool.get_all_documents()
# 要处理的路径前缀
# path_prefixes = [
# "技改检修计价通(2020",
# "西藏造价软件(2023",
# "新型储能电站建设计价通C12024",
# "配网造价软件(2022",
# ]
path_prefixes = [
"主网电力建设计价通(2018",
]
# 为每个路径创建单独的结果列表
path_terms = {prefix: [] for prefix in path_prefixes}
# 过滤出符合路径前缀的文档
filtered_docs = []
for doc in all_docs:
for prefix in path_prefixes:
if doc['path'].startswith(prefix):
filtered_docs.append(doc)
break
logging.info(f"开始使用线程池处理 {len(filtered_docs)} 个文档...")
# 使用线程池处理所有文档
with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
futures = []
for doc in filtered_docs:
future = executor.submit(self._process_document, doc, path_terms)
futures.append(future)
# 等待所有任务完成
for i, future in enumerate(concurrent.futures.as_completed(futures)):
try:
prefix = future.result()
if i % 10 == 0:
logging.info(f"已完成 {i+1}/{len(futures)} 个文档的处理")
except Exception as e:
logging.error(f"处理文档时出错: {str(e)}")
# 保存最终结果
for prefix, terms in path_terms.items():
# 为每个路径保存单独的文件
output_file = os.path.join(output_dir, f"{prefix.split('')[0]}_nouns.json")
self._save_terms_to_file(terms, output_file)
logging.info(f"{prefix} 处理完成,共提取 {len(terms)} 个专业名词,已保存到 {output_file}")
def _save_terms_to_file(self, terms, output_file):
"""保存术语列表到文件"""
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(terms, f, ensure_ascii=False, indent=2)
def main():
# 从环境变量获取配置
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
# os.environ["LLM_MODEL_NAME"] = "Qwen/Qwen2.5-72B-Instruct-128K"
extractor = WikijsNounsExtractor(api_key=api_key, base_url=base_url, model_name=os.getenv("LLM_MODEL_NAME"))
current_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(current_dir, "..", "..", "data", "wiki_extracted_nouns")
extractor.process_all_documents(output_dir=output_dir, max_concurrency=2)
if __name__ == "__main__":
# 配置日志输出到文件,并设置格式
current_dir = os.path.dirname(os.path.abspath(__file__))
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
date_format = '%Y-%m-%d %H:%M:%S'
# 创建一个控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(log_format, date_format))
# 获取根日志记录器并添加处理器
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
root_logger.addHandler(console_handler)
main()