上传问题改写、意图识别模块代码

This commit is contained in:
2025-05-27 09:48:03 +08:00
commit 99017f0cb0
66 changed files with 111493 additions and 0 deletions
+250
View File
@@ -0,0 +1,250 @@
"""
提问内容补全工具
此模块用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
并将原提问和补全后的提问保存到新的Excel文件中。
用法示例:
completer = QuestionCompleter(input_path="历史提问数据(dislike).xlsx", output_path="补全后的提问数据.xlsx")
completer.process()
"""
import pandas as pd
from tqdm import tqdm
import os
from dotenv import load_dotenv
from rag2_0.tool.ModelTool import OpenAiLLM
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
import concurrent.futures
from threading import Lock
class RewriteQuery(BaseModel):
rewrite_query:str = Field(description="补全后的提问")
software_name:str = Field(description="软件名称")
# 加载环境变量
load_dotenv()
class QuestionCompleter:
"""
提问内容补全工具类
用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
并将原提问和补全后的提问保存到新的Excel文件中。
"""
def __init__(self, input_path="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
output_path="/data/Rag2_0/data/excel/历史提问数据(dislike)_补全后的提问数据.xlsx",
question_column="提问", answer_column="回答", max_workers=10):
"""
初始化提问内容补全工具
参数:
input_path (str): 输入Excel文件路径
output_path (str): 输出Excel文件路径
question_column (str): 提问列的名称
answer_column (str): 回答列的名称
max_workers (int): 最大线程数
"""
self.input_path = input_path
self.output_path = output_path
self.question_column = question_column
self.answer_column = answer_column
self.max_workers = max_workers
self.rewrite_query_parser = PydanticOutputParser(pydantic_object=RewriteQuery)
self.lock = Lock() # 添加线程锁
# 初始化LLM
self.api_key = os.getenv("OPENAI_API_KEY")
self.base_url = os.getenv("OPENAI_API_BASE")
self.model = os.getenv("LLM_MODEL_NAME")
if not all([self.api_key, self.base_url, self.model]):
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model)
# 读取Excel文件
try:
self.df = pd.read_excel(self.input_path)
print(f"成功读取Excel文件: {self.input_path}")
print(f"共有 {len(self.df)} 条记录")
except Exception as e:
raise RuntimeError(f"读取Excel文件失败: {str(e)}")
# 检查列是否存在
if self.question_column not in self.df.columns:
raise ValueError(f"Excel文件中不存在列: {self.question_column}")
if self.answer_column not in self.df.columns:
raise ValueError(f"Excel文件中不存在列: {self.answer_column}")
def create_completion_prompt(self, question, answer):
"""
创建用于补全提问的prompt
参数:
question (str): 原始提问
answer (str): 对应的回答
返回:
str: 格式化的prompt
"""
prompt = f"""
1、判断提问中是否缺少软件名称,如果不缺少,则直接返回原始提问
2、如果缺少软件名称,则根据回答中的软件名称,补全提问
3、补全后的提问需要保持问题原有意图不变
4、软件名称包括:
配网D3软件(配网工程计价通D3)
西藏Z1软件(西藏电力工程计价通Z1)
主网计价通软件(电力建设计价通)
技改检修工程计价通T1软件(技改检修工程计价通T1)
技改检修清单计价通T1软件(技改检修清单计价通T1)
储能C1软件(新型储能电站建设计价通C1)
如果没有包含上述软件名称,则直接返回原始提问,software_name为空字符串
{{
"rewrite_query": "xxx",
"software_name": ""
}}
原始提问:{question}
系统回答:{answer}
输出格式:
{self.rewrite_query_parser.get_format_instructions()}
示例:
例如,如果输入是:
提问:这个软件怎么用?
回答:Photoshop的使用方法是...
那么输出会是:
{{
"rewrite_query": "Photoshop这个软件怎么用?",
"software_name": "Photoshop"
}}
或者如果提问已经包含软件名称:
提问:Photoshop怎么用?
回答:Photoshop的使用方法是...
那么输出会是:
{{
"rewrite_query": "Photoshop怎么用?",
"software_name": "Photoshop"
}}
"""
return prompt
def complete_question(self, question, answer):
"""
调用LLM补全提问内容
参数:
question (str): 原始提问
answer (str): 对应的回答
返回:
str: 补全后的提问,如果补全失败则返回原始提问
"""
# 如果提问或回答为空,直接返回原始提问
if pd.isna(question) or question.strip() == "" or pd.isna(answer) or answer.strip() == "":
return question, ""
try:
prompt = self.create_completion_prompt(question, answer)
response = self.llm.invoke(prompt)
completed_question = self.rewrite_query_parser.parse(response.content)
return completed_question.rewrite_query, completed_question.software_name
except Exception as e:
print(f"补全提问失败: {str(e)}")
return question, ""
def process_row(self, row):
"""
处理单行数据
参数:
row: DataFrame中的一行
返回:
dict: 处理结果
"""
original_question = row[self.question_column]
answer = row[self.answer_column]
# 调用LLM补全提问
completed_question, software_name = self.complete_question(original_question, answer)
# 创建结果字典
result = {
"原始提问": original_question,
"补全后的提问": completed_question,
"软件名称": software_name
}
return result
def process(self):
"""
使用多线程处理所有提问并补全内容
读取Excel文件中的提问和回答,调用LLM补全提问内容,
并将原提问和补全后的提问保存到新的Excel文件中
"""
results = []
total = len(self.df)
# 使用进度条显示总体进度
with tqdm(total=total, desc="补全提问") as pbar:
# 创建线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有任务
future_to_idx = {executor.submit(self.process_row, self.df.iloc[idx]): idx for idx in range(total)}
# 处理完成的任务
for future in concurrent.futures.as_completed(future_to_idx):
result = future.result()
with self.lock:
results.append(result)
pbar.update(1)
# 将结果转换为DataFrame并保存
results_df = pd.DataFrame(results)
results_df.to_excel(self.output_path, index=False)
print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='补全Excel文件中的提问内容')
parser.add_argument('-i', '--input', type=str, default="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
help='输入Excel文件路径')
parser.add_argument('-o', '--output', type=str, default="/data/Rag2_0/data/excel/补全后的提问数据.xlsx",
help='输出Excel文件路径')
parser.add_argument('-q', '--question', type=str, default="提问",
help='提问列的名称')
parser.add_argument('-a', '--answer', type=str, default="回答",
help='回答列的名称')
parser.add_argument('-w', '--workers', type=int, default=50,
help='最大线程数')
args = parser.parse_args()
# 创建提问补全工具实例
completer = QuestionCompleter(
input_path=args.input,
output_path=args.output,
question_column=args.question,
answer_column=args.answer,
max_workers=args.workers
)
# 执行处理
completer.process()
if __name__ == "__main__":
main()
+282
View File
@@ -0,0 +1,282 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: extract_wikijs_nouns.py
Author: oyyz
Description: 从 Wikijs 文档中提取专业名词
"""
import os
from typing import List
from dotenv import load_dotenv
from langchain.output_parsers import PydanticOutputParser
from rag2_0.tool.WikijsTool import WikijsTool
from rag2_0.intent_recognition.DataModels import Term, TermList
from rag2_0.tool.html_to_md import convert_html_to_md
from rag2_0.tool.ModelTool import OpenAiLLM
import json
import datetime
import logging
import threading
import concurrent.futures
from threading import Semaphore
# 加载环境变量
load_dotenv()
extract_wiki_nouns_prompt="""
我在完善我的专业词库,请从提供的电力行业造价软件相关文本中提取关键词,要求如下:
一、提取范围
1. 核心功能模块
(例:多工程批量计价、材机数据反算、变电工程智能组价、架空线路地形系数计算)
2、软件功能及界面名称(包括:界面页签、功能按钮、功能名称等)
(例:新建工程量清单、导出工程量清单等)
3. 业务专用术语
(例:装置性材料、甲供材保管费、施工降效补偿、电缆头试验配套费)
4. 计价标准体系
(例:预规2020版、电网检修定额2015版、配网工程概算定额)
二、提取规则
1. 识别核心功能名称(如"多工程批量设置工程量、工程设置密码"
2. 提取业务专用名词(如"主材卸车保管费"
3. 标注关联术语的对应关系(如"市场价""市场价格"互为同义词)
4. 包含定额标准相关术语(如"预规2020版"
5. 复合型术语需保持完整
√ 正确:"地形增加系数批量设置"
× 错误:"地形""系数""设置"
6. 总结生成关键词解释
关键词:编制依据
描述:造价文件编制基准规范
7. 软件的特定版本号不作为关键词
三、输出格式:
{output_format}
四、输入内容:
{content}
"""
class WikijsNounsExtractor:
"""从 Wikijs 文档中提取专业名词"""
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo"):
"""
初始化专业名词提取器
Args:
api_key: API密钥,如果为None则从环境变量获取
base_url: API基础URL,如果为None则使用默认URL
model_name: 要使用的模型名称
"""
# 保存参数
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
# 初始化LLM
llm_params = {
"temperature": 0.6,
"model": model_name
}
if api_key:
llm_params["api_key"] = api_key
if base_url:
llm_params["base_url"] = base_url
self.llm = OpenAiLLM(**llm_params)
# 准备术语列表解析器
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
# 信号量,限制并发请求数量
self.semaphore = None
# 线程锁,用于保护共享资源
self.lock = threading.Lock()
def _convert_html_to_md(self, content, title):
"""HTML转Markdown"""
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
new_content = (content.replace("h6>", "h7>")
.replace("h5>", "h6>")
.replace("h4>", "h5>")
.replace("h3>", "h4>")
.replace("h2>", "h3>")
.replace("h1>", "h2>"))
# 将HTML内容转换为Markdown
markdown_content = convert_html_to_md(new_content, "", **options)
markdown_content = f"# {title}\n\n{markdown_content}"
return markdown_content
def extract_from_document(self, doc_info: dict) -> List[Term]:
"""从单个文档中提取专业名词"""
try:
# 使用LLM调用处理文档
content = doc_info['content']
title = doc_info["title"]
# 转换HTML到Markdown
markdown_content = self._convert_html_to_md(content, title)
# 准备提示词
formatted_prompt = extract_wiki_nouns_prompt.replace("{content}", markdown_content)
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
try:
# 调用LLM
response = self.llm.invoke(formatted_prompt)
# 使用Pydantic解析器解析结果
parsed_output = self.terms_list_parser.parse(response.content)
return parsed_output.terms
except Exception as e:
logging.error(f"解析LLM响应时出错: {str(e)}")
logging.error(f"原始响应: {response.content}")
return []
except Exception as e:
logging.error(f"提取专业名词时出错: {str(e)}")
return []
def _process_document(self, doc, path_terms):
"""处理单个文档"""
try:
# 获取信号量
with self.semaphore:
# 检查文档路径是否在我们要处理的路径中
path_prefix = None
for prefix in path_terms.keys():
if doc['path'].startswith(prefix):
path_prefix = prefix
break
# 如果不在要处理的路径中,则跳过
if not path_prefix:
return None
# 获取文档详细信息
doc_info = WikijsTool.query_doc_info(doc['id'])
if not doc_info or not doc_info.get('content'):
return None
# 提取专业名词
terms = self.extract_from_document(doc_info)
# 将提取的术语添加到对应路径的结果列表中
terms_dicts = [{"name": term.name, "synonymous": term.synonymous, "description": term.description} for term in terms]
with self.lock:
path_terms[path_prefix].extend(terms_dicts)
logging.info(f"文档 {doc['path']} 处理完成,提取了 {len(terms)} 个专业名词")
# 每处理10个文档保存一次中间结果
current_count = len(path_terms[path_prefix])
if current_count % 10 == 0:
# 使用锁保护文件IO
self._save_terms_to_file(path_terms[path_prefix], os.path.join(self.output_dir, f"{path_prefix.split('')[0]}_nouns.json"))
logging.info(f"已处理 {path_prefix} 的文档数达到 {current_count//10*10} 个,已保存中间结果")
return path_prefix
except Exception as e:
logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}")
return None
def process_all_documents(self, output_dir: str = "extracted_nouns", max_concurrency: int = 5):
"""使用线程池处理所有文档"""
# 保存输出目录
self.output_dir = output_dir
# 创建输出目录
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 初始化信号量,限制并发请求数
self.semaphore = Semaphore(max_concurrency)
# 获取所有文档
all_docs = WikijsTool.get_all_documents()
# 要处理的路径前缀
# path_prefixes = [
# "技改检修计价通(2020",
# "西藏造价软件(2023",
# "新型储能电站建设计价通C12024",
# "配网造价软件(2022",
# ]
path_prefixes = [
"主网电力建设计价通(2018",
]
# 为每个路径创建单独的结果列表
path_terms = {prefix: [] for prefix in path_prefixes}
# 过滤出符合路径前缀的文档
filtered_docs = []
for doc in all_docs:
for prefix in path_prefixes:
if doc['path'].startswith(prefix):
filtered_docs.append(doc)
break
logging.info(f"开始使用线程池处理 {len(filtered_docs)} 个文档...")
# 使用线程池处理所有文档
with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
futures = []
for doc in filtered_docs:
future = executor.submit(self._process_document, doc, path_terms)
futures.append(future)
# 等待所有任务完成
for i, future in enumerate(concurrent.futures.as_completed(futures)):
try:
prefix = future.result()
if i % 10 == 0:
logging.info(f"已完成 {i+1}/{len(futures)} 个文档的处理")
except Exception as e:
logging.error(f"处理文档时出错: {str(e)}")
# 保存最终结果
for prefix, terms in path_terms.items():
# 为每个路径保存单独的文件
output_file = os.path.join(output_dir, f"{prefix.split('')[0]}_nouns.json")
self._save_terms_to_file(terms, output_file)
logging.info(f"{prefix} 处理完成,共提取 {len(terms)} 个专业名词,已保存到 {output_file}")
def _save_terms_to_file(self, terms, output_file):
"""保存术语列表到文件"""
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(terms, f, ensure_ascii=False, indent=2)
def main():
# 从环境变量获取配置
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
# os.environ["LLM_MODEL_NAME"] = "Qwen/Qwen2.5-72B-Instruct-128K"
extractor = WikijsNounsExtractor(api_key=api_key, base_url=base_url, model_name=os.getenv("LLM_MODEL_NAME"))
current_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(current_dir, "..", "..", "data", "wiki_extracted_nouns")
extractor.process_all_documents(output_dir=output_dir, max_concurrency=2)
if __name__ == "__main__":
# 配置日志输出到文件,并设置格式
current_dir = os.path.dirname(os.path.abspath(__file__))
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
date_format = '%Y-%m-%d %H:%M:%S'
# 创建一个控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(log_format, date_format))
# 获取根日志记录器并添加处理器
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
root_logger.addHandler(console_handler)
main()
+189
View File
@@ -0,0 +1,189 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: intent_recognition_example.py
Date: 2025-05-14
Description: 意图识别和问题改写示例
"""
import os
from dotenv import load_dotenv
from rag2_0.intent_recognition import IntentRecognizer
import pandas as pd
import logging
import json
import concurrent.futures
from tqdm import tqdm
import time
# 加载环境变量
load_dotenv()
# 读取Excel文件中的提问数据
def load_questions_from_excel(file_path=None):
"""
从Excel文件中读取提问数据
Args:
file_path: Excel文件路径,如果为None则使用默认路径
Returns:
提问列表
"""
try:
# 读取Excel文件的第一列数据
df = pd.read_excel(file_path)
questions = df.iloc[:, 0].tolist() # 获取第一列数据
logging.info(f"成功从{file_path}读取了{len(questions)}条提问")
return questions
except Exception as e:
logging.error(f"读取Excel文件时出错: {e}")
return []
def process_query(recognizer, query):
"""
处理单个查询,支持重试机制
Args:
recognizer: 意图识别器实例
query: 查询字符串
Returns:
处理结果字典
"""
max_retries = 3
retry_count = 0
while retry_count <= max_retries:
try:
# 如果是重试,添加重试信息到日志
classification, keywords, rewrite, query_keys = recognizer.process_query(query)
# 将keywords对象转换为字符串
keywords_str = ""
if keywords and keywords.terms:
term_details = []
for term in keywords.terms:
term_info = {
"名称": term.name,
"同义词": ";".join(term.synonymous) if term.synonymous else "",
"描述": term.description
}
term_details.append(term_info)
# 将term_details转换为JSON字符串,确保中文正确显示
keywords_str = json.dumps(term_details, ensure_ascii=False, indent=2)
# 处理成功,返回结果
return {
"提问": query,
"问题拆解": query_keys,
"一级分类": classification.vertical_classification,
"二级分类": classification.sub_classification,
"问题改写": rewrite.rewrite,
"检索的关键词": keywords_str
}
except Exception as e:
retry_count += 1
# 如果已经重试了最大次数,则记录错误并返回错误结果
if retry_count > max_retries:
logging.error(f"处理问题 '{query}' 时出错: {e.__class__}{e}")
return {
"提问": query,
"一级分类": "处理出错",
"二级分类": "处理出错",
"问题改写": "处理出错",
"检索的关键词": f"重试 {max_retries} 次后失败: {str(e)}"
}
else:
# 可以在这里添加延迟,避免过快重试
time.sleep(10 * retry_count)
examples_query = """下载软件在哪下载?"""
def main():
"""
意图识别和问题改写示例
"""
# 从环境变量中获取配置
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
# 初始化意图识别器
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
# 读取提问数据
current_dir = os.path.dirname(os.path.abspath(__file__))
data_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条提问数据.xlsx")
examples = load_questions_from_excel(data_file)
# examples = examples_query.split("\n")
max_workers = 20
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
# 创建一个与输入顺序相同的结果列表
results = [None] * len(examples)
# 使用线程池进行并发处理
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务并记录它们的索引
future_to_index = {}
for idx, query in enumerate(examples):
future = executor.submit(process_query, recognizer, query)
future_to_index[future] = idx
# 使用tqdm显示进度条
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(examples), desc="处理进度"):
idx = future_to_index[future]
result = future.result()
# 将结果放在与输入相同的位置
results[idx] = result
# 将结果保存到Excel文件
results_df = pd.DataFrame(results)
output_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条提问数据_重写结果.xlsx")
# 使用ExcelWriter设置格式
with pd.ExcelWriter(output_file, engine='xlsxwriter') as writer:
results_df.to_excel(writer, index=False, sheet_name='Sheet1')
# 获取工作簿和工作表对象
workbook = writer.book
worksheet = writer.sheets['Sheet1']
# 设置列宽(单位:像素)
# 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位)
worksheet.set_column('A:A', 60) # 提问列 60个Excel单位
worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位
worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位
worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
# 设置所有行高为20磅
for i in range(len(results_df) + 1): # +1 是为了包括表头
worksheet.set_row(i, 20)
logging.info(f"处理完成,结果已保存至: {output_file}")
def setup_logging():
# 配置日志输出到控制台
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler() # 添加控制台处理器
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
if __name__ == "__main__":
setup_logging()
logging.info("意图识别示例程序开始运行...")
main()
+293
View File
@@ -0,0 +1,293 @@
"""
答案正确性评判工具
此模块用于评判问题的新旧回答是否正确,通过与标准答案(Wiki内容)进行比较,
或者在没有标准答案的情况下比较新旧回答的差异。
用法示例:
judge = AnswerCorrectnessJudge()
judge.process()
"""
import pandas as pd
from urllib.parse import unquote
from rag2_0.tool.WikijsTool import WikijsTool
from rag2_0.tool.html_to_md import convert_html_to_md
from rag2_0.tool.ModelTool import OpenAiLLM
from dotenv import load_dotenv
import os
from tqdm import tqdm
load_dotenv()
class AnswerCorrectnessJudge:
"""
答案正确性评判工具类
用于评估问题的新旧回答是否正确,可以通过与标准答案(Wiki内容)进行比较,
或者在没有标准答案的情况下比较新旧回答的差异。
"""
def __init__(self, wiki_excel_path="/data/Rag2_0/data/excel/部分提问_软件名称明确.xlsx",
answer_excel_path="/data/Rag2_0/data/excel/主网软件提问_对比结果.xlsx",
output_path="/data/Rag2_0/data/excel/主网软件提问回答_判断结果.xlsx"):
"""
初始化答案正确性评判工具
参数:
wiki_excel_path (str): Wiki Excel文件路径
answer_excel_path (str): 答案对比Excel文件路径
output_path (str): 输出Excel文件路径
"""
self.wiki_excel_path = wiki_excel_path
self.answer_excel_path = answer_excel_path
self.output_path = output_path
# 读取Excel文件
self.wiki_excel = pd.read_excel(self.wiki_excel_path)
self.answer_excel = pd.read_excel(self.answer_excel_path)
# 初始化LLM
self.api_key = os.getenv("OPENAI_API_KEY")
self.base_url = os.getenv("OPENAI_API_BASE")
self.model = os.getenv("LLM_MODEL_NAME")
if not all([self.api_key, self.base_url, self.model]):
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
self.openai_llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model)
def find_wiki_link(self, query) -> str | None:
"""
根据查询(对应wiki_excel中的新提问列)找出对应的词条链接
参数:
query (str): 查询内容,对应wiki_excel中的新提问列
返回:
str: 对应的词条链接,如果没有找到则返回None
"""
# 确保query不为空
if not query or pd.isna(query):
return None
# 在"新提问"列中查找匹配的行
matched_rows = self.wiki_excel[self.wiki_excel['新提问'] == query]
# 如果找到了匹配的行,返回对应的词条链接
if not matched_rows.empty:
return matched_rows.iloc[0]['对应词条链接']
# 如果没有完全匹配,尝试部分匹配
# 去除软件名称部分(如果有)
query_parts = query.split(',', 1)
if len(query_parts) > 1:
clean_query = query_parts[1].strip()
# 在"提问"列中查找包含清理后查询的行
for idx, row in self.wiki_excel.iterrows():
if pd.notna(row['提问']) and clean_query in row['提问']:
return row['对应词条链接']
return None
def get_wiki_content(self, link) -> str:
"""
获取词条链接的内容
参数:
link (str): 词条链接
返回:
str: 链接内容,如果获取失败则返回错误信息
"""
try:
if not link or pd.isna(link):
return "链接为空或无效"
# 移除域名部分,只保留路径
path = link.split('/', 3)[-1]
decoded_path = unquote(path)
path_parts = decoded_path.split('/')
doc_path = "/".join(path_parts[1:])
wiki_doc = WikijsTool.get_all_doc_by_path(path=doc_path, path_is_dir=False)
html_content = WikijsTool.query_doc_info(wiki_doc[0]["id"]).get('content')
if not html_content:
return "获取内容失败"
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
new_content = (html_content.replace("h6>", "h7>")
.replace("h5>", "h6>")
.replace("h4>", "h5>")
.replace("h3>", "h4>")
.replace("h2>", "h3>")
.replace("h1>", "h2>"))
# 将HTML内容转换为Markdown
markdown_content = convert_html_to_md(new_content, "", **options)
markdown_content = f"# {path_parts[-1]}\n\n{markdown_content}"
return markdown_content
except Exception as e:
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
def create_prompt(self, standard_answer: str, answer_to_check: str) -> str:
"""
创建用于评判答案的prompt
参数:
standard_answer (str): 标准答案
answer_to_check (str): 需要检查的答案
返回:
str: 格式化的prompt
"""
return f"""请作为一个专业的答案评判专家,评估以下回答与标准答案的匹配程度。
标准答案:
{standard_answer}
待评估的回答:
{answer_to_check}
请仔细分析两个答案的内容,并给出你的判断。只需要回答"正确""错误",不需要其他解释。
如果待评估的回答与标准答案在核心内容和关键信息(步骤)上一致,即使表达方式不同,也应判定为"正确"
如果待评估的回答存在明显的错误信息或重要信息缺失,应判定为"错误"
请严格按以下格式输出:【正确】或【错误】:"""
def judge_old_answer(self, standard_answer: str, old_answer: str) -> bool | None:
"""
调用LLM判断旧回答是否正确
参数:
standard_answer (str): 标准答案(来自Wiki
old_answer (str): 旧流程的回答
返回:
bool | None: 判断结果,True表示正确,False表示错误,None表示判断失败
"""
prompt = self.create_prompt(standard_answer, old_answer)
try:
response = self.openai_llm.invoke(prompt)
return "正确" in response.content
except Exception as e:
return None
def judge_new_answer(self, standard_answer: str, new_answer: str) -> bool | None:
"""
调用LLM判断新回答是否正确
参数:
standard_answer (str): 标准答案(来自Wiki
new_answer (str): 新流程的回答
返回:
bool | None: 判断结果,True表示正确,False表示错误,None表示判断失败
"""
prompt = self.create_prompt(standard_answer, new_answer)
try:
response = self.openai_llm.invoke(prompt)
return "正确" in response.content
except Exception as e:
return None
def judge_by_standard_answer(self, standard_answer: str, old_answer: str, new_answer: str) -> str | None:
"""
综合判断新旧回答的正确性
参数:
standard_answer (str): 标准答案(来自Wiki
old_answer (str): 旧流程的回答
new_answer (str): 新流程的回答
返回:
str | None: 包含新旧回答判断结果的字符串,None表示判断失败
"""
old_result = self.judge_old_answer(standard_answer, old_answer)
new_result = self.judge_new_answer(standard_answer, new_answer)
if old_result is None or new_result is None:
return None
if new_result and old_result:
return "新旧答案均正确"
elif new_result and not old_result:
return "新答案正确"
elif not new_result and old_result:
return "旧答案正确"
else:
return "新旧答案均错误"
def judge_answer_diff(self, old_answer: str, new_answer: str) -> str | None:
"""
判断新旧回答是否存在较大差异
参数:
old_answer (str): 旧流程的回答
new_answer (str): 新流程的回答
返回:
str | None: 差异判断结果,None表示判断失败
"""
prompt = f"""请判断以下两个回答是否存在较大差异:
旧回答: {old_answer}
新回答: {new_answer}
主要是关键步骤、关键信息、或者关键主体的差异
请仅回答"存在较大差异""差异较小""""
try:
response = self.openai_llm.invoke(prompt)
return "无法判断,新老答案差异较大" if "存在较大差异" in response.content else "无法判断,新老答案基本相同"
except Exception as e:
return None
def process(self):
"""
处理所有问题并评判答案正确性
读取Excel文件中的问题和答案,进行评判,并将结果保存到输出Excel文件
"""
# 创建结果列表
results = []
# 读取Excel文件
for idx, row in tqdm(self.answer_excel.iterrows(), total=len(self.answer_excel), desc="处理问题"):
query = row["问题"]
old_answer = row["旧流程答案"]
new_answer = row["新流程答案"]
standard_answer = ""
try:
wiki_url = self.find_wiki_link(query)
if wiki_url and not pd.isna(wiki_url):
standard_answer = self.get_wiki_content(wiki_url)
except Exception as e:
print(f"处理问题 '{query}' 时发生错误: {str(e)}")
if standard_answer:
# 判断答案正确性
judge_result = self.judge_by_standard_answer(standard_answer, old_answer, new_answer)
else:
judge_result = self.judge_answer_diff(old_answer, new_answer)
if judge_result is None:
judge_result = ""
results.append({
"问题": query,
"旧流程答案": old_answer,
"新流程答案": new_answer,
"判断结果": judge_result
})
# 将结果转换为DataFrame并保存
results_df = pd.DataFrame(results)
results_df.to_excel(self.output_path, index=False)
print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}")
# 测试函数
if __name__ == "__main__":
# 创建答案正确性评判工具实例
judge = AnswerCorrectnessJudge()
# 执行处理
judge.process()
+615
View File
@@ -0,0 +1,615 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
完整性问题判断工具
此脚本用于读取Excel文件中的问题,调用LLM判断问题是否完整,并将结果保存到Excel文件中。
用法示例:
python judge_query_full.py -i "问题数据.xlsx" -o "完整问题结果.xlsx" -w 50 -c 0
命令行参数:
-i, --input: 输入Excel文件路径
-o, --output: 输出Excel文件路径
-w, --workers: 并发处理的最大线程数
-c, --column: 要处理的问题所在列的索引(从0开始)
-t, --test: 测试单个问题,不处理Excel文件
"""
import pandas as pd
import json
import os
import time
import re
import argparse
from pathlib import Path
from rag2_0.tool.ModelTool import OpenAiLLM
from rag2_0.tool.APIKeyManager import APIKeyManager
from openpyxl.utils import get_column_letter
from openpyxl.styles import Alignment, PatternFill, Font, Border, Side
from tqdm import tqdm
import concurrent.futures
import threading
# 默认设置
DEFAULT_EXCEL_PATH = r"/data/Rag2_0/data/excel/7000条对话数据.xlsx"
DEFAULT_OUTPUT_PATH = r"/data/Rag2_0/data/excel/7000条对话数据_完整问题结果.xlsx"
DEFAULT_MAX_WORKERS = 50
class QueryCompletenessJudge:
"""
问题完整性判断工具类
用于评估问题是否完整,并将结果保存到Excel文件中。
可以批量处理Excel文件中的问题,也可以测试单个问题。
"""
def __init__(self, input_path=DEFAULT_EXCEL_PATH, output_path=DEFAULT_OUTPUT_PATH,
max_workers=DEFAULT_MAX_WORKERS, column_index=0):
"""
初始化问题完整性判断工具
参数:
input_path (str): 输入Excel文件路径
output_path (str): 输出Excel文件路径
max_workers (int): 并发处理的最大线程数
column_index (int): 要处理的问题所在列的索引(从0开始)
"""
self.input_path = input_path
self.output_path = output_path
self.max_workers = max_workers
self.column_index = column_index
self.llm_client = self._create_llm_client()
def _extract_json_from_response(self, full_answer):
"""
从LLM响应中提取JSON部分
参数:
full_answer (str): LLM的完整响应文本
返回:
dict: 解析后的JSON对象,如果解析失败则返回None
"""
# 尝试从回答中提取JSON部分
json_match = re.search(r'```json\s*(.*?)\s*```', full_answer, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
# 如果没有找到```json```格式,尝试寻找普通的JSON对象
json_match = re.search(r'({[\s\S]*"is_complete"[\s\S]*})', full_answer)
if json_match:
json_str = json_match.group(1)
else:
# 如果仍然没有找到,返回None
return None
try:
# 解析JSON
return json.loads(json_str)
except json.JSONDecodeError:
return None
def _create_llm_prompt(self, question):
"""
创建LLM提示词
参数:
question (str): 需要判断完整性的问题
返回:
str: 格式化后的提示词
"""
return f"""你是一个电力造价行业专家,用户正在使用电力造价软件,并提出了相关问题。请分析以下问题是否完整。
问题:{question}
首先,分析这个问题的结构和内容,思考它是否包含足够的信息来表达清晰的意图。
考虑以下几点:
1. 问题是否有明确的核心意图,不需要面面俱到
2. 问题是否缺少必要的上下文
3. **问题如果涉及软件相关,则只需要包含:软件名称、软件功能或软件目的即可**
在你的分析之后,请用JSON格式给出最终结论,格式如下:
```json
{{
"is_complete": true或false,
"reason": "判断原因的简要说明",
"confidence": 0到100之间的数值,表示你对判断的置信度
}}
```
请确保JSON格式正确,以便于程序解析。"""
def _create_llm_client(self, api_key=None):
"""
创建LLM客户端
参数:
api_key (str, optional): API密钥,如果为None则从APIKeyManager获取
返回:
OpenAiLLM: LLM客户端实例
"""
if api_key is None:
api_key = APIKeyManager.get_api_key()
return OpenAiLLM(
api_key=api_key,
base_url="https://api.siliconflow.cn/v1", # 可以根据实际情况修改
model="deepseek-ai/DeepSeek-V3", # 可以根据实际情况修改
temperature=0.2,
max_tokens=100
)
def is_question_complete(self, question):
"""
调用LLM判断问题是否完整
参数:
question (str): 需要判断的问题
返回:
tuple: (bool, str) - 是否完整的布尔值和LLM的详细回复
"""
# 最大重试次数
max_retries = 3
retry_count = 0
retry_delay = 2 # 重试延迟,单位:秒
while retry_count <= max_retries:
try:
# 创建提示词
prompt = self._create_llm_prompt(question)
# 使用OpenAiLLM调用模型
response = self.llm_client.invoke(prompt)
# 处理可能的响应格式
if hasattr(response, 'content'):
full_answer = response.content
else:
# 如果response是字符串
full_answer = str(response)
# 提取JSON部分
result = self._extract_json_from_response(full_answer)
if result:
is_complete = result.get("is_complete", False)
return is_complete, full_answer
else:
# 如果没有找到或解析失败,使用简单判断
is_complete = "完整" in full_answer[:100]
return is_complete, full_answer
except Exception as e:
retry_count += 1
if retry_count <= max_retries:
# 非最后一次重试,打印错误并继续
time.sleep(retry_delay)
# 每次重试增加延迟时间,避免频繁失败
retry_delay *= 2
else:
# 已达到最大重试次数,返回错误
print(f"错误: 经过 {max_retries} 次重试后仍然失败: {str(e)}")
return False, f"错误: 经过 {max_retries} 次重试后仍然失败: {str(e)}"
# 不应该到达这里,但为了代码完整性添加
return False, "未知错误:重试机制逻辑错误"
def _process_question(self, args, complete_questions, progress_counter, progress_lock, complete_questions_lock, pbar):
"""
处理单个问题并更新进度
参数:
args (tuple): 包含问题索引、问题内容、LLM客户端和总问题数的元组
complete_questions (list): 存储完整问题的列表
progress_counter (dict): 进度计数器
progress_lock (threading.Lock): 进度锁
complete_questions_lock (threading.Lock): 完整问题列表锁
pbar (tqdm): 进度条对象
"""
index, question, llm_client, total_questions = args
# 跳过空问题
if pd.isna(question) or question.strip() == "":
with progress_lock:
progress_counter["processed"] += 1
pbar.update(1)
return None
# 调用LLM判断问题是否完整
is_complete, full_answer = self.is_question_complete(question)
if is_complete:
# 从答案中提取JSON
parsed_json = self._extract_json_from_response(full_answer)
if parsed_json:
# 构造包含解析出的JSON信息的结果
result = {
"问题": question,
"LLM回复": full_answer,
"完整性": "完整" if parsed_json.get("is_complete", False) else "不完整",
"原因": parsed_json.get("reason", "未提供"),
"置信度": parsed_json.get("confidence", 0)
}
# 更新计数
with progress_lock:
if result["完整性"] == "完整":
progress_counter["complete"] += 1
else:
progress_counter["incomplete"] += 1
else:
# JSON解析失败,只保存原始回答
result = {
"问题": question,
"LLM回复": full_answer,
"完整性": "完整"
}
# 更新计数
with progress_lock:
progress_counter["complete"] += 1
with complete_questions_lock:
complete_questions.append(result)
else:
with progress_lock:
progress_counter["incomplete"] += 1
# 更新进度条
with progress_lock:
progress_counter["processed"] += 1
# 更新进度条描述
pbar.set_postfix(
完整=progress_counter["complete"],
不完整=progress_counter["incomplete"],
完整率=f"{progress_counter['complete']/max(1, progress_counter['processed']):.1%}"
)
pbar.update(1)
def _shorten_response(self, response):
"""
截断LLM响应,提取重要信息
参数:
response (str): 原始LLM响应
返回:
str: 截断后的响应
"""
# 保留思考过程的前200个字符和JSON部分
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
if json_match:
json_part = json_match.group(0)
prefix = response[:200] + "..." if len(response) > 200 else response
return f"{prefix}\n\n{json_part}"
return response[:500] + "..." if len(response) > 500 else response
def _prepare_excel_dataframe(self, complete_questions):
"""
将结果处理为DataFrame用于Excel输出
参数:
complete_questions (list): 完整问题列表
返回:
pandas.DataFrame: 处理后的DataFrame
"""
# 将结果列表转换为DataFrame
result_df = pd.DataFrame(complete_questions)
# 处理LLM回复列,截取一定长度以避免Excel单元格过大
if "LLM回复" in result_df.columns:
result_df["LLM回复"] = result_df["LLM回复"].apply(self._shorten_response)
# 调整列的顺序,确保重要列在前面
column_order = ["问题", "完整性", "置信度", "原因", "LLM回复"]
# 过滤掉不存在的列
column_order = [col for col in column_order if col in result_df.columns]
# 确保所有剩余的列也被包含
for col in result_df.columns:
if col not in column_order:
column_order.append(col)
# 重新排序列
return result_df[column_order]
def _set_excel_column_widths(self, worksheet):
"""
设置Excel列宽
参数:
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
"""
for col in range(1, worksheet.max_column + 1):
col_letter = get_column_letter(col)
column_name = worksheet[f"{col_letter}1"].value
if column_name == "问题":
worksheet.column_dimensions[col_letter].width = 40
elif column_name == "LLM回复":
worksheet.column_dimensions[col_letter].width = 60
elif column_name == "原因":
worksheet.column_dimensions[col_letter].width = 30
elif column_name == "完整性":
worksheet.column_dimensions[col_letter].width = 10
elif column_name == "置信度":
worksheet.column_dimensions[col_letter].width = 10
else:
worksheet.column_dimensions[col_letter].width = 15
def _apply_excel_cell_styles(self, worksheet):
"""
应用单元格样式
参数:
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
返回:
openpyxl.styles.Border: 边框样式,用于统计信息
"""
# 定义样式
header_fill = PatternFill(start_color="DDEBF7", end_color="DDEBF7", fill_type="solid")
header_font = Font(bold=True)
wrap_alignment = Alignment(wrap_text=True, vertical="top")
border = Border(
left=Side(style='thin'),
right=Side(style='thin'),
top=Side(style='thin'),
bottom=Side(style='thin')
)
# 应用样式到每个单元格
for row in worksheet.iter_rows(min_row=1, max_row=worksheet.max_row, min_col=1, max_col=worksheet.max_column):
for cell in row:
cell.alignment = wrap_alignment
cell.border = border
# 为标题行应用特殊样式
if cell.row == 1:
cell.fill = header_fill
cell.font = header_font
# 为完整性列应用条件格式
if cell.row > 1: # 跳过标题行
column_name = worksheet.cell(row=1, column=cell.column).value
if column_name == "完整性":
if cell.value == "完整":
cell.fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
else:
cell.fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
return border # 返回边框样式以便在统计信息中重用
def _add_statistics_to_excel(self, worksheet, complete_questions, total_rows, total_questions, border):
"""
添加统计信息到Excel表格
参数:
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
complete_questions (list): 完整问题列表
total_rows (int): 总行数
total_questions (int): 总问题数
border (openpyxl.styles.Border): 边框样式
返回:
int: 完整问题数量
"""
# 计算统计数据
complete_count = sum(1 for item in complete_questions if item.get("完整性") == "完整")
incomplete_count = total_rows - complete_count
# 添加统计行
worksheet.append([""]) # 空行
stat_row = worksheet.max_row + 1
worksheet.cell(row=stat_row, column=1, value="统计信息")
worksheet.cell(row=stat_row, column=1).font = Font(bold=True)
worksheet.cell(row=stat_row+1, column=1, value="总问题数")
worksheet.cell(row=stat_row+1, column=2, value=total_rows)
worksheet.cell(row=stat_row+2, column=1, value="完整问题数")
worksheet.cell(row=stat_row+2, column=2, value=complete_count)
worksheet.cell(row=stat_row+2, column=2).fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
worksheet.cell(row=stat_row+3, column=1, value="不完整问题数")
worksheet.cell(row=stat_row+3, column=2, value=incomplete_count)
worksheet.cell(row=stat_row+3, column=2).fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
worksheet.cell(row=stat_row+4, column=1, value="完整问题比例")
worksheet.cell(row=stat_row+4, column=2, value=f"{complete_count/total_rows:.2%}" if total_rows > 0 else "0%")
# 应用边框到统计行
for r in range(stat_row, stat_row+5):
for c in range(1, 3):
worksheet.cell(row=r, column=c).border = border
return complete_count
def save_results_to_excel(self, complete_questions, total_questions):
"""
将结果保存到Excel文件
参数:
complete_questions (list): 完整问题列表
total_questions (int): 总问题数
"""
if not complete_questions:
print(f"没有找到完整的问题。")
return
# 准备数据
result_df = self._prepare_excel_dataframe(complete_questions)
total_rows = len(result_df)
# 保存到Excel文件
result_df.to_excel(self.output_path, index=False, engine='openpyxl')
# 应用Excel样式
from openpyxl import load_workbook
wb = load_workbook(self.output_path)
ws = wb.active
# 设置列宽
self._set_excel_column_widths(ws)
# 应用单元格样式
border = self._apply_excel_cell_styles(ws)
# 添加统计信息
complete_count = self._add_statistics_to_excel(ws, complete_questions, total_rows, total_questions, border)
# 保存样式化的工作簿
wb.save(self.output_path)
# 输出结果统计
print(f"处理完成。共有{complete_count}/{total_questions}个完整问题被保存到 {self.output_path}")
print(f"完整问题比例: {complete_count/total_questions:.2%}" if total_questions > 0 else "完整问题比例: 0%")
def process_excel_file(self):
"""
处理Excel文件中的问题
读取Excel文件,判断问题完整性,并将结果保存到输出Excel文件
"""
# 确保Excel文件存在
if not os.path.exists(self.input_path):
print(f"错误: 找不到Excel文件 '{self.input_path}'")
return
# 读取Excel文件
print(f"正在读取Excel文件: {self.input_path}")
try:
df = pd.read_excel(self.input_path)
except Exception as e:
print(f"读取Excel文件时出错: {e}")
return
# 检查列数据
if len(df.columns) <= self.column_index:
print(f"错误: Excel文件没有足够的列,请求索引 {self.column_index},但只有 {len(df.columns)}")
return
# 获取目标列名称
target_col = df.columns[self.column_index]
print(f"目标列名称: {target_col}")
# 准备存储完整问题的列表
complete_questions = []
total_questions = len(df)
print(f"总共有{total_questions}个问题需要判断")
# 用于线程安全的列表操作和进度计数
complete_questions_lock = threading.Lock()
progress_counter = {"processed": 0, "complete": 0, "incomplete": 0}
progress_lock = threading.Lock()
# 准备问题列表
questions = [(i, str(row[target_col]), self.llm_client, total_questions)
for i, row in df.iterrows()]
# 记录开始时间
start_time = time.time()
# 使用tqdm创建进度条
print(f"开始处理问题,使用 {self.max_workers} 个并发线程...")
with tqdm(total=total_questions, desc="处理问题", unit="问题") as pbar:
# 使用线程池并发处理
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有任务
futures = [executor.submit(
self._process_question,
args,
complete_questions,
progress_counter,
progress_lock,
complete_questions_lock,
pbar
) for args in questions]
# 等待所有任务完成
concurrent.futures.wait(futures)
# 计算总处理时间
processing_time = time.time() - start_time
print(f"处理完成,耗时: {processing_time:.2f}秒,平均每问题: {processing_time/total_questions:.2f}")
# 将完整问题保存到Excel文件
self.save_results_to_excel(complete_questions, total_questions)
def test_single_question(self, question):
"""
测试单个问题的完整性
参数:
question (str): 要测试的问题
"""
print(f"问题: {question}")
print("正在调用LLM判断问题是否完整...")
# 调用LLM判断问题是否完整
is_complete, full_answer = self.is_question_complete(question)
# 从答案中提取JSON
parsed_json = self._extract_json_from_response(full_answer)
print("\n==== LLM回复 ====")
print(full_answer)
print("================\n")
if parsed_json:
print(f"判断结果: {'完整' if parsed_json.get('is_complete', False) else '不完整'}")
print(f"判断原因: {parsed_json.get('reason', '未提供')}")
print(f"置信度: {parsed_json.get('confidence', 0)}%")
else:
print(f"判断结果: {'完整' if is_complete else '不完整'} (简单判断)")
print("无法从回复中提取JSON结构化数据")
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='判断Excel文件中的问题是否完整')
parser.add_argument('-i', '--input', type=str, default=DEFAULT_EXCEL_PATH,
help=f'输入Excel文件路径 (默认: {DEFAULT_EXCEL_PATH})')
parser.add_argument('-o', '--output', type=str, default=DEFAULT_OUTPUT_PATH,
help=f'输出Excel文件路径 (默认: {DEFAULT_OUTPUT_PATH})')
parser.add_argument('-w', '--workers', type=int, default=DEFAULT_MAX_WORKERS,
help=f'并发处理的最大线程数 (默认: {DEFAULT_MAX_WORKERS})')
parser.add_argument('-c', '--column', type=int, default=0,
help='要处理的问题所在列的索引 (默认: 0,即第一列)')
parser.add_argument('-t', '--test', type=str,
help='测试单个问题,不处理Excel文件')
return parser.parse_args()
def main():
"""主函数"""
args = parse_arguments()
# 创建问题完整性判断工具实例
judge = QueryCompletenessJudge(
input_path=args.input,
output_path=args.output,
max_workers=args.workers,
column_index=args.column
)
# 如果是测试单个问题
if args.test:
judge.test_single_question(args.test)
return
# 处理Excel文件
judge.process_excel_file()
if __name__ == "__main__":
main()
+239
View File
@@ -0,0 +1,239 @@
import pandas as pd
from urllib.parse import unquote
from rag2_0.tool.WikijsTool import WikijsTool
from rag2_0.tool.html_to_md import convert_html_to_md
from rag2_0.tool.ModelTool import OpenAiLLM
from dotenv import load_dotenv
import os
from tqdm import tqdm
from rag2_0.dify.dify_tool import DifyTool
import json
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
load_dotenv()
class ContentSource(BaseModel):
score:int = Field(description="相关性分数")
reason:str = Field(description="评分理由")
class RetrieveContentScoreJudge:
"""
检索内容相关性评分工具类
用于评估检索内容与问题之间的相关性,并计算相关性分数
"""
def __init__(self, wiki_excel_path, answer_excel_path, output_path=None):
"""
初始化评分工具类
参数:
wiki_excel_path (str): Wiki Excel文件路径
answer_excel_path (str): 回答Excel文件路径
output_path (str, optional): 输出Excel文件路径,默认为None
"""
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
if os.path.exists(wiki_excel_path):
self.wiki_excel = pd.read_excel(wiki_excel_path)
else:
self.wiki_excel = None
self.answer_excel = pd.read_excel(answer_excel_path)
self.output_path = output_path or "/data/Rag2_0/data/excel/dify问答_检索内容评分.xlsx"
# 从环境变量中获取OpenAI的配置
self.api_key = os.getenv("OPENAI_API_KEY")
self.base_url = os.getenv("OPENAI_API_BASE")
self.model_name = os.getenv("LLM_MODEL_NAME")
if not all([self.api_key, self.base_url, self.model_name]):
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model_name)
def find_wiki_link(self, query) -> str | None:
"""
根据查询(对应wiki_excel中的新提问列)找出对应的词条链接
参数:
query (str): 查询内容,对应wiki_excel中的新提问列
返回:
str: 对应的词条链接,如果没有找到则返回None
"""
# 确保query不为空
if not query or pd.isna(query):
return None
if self.wiki_excel is None:
return None
# 在"新提问"列中查找匹配的行
matched_rows = self.wiki_excel[self.wiki_excel['新提问'] == query]
# 如果找到了匹配的行,返回对应的词条链接
if not matched_rows.empty:
return matched_rows.iloc[0]['对应词条链接']
# 如果没有完全匹配,尝试部分匹配
# 去除软件名称部分(如果有)
query_parts = query.split(',', 1)
if len(query_parts) > 1:
clean_query = query_parts[1].strip()
# 在"提问"列中查找包含清理后查询的行
for idx, row in self.wiki_excel.iterrows():
if pd.notna(row['提问']) and clean_query in row['提问']:
return row['对应词条链接']
return None
def get_wiki_title(self, link) -> str | None:
"""
获取词条标题
参数:
link (str): 词条链接
返回:
str: 词条标题,如果获取失败则返回None
"""
try:
if not link or pd.isna(link):
return None
# 移除域名部分,只保留路径
path = link.split('/', 3)[-1]
decoded_path = unquote(path)
path_parts = decoded_path.split('/')
return path_parts[-1]
except Exception as e:
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
def calculate_score(self, answer:str, content:str) -> int:
"""
使用OpenAiLLM通过LLM判断answer与content之间的相关性分数
参数:
answer (str): 用户问题
content (str): 检索内容
返回:
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
"""
try:
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
【评分标准】
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
8-9分:高度相关,核心要素匹配但存在少量信息缺失
6-7分:部分相关,涉及相同主题但存在重要信息缺失
4-5分:弱相关,仅次要信息点匹配
1-3分:完全不相关或信息冲突
【评估维度】
1. 主题一致性:核心主题/意图的匹配程度
2. 内容覆盖度:是否涵盖query的关键要素
3. 信息准确性:是否存在矛盾/错误信息
4. 细节丰富度:是否提供query要求的详细信息
【输出格式】
{{
"score": 评分,
"reason": "简明扼要的评分理由(中文)"
}}
【示例】
query: "新冠疫苗的常见副作用"
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)"
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
现在评估:
query: "{answer}"
content: "{content}"
"""
response = self.llm.invoke(user_prompt=prompt, need_retry=True)
# 解析JSON响应
try:
parsed_output = self.content_source_parser.parse(response.content)
return parsed_output.score
except Exception as e:
return -1
except Exception as e:
return -1
def get_retrieve_info(self, query:str, outputs:dict) -> tuple:
"""
获取检索信息并计算分数
参数:
query (str): 用户问题
outputs (dict): 检索输出结果
返回:
tuple: (检索内容列表, 最高分, 最低分, 平均分)
"""
max_score = 0
min_score = 10
total_score = 0
valid_scores = 0
retrieve_content = []
for result in outputs["result"]:
content = result["content"].strip()
score = self.calculate_score(answer=query, content=content)
if score != -1:
max_score = max(max_score, score)
min_score = min(min_score, score)
total_score += score
valid_scores += 1
content_title = content.split("\n")[0]
if content_title:
retrieve_content.append(content_title + f"--得分({score}分)")
avg_score = total_score / valid_scores if valid_scores > 0 else 0
return retrieve_content, max_score, min_score, avg_score
def process(self):
"""
处理所有问题并评估检索内容相关性
遍历answer_excel中的所有问题,计算检索内容与问题的相关性分数,
并更新Excel文件
"""
for idx, row in tqdm(self.answer_excel.iterrows(), total=len(self.answer_excel), desc="处理问题评分中"):
query = row["问题"]
link = self.find_wiki_link(query)
answer_title = self.get_wiki_title(link)
retrieve_content = []
max_score = 0
min_score = 0
avg_score = 0 # 初始化平均分
rewrite_query=""
message_info = DifyTool.get_message_debug_info(appid="ccf92b97-2789-4a3f-90e0-135a869a37c5", query=query)
for workflow_node in message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "知识检索结果后处理":
outputs = json.loads(workflow_node["outputs"])
retrieve_content, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
elif workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
# 更新 answer_excel 中的词条内容
self.answer_excel.at[idx, "答案词条"] = answer_title if answer_title else ""
self.answer_excel.at[idx, "问题改写"] = rewrite_query
self.answer_excel.at[idx, "检索得到词条"] = "\n".join(retrieve_content) if retrieve_content else "未检索知识库"
self.answer_excel.at[idx, "最大得分"] = max_score
self.answer_excel.at[idx, "最小得分"] = min_score
self.answer_excel.at[idx, "平均得分"] = avg_score
# 保存结果到Excel文件
self.answer_excel.to_excel(self.output_path, index=False)
print(f"结果已保存到 {self.output_path}")
if __name__ == "__main__":
# 创建评分工具实例
judge = RetrieveContentScoreJudge(
wiki_excel_path="/data/Rag2_0/data/excel/400条人工标注-部分提问_软件名称明确.xlsx",
answer_excel_path="/data/Rag2_0/data/excel/主网软件提问_回答内容评判.xlsx",
output_path="/data/Rag2_0/data/excel/dify问答_检索内容评分.xlsx"
)
# 执行处理
judge.process()
+178
View File
@@ -0,0 +1,178 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: merge_nouns_with_llm.py
Description: 合并多个nouns.json中的同名专业名词,利用LLM生成唯一合并结果
"""
import os
import json
import glob
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
from rag2_0.tool.ModelTool import OpenAiLLM
from rag2_0.intent_recognition.DataModels import Term
import logging
from langchain.output_parsers import PydanticOutputParser
from tqdm import tqdm
import time
# 加载环境变量
load_dotenv()
class TermMerger:
"""专业名词合并类,用于合并多个数据源中的同名专业名词"""
def __init__(self, input_dir=None, output_path=None, max_workers=3):
"""初始化名词合并器
Args:
input_dir: 包含nouns.json文件的目录路径
output_path: 合并结果的输出文件路径
max_workers: 线程池最大工作线程数
"""
self.EXTRACTED_NOUNS_DIR = input_dir
self.OUTPUT_PATH = output_path
self.MAX_WORKERS = max_workers
self.terms_parser = PydanticOutputParser(pydantic_object=Term)
self.MERGE_PROMPT = '''
请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请:
- 同义词(synonymous)去重合并
- 描述(description)合并为更完整、简明的描述
- 保持输出格式为:
{output_format}
原始条目:
{items}
'''
# 配置LLM
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
llm_params = {"temperature": 0.3, "model": model_name}
if api_key:
llm_params["api_key"] = api_key
if base_url:
llm_params["base_url"] = base_url
self.llm = OpenAiLLM(**llm_params)
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def load_all_terms(self):
"""读取目录下所有nouns.json,返回所有Term列表"""
all_terms = []
for file in glob.glob(os.path.join(self.EXTRACTED_NOUNS_DIR, '*_nouns.json')):
with open(file, 'r', encoding='utf-8') as f:
try:
file_terms = json.load(f)
new_terms = [{"name": term["name"].upper(), "synonymous": term["synonymous"], "description": term["description"]} for term in file_terms]
all_terms.extend(new_terms)
logging.info(f"加载{file},共{len(new_terms)}")
except Exception as e:
logging.warning(f"读取{file}失败: {e}")
# 加载suffix_keywords.json文件
suffix_keywords_path = os.path.join(os.path.dirname(os.path.dirname(self.EXTRACTED_NOUNS_DIR)), 'data', 'nouns', 'suffix_keywords.json')
if os.path.exists(suffix_keywords_path):
try:
with open(suffix_keywords_path, 'r', encoding='utf-8') as f:
suffix_terms = json.load(f)
suffix_terms = [{"name": term["name"].upper(), "synonymous": "", "description": ""} for term in suffix_terms]
all_terms.extend(suffix_terms)
logging.info(f"加载{suffix_keywords_path},共{len(suffix_terms)}")
except Exception as e:
logging.warning(f"读取{suffix_keywords_path}失败: {e}")
return all_terms
def group_terms_by_name(self, terms):
"""按name聚合Term"""
name2terms = defaultdict(list)
for term in terms:
name = term.get('name', '').strip()
if name:
name2terms[name].append(term)
return name2terms
def merge_terms_with_llm(self, name, term_list):
"""调用LLM合并同名Term,失败最多重试三次"""
items = json.dumps(term_list, ensure_ascii=False)
prompt = self.MERGE_PROMPT.format(name=name, items=items, output_format=self.terms_parser.get_format_instructions())
max_retries = 3
for attempt in range(1, max_retries + 1):
try:
response = self.llm.invoke(prompt, False)
parsed_output = self.terms_parser.parse(response.content)
return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description}
except Exception as e:
if attempt == max_retries:
logging.warning(f"解析LLM合并结果失败: {e}")
return None
else:
time.sleep(10*attempt)
def process_term(self, name_terms_tuple):
"""处理单个词条,用于线程池并行处理"""
name, term_list = name_terms_tuple
try:
merged = self.merge_terms_with_llm(name, term_list)
if merged:
return merged
else:
return term_list[0]
except Exception as e:
logging.error(f"处理词条 {name} 时出错: {e}")
return term_list[0]
def merge(self):
"""合并所有词条的入口方法"""
# 1. 读取所有术语
all_terms = self.load_all_terms()
logging.info(f"共加载{len(all_terms)}条术语")
# 2. 按名称聚合
name2terms = self.group_terms_by_name(all_terms)
logging.info(f"{len(name2terms)}个唯一名词")
# 3. 使用线程池并行处理
merged_terms = []
items_to_process = []
# 先处理只有一个条目的词条(不需要合并)
for name, term_list in name2terms.items():
if len(term_list) == 1:
merged_terms.append(term_list[0])
else:
items_to_process.append((name, term_list))
logging.info(f"{len(merged_terms)}个单一条目,{len(items_to_process)}个需要合并的条目")
# 只对需要合并的词条使用线程池处理
if items_to_process:
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
# 使用tqdm显示进度
for result in tqdm(executor.map(self.process_term, items_to_process), total=len(items_to_process)):
merged_terms.append(result)
# 4. 保存合并结果
os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True)
with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f:
json.dump(merged_terms, f, ensure_ascii=False, indent=2)
logging.info(f"合并后结果已保存到: {self.OUTPUT_PATH}")
return merged_terms
def main():
"""主函数,创建TermMerger实例并执行合并"""
cur_path = os.path.dirname(__file__)
input_dir = os.path.abspath(os.path.join(cur_path, '../../data/wiki_extracted_nouns'))
output_path = os.path.join(cur_path, "..", "..", "data", "nouns", 'merged_nouns.json')
merger = TermMerger(input_dir=input_dir, output_path=output_path, max_workers=2)
merger.merge()
if __name__ == "__main__":
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
main()
+408
View File
@@ -0,0 +1,408 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: validate_excel_data_batch.py
Description: 使用LLM批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写是否正确
"""
import os
import pandas as pd
import json
import argparse
import logging
import concurrent.futures
from tqdm import tqdm
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from rag2_0.intent_recognition.PromptTemplates import classification
from rag2_0.tool.ModelTool import OpenAiLLM
class ExcelDataValidator:
"""Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写"""
def __init__(self, input_file=None, output_file=None, workers=4, batch_size=10):
"""
初始化验证器
Args:
input_file: 输入Excel文件路径
output_file: 输出结果Excel文件路径
workers: 并行工作线程数
batch_size: 每批处理的行数
"""
# 加载环境变量
load_dotenv()
self.input_file = input_file
self.output_file = output_file
self.workers = workers
self.batch_size = batch_size
self.df = None
# 设置日志
self.setup_logging()
def setup_logging(self):
"""配置日志输出"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
def load_data_from_excel(self, file_path=None):
"""
从Excel文件中读取数据
Args:
file_path: Excel文件路径,如不提供则使用初始化时的路径
Returns:
DataFrame对象
"""
file_path = file_path or self.input_file
if not file_path:
logging.error("未指定输入文件路径")
return None
try:
df = pd.read_excel(file_path)
required_columns = ["提问", "问题拆解", "一级分类", "二级分类", "问题改写", "检索的关键词"]
for col in required_columns:
if col not in df.columns:
logging.error(f"缺少必要的列: {col}")
return None
logging.info(f"成功从{file_path}读取了{len(df)}条数据")
self.df = df
return df
except Exception as e:
logging.error(f"读取Excel文件时出错: {e}")
return None
def validate_classification(self, llm, query, vertical_class, sub_class):
"""
验证问题分类是否正确
Args:
llm: LLM模型
query: 原始问题
vertical_class: 一级分类
sub_class: 二级分类
Returns:
(bool, str): 是否正确,错误原因(如果有)
"""
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用,也可能涉及电力造价专业知识。我对用户问题进行了分类,请评估以下问题分类是否正确。
我目前总共有以下分类:
{classification}
问题的分类情况如下:
原始问题: {query}
一级分类: {vertical_class}
二级分类: {sub_class}
请从专业角度分析这个分类是否准确。只需返回"正确""错误:原因",不需要其他解释。"""
try:
response = llm.invoke(prompt)
result = response.content.strip()
if result.startswith("正确"):
return True, ""
else:
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
return False, error_reason
except Exception as e:
logging.warning(f"验证问题分类时出错: {e}")
return False, f"验证过程出错: {str(e)}"
def validate_query_keys(self, llm, query, query_keys):
"""
验证问题拆解是否正确
Args:
llm: LLM模型
query: 原始问题
query_keys: 问题拆解
Returns:
(bool, str): 是否正确,错误原因(如果有)
"""
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了拆解,请评估以下问题拆解是否正确。
原始问题: {query}
问题拆解: {query_keys}
问题拆解应该准确提取原始问题中的关键词和信息。请分析这个拆解是否准确。只需返回"正确""错误:原因",不需要其他解释。"""
try:
response = llm.invoke(prompt)
result = response.content.strip()
if result.startswith("正确"):
return True, ""
else:
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
return False, error_reason
except Exception as e:
logging.warning(f"验证问题拆解时出错: {e}")
return False, f"验证过程出错: {str(e)}"
def validate_keywords(self, llm, query, query_keys, keywords_str):
"""
验证检索关键词是否准确
Args:
llm: LLM模型
query: 原始问题
query_keys: 问题拆解
keywords_str: 检索关键词(JSON字符串)
Returns:
(bool, str): 是否正确,错误原因(如果有)
"""
# 解析关键词JSON
try:
if isinstance(keywords_str, str) and keywords_str.strip():
keywords = json.loads(keywords_str)
else:
keywords = []
except:
keywords = keywords_str
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。通过问题检索出了一些关键词,请评估这些关键词是否准确,是否与问题相关
原始问题: {query}
问题拆解: {query_keys}
检索关键词: {keywords}
检索关键词应该准确反映问题中需要检索的关键概念和术语。请分析这些关键词是否准确、完整。只需返回"正确""错误:原因",不需要其他解释。"""
try:
response = llm.invoke(prompt)
result = response.content.strip()
if result.startswith("正确"):
return True, ""
else:
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
return False, error_reason
except Exception as e:
logging.warning(f"验证检索关键词时出错: {e}")
return False, f"验证过程出错: {str(e)}"
def validate_rewrite(self, llm, query, rewrite):
"""
验证问题改写是否正确
Args:
llm: LLM模型
query: 原始问题
rewrite: 问题改写
Returns:
(bool, str): 是否正确,错误原因(如果有)
"""
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了改写,请评估以下问题改写是否正确。
原始问题: {query}
问题改写: {rewrite}
问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确。只需返回"正确""错误:原因",不需要其他解释。"""
try:
response = llm.invoke(prompt)
result = response.content.strip()
if result.startswith("正确"):
return True, ""
else:
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
return False, error_reason
except Exception as e:
logging.warning(f"验证问题改写时出错: {e}")
return False, f"验证过程出错: {str(e)}"
def validate_row(self, llm, row_data):
"""
按顺序验证一行数据中的各个环节
Args:
llm: LLM模型
row_data: (index, row)元组
Returns:
(index, is_all_correct, error_phase, error_reason): 行索引,是否全部正确,错误环节,错误原因
"""
index, row = row_data
query = row["提问"]
query_keys = row["问题拆解"]
vertical_class = row["一级分类"]
sub_class = row["二级分类"]
rewrite = row["问题改写"]
keywords_str = row["检索的关键词"]
try:
# 1. 验证问题分类
is_correct, error_reason = self.validate_classification(llm, query, vertical_class, sub_class)
if not is_correct:
return index, False, "问题分类", error_reason
# 2. 验证问题拆解
is_correct, error_reason = self.validate_query_keys(llm, query, query_keys)
if not is_correct:
return index, False, "问题拆解", error_reason
# 3. 验证检索关键词
is_correct, error_reason = self.validate_keywords(llm, query, query_keys, keywords_str)
if not is_correct:
return index, False, "关键词检索", error_reason
# 4. 验证问题改写
is_correct, error_reason = self.validate_rewrite(llm, query, rewrite)
if not is_correct:
return index, False, "问题改写", error_reason
return index, True, "", ""
except Exception as e:
error_msg = f"处理行 {index} 时发生错误: {str(e)}"
logging.error(error_msg)
return index, False, "处理错误", error_msg
def process_batch(self, llm, batch_data):
"""处理一批数据"""
results = []
for row_data in batch_data:
results.append(self.validate_row(llm, row_data))
return results
def create_llm_instances(self, count):
"""创建多个LLM实例"""
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
llm_params = {"temperature": 0.7, "model": model_name}
if api_key:
llm_params["api_key"] = api_key
if base_url:
llm_params["base_url"] = base_url
return [OpenAiLLM(**llm_params) for _ in range(count)]
def validate(self, input_file=None, output_file=None, workers=None, batch_size=None):
"""
执行验证过程
Args:
input_file: 输入Excel文件路径
output_file: 输出结果Excel文件路径
workers: 并行工作线程数
batch_size: 每批处理的行数
Returns:
验证后的DataFrame
"""
input_file = input_file or self.input_file
output_file = output_file or self.output_file
workers = workers or self.workers
batch_size = batch_size or self.batch_size
# 读取数据
df = self.load_data_from_excel(input_file)
if df is None:
return None
# 添加验证结果列
df["验证结果"] = ""
df["错误环节"] = ""
df["错误原因"] = ""
# 准备数据批次
all_rows = list(df.iterrows())
batches = [all_rows[i:i+batch_size] for i in range(0, len(all_rows), batch_size)]
# 创建多个LLM实例
llm_instances = self.create_llm_instances(min(workers, len(batches)))
# 使用线程池处理数据
all_results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
# 为每个批次分配一个LLM实例
future_to_batch = {
executor.submit(self.process_batch, llm_instances[i % len(llm_instances)], batch):
i for i, batch in enumerate(batches)
}
# 使用tqdm显示进度条
for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches), desc="批次处理进度"):
batch_results = future.result()
all_results.extend(batch_results)
# 按行索引排序结果,确保与原始数据顺序一致
all_results.sort(key=lambda x: x[0])
# 将结果填充到DataFrame
for index, is_correct, error_phase, error_reason in all_results:
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
df.at[index, "错误环节"] = error_phase
df.at[index, "错误原因"] = error_reason
# 保存结果
if output_file is None:
output_file = os.path.join(
os.path.dirname(input_file),
f"validated_{os.path.basename(input_file)}"
)
df.to_excel(output_file, index=False)
logging.info(f"验证完成,结果已保存至: {output_file}")
# 输出统计信息
self.print_statistics(df)
return df
def print_statistics(self, df):
"""打印统计信息"""
total = len(df)
passed = len(df[df["验证结果"] == "通过"])
error_stats = df[df["验证结果"] == "不通过"]["错误环节"].value_counts()
logging.info(f"统计信息: 总计 {total} 条, 通过 {passed} 条, 通过率 {passed/total*100:.2f}%")
logging.info("错误环节统计:")
for phase, count in error_stats.items():
logging.info(f"- {phase}: {count}")
def main():
"""主函数"""
# 解析命令行参数
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "问题分类重写结果")
output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx")
parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写")
parser.add_argument("--input", "-i", type=str, required=True, help="输入Excel文件路径", default=input_excel)
parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel)
parser.add_argument("--workers", "-w", type=int, default=2, help="并行工作线程数")
parser.add_argument("--batch-size", "-b", type=int, default=5, help="每批处理的行数")
args = parser.parse_args()
# 创建验证器实例并执行验证
validator = ExcelDataValidator(
input_file=args.input,
output_file=args.output,
workers=args.workers,
batch_size=args.batch_size
)
validator.validate()
if __name__ == "__main__":
main()
+47
View File
@@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: vectorize_save_noun.py
Date: 2025-05-15
Description: 专业名词向量化和保存的示例程序
"""
import os
import json
from dotenv import load_dotenv
from rag2_0.intent_recognition import ProfessionalNounVectorizer
import logging
# 加载环境变量
load_dotenv()
def main():
"""
主函数:创建索引并保存
"""
# 指定文件路径
current_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(current_dir, "..", "..", "data", "nouns")
# 创建向量化器并指定路径
noun_vectorizer = ProfessionalNounVectorizer(
output_dir=output_dir
)
file_paths = [
os.path.join(current_dir, "..", "..", "data/nouns/merged_nouns.json"),
]
# 执行向量化和保存(一步完成)
success = noun_vectorizer.vectorize_files_and_save(file_paths)
if success:
logging.info("✓ 索引创建和保存成功")
logging.info(f" 索引保存路径: {os.path.join(output_dir, 'professional_nouns_index')}")
else:
logging.error("✗ 索引创建失败")
if __name__ == "__main__":
# 配置日志输出到控制台
logging.basicConfig(
level=logging.INFO,
format='%(message)s'
)
main()
Binary file not shown.
+1
View File
@@ -0,0 +1 @@
from dify_client.client import ChatClient, CompletionClient, DifyClient
+459
View File
@@ -0,0 +1,459 @@
import json
import requests
class DifyClient:
def __init__(self, api_key, base_url: str = "https://api.dify.ai/v1"):
self.api_key = api_key
self.base_url = base_url
def _send_request(self, method, endpoint, json=None, params=None, stream=False):
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
url = f"{self.base_url}{endpoint}"
response = requests.request(
method, url, json=json, params=params, headers=headers, stream=stream, verify=False
)
return response
def _send_request_with_files(self, method, endpoint, data, files):
headers = {"Authorization": f"Bearer {self.api_key}"}
url = f"{self.base_url}{endpoint}"
response = requests.request(
method, url, data=data, headers=headers, files=files
)
return response
def message_feedback(self, message_id, rating, user):
data = {"rating": rating, "user": user}
return self._send_request("POST", f"/messages/{message_id}/feedbacks", data)
def get_application_parameters(self, user):
params = {"user": user}
return self._send_request("GET", "/parameters", params=params)
def file_upload(self, user, files):
data = {"user": user}
return self._send_request_with_files(
"POST", "/files/upload", data=data, files=files
)
def text_to_audio(self, text: str, user: str, streaming: bool = False):
data = {"text": text, "user": user, "streaming": streaming}
return self._send_request("POST", "/text-to-audio", data=data)
def get_meta(self, user):
params = {"user": user}
return self._send_request("GET", "/meta", params=params)
class CompletionClient(DifyClient):
def create_completion_message(self, inputs, response_mode, user, files=None):
data = {
"inputs": inputs,
"response_mode": response_mode,
"user": user,
"files": files,
}
return self._send_request(
"POST",
"/completion-messages",
data,
stream=True if response_mode == "streaming" else False,
)
class ChatClient(DifyClient):
def create_chat_message(
self,
inputs,
query,
user,
response_mode="blocking",
conversation_id=None,
files=None,
):
data = {
"inputs": inputs,
"query": query,
"user": user,
"response_mode": response_mode,
"files": files,
}
if conversation_id:
data["conversation_id"] = conversation_id
return self._send_request(
"POST",
"/chat-messages",
data,
stream=True if response_mode == "streaming" else False,
)
def get_suggested(self, message_id, user: str):
params = {"user": user}
return self._send_request(
"GET", f"/messages/{message_id}/suggested", params=params
)
def stop_message(self, task_id, user):
data = {"user": user}
return self._send_request("POST", f"/chat-messages/{task_id}/stop", data)
def get_conversations(self, user, last_id=None, limit=None, pinned=None):
params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned}
return self._send_request("GET", "/conversations", params=params)
def get_conversation_messages(
self, user, conversation_id=None, first_id=None, limit=None
):
params = {"user": user}
if conversation_id:
params["conversation_id"] = conversation_id
if first_id:
params["first_id"] = first_id
if limit:
params["limit"] = limit
return self._send_request("GET", "/messages", params=params)
def rename_conversation(
self, conversation_id, name, auto_generate: bool, user: str
):
data = {"name": name, "auto_generate": auto_generate, "user": user}
return self._send_request(
"POST", f"/conversations/{conversation_id}/name", data
)
def delete_conversation(self, conversation_id, user):
data = {"user": user}
return self._send_request("DELETE", f"/conversations/{conversation_id}", data)
def audio_to_text(self, audio_file, user):
data = {"user": user}
files = {"audio_file": audio_file}
return self._send_request_with_files("POST", "/audio-to-text", data, files)
class WorkflowClient(DifyClient):
def run(
self, inputs: dict, response_mode: str = "streaming", user: str = "abc-123"
):
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
return self._send_request("POST", "/workflows/run", data)
def stop(self, task_id, user):
data = {"user": user}
return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data)
def get_result(self, workflow_run_id):
return self._send_request("GET", f"/workflows/run/{workflow_run_id}")
class KnowledgeBaseClient(DifyClient):
def __init__(
self,
api_key,
base_url: str = "https://api.dify.ai/v1",
dataset_id: str | None = None,
):
"""
Construct a KnowledgeBaseClient object.
Args:
api_key (str): API key of Dify.
base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'.
dataset_id (str, optional): ID of the dataset. Defaults to None. You don't need this if you just want to
create a new dataset. or list datasets. otherwise you need to set this.
"""
super().__init__(api_key=api_key, base_url=base_url)
self.dataset_id = dataset_id
def _get_dataset_id(self):
if self.dataset_id is None:
raise ValueError("dataset_id is not set")
return self.dataset_id
def create_dataset(self, name: str, **kwargs):
return self._send_request("POST", "/datasets", {"name": name}, **kwargs)
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
return self._send_request(
"GET", f"/datasets?page={page}&limit={page_size}", **kwargs
)
def create_document_by_text(
self, name, text, extra_params: dict | None = None, **kwargs
):
"""
Create a document by text.
:param name: Name of the document
:param text: Text content of the document
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
e.g.
{
'indexing_technique': 'high_quality',
'process_rule': {
'rules': {
'pre_processing_rules': [
{'id': 'remove_extra_spaces', 'enabled': True},
{'id': 'remove_urls_emails', 'enabled': True}
],
'segmentation': {
'separator': '\n',
'max_tokens': 500
}
},
'mode': 'custom'
}
}
:return: Response from the API
"""
data = {
"indexing_technique": "high_quality",
"process_rule": {"mode": "automatic"},
"name": name,
"text": text,
}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = f"/datasets/{self._get_dataset_id()}/document/create_by_text"
return self._send_request("POST", url, json=data, **kwargs)
def update_document_by_text(
self, document_id, name, text, extra_params: dict | None = None, **kwargs
):
"""
Update a document by text.
:param document_id: ID of the document
:param name: Name of the document
:param text: Text content of the document
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
e.g.
{
'indexing_technique': 'high_quality',
'process_rule': {
'rules': {
'pre_processing_rules': [
{'id': 'remove_extra_spaces', 'enabled': True},
{'id': 'remove_urls_emails', 'enabled': True}
],
'segmentation': {
'separator': '\n',
'max_tokens': 500
}
},
'mode': 'custom'
}
}
:return: Response from the API
"""
data = {"name": name, "text": text}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = (
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
)
return self._send_request("POST", url, json=data, **kwargs)
def create_document_by_file(
self, file_path, original_document_id=None, extra_params: dict | None = None
):
"""
Create a document by file.
:param file_path: Path to the file
:param original_document_id: pass this ID if you want to replace the original document (optional)
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
e.g.
{
'indexing_technique': 'high_quality',
'process_rule': {
'rules': {
'pre_processing_rules': [
{'id': 'remove_extra_spaces', 'enabled': True},
{'id': 'remove_urls_emails', 'enabled': True}
],
'segmentation': {
'separator': '\n',
'max_tokens': 500
}
},
'mode': 'custom'
}
}
:return: Response from the API
"""
files = {"file": open(file_path, "rb")}
data = {
"process_rule": {"mode": "automatic"},
"indexing_technique": "high_quality",
}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
if original_document_id is not None:
data["original_document_id"] = original_document_id
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
return self._send_request_with_files(
"POST", url, {"data": json.dumps(data)}, files
)
def update_document_by_file(
self, document_id, file_path, extra_params: dict | None = None
):
"""
Update a document by file.
:param document_id: ID of the document
:param file_path: Path to the file
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
e.g.
{
'indexing_technique': 'high_quality',
'process_rule': {
'rules': {
'pre_processing_rules': [
{'id': 'remove_extra_spaces', 'enabled': True},
{'id': 'remove_urls_emails', 'enabled': True}
],
'segmentation': {
'separator': '\n',
'max_tokens': 500
}
},
'mode': 'custom'
}
}
:return:
"""
files = {"file": open(file_path, "rb")}
data = {}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = (
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
)
return self._send_request_with_files(
"POST", url, {"data": json.dumps(data)}, files
)
def batch_indexing_status(self, batch_id: str, **kwargs):
"""
Get the status of the batch indexing.
:param batch_id: ID of the batch uploading
:return: Response from the API
"""
url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status"
return self._send_request("GET", url, **kwargs)
def delete_dataset(self):
"""
Delete this dataset.
:return: Response from the API
"""
url = f"/datasets/{self._get_dataset_id()}"
return self._send_request("DELETE", url)
def delete_document(self, document_id):
"""
Delete a document.
:param document_id: ID of the document
:return: Response from the API
"""
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}"
return self._send_request("DELETE", url)
def list_documents(
self,
page: int | None = None,
page_size: int | None = None,
keyword: str | None = None,
**kwargs,
):
"""
Get a list of documents in this dataset.
:return: Response from the API
"""
params = {}
if page is not None:
params["page"] = page
if page_size is not None:
params["limit"] = page_size
if keyword is not None:
params["keyword"] = keyword
url = f"/datasets/{self._get_dataset_id()}/documents"
return self._send_request("GET", url, params=params, **kwargs)
def add_segments(self, document_id, segments, **kwargs):
"""
Add segments to a document.
:param document_id: ID of the document
:param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}]
:return: Response from the API
"""
data = {"segments": segments}
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
return self._send_request("POST", url, json=data, **kwargs)
def query_segments(
self,
document_id,
keyword: str | None = None,
status: str | None = None,
**kwargs,
):
"""
Query segments in this document.
:param document_id: ID of the document
:param keyword: query keyword, optional
:param status: status of the segment, optional, e.g. completed
"""
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
params = {}
if keyword is not None:
params["keyword"] = keyword
if status is not None:
params["status"] = status
if "params" in kwargs:
params.update(kwargs["params"])
return self._send_request("GET", url, params=params, **kwargs)
def delete_document_segment(self, document_id, segment_id):
"""
Delete a segment from a document.
:param document_id: ID of the document
:param segment_id: ID of the segment
:return: Response from the API
"""
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
return self._send_request("DELETE", url)
def update_document_segment(self, document_id, segment_id, segment_data, **kwargs):
"""
Update a segment in a document.
:param document_id: ID of the document
:param segment_id: ID of the segment
:param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True}
:return: Response from the API
"""
data = {"segment": segment_data}
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
return self._send_request("POST", url, json=data, **kwargs)
+215
View File
@@ -0,0 +1,215 @@
import psycopg2
from psycopg2 import sql
import os
import json
from datetime import timezone, timedelta
class PgSql:
"""
用于连接和操作 PostgreSQL 数据库的类。
该类封装了数据库连接、关闭连接以及执行特定查询的方法,
主要用于从 Dify 应用相关的表中获取数据。
"""
def __init__(self):
"""
初始化 PgSql 实例并建立数据库连接。
"""
self.connection = None
self.connect_sql()
def connect_sql(self):
"""
连接到 PostgreSQL 数据库。
使用预定义的凭据连接到 'dify' 数据库。
如果连接失败,会打印错误信息。
"""
try:
# 连接数据库
self.connection = psycopg2.connect(
user="postgres",
password="difyai123456",
host="172.20.0.145",
port=5432,
database="dify"
)
except (Exception, psycopg2.Error) as error:
print("Error while connecting to PostgreSQL", error)
def close_connection(self):
"""
关闭当前的 PostgreSQL 数据库连接。
如果存在活动的连接,则关闭它并打印确认信息。
"""
if self.connection:
self.connection.close()
print("PostgreSQL connection is closed")
def get_appinfo(self, appid:str)->dict | None:
"""
根据应用 ID 从 'apps' 表中获取应用信息。
Args:
appid: 目标应用的 ID。
Returns:
一个字典,其中键是列名,值是对应的应用数据。
如果未找到应用或发生错误,则返回 None。
"""
try:
with self.connection.cursor() as cursor:
cursor.execute(
"""
SELECT * FROM apps WHERE id = %s
""",
(appid,)
)
result = cursor.fetchone()
if result:
colnames = [desc[0] for desc in cursor.description]
return dict(zip(colnames, result))
return None
except (Exception, psycopg2.Error) as error:
print("Error while getting tenant_id by appid", error)
def get_messages_info(self, appid:str, query:str)->dict | None:
"""
根据应用 ID 和查询内容从 'messages' 表中获取消息信息。
Args:
appid: 目标应用的 ID。
query: 用户查询的具体内容。
Returns:
一个字典,其中键是列名,值是对应的消息数据。
如果未找到消息或发生错误,则返回 None。
"""
try:
with self.connection.cursor() as cursor:
cursor.execute(
"""
SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC
""",
(appid, query)
)
result = cursor.fetchone()
if result:
colnames = [desc[0] for desc in cursor.description]
return dict(zip(colnames, result))
return None
except (Exception, psycopg2.Error) as error:
print("Error while getting messages_info", error)
def get_messages_info_by_id(self, message_id:str)->dict | None:
"""
根据消息 ID 从 'messages' 表中获取消息信息。
"""
try:
with self.connection.cursor() as cursor:
cursor.execute(
"""
SELECT * FROM messages WHERE id = %s
""",
(message_id, )
)
result = cursor.fetchone()
if result:
colnames = [desc[0] for desc in cursor.description]
return dict(zip(colnames, result))
return None
except (Exception, psycopg2.Error) as error:
print("Error while getting messages_info", error)
def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None:
"""
根据工作流运行 ID 从 'workflow_node_executions' 表中获取节点执行信息。
Args:
workflow_run_id: 目标工作流运行的 ID。
Returns:
一个字典,其中键是列名,值是对应的节点执行数据。
如果未找到执行信息或发生错误,则返回 None。
"""
try:
with self.connection.cursor() as cursor:
cursor.execute(
"""
SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s
""",
(workflow_run_id,)
)
result = cursor.fetchall()
if result:
colnames = [desc[0] for desc in cursor.description]
return [dict(zip(colnames, row)) for row in result]
return None
except (Exception, psycopg2.Error) as error:
print("Error while getting workflow_node_executions_info", error)
class DifyTool:
"""
提供用于获取 Dify 应用调试信息的工具类。
该类利用 PgSql 类从数据库中检索与特定应用和查询相关的
应用信息、消息详情以及工作流节点执行情况。
"""
@staticmethod
def get_message_debug_info_id(message_id:str)->dict | None:
"""
根据消息 ID 从 'messages' 表中获取消息信息。
"""
dify_pgsql = PgSql()
messages_info = dify_pgsql.get_messages_info_by_id(message_id)
if not messages_info:
return None
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
if not workflow_node_executions_info:
return None
return {
"messages_info": messages_info,
"workflow_node_executions_info": workflow_node_executions_info
}
@staticmethod
def get_message_debug_info(appid:str, query:str)->dict:
"""
获取指定应用和查询相关的调试信息。
此静态方法会创建一个临时的 PgSql 实例来查询数据库,
然后聚合应用信息、消息信息和工作流节点执行信息。
Args:
appid: 目标应用的 ID。
query: 用户查询的具体内容。
Returns:
一个包含 "appinfo", "messages_info", 和
"workflow_node_executions_info"键的字典,分别对应
查询到的应用数据、消息数据和节点执行数据。
"""
dify_pgsql = PgSql()
appinfo = dify_pgsql.get_appinfo(appid)
if not appinfo:
return None
messages_info = dify_pgsql.get_messages_info(appid, query)
if not messages_info:
return None
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
if not workflow_node_executions_info:
return None
return {
"appinfo": appinfo,
"messages_info": messages_info,
"workflow_node_executions_info": workflow_node_executions_info
}
if __name__ == "__main__":
print(DifyTool.get_message_debug_info("ccf92b97-2789-4a3f-90e0-135a869a37c5", "电力建设计价通软件,导入结算后没有暂列金怎么办?要手动添加么?"))
+54
View File
@@ -0,0 +1,54 @@
from flask import Flask, request, Response
import os
from dotenv import load_dotenv
from rag2_0.intent_recognition import IntentRecognizer
import json
import time
# 加载环境变量
load_dotenv()
app = Flask(__name__)
# 初始化意图识别器
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
@app.route('/intent_recognize', methods=['POST'])
def intent_recognize():
try:
data = request.get_json(force=True)
query = data.get('query')
if not query:
return Response(json.dumps({"error": "缺少query参数"}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=400)
start_time = time.time()
classification, keywords, rewrite, query_keys = recognizer.process_query(query)
end_time = time.time()
print(f"意图识别耗时: {end_time - start_time:.2f}")
# keywords对象转为字符串
keywords_str = ""
if keywords and keywords.terms:
term_details = []
for term in keywords.terms:
term_info = {
"名称": term.name,
"同义词": ";".join(term.synonymous) if term.synonymous else "",
"描述": term.description
}
term_details.append(term_info)
keywords_str = term_details
result = {
"source_query": query,
"source_query_keys": query_keys,
"vertical_classification": classification.vertical_classification,
"sub_classification": classification.sub_classification,
"rewrite_query": rewrite.rewrite,
"keywords": keywords_str
}
return Response(json.dumps(result, ensure_ascii=False), content_type='application/json; charset=utf-8')
except Exception as e:
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8001)
+136
View File
@@ -0,0 +1,136 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from rag2_0.dify.dify_client import ChatClient, DifyClient
import pandas as pd
# 使用线程池并发执行
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from rag2_0.dify.dify_tool import DifyTool
import json
class DifyComparisonTester:
"""
Dify新旧流程对比测试类,用于比较两个不同流程的问答效果
"""
def __init__(self, excel_path:str, baseurl:str, old_workflow_api_key:str, new_workflow_api_key:str):
"""
初始化对比测试器
Args:
excel_path: 包含问题的Excel文件路径
baseurl: Dify API的基础URL
old_workflow_api_key: 旧流程的API密钥
new_workflow_api_key: 新流程的API密钥
"""
self.excel_path = excel_path
self.baseurl = baseurl
self.old_workflow_api_key = old_workflow_api_key
self.new_workflow_api_key = new_workflow_api_key
self.old_chat = ChatClient(api_key=old_workflow_api_key, base_url=baseurl)
self.new_chat = ChatClient(api_key=new_workflow_api_key, base_url=baseurl)
def process_question(self, q:str):
"""
处理单个问题,并行获取新旧流程的回答
Args:
q: 问题内容
Returns:
dict: 包含问题和两个流程回答的字典
"""
q="qwqwwq"
def get_old_answer():
try:
return self.old_chat.create_chat_message(inputs={}, query=q, user="AutoTestDifyChat").json()
except Exception as e:
return f"error: {str(e)}"
def get_new_answer():
try:
return self.new_chat.create_chat_message(inputs={}, query=q, user="AutoTestDifyChat").json()
except Exception as e:
return f"error: {str(e)}"
# 并行执行old_chat和new_chat
with ThreadPoolExecutor(max_workers=2) as executor:
future_old = executor.submit(get_old_answer)
future_new = executor.submit(get_new_answer)
old_result = future_old.result()
new_result = future_new.result()
old_message_id = old_result["message_id"]
new_message_id = new_result["message_id"]
old_message_info = DifyTool.get_message_debug_info_id(message_id=old_message_id)
new_message_info = DifyTool.get_message_debug_info_id(message_id=new_message_id)
for workflow_node in new_message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
old_answer = old_result["answer"]
new_answer = new_result["answer"]
return {"问题": q, "问题改写": rewrite_query, "旧流程答案": old_answer, "新流程答案": new_answer}
def run_comparison(self):
"""
运行对比测试,处理所有问题并生成结果Excel
Returns:
str: 输出Excel文件的路径
"""
# 读取Excel文件中的问题
df = pd.read_excel(self.excel_path)
questions = df.iloc[:,0].tolist()
results = []
# 按顺序处理问题
with tqdm(total=len(questions), desc="处理问题进度") as pbar:
for q in questions:
result = self.process_question(q)
results.append(result)
pbar.update(1)
# 生成输出Excel文件
out_path = os.path.join(os.path.dirname(self.excel_path), "dify问答_对比结果.xlsx")
df_results = pd.DataFrame(results)
# 使用ExcelWriter设置格式
with pd.ExcelWriter(out_path, engine='xlsxwriter') as writer:
df_results.to_excel(writer, index=False, sheet_name='Sheet1')
# 获取工作簿和工作表对象
workbook = writer.book
worksheet = writer.sheets['Sheet1']
# 设置列宽
worksheet.set_column('A:A', 50) # 问题列宽 50个Excel单位
worksheet.set_column('B:B', 70) # 旧流程答案列宽 70个Excel单位
worksheet.set_column('C:C', 70) # 新流程答案列宽 70个Excel单位
return out_path
if __name__ == "__main__":
# 定义Excel路径
excel_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", ".." ,"data/excel/历史提问数据(dislike)_1000条_软件明确.xlsx")
if not os.path.exists(excel_path):
print(f"错误:Excel文件不存在: {excel_path}")
exit(1)
# Dify API配置
baseurl = "http://172.20.0.145/v1"
old_workflow_api_key = "app-wUdkWJx5zeOvmvBUZizMoSw3"
new_workflow_api_key = "app-Lf1pQ1NVwdMfCRVNTBCOTPHT"
# 创建测试器并运行
tester = DifyComparisonTester(excel_path, baseurl, old_workflow_api_key, new_workflow_api_key)
output_file = tester.run_comparison()
print(f"对比结果已保存至: {output_file}")
# 单个问题测试示例
# c = DifyChat(baseurl="http://172.20.0.145/v1", api_key="app-LjJaeLoAfqa6aoGzqU9UvxSf")
# c.chat("如何新建配电线路工程")
+36
View File
@@ -0,0 +1,36 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: DataModels.py
Author: oyyz
Date: 2025-05-13
Description: 提取和分类的数据模型
"""
from pydantic import BaseModel, Field
from typing import List, Optional
# 定义输出模型
class Term(BaseModel):
name: str = Field(description="专业名词")
synonymous: List[str] = Field(description="同义词列表")
description: str = Field(description="描述信息", default="")
def __hash__(self):
return hash(self.name)
def __eq__(self, other):
if isinstance(other, Term):
return self.name == other.name
return False
class TermList(BaseModel):
terms: List[Term] = Field(description="专业名词列表")
class Classification(BaseModel):
vertical_classification:str = Field(description="垂直领域一级分类")
sub_classification:str = Field(description="一级分类下的二级分类")
class QueryRewrite(BaseModel):
rewrite:str = Field(description="问题改写")
@@ -0,0 +1,289 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: IntentRecognition.py
Author: oyyz
Date: 2025-05-13
Description: 意图分类、改写核心逻辑
"""
import os
from langchain_openai import ChatOpenAI
from langchain.output_parsers import PydanticOutputParser
import json
from typing import List, Tuple
import re
from .PromptTemplates import classification_prompt, query_rewrite_prompt, extract_nouns_prompt, classification_info
from .DataModels import Classification, QueryRewrite, Term, TermList
from .ProfessionalNounVector import ProfessionalNounRetriever
from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM
class IntentRecognizer:
"""
意图识别和问题改写类
"""
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo", vector_index_dir: str = None):
"""
初始化意图识别器
Args:
api_key: OpenAI API密钥,如果为None则从环境变量获取
base_url: OpenAI API基础URL,如果为None则使用默认URL
model_name: 要使用的模型名称
vector_index_dir: 向量索引目录,如果为None则使用默认目录
"""
# 初始化LLM
llm_params = {
"temperature": 0.2, # 降低随机性,使结果更确定
"model": model_name
}
# 如果提供了API密钥,则使用提供的密钥
if api_key:
llm_params["api_key"] = api_key
# 如果提供了自定义URL,则使用提供的URL
if base_url:
llm_params["base_url"] = base_url
self.llm = OpenAiLLM(**llm_params)
# 准备分类解析器
self.classification_parser = PydanticOutputParser(pydantic_object=Classification)
# 准备问题改写解析器
self.query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
# 准备术语列表解析器
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
# 加载suffix关键词
self.suffix_keywords = self._load_suffix_keywords()
# 初始化向量检索器
self.noun_retriever = ProfessionalNounRetriever(api_key=api_key, index_dir=vector_index_dir)
def _load_suffix_keywords(self, filepath: str = None) -> List[str]:
"""
加载后缀关键词列表
Args:
filepath: 后缀关键词文件路径,默认为None使用默认路径
Returns:
后缀关键词列表
"""
try:
# 如果未指定路径,使用默认路径
if filepath is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
filepath = os.path.join(current_dir, "..", "..", "data", "nouns", "suffix_keywords.json")
# 读取JSON文件
with open(filepath, "r", encoding="utf-8") as f:
suffix_data = json.load(f)
# 添加额外的固定后缀
return suffix_data
except Exception as e:
raise RuntimeError(f"加载后缀关键词失败: {e}") from e
def classify_intent(self, query: str, keywords: TermList) -> Classification:
"""
对用户输入进行意图分类
Args:
content: 用户输入内容
keywords: 匹配到的关键词列表
rewrite: 重写的问题
Returns:
分类结果
"""
formatted_prompt = classification_prompt.replace("{user_input}", query)
formatted_prompt = formatted_prompt.replace("{classification_info}", classification_info)
formatted_prompt = formatted_prompt.replace("{output_format}", self.classification_parser.get_format_instructions())
# 将关键词列表转换为JSON字符串
terms_dict = [term.model_dump() for term in keywords.terms]
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
formatted_prompt = formatted_prompt.replace("{keywords}", keywords_str)
# 调用LLM
response = self.llm.invoke(formatted_prompt, False)
# 解析输出
try:
# 尝试直接解析JSON响应
parsed_output = self.classification_parser.parse(response.content.strip())
return parsed_output
except Exception as e:
raise RuntimeError(f"解析分类结果时出错: {e}") from e
def extract_keywords_with_llm(self, query: str) -> List[Term]:
"""
使用LLM从用户查询中提取专业关键词
Args:
query: 用户查询
Returns:
提取的术语列表
"""
# 准备提示词
formatted_prompt = extract_nouns_prompt.replace("{content}", query)
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
# 调用LLM
response = self.llm.invoke(formatted_prompt, False)
try:
# 尝试使用Pydantic解析器解析TermList
parsed_output = self.terms_list_parser.parse(response.content)
return parsed_output.terms
except Exception as e:
raise RuntimeError(f"无法解析LLM关键词提取响应: {e}") from e
def match_keywords(self, query: str) -> Tuple[TermList, List[str]]:
"""
从用户问题中匹配关键词,结合LLM提取和向量检索
Args:
query: 用户问题
Returns:
匹配到的关键词列表
"""
matched_terms = set() # 存储匹配到的Term对象
query_keys=[]
# 步骤2: 使用LLM提取查询中的关键词
try:
extracted_terms = self.extract_keywords_with_llm(query)
for term in extracted_terms:
query_keys.append(term.name)
except Exception as e:
raise RuntimeError(f"LLM关键词提取失败: {e}") from e
# 步骤3: 使用向量检索找到相似的专业名词
try:
# 对matched_terms中的每个关键字进行向量检索
for current_key in query_keys:
vector_results = self.noun_retriever.query(current_key, top_k=3, use_intersection=True)
# 添加向量检索结果
for result in vector_results:
term = Term(
name=result.get('name'),
synonymous=result.get('synonymous', []),
description=result.get('description', '')
)
matched_terms.add(term)
except Exception as e:
raise RuntimeError(f"向量检索关键词时出错: {e}") from e
if len(matched_terms) != 0:
txts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) + "|" + "描述:" + term.description for term in matched_terms]
# txts = [term.name for term in matched_terms]
xinference_reranker = XinferenceReRankerModel()
rerank_results = xinference_reranker.rerank(query, txts, top_k=5)
matched_terms_list = list(matched_terms)
matched_terms = [matched_terms_list[result["index"]] for result in rerank_results]
# 提取所有Term对象的名称并排序
# 将set类型的matched_terms转换为TermList类型
term_list = TermList(terms=list(matched_terms))
return term_list, query_keys
def rewrite_query(self, query: str, keywords: TermList) -> QueryRewrite:
"""
对用户问题进行改写
Args:
query: 用户原始问题
keywords: 匹配到的关键词列表
Returns:
改写结果
"""
# 准备问题改写提示
terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
formatted_prompt = query_rewrite_prompt.format(query=query, output_format=self.query_rewrite_parser.get_format_instructions(),keywords=keywords_str)
# 调用LLM
response = self.llm.invoke(formatted_prompt, False)
# 解析输出
try:
# 尝试直接解析JSON响应
parsed_output = self.query_rewrite_parser.parse(response.content)
return parsed_output
except Exception as e:
raise RuntimeError(f"解析问题改写结果时出错: {e}") from e
def judge_define_suffix(self, input_str: str) -> Tuple[bool, List[str]]:
"""
判断输入字符串是否包含定义的后缀,并返回所有匹配到的后缀名列表
Args:
input_str: 输入字符串
Returns:
Tuple[bool, List[str]]: (是否包含定义的后缀, 匹配到的后缀名列表)
"""
# 构建正则表达式模式,匹配大小写不敏感且前面可能带有.
pattern = r'(?:\.?)(' + '|'.join(re.escape(field.get('name')) for field in self.suffix_keywords) + r')'
# 使用 re.IGNORECASE 标志来忽略大小写,findall找到所有匹配
matches = re.finditer(pattern, input_str, re.IGNORECASE)
matched_suffixes = [match.group(1) for match in matches]
return bool(matched_suffixes), matched_suffixes
def process_query(self, query: str) -> Tuple[Classification, TermList, QueryRewrite, List[str]]:
"""
处理用户问题的完整流程
Args:
query: 用户原始问题
Returns:
(意图分类结果, 匹配的关键词列表, 问题改写结果)的元组
"""
# 是否是扩展名
# is_suffix, matched_suffixes = self.judge_define_suffix(query)
# if is_suffix:
# # 将所有匹配到的后缀名作为Term添加到结果中
# suffix_terms = []
# for suffix in matched_suffixes:
# term_dict = next((item for item in self.suffix_keywords if item['name'].lower() == suffix.lower()), None)
# if term_dict:
# suffix_term = Term(
# name=term_dict.get('name'),
# synonymous=term_dict.get('synonymous', []),
# description=json.dumps(term_dict.get('description', ''), ensure_ascii=False)
# )
# suffix_terms.append(suffix_term)
# return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes
# 步骤1: 匹配关键词
keywords_terms, query_keys = self.match_keywords(query)
# 步骤2: 问题改写
rewrite = self.rewrite_query(
query=query,
keywords=keywords_terms
)
# 步骤3: 进行意图分类
classification = self.classify_intent(query, keywords_terms)
if classification.vertical_classification == "其他" or classification.sub_classification == "其他":
return classification, TermList(terms=[]), QueryRewrite(rewrite=query), []
if classification.vertical_classification == "闲聊" or classification.sub_classification == "闲聊":
return classification, TermList(terms=[]), QueryRewrite(rewrite=query),[]
# rewrite = QueryRewrite(rewrite=query)
return classification, keywords_terms, rewrite, query_keys
@@ -0,0 +1,321 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: ProfessionalNounVector.py
Date: 2025-05-15
Author: oyyz
Description: 专业名词向量化和检索的核心逻辑
"""
import os
import json
import shutil
from typing import List, Dict, Any, Tuple, Optional
from langchain.embeddings.base import Embeddings
from langchain_community.vectorstores import FAISS
from rag2_0.tool.ModelTool import SiliconFlowEmbeddings
import logging
def get_embedding_model(api_key: str = None) -> Embeddings:
"""
获取嵌入模型
Args:
api_key: API密钥,如果为None则从环境变量获取
Returns:
嵌入模型实例
"""
if not api_key:
api_key = os.getenv("SILICONFLOW_API_KEY", "sk-ftnofbucchwnscojohyxwmfzgaykdxihafnlphohsinftkbr")
return SiliconFlowEmbeddings(api_key=api_key)
class ProfessionalNounVectorizer:
"""专业名词向量化和保存类"""
def __init__(self,
embedding_model: Optional[Embeddings] = None,
api_key: str = None,
output_dir: str = None):
"""
初始化向量化器
Args:
embedding_model: 嵌入模型,如果为None则使用默认模型
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
output_dir: 索引输出目录,默认为None使用默认路径
"""
# 设置嵌入模型
if embedding_model:
self.embedding_model = embedding_model
else:
self.embedding_model = get_embedding_model(api_key)
# 设置输出目录
self.output_dir = output_dir
if self.output_dir is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
self.output_dir = os.path.join(current_dir, "..", "..", "data", "nouns")
# 设置索引路径
self.index_path = os.path.join(self.output_dir, "professional_nouns_index")
def _loadfile(self, file_paths: List[str]) -> List[Dict[str, Any]]:
"""
加载多个专业术语JSON文件并合并
Args:
file_paths: JSON文件路径列表
Returns:
合并后的术语列表
"""
merged_terms = []
try:
for file_path in file_paths:
if not os.path.exists(file_path):
logging.warning(f"文件不存在: {file_path}")
continue
with open(file_path, "r", encoding="utf-8") as f:
terms_data = json.load(f)
if isinstance(terms_data, list):
merged_terms.extend(terms_data)
logging.info(f"{file_path} 加载了 {len(terms_data)} 条专业名词")
else:
logging.warning(f"文件格式错误: {file_path},应为JSON数组")
logging.info(f"总共加载了 {len(merged_terms)} 条专业名词")
return merged_terms
except Exception as e:
logging.error(f"加载多个文件失败: {e}")
return []
def vectorize_files_and_save(self, file_paths: List[str]) -> bool:
"""
处理多个文件:加载多个术语文件、创建索引并保存
Args:
file_paths: JSON文件路径列表
Returns:
处理成功返回True,否则返回False
"""
try:
# 加载多个文件的术语
terms = self._loadfile(file_paths)
if not terms:
logging.warning("未找到术语数据,退出处理")
return False
# 根据名称去重
unique_terms = {}
for term in terms:
name = term.get("name", "")
if name and name not in unique_terms:
unique_terms[name] = term
# 转换回列表
deduplicated_terms = list(unique_terms.values())
logging.info(f"去重后剩余 {len(deduplicated_terms)} 条专业名词")
# 准备数据
texts, metadatas = self._prepare_terms_for_faiss(deduplicated_terms)
# 创建索引
faiss_index = self._create_index(texts, metadatas)
# 保存索引
self._save_index(faiss_index)
logging.info("完成多文件专业名词向量化和索引创建")
return True
except Exception as e:
logging.error(f"多文件向量化处理失败: {e}")
return False
def _prepare_terms_for_faiss(self, terms: List[Dict[str, Any]]) -> Tuple[List[str], List[Dict]]:
"""
将术语准备为FAISS可用的格式 (内部方法)
Args:
terms: 术语列表
Returns:
格式化的术语文本列表和元数据列表
"""
texts = []
metadatas = []
for term in terms:
name = term["name"]
texts.append(name.strip())
synonyms = term.get("synonymous", [])
description = term.get("description", "")
# 记录元数据
metadatas.append({
"name": name,
"synonyms": synonyms,
"description": description
})
if len(synonyms) > 0:
synonyms_str = ', '.join(synonyms)
texts.append(synonyms_str.strip())
metadatas.append({
"name": name,
"synonyms": synonyms,
"description": description
})
if len(description) > 0:
texts.append(description.strip())
metadatas.append({
"name": name,
"synonyms": synonyms,
"description": description
})
return texts, metadatas
def _create_index(self, texts: List[str], metadatas: List[Dict]) -> FAISS:
"""
创建FAISS索引 (内部方法)
Args:
texts: 文本列表
metadatas: 元数据列表
Returns:
FAISS索引
"""
logging.info(f"正在创建FAISS索引,共 {len(texts)} 条数据...")
return FAISS.from_texts(texts=texts, embedding=self.embedding_model, metadatas=metadatas)
def _save_index(self, faiss_index: FAISS) -> None:
"""
保存FAISS索引到本地 (内部方法)
Args:
faiss_index: 要保存的FAISS索引
"""
try:
# 确保输出目录存在
os.makedirs(self.output_dir, exist_ok=True)
# 如果索引目录已存在,先删除
if os.path.exists(self.index_path):
shutil.rmtree(self.index_path)
# 保存FAISS索引
faiss_index.save_local(folder_path=self.index_path)
logging.info(f"FAISS索引已保存至 {self.index_path}")
except Exception as e:
logging.error(f"保存FAISS索引失败: {e}")
raise e
class ProfessionalNounRetriever:
"""专业名词检索类"""
def __init__(self,
embedding_model: Optional[Embeddings] = None,
api_key: str = None,
index_dir: str = None):
"""
初始化检索器并加载索引
Args:
embedding_model: 嵌入模型,如果为None则使用默认模型
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
index_dir: 索引目录路径,默认为None使用默认路径
"""
# 设置嵌入模型
if embedding_model:
self.embedding_model = embedding_model
else:
self.embedding_model = get_embedding_model(api_key)
# 设置索引路径
self.index_dir = index_dir
if self.index_dir is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
self.index_dir = os.path.join(current_dir, "..", "..", "data", "nouns", "professional_nouns_index")
# 在构造函数中加载索引
self.faiss_index = None
self._load_index()
def _load_index(self) -> None:
"""
从本地加载FAISS索引 (内部方法)
"""
try:
# 加载FAISS索引,启用不安全反序列化(仅用于可信数据源)
self.faiss_index = FAISS.load_local(
folder_path=self.index_dir,
embeddings=self.embedding_model,
allow_dangerous_deserialization=True
)
logging.info(f"成功从 {self.index_dir} 加载FAISS索引")
except Exception as e:
logging.warning(f"加载FAISS索引失败: {e}")
self.faiss_index = None
def query(self, query_text: str, top_k: int = 5, use_intersection: bool = True) -> List[Dict]:
"""
查询FAISS索引,获取最相似的专业名词 (唯一对外接口)
Args:
query_text: 查询文本
top_k: 返回的结果数量,默认为5
use_intersection: 是否使用三种检索方式的交集,默认为True
Returns:
相似度最高的专业名词列表
"""
try:
# 检查索引是否已加载
if self.faiss_index is None:
logging.warning("FAISS索引未加载,无法执行查询")
return []
# 使用三种检索方式并取交集
retriever1 = self.faiss_index.as_retriever(search_kwargs={"k": top_k})
retriever2 = self.faiss_index.as_retriever(
search_type="mmr",
search_kwargs={"k": top_k, "fetch_k": 3, "lambda_mult": 0.5}
)
retriever3 = self.faiss_index.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": 0.5}
)
# 用json.dumps将dict转为字符串,便于取交集
set1 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
for i in retriever1.invoke(query_text))
set2 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
for i in retriever2.invoke(query_text))
set3 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
for i in retriever3.invoke(query_text))
intersection = set1 | set2 | set3
# 如果交集为空,使用第一种检索方式的结果
if not intersection:
logging.warning("三种检索方式无交集,使用普通检索结果")
return [json.loads(item) for item in set1]
# 转回dict
return [json.loads(item) for item in intersection]
except Exception as e:
logging.error(f"查询FAISS索引失败: {e}")
return []
@@ -0,0 +1,130 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: PromptTemplates.py
Author: oyyz
Date: 2025-05-13
Description: 提示词模板
"""
extract_nouns_prompt="""
【智能关键词提取助手】
请根据用户问题自动识别核心关键词,并按照以下规则输出:
1. 只输出最终关键词列表,不要解释说明
2. 关键词提取范围包括但不限于以下内容:
- 软件相关:功能模块/操作步骤/报错提示/扩展名后缀名
- 造价专业:费用类型/计算标准/行业规范
- 电力工程:项目类型/设备型号/工程阶段
3. 自动展开缩写(如将'导excel'转为'Excel导入'
4. 严格基于用户问题提取关键词,不要输出与用户问题无关的关键词
三、输出格式:
{output_format}
四、用户问题:
{content}
"""
classification_info="""【垂直领域分类】:
1. 软件问题 -- 指涉及软件使用、功能询问、软件故障排查等方面的提问或请求。
2. 业务问题 -- 指涉及电力造价领域专业知识、造价费用计算等电力造价业务知识
3. 安装下载注册 -- 指涉及软件(或插件)安装下载、注册、激活等操作类问题。
4. 其他 -- 指与软件或电力造价专业无关的日常对话、问候、感慨、情绪表达等。
【软件问题包括以下两类】:
1. 软件功能:询问软件功能的使用、操作、位置等
2. 故障排查:软件运行异常、软件报错、软件显示错误等
【业务问题包括以下两类】:
1. 专业咨询:涉及电力造价规范、工程计价规则问题、行业标准解读等
2. 数据问题:涉及电力造价费用、造价指标等
【安装下载注册包括以下三类】:
1. 后缀名查询:询问有关软件后缀名、工程文件扩展名等问题,例如:BDY3是什么文件?、用什么软件打开.BDY3文件?
2. 软件锁类:询问软件锁信息、锁注册号查询、许可证查询、锁激活问题等软件锁相关问题
3. 安装下载类:安装下载咨询、组件(插件)选择、环境配置等
4. 问题排查类:软件安装下载失败、报错,系统兼容性问题等
【其他】:
1. 其他"""
classification_prompt="""
用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容,将其归类为以下垂直领域之一:
{classification_info}
【用户输入】:
{user_input}
【输出格式要求】:
{output_format}
【示例】
用户输入1: 技改T1怎样新建工程
输出1:
{
"vertical_classification":"软件咨询",
"sub_classification":"软件功能"
}
"""
query_rewrite_prompt = """
你是一名电力造价专业问答优化工程师,负责通过多维度信息整合重构用户问题以提升知识库检索准确率。请严格遵循以下流程处理:
# 任务处理框架
## 第一阶段:输入分析
1. 解析基础信息
- 原始问题(需保留核心语义){query}
- 关键词集合:{keywords}
## 第二阶段:语义匹配验证
2. 执行关键词校验
- 建立意图关联矩阵,验证关键词与原始问题的语义一致性
- 若存在≥1个有效关联词 → 进入重构流程
- 若无有效关联 → 直接输出原始问题
## 第三阶段:专业重构
3. 术语规范化处理
a. 实施术语映射:将口语表达替换为知识库标准术语
b. 执行结构优化:
- 采用【术语标记】规范标注关键概念
- 构建主谓宾明确的问题句式
- 保持原问题时态与语态特征
# 输出规范
{output_format}
# 示范案例库
▶ 案例1(有效匹配)
输入:
原始问题:怎么把旧版西藏定额工程转到Z1新版
关键词:【'老版本定额升级', '批量设置定额', '西藏造价软件Z1'
输出:
{{"rewrite":"【西藏造价软件Z1】如何执行【老版本定额升级】操作?"}}
▶ 案例2(无效匹配)
输入:
原始问题:程序界面文字显示过小如何处理?
关键词:【'定额升级', '工程批量导入'
输出:
{{"rewrite":"程序界面文字显示过小如何处理?"}}
# 质量约束条款
1. 语义内容保真原则
- 禁止修改原问题核心诉求(如转换主语/变更操作对象)
- 保留原始问题的限定条件
2. 术语使用规范
- 仅使用检索返回的关键词进行术语替换
- 新增术语必须来自关键词集合
3. 结构优化标准
- 问题长度控制在20字内
- 必须包含≥1个【标注术语】
- 禁止添加解释性语句
4. 异常处理机制
- 当关键词与问题无明显关联时,触发直通输出规则
- 出现术语冲突时优先保留原始表述
"""
+5
View File
@@ -0,0 +1,5 @@
#!/usr/bin/env python
from .ProfessionalNounVector import ProfessionalNounVectorizer, ProfessionalNounRetriever
from .IntentRecognition import IntentRecognizer
from .DataModels import Term, TermList, Classification, QueryRewrite
+256
View File
@@ -0,0 +1,256 @@
import os
import random
import time
from typing import List, Optional, Dict
from threading import Lock
API_KEY_LIST=[
"sk-xxaiabmfhzwwpijuledllkmkzhzwsqeicjxmjwnvriqpwmpk",
"sk-lldcprpqjhgdimiwewgbthngfbrazhkiuioubmaatrcpjjum",
"sk-bppugibbtvujomvoysnbcdzpcwndxtwrkfvmgbkbzcmobdon",
"sk-hnqitgdlfrrnpimcfxigqibstqquintnzpiidsshpajjyxqd",
"sk-hrojkkkrrkmsajtnizokbcgexsfggdiqavbtvbayuwqbnmom",
"sk-kkdklmnyompoiotzkfqahpayzlkgogfudjkyaebehtsowvid",
"sk-sfxzvllifafbyfduupcdtcrjwhdyiyojnksyopnfslurnhsp",
"sk-faqirxiszukfswqvzqawxnemqfacrkyurbxxkzwbbujqacdp",
"sk-vonaanuueqiczppkntjuphateshrcpqpnvxmwxorkyihjmrb",
"sk-qfpeoodgupcukcdstjcxgegwxnuhtxkkrupkogkcvhavxgny",
"sk-fsvjnbpfgoadixympaabaukupuhjvbturcbxaqfdzjznemtr",
"sk-fltvnbiqntfawjwkfnnhmyfiimzgzxkweqmefcfqkbucwrhi",
"sk-oosswdriwyqkglwdigvcxgmcpyplcyowicbaugpizoscevdl",
"sk-jswtxhkiralnyiukqimtyuurcaepulxdrfijadtxzrgsajyc",
"sk-dcjuhoukdyrbneadtxtnyxzmigkpiqgtqqnreiprxpioftsv",
"sk-yrhezyuxjblpaxzzudbowqmvcoxcammupcubghbodolikbdk",
"sk-dsgvwpfagmarilmnewwbzhfzlqehburoupjaopucdvybpbdo",
"sk-oljjlspuaurtoczyekztiidwtoerugadgepiufclpmrbdfqc",
"sk-crgrimubjesthvxuqwedqqdoetljyrgeahxxpctfefgnkpyo",
"sk-tubqhwgycxrdhwsqzjopxgeaqpsjdfppckckayvzornaluwq",
"sk-amcxlmsdnadptpnehqnkvseolacipztmvovnmxojzohbjjil",
"sk-pdyymhshpzmdduwxsezthnrgarnnhgzvmiflbpisfzxkiayt",
"sk-qhwoorywmejumyudfxbrkegxtqifsbgcdkmpjckezepgyqnz",
"sk-cpoctrgcnstaybeyuieuwjdgeakudhqdnnwdjavjudcbvvem",
]
class APIKeyManager:
"""
API密钥管理器,用于解析环境变量中的多个API密钥并提供获取接口
支持密钥轮转使用
"""
# 类变量,用于保存单例实例
_instance = None
_lock = Lock()
# 密钥使用计数和上次使用时间
_key_usage: Dict[str, Dict] = {}
# 当前正在使用的密钥索引
_current_index = 0
@classmethod
def get_instance(cls, env_var_name: str = "OPENAI_API_KEY", separator: str = ";"):
"""
获取单例实例
Args:
env_var_name: 环境变量名称,默认为'OPENAI_API_KEY'
separator: 密钥分隔符,默认为分号
Returns:
APIKeyManager实例
"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls(env_var_name, separator)
return cls._instance
@classmethod
def get_api_key(cls) -> Optional[str]:
"""
静态方法:获取一个API密钥,使用轮转策略
Returns:
API密钥,如果没有可用的密钥则返回None
"""
instance = cls.get_instance()
return instance._get_next_api_key()
@classmethod
def get_random_api_key(cls) -> Optional[str]:
"""
静态方法:随机获取一个API密钥
Returns:
API密钥,如果没有可用的密钥则返回None
"""
instance = cls.get_instance()
return instance._get_random_api_key()
@classmethod
def get_valid_api_keys(cls) -> List[str]:
"""
静态方法:获取有效的API密钥列表
Returns:
"""
# 验证每一个apikey是否有效,无效则删除并打印日志。地址https://api.siliconflow.cn/v1/
import requests
import logging
valid_api_keys = []
url = "https://api.siliconflow.cn/v1/chat/completions"
headers_template = {
"Content-Type": "application/json"
}
data = {
"model": "deepseek-ai/DeepSeek-V3",
"messages": [
{"role": "user", "content": "ping"}
],
"max_tokens": 1
}
for key in API_KEY_LIST:
headers = headers_template.copy()
headers["Authorization"] = f"Bearer {key}"
try:
resp = requests.post(url, headers=headers, json=data, timeout=8)
if resp.status_code == 200:
valid_api_keys.append(key)
else:
logging.warning(f"API密钥无效(被移除): {key}, 状态码: {resp.status_code}, 响应: {resp.text}")
except Exception as e:
logging.warning(f"API密钥验证异常(被移除): {key}, 错误: {e}")
return valid_api_keys
@classmethod
def count(cls) -> int:
"""
静态方法:获取API密钥数量
Returns:
API密钥数量
"""
instance = cls.get_instance()
return len(instance.api_keys)
def __init__(self, env_var_name: str = "OPENAI_API_KEY", separator: str = ";"):
"""
初始化API密钥管理器
Args:
env_var_name: 环境变量名称,默认为'OPENAI_API_KEY'
separator: 密钥分隔符,默认为分号
"""
self.env_var_name = env_var_name
self.separator = separator
self.api_keys = self._load_api_keys()
# 初始化密钥使用统计
for key in self.api_keys:
if key not in self._key_usage:
self._key_usage[key] = {
"count": 0,
"last_used": 0
}
def _load_api_keys(self) -> List[str]:
"""
从环境变量加载API密钥
Returns:
API密钥列表
"""
# api_keys = []
# env_value = os.environ.get(self.env_var_name)
# if env_value:
# # 分割环境变量并移除空白字符
# keys = [key.strip() for key in env_value.split(self.separator)]
# # 过滤掉空字符串
# api_keys = [key for key in keys if key]
# return api_keys
return API_KEY_LIST
def _get_next_api_key(self) -> Optional[str]:
"""
获取下一个API密钥,使用轮转策略
Returns:
API密钥,如果没有可用的密钥则返回None
"""
if not self.api_keys:
return None
with self._lock:
# 轮转到下一个密钥
self._current_index = (self._current_index + 1) % len(self.api_keys)
selected_key = self.api_keys[self._current_index]
# 更新使用统计
self._key_usage[selected_key]["count"] += 1
self._key_usage[selected_key]["last_used"] = time.time()
return selected_key
def _get_random_api_key(self) -> Optional[str]:
"""
随机获取一个API密钥
Returns:
API密钥,如果没有可用的密钥则返回None
"""
if not self.api_keys:
return None
with self._lock:
selected_key = random.choice(self.api_keys)
# 更新使用统计
self._key_usage[selected_key]["count"] += 1
self._key_usage[selected_key]["last_used"] = time.time()
return selected_key
def get_all_api_keys(self) -> List[str]:
"""
获取所有API密钥
Returns:
API密钥列表
"""
return self.api_keys.copy()
def is_valid(self) -> bool:
"""
检查是否有可用的API密钥
Returns:
如果有可用的API密钥则返回True,否则返回False
"""
return len(self.api_keys) > 0
def get_usage_stats(self) -> Dict:
"""
获取密钥使用统计信息
Returns:
密钥使用统计信息
"""
return self._key_usage.copy()
# 使用示例
if __name__ == "__main__":
# 获取有效的API密钥列表
valid_keys = APIKeyManager.get_valid_api_keys()
print(f"有效的API密钥列表:\n" + "\n".join(valid_keys))
# 查看总密钥数
print(f"总共有 {APIKeyManager.count()} 个API密钥")
# 获取实例并查看使用统计
instance = APIKeyManager.get_instance()
stats = instance.get_usage_stats()
for key, data in stats.items():
print(f"密钥 {key[:5]}... 使用次数: {data['count']}")
+143
View File
@@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: ModelTool.py
Date: 2025-05-15
Author: oyyz
Description: 模型工具类
"""
from openai import OpenAI
import httpx
import time
import logging # 导入 logging 模块
from langchain.embeddings.base import Embeddings
from typing import List, Any
import requests
import os
import logging
from .APIKeyManager import APIKeyManager
class SiliconFlowEmbeddings(Embeddings):
"""SiliconFlow嵌入模型封装"""
def __init__(self, api_key: str, model: str = "bge-m3"):
self.api_key = api_key
self.model = model
self.url = "http://10.1.16.39:9995/v1/embeddings"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
def _embed(self, input: List[str]) -> List[List[float]]:
payload = {
"model": self.model,
"input": input,
"encoding_format": "float"
}
response = requests.post(self.url, json=payload, headers=self.headers)
response.raise_for_status()
data = response.json()
return [item["embedding"] for item in data["data"]]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._embed(texts)
def embed_query(self, text: str) -> List[float]:
return self._embed([text])[0]
class XinferenceReRankerModel:
"""重排模型封装"""
@staticmethod
def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]:
"""
使用重排序模型对文档进行重新排序
Args:
query: 用户查询文本
documents: 需要重新排序的文档列表
top_k: 返回排序后的前k个文档
Returns:
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
"""
url = "http://10.1.16.39:9995/v1/rerank"
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": os.getenv("RERANKER_MODEL_NAME")}
headers = {
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
"Content-Type": "application/json"
}
try:
response = requests.post(url, json=params, headers=headers)
response.raise_for_status() # 检查响应状态
results = response.json()
# 返回重排序后的文档列表
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
except requests.exceptions.RequestException as e:
logging.error(f"重排序请求失败: {str(e)}")
return []
class OpenAiLLM:
def __init__(self, **kwargs):
if kwargs.get("api_key") == None or kwargs.get("base_url") == None or kwargs.get("model") == None:
raise ValueError("api_key, base_url, model 不能为空")
self._api_key = kwargs.get("api_key")
self._url = kwargs.get("base_url")
self._model = kwargs.get("model")
kwargs.pop("api_key")
kwargs.pop("base_url")
kwargs.pop("model")
self._kwargs = kwargs
def invoke(self, user_prompt="你是谁?", need_retry=True):
# 初始化 OpenAI 客户端
api_key = APIKeyManager.get_api_key()
client = OpenAI(api_key=api_key, base_url=self._url)
max_retries = 3
retry_count = 0
if need_retry:
while retry_count < max_retries:
try:
# 创建 Completion 请求. 超时120s
completion = client.chat.completions.create(
model=self._model,
messages=[{'role': 'user', 'content': user_prompt}],
timeout=httpx.Timeout(300.0),
**self._kwargs
)
return completion.choices[0].message
except Exception as e:
retry_count += 1
if retry_count == max_retries:
logging.error(f"LLM 重试{max_retries}次后仍然失败: {e}")
return ""
else:
time.sleep(5*retry_count) # 重试前等待1秒
else:
# 创建 Completion 请求. 超时120s
completion = client.chat.completions.create(
model=self._model,
messages=[{'role': 'user', 'content': user_prompt}],
timeout=httpx.Timeout(300.0),
**self._kwargs
)
return completion.choices[0].message
if __name__ == "__main__":
reranker = XinferenceReRankerModel()
query = "什么是AI"
documents = ["AI是人工智能", "AI是机器学习", "AI是深度学习"]
results = reranker.rerank(query, documents)
print(results)
+159
View File
@@ -0,0 +1,159 @@
import os.path
import requests
import json
import time
from pathlib import Path
class WikijsTool:
BASE_URL = "http://10.1.16.39:8090/graphql"
HEADERS = {
"Authorization": "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJhcGkiOjcsImdycCI6MSwiaWF"
"0IjoxNzIzMDIwNzg4LCJleHAiOjE4MTc2OTM1ODgsImF1ZCI6InVybjp3aWtpLmpzIiwiaX"
"NzIjoidXJuOndpa2kuanMifQ.NSfE4tB7tkN8yapAs0CgkR-Yll6wc3gO3QGKMAv-TlGxx6A-9fJRmkwhRDTVMj_yPVG6"
"NXVy_AZpJtLapRXFGn0cvscsRJxq3fY1KgEyt8wO99jvd8DpNHpHhAIgrtyDelmHsBD2Wb5Ib3WJFsWC6d8Yhm9dkpx6tZ"
"vMAlFIKOg6UodMoMIry3YWiPGLaqJPQ0gcKmcnB2tC7sPXIIZnvfb5912GVM0n-4wvWobQnb_tXQuYZf99wH_leXjC_7BK8"
"8JSaAmB980i3rBxfejmaJ8E6D48zRxwwPFa0veVjjzRkVqHPwAjl1CXb2HE29pGtNmSEE1kLQVqOZD_ibOwKQ"
}
def __init__(self):
pass
@staticmethod
def init_url():
# 获取当前文件的路径
file_path = Path(__file__).resolve()
file_path = os.path.join(file_path.parent, 'wikiconfig.json')
if not os.path.exists(file_path):
return False
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
if 'url' in data:
WikijsTool.BASE_URL = data['url']
if 'Authorization' in data:
WikijsTool.HEADERS['Authorization'] = data['Authorization']
return True
@staticmethod
def get_all_documents() -> list[dict]:
query = """
query Pages {
pages {
list {
path
locale
title
contentType
id
isPublished
}
}
}
"""
# 构建请求数据
data = {
'query': query,
}
# 发送 POST 请求
response = requests.post(WikijsTool.BASE_URL, headers=WikijsTool.HEADERS, json=data)
if response.status_code == 200:
# 解析数据
list_info = json.loads(response.content)['data']['pages']['list']
return [item for item in list_info]
else:
raise ValueError(f"获取文档列表失败,原因:“{response.text}")
@staticmethod
def get_all_doc_by_path(path: str, path_is_dir: bool = True) -> list[dict]:
list_document = WikijsTool.get_all_documents()
all_document_list = []
if path_is_dir:
temp_path = path + '/'
else:
temp_path = path
for document_info in list_document:
document_path = str(document_info["path"])
# 根据路径过滤出对应的所有文档
if not document_path.startswith(temp_path):
continue
all_document_list.append(document_info)
return all_document_list
@staticmethod
def search_document(query_str: str) -> list[dict]:
graphql_query = f"""
query Pages {{
pages {{
search(query: "{query_str}") {{
results {{
id
path
locale
title
}}
}}
}}
}}
"""
# 构建请求数据
data = {
'query': graphql_query,
}
# 发送 POST 请求
response = requests.post(WikijsTool.BASE_URL, headers=WikijsTool.HEADERS, json=data)
if response.status_code == 200:
# 解析数据
search_results = json.loads(response.content)['data']['pages']['search']['results']
return search_results
else:
raise ValueError(f"查询文档失败,原因:“{response.text}")
@staticmethod
def query_doc_info(doc_id: int) -> dict:
query = """
query singlePages($doc_id: Int!) {
pages {
single(id: $doc_id) {
id
path
title
isPublished
content
contentType
isPrivate
updatedAt
createdAt
}
}
}
"""
# 构建请求数据
variables = {
'doc_id': doc_id,
}
data = {
'query': query,
'variables': variables
}
# 发送 POST 请求
response = requests.post(WikijsTool.BASE_URL, headers=WikijsTool.HEADERS, json=data)
if "errors" in response.text:
result = json.loads(response.content)['errors'][0]['message']
return {}
else:
return json.loads(response.content)['data']['pages']['single']
WikijsTool.init_url()
if __name__ == "__main__":
WikijsTool.query_doc_info(6448)
print(WikijsTool.rename_directory("配网知识库/配网造价软件", "配网知识库/配网造价软件1"))
View File
Binary file not shown.
Binary file not shown.
Binary file not shown.
+3
View File
@@ -0,0 +1,3 @@
from . import custom_markdownify
convert_html_to_md = custom_markdownify.md
@@ -0,0 +1,491 @@
import re
from textwrap import fill
import requests
from bs4 import NavigableString
from bs4 import BeautifulSoup
from markdownify import MarkdownConverter, chomp, UNDERLINED, ATX_CLOSED
import copy
from . import picture_process
# <br>是否是单元格内部的换行符
def judge_br_in_table(el):
if el.name in ['td', 'tr']:
return True
if el.parent is None:
return False
# 递归父级元素
return judge_br_in_table(el.parent)
# 获取div标签中是否为标题,如果是标题则markdown中的返回标题等级
def get_markdown_title_level(el):
if el.name != 'div' or 'class' not in el.attrs:
return ''
title_level = ''
if 'hdwiki_tmml' in el.attrs['class']:
title_level = '## '
elif 'hdwiki_tmmll' in el.attrs['class']:
title_level = '### '
return title_level
def str_is_title(text) -> bool:
text = text.strip()
pattern = r'^#+'
# 使用re.search匹配字符串开头的 # 符号
match = re.search(pattern, text)
if match:
return True
else:
return False
# 判断el 是否是图片的DIV标签
def is_img_div_tag(el) -> bool:
if el is None:
return False
if el.name != "div":
return False
class_attr = el.get('class')
if class_attr is None:
return False
if "img" in class_attr or "img_l" in class_attr:
return True
else:
return False
# 判断div内部是否是纯文本内容,并且display是否为block
def is_only_text_div(el) -> bool:
if el is None or el.name != "div" or el.text == "":
return False
if el.get("display", "block") != "block":
return False
# div标签下只包含文本
if isinstance(el.string, NavigableString):
return True
# 兼容<div><b>1.&nbsp;版本概述</b>&nbsp;</div> 判断错误问题
# 递归获取所有子标签
child_tags = el.find_all(recursive=True)
for tag in child_tags:
if tag.text == "":
continue
if tag.name in ["table", "td", "img"]:
return False
if isinstance(tag.string, NavigableString):
continue
else:
return False
return True
# a标签是否在图片的div标签内部
def a_tag_is_in_img(el) -> bool:
if el.parent is None:
return False
if el.name != "a" or el.parent.name != "div":
return False
return is_img_div_tag(el.parent)
class CustomMarkDownConverter(MarkdownConverter):
"""
创建自定义的换行装换函数
"""
def __init__(self, img_download_path, **options):
super().__init__(**options)
self.img_download_path = img_download_path
# 单元格内的换行依旧保持<br>格式
def convert_br(self, el, text, convert_as_inline):
if judge_br_in_table(el):
return "<br/>"
# 容错处理(文章4696),因bs4解析html错误 导致将 分类图标签 解析到了br标签下导致图片丢失
if text.strip():
return text + "\n"
return super().convert_br(el, text, convert_as_inline)
# 图片div标签 在图片与图片描述之间添加换行
@staticmethod
def convert_img_div(text):
pattern = r'\*\*(.*?)\*\*'
match = re.search(pattern, text)
if match:
start_index = match.start()
text = text[:start_index] + "\n" + text[start_index:]
return text
# 装换标题格式
def convert_div(self, el, text, convert_as_inline):
title_level = get_markdown_title_level(el)
if title_level != '':
return "\n\n" + title_level + text + '\n\n'
if is_img_div_tag(el):
# 图片与图片描述文字之间掺入换行符
return self.convert_img_div(text)
if is_only_text_div(el):
text = "\n\n" + text + "\n\n"
return text
# 检查 URL 是否有效的函数
@staticmethod
def is_valid_url(url):
try:
response = requests.head(url, allow_redirects=True)
return response.status_code == 200
except requests.RequestException:
return False
@staticmethod
def try_complete_img_description(img_el):
if img_el is None or img_el.name != "img":
return
# 找到父级的div标签
img_el_parent_div = None
cur_el = img_el
while cur_el.parent is not None:
if is_img_div_tag(cur_el.parent):
img_el_parent_div = cur_el.parent
break
cur_el = cur_el.parent
if img_el_parent_div is not None and len(img_el_parent_div.text) != 0:
img_el.attrs["alt"] = img_el_parent_div.text
return
# 找到父级的figure标签
img_el_parent_div = None
cur_el = img_el
while cur_el.parent is not None:
if cur_el.parent is not None and cur_el.parent.name == 'figure':
img_el_parent_div = cur_el.parent
break
cur_el = cur_el.parent
if img_el_parent_div is not None and len(img_el_parent_div.text) != 0:
img_el.attrs["alt"] = img_el_parent_div.text
return
def convert_figcaption(self, el, text, convert_as_inline):
return ""
# 图片后添加空行,图片应该单独在一行后面不接文字(示例文章:6925)
def convert_img(self, el, text, convert_as_inline):
self.try_complete_img_description(el)
img_text = super().convert_img(el, text, convert_as_inline)
# 5195 出现img标签内出现换行导致 markdown图片显示出现问题
img_text = img_text.replace("\r\n", "")
img_text = img_text.replace("\n", "")
# 空的img标签直接返回空行
if img_text == "![]()":
return '\n\n'
# img 标签使用父级超链接标签中的中大图
src = el.attrs.get('src', None) or ''
if el.parent is not None and el.parent.name == "a":
href = el.parent.attrs.get('href', None) or ''
href_path = href.rsplit(".", 1)[0]
src_path = src.rsplit(".", 1)[0]
if href_path + "_s" == src_path:
img_text = img_text.replace(src, href)
if '_s' in img_text:
src_path = src.rsplit(".", 1)[0]
if src_path.endswith('_s'):
original_src_path = src_path[:-2] # 去掉末尾的 '_s'
# 构建原始 URL
original_url = original_src_path + "." + src.split(".")[-1]
if self.is_valid_url(original_url):
img_text = img_text.replace(src, original_url)
# 转换并下载图片
return picture_process.process_img_tag(img_text, self.img_download_path)
@staticmethod
def is_img_describe_strong(el) -> bool:
if el is None or el.parent is None:
return False
if len(el.contents) == 0:
return False
# if not isinstance(el.contents[0], NavigableString):
# return False
img_list = el.parent.findAll("img")
if len(img_list) == 0:
return False
for img_tag in img_list:
alt = img_tag.get("alt", None)
title = img_tag.get("title", None)
if alt is None and title is None:
continue
if alt == el.text or title == el.text:
return True
return False
def convert_b(self, el, text, convert_as_inline):
# 如果b 标签下只存在一个标题,则该b不做任何处理,避免对标题进行加粗(示例文章:6925)
if len(el.contents) == 1:
title_level = get_markdown_title_level(el.contents[0])
if title_level != '':
return text
# <b> 标签中存在标题时,不在对内容进行加粗
if str_is_title(text):
return text
if self.is_img_describe_strong(el):
return ""
text = text.strip(" \t")
suffix = ""
if text.endswith("\n"):
suffix = " \n"
b_text = super().convert_b(el, text, convert_as_inline)
# 解析完<b> 标签后添加空格。避免出现markdown文档中出现《**1.****版本概述**》(文章2377 4292等)
return " " + b_text + suffix + " "
convert_strong = convert_b
# 有可能出现<p>之后紧接一个标题hdwiki_tmml 故前后添加换行
def convert_p(self, el, text, convert_as_inline):
if convert_as_inline:
return text
if self.options['wrap']:
text = fill(text,
width=self.options['wrap_width'],
break_long_words=False,
break_on_hyphens=False)
# <p>标签前后换行
return '\n\n%s\n\n' % text if text else ''
def convert_a(self, el, text, convert_as_inline):
prefix, suffix, text = chomp(text)
if not text:
return ''
href = el.get('href')
if self.is_href_img(href):
return text
title = el.get('title')
# 5195 出现img标签内出现换行导致 markdown图片显示出现问题
if title is not None:
title = title.replace("\n", "")
# For the replacement see #29: text nodes underscores are escaped
if (self.options['autolinks']
and text.replace(r'\_', '_') == href
and not title
and not self.options['default_title']):
# Shortcut syntax
return '<%s>' % href
if self.options['default_title'] and not title:
title = href
title_part = ' "%s"' % title.replace('"', r'\"') if title else ''
a_tag = '%s[%s](%s%s)%s' % (prefix, text, href, title_part, suffix) if href else text
return a_tag
@staticmethod
def is_href_img(href_url) -> bool:
if href_url is None:
return False
file_extension = href_url.split(".")[-1]
# 不是图片不处理
file_extension = file_extension.lower()
if file_extension not in ["jpg", "jpeg", "png", "gif"]:
return False
return True
def convert_li(self, el, text, convert_as_inline):
# 为空的li标签返回空(文章 4347)
if not text.strip():
return ""
li_text = super().convert_li(el, text, convert_as_inline)
return li_text
def convert_td(self, el, text, convert_as_inline):
if "\r\n" in text:
text = text.replace("\r\n", "<br>")
if "\n" in text:
text = text.replace("\n", "<br>")
return ' ' + text + ' |'
def convert_hn(self, n, el, text, convert_as_inline):
if convert_as_inline:
return text
style = self.options['heading_style'].lower()
text = text.rstrip()
if style == UNDERLINED and n <= 2:
line = '=' if n == 1 else '-'
return self.underline(text, line)
hashes = '#' * n
hashes = hashes + " "
if style == ATX_CLOSED:
return '\n\n %s %s %s\n\n' % (hashes, text, hashes)
return '\n\n%s %s\n\n' % (hashes, text)
@staticmethod
def convert_thead_table(el, text, cell_name, convert_as_inline):
cells = el.find_all(['td', 'th'])
is_headrow = all([cell.name == cell_name for cell in cells])
overline = ''
underline = ''
if is_headrow and not el.previous_sibling:
# first row and is headline: print headline underline
underline += '| ' + ' | '.join(['---'] * len(cells)) + ' |' + '\n'
elif (not el.previous_sibling
and (el.parent.name == 'table'
or (el.parent.name == 'tbody'
and not el.parent.previous_sibling))):
# first row, not headline, and:
# - the parent is table or
# - the parent is tbody at the beginning of a table.
# print empty headline above this row
overline += '| ' + ' | '.join([''] * len(cells)) + ' |' + '\n'
overline += '| ' + ' | '.join(['---'] * len(cells)) + ' |' + '\n'
return overline + '|' + text + '\n' + underline
def convert_tr(self, el, text, convert_as_inline):
# 解决table标签下存在thead的问题 (文章4061 1976)
if el and el.parent and el.parent.name == "thead":
return CustomMarkDownConverter.convert_thead_table(el, text, 'td', convert_as_inline)
# 兼容 table->colgroup、tbody->tr 文章4364
if (el and el.parent and el.parent.previousSibling
and el.parent.name == "tbody"
and el.parent.previousSibling.name == "colgroup"):
return CustomMarkDownConverter.convert_thead_table(el, text, 'td', convert_as_inline)
return super().convert_tr(el, text, convert_as_inline)
def convert_pre(self, el, text, convert_as_inline):
# 文章5192出现pre标签,但内容不是代码。故不额外处理pre标签
return text
def escape(self, text):
if not text:
return ''
if self.options['escape_misc']:
# text = re.sub(r'([\\&<`[>~#=+|-])', r'\\\1', text)
text = re.sub(r'([\\&<`[>~#%=+|-])', r'\\\1', text)
# 以下的转义是不必要的
# text = re.sub(r'([0-9])([.)])', r'\1\\\2', text)
if self.options['escape_asterisks']:
text = text.replace('*', r'\*')
if self.options['escape_underscores']:
text = text.replace('_', r'\_')
return text
@staticmethod
def convert_span(el, text, convert_as_inline):
# 文章3526出现图片后面紧接图片文本的问题。图片文本在span标签内
if "style" not in el.attrs:
return text
style_attr = el.attrs['style']
if style_attr is None:
return text
style_content = style_attr.split(';')
# 遍历style属性内容,找到display的值
for item in style_content:
if 'display' in item:
display_value = item.split(': ')[1] # 获取冒号后的值
if display_value == "block" and text != "":
return f"\n\n{text}\n\n"
return text
def expand_html_table(html) -> tuple[str, bool]:
soup = BeautifulSoup(html, 'html.parser')
tables = soup.find_all('table')
if len(tables) == 0:
return html, False
for table in tables:
# 创建一个二维列表来表示表格
table_rows = table.find_all('tr')
max_cols = 0
for row in table_rows:
cols = row.find_all(['td', 'th'])
col_count = sum([int(col.get('colspan', 1)) for col in cols])
if col_count > max_cols:
max_cols = col_count
# 初始化一个二维列表来存储最终的表格
result_table = []
for _ in range(len(table_rows)):
result_table.append([None] * max_cols)
# 填充二维列表
for r, row in enumerate(table_rows):
cols = row.find_all(['td', 'th'])
c = 0
for col in cols:
while result_table[r][c] is not None:
c += 1
colspan = int(col.get('colspan', 1))
rowspan = int(col.get('rowspan', 1))
for i in range(rowspan):
for j in range(colspan):
# 拆分合并单元格时,重复内容
result_table[r + i][c + j] = copy.copy(col)
# if j == 0 and i == 0:
# result_table[r + i][c + j] = copy.copy(col)
# else:
# result_table[r + i][c + j] = soup.new_tag('td')
c += colspan
# 生成新的表格 HTML
new_table = soup.new_tag('table', border="1", cellspacing="0")
tbody = soup.new_tag('tbody')
new_table.append(tbody)
for row in result_table:
tr = soup.new_tag('tr')
for col in row:
if col is not None:
td = soup.new_tag(col.name)
td.string = col.get_text()
tr.append(td)
tbody.append(tr)
# 替换原始HTML中的旧表格
table.replace_with(new_table)
return str(soup), True
# Create shorthand method for conversion
def md(html, img_download_path, **options):
new_html, result = expand_html_table(html)
markdown_content = CustomMarkDownConverter(img_download_path, **options).convert(new_html)
# 删除换行符中间的空格
temp_txt = re.sub(r'\n\s*\n', '\n\n', markdown_content)
# 连续超过3个以上的换行符替换为3个
temp_txt = re.sub(r'\n{3,}', '\n\n\n', temp_txt)
return temp_txt
+170
View File
@@ -0,0 +1,170 @@
import base64
import hashlib
import logging
import os
import re
import uuid
from urllib.parse import urljoin
import requests
def get_img_tag_url(img_tag):
# 提取图片url的正则表达式模式
pattern = r'\!\[.*?\]\((.*?)\)'
# 找到第一个匹配的链接
match = re.search(pattern, img_tag)
if not match:
return ""
# 获取匹配到的链接
link = match.group(1)
# 第0个为链接
link = link.split(" ")[0]
return link
# 填充img标签中的图片链接
# img_tag '![1](http://wiki.jxbw.com/hdwiki/uploads/202303/1679471232U4iPCjtm_s.jpg "1")'
# img_tag '![1](uploads/202303/1679471232U4iPCj6tm_s.jpg "1")'
def fill_img_url(img_tag):
"""
填充img标签中的图片链接。
参数:
img_tag (str): 原始的img标签
返回:
tuple: 修改后的img标签和图片的完整链接
"""
# 一个完整的img标签内删除换行符
img_tag = img_tag.replace("\n", "")
link = get_img_tag_url(img_tag)
if len(link) == 0:
return img_tag, ''
base_url = os.getenv("IMG_URL_PREFIX")
if "http:" in link:
# 图片为全链接,不替换
return img_tag, link
elif base_url:
# 补全图片链接
full_link = urljoin(base_url, link)
img_tag = img_tag.replace(link, full_link)
return img_tag, full_link
else:
return img_tag, ''
def download_picture(img_tag, download_path):
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '
'Chrome/94.0.4606.71 Safari/537.36 '
}
img_tag, img_url = fill_img_url(img_tag)
if img_url == '':
return img_tag
# if "_s" in img_tag:
# breakpoint()
file_name = img_url.split("/")[-1]
file_path = os.path.normpath(download_path + "\\" + file_name)
file_path = file_path.replace("\\", "/")
# 文件已经存在时不下载
if not os.path.exists(file_path):
img_date = requests.get(url=img_url, headers=headers).content
logging.info(f"图片下载成功:{img_url}")
with open(file_path, 'wb') as fp:
fp.write(img_date)
# img_tag中的url替换为下载的图片路径
return img_tag.replace(img_url, file_path)
def download_picture_from_other_url(img_tag, download_path):
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '
'Chrome/94.0.4606.71 Safari/537.36 '
}
img_tag, img_url = fill_img_url(img_tag)
# if "_s" in img_tag:
# breakpoint()
file_name = uuid.uuid4()
file_path = os.path.join(download_path, f"{file_name}.png")
file_path = os.path.normpath(file_path)
# 文件已经存在时不下载
if not os.path.exists(file_path):
try:
img_date = requests.get(url=img_url, headers=headers).content
with open(file_path, 'wb') as fp:
fp.write(img_date)
logging.info(f"图片下载成功:{img_url}")
except Exception as e:
logging.warning(f"img download error url:{img_url}")
return img_tag
# img_tag中的url替换为下载的图片路径
return img_tag.replace(img_url, file_path)
def extract_base64_from_data_uri(data_uri):
# 分割字符串以找到 base64 部分
parts = data_uri.split(',')
if len(parts) == 2 and parts[0].endswith('base64'):
# 移除后缀并返回 base64 值
return parts[1][:-1]
else:
return None
def picture_base64(img_tag, picture_save_path):
# 解码Base64字符串
# ![](data:image/png;base64,1679471232U4iPCj6tm_s)
base64_str = extract_base64_from_data_uri(img_tag)
if picture_save_path is None or picture_save_path == "":
return "![picture](空)"
# 将图片内容做MD5 用作文件名
hash_object = hashlib.md5()
hash_object.update(base64_str.encode())
img_md5 = hash_object.hexdigest()
picture_save_path = picture_save_path + "\\%s.png" % img_md5
picture_save_path = os.path.normpath(picture_save_path)
picture_save_path = picture_save_path.replace("\\", "/")
# 文件已经存在时不重新保存
if not os.path.exists(picture_save_path):
decoded_string = base64.b64decode(base64_str)
with open(picture_save_path, 'wb') as fp:
fp.write(decoded_string)
# 修改img_tab的图片路径
match = re.search("\[(.*?)\]", img_tag)
result = ""
if match:
result = match.group(1)
if result == "":
return "![picture](%s)" % picture_save_path
else:
return "![%s](%s \"%s\")" % (result, picture_save_path, result)
def process_img_tag(str_img_tag, img_path):
# 如果img标签指向的是本地磁盘路径 则忽略该标签返回空
if "file:///" in str_img_tag:
logging.warning(f"存在非法的链接地址:{str_img_tag}")
return ""
if img_path is None or img_path == "":
return "![picture](空)"
img_url = get_img_tag_url(str_img_tag)
if "data:image/png;base64" in str_img_tag:
return picture_base64(str_img_tag, img_path)
# (4696等存在指向外部链接的 img标签。 暂时保留不删除)
elif "http://" in str_img_tag or "https://" in str_img_tag:
return download_picture_from_other_url(str_img_tag, img_path)
elif not img_url.startswith("http"):
return download_picture(str_img_tag, img_path)
else:
logging.warning(f"未处理的图片标签:{str_img_tag}")
return str_img_tag
+4
View File
@@ -0,0 +1,4 @@
{
"url":"http://10.1.0.145:8090/graphql",
"Authorization":"Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJhcGkiOjEsImdycCI6MSwiaWF0IjoxNzIzNjMxMjcwLCJleHAiOjE4MTgzMDQwNzAsImF1ZCI6InVybjp3aWtpLmpzIiwiaXNzIjoidXJuOndpa2kuanMifQ.g5H1xVMtk7Q3uvrRdtD3aTm49dQkS11cYdDKIwXo7DthOOTGj9DmFO7yILNDU7XFACTZc1Ej6ryguYV_8vGqoc-Rc7LciwvqS_RHDYUKZNKENbv8df9UGDMB-F9DT_airGc1lGJXgVqypxejDL3fY8aRMGXm7GBIlZKY4JTeI2uJZxffgfqKGrOvc3EOtsGgJzKZo4OyQ8UInGtCTiuq6-mLj_Syix_1z52K1tgfnF4E4-rZH_zCD05hUlUMYUV-KWhPkeOEGR5xbRTrulfCvzDD4T0CX4pI-keSKmgVn1HYSSN4o1Tj_l9zsyhUoLRzhzPK29Q3uekIc9obrvCHrg"
}