上传问题改写、意图识别模块代码
This commit is contained in:
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
提问内容补全工具
|
||||
|
||||
此模块用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||
并将原提问和补全后的提问保存到新的Excel文件中。
|
||||
|
||||
用法示例:
|
||||
completer = QuestionCompleter(input_path="历史提问数据(dislike).xlsx", output_path="补全后的提问数据.xlsx")
|
||||
completer.process()
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
import concurrent.futures
|
||||
from threading import Lock
|
||||
|
||||
class RewriteQuery(BaseModel):
|
||||
rewrite_query:str = Field(description="补全后的提问")
|
||||
software_name:str = Field(description="软件名称")
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
class QuestionCompleter:
|
||||
"""
|
||||
提问内容补全工具类
|
||||
|
||||
用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||
并将原提问和补全后的提问保存到新的Excel文件中。
|
||||
"""
|
||||
|
||||
def __init__(self, input_path="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
|
||||
output_path="/data/Rag2_0/data/excel/历史提问数据(dislike)_补全后的提问数据.xlsx",
|
||||
question_column="提问", answer_column="回答", max_workers=10):
|
||||
"""
|
||||
初始化提问内容补全工具
|
||||
|
||||
参数:
|
||||
input_path (str): 输入Excel文件路径
|
||||
output_path (str): 输出Excel文件路径
|
||||
question_column (str): 提问列的名称
|
||||
answer_column (str): 回答列的名称
|
||||
max_workers (int): 最大线程数
|
||||
"""
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
self.question_column = question_column
|
||||
self.answer_column = answer_column
|
||||
self.max_workers = max_workers
|
||||
self.rewrite_query_parser = PydanticOutputParser(pydantic_object=RewriteQuery)
|
||||
self.lock = Lock() # 添加线程锁
|
||||
|
||||
# 初始化LLM
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = os.getenv("OPENAI_API_BASE")
|
||||
self.model = os.getenv("LLM_MODEL_NAME")
|
||||
|
||||
if not all([self.api_key, self.base_url, self.model]):
|
||||
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
|
||||
|
||||
self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model)
|
||||
|
||||
# 读取Excel文件
|
||||
try:
|
||||
self.df = pd.read_excel(self.input_path)
|
||||
print(f"成功读取Excel文件: {self.input_path}")
|
||||
print(f"共有 {len(self.df)} 条记录")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"读取Excel文件失败: {str(e)}")
|
||||
|
||||
# 检查列是否存在
|
||||
if self.question_column not in self.df.columns:
|
||||
raise ValueError(f"Excel文件中不存在列: {self.question_column}")
|
||||
if self.answer_column not in self.df.columns:
|
||||
raise ValueError(f"Excel文件中不存在列: {self.answer_column}")
|
||||
|
||||
def create_completion_prompt(self, question, answer):
|
||||
"""
|
||||
创建用于补全提问的prompt
|
||||
|
||||
参数:
|
||||
question (str): 原始提问
|
||||
answer (str): 对应的回答
|
||||
|
||||
返回:
|
||||
str: 格式化的prompt
|
||||
"""
|
||||
prompt = f"""
|
||||
1、判断提问中是否缺少软件名称,如果不缺少,则直接返回原始提问
|
||||
2、如果缺少软件名称,则根据回答中的软件名称,补全提问
|
||||
3、补全后的提问需要保持问题原有意图不变
|
||||
|
||||
4、软件名称包括:
|
||||
配网D3软件(配网工程计价通D3)
|
||||
西藏Z1软件(西藏电力工程计价通Z1)
|
||||
主网计价通软件(电力建设计价通)
|
||||
技改检修工程计价通T1软件(技改检修工程计价通T1)
|
||||
技改检修清单计价通T1软件(技改检修清单计价通T1)
|
||||
储能C1软件(新型储能电站建设计价通C1)
|
||||
如果没有包含上述软件名称,则直接返回原始提问,software_name为空字符串
|
||||
{{
|
||||
"rewrite_query": "xxx",
|
||||
"software_name": ""
|
||||
}}
|
||||
|
||||
原始提问:{question}
|
||||
系统回答:{answer}
|
||||
|
||||
输出格式:
|
||||
{self.rewrite_query_parser.get_format_instructions()}
|
||||
|
||||
示例:
|
||||
例如,如果输入是:
|
||||
提问:这个软件怎么用?
|
||||
回答:Photoshop的使用方法是...
|
||||
|
||||
那么输出会是:
|
||||
{{
|
||||
"rewrite_query": "Photoshop这个软件怎么用?",
|
||||
"software_name": "Photoshop"
|
||||
}}
|
||||
|
||||
或者如果提问已经包含软件名称:
|
||||
提问:Photoshop怎么用?
|
||||
回答:Photoshop的使用方法是...
|
||||
|
||||
那么输出会是:
|
||||
{{
|
||||
"rewrite_query": "Photoshop怎么用?",
|
||||
"software_name": "Photoshop"
|
||||
}}
|
||||
|
||||
"""
|
||||
return prompt
|
||||
|
||||
def complete_question(self, question, answer):
|
||||
"""
|
||||
调用LLM补全提问内容
|
||||
|
||||
参数:
|
||||
question (str): 原始提问
|
||||
answer (str): 对应的回答
|
||||
|
||||
返回:
|
||||
str: 补全后的提问,如果补全失败则返回原始提问
|
||||
"""
|
||||
# 如果提问或回答为空,直接返回原始提问
|
||||
if pd.isna(question) or question.strip() == "" or pd.isna(answer) or answer.strip() == "":
|
||||
return question, ""
|
||||
|
||||
try:
|
||||
prompt = self.create_completion_prompt(question, answer)
|
||||
response = self.llm.invoke(prompt)
|
||||
completed_question = self.rewrite_query_parser.parse(response.content)
|
||||
return completed_question.rewrite_query, completed_question.software_name
|
||||
except Exception as e:
|
||||
print(f"补全提问失败: {str(e)}")
|
||||
return question, ""
|
||||
|
||||
def process_row(self, row):
|
||||
"""
|
||||
处理单行数据
|
||||
|
||||
参数:
|
||||
row: DataFrame中的一行
|
||||
|
||||
返回:
|
||||
dict: 处理结果
|
||||
"""
|
||||
original_question = row[self.question_column]
|
||||
answer = row[self.answer_column]
|
||||
|
||||
# 调用LLM补全提问
|
||||
completed_question, software_name = self.complete_question(original_question, answer)
|
||||
|
||||
# 创建结果字典
|
||||
result = {
|
||||
"原始提问": original_question,
|
||||
"补全后的提问": completed_question,
|
||||
"软件名称": software_name
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
使用多线程处理所有提问并补全内容
|
||||
|
||||
读取Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||
并将原提问和补全后的提问保存到新的Excel文件中
|
||||
"""
|
||||
results = []
|
||||
total = len(self.df)
|
||||
|
||||
# 使用进度条显示总体进度
|
||||
with tqdm(total=total, desc="补全提问") as pbar:
|
||||
# 创建线程池
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_idx = {executor.submit(self.process_row, self.df.iloc[idx]): idx for idx in range(total)}
|
||||
|
||||
# 处理完成的任务
|
||||
for future in concurrent.futures.as_completed(future_to_idx):
|
||||
result = future.result()
|
||||
with self.lock:
|
||||
results.append(result)
|
||||
pbar.update(1)
|
||||
|
||||
# 将结果转换为DataFrame并保存
|
||||
results_df = pd.DataFrame(results)
|
||||
results_df.to_excel(self.output_path, index=False)
|
||||
print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}")
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='补全Excel文件中的提问内容')
|
||||
parser.add_argument('-i', '--input', type=str, default="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
|
||||
help='输入Excel文件路径')
|
||||
parser.add_argument('-o', '--output', type=str, default="/data/Rag2_0/data/excel/补全后的提问数据.xlsx",
|
||||
help='输出Excel文件路径')
|
||||
parser.add_argument('-q', '--question', type=str, default="提问",
|
||||
help='提问列的名称')
|
||||
parser.add_argument('-a', '--answer', type=str, default="回答",
|
||||
help='回答列的名称')
|
||||
parser.add_argument('-w', '--workers', type=int, default=50,
|
||||
help='最大线程数')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建提问补全工具实例
|
||||
completer = QuestionCompleter(
|
||||
input_path=args.input,
|
||||
output_path=args.output,
|
||||
question_column=args.question,
|
||||
answer_column=args.answer,
|
||||
max_workers=args.workers
|
||||
)
|
||||
|
||||
# 执行处理
|
||||
completer.process()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: extract_wikijs_nouns.py
|
||||
Author: oyyz
|
||||
Description: 从 Wikijs 文档中提取专业名词
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
from dotenv import load_dotenv
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from rag2_0.tool.WikijsTool import WikijsTool
|
||||
from rag2_0.intent_recognition.DataModels import Term, TermList
|
||||
from rag2_0.tool.html_to_md import convert_html_to_md
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
import json
|
||||
import datetime
|
||||
import logging
|
||||
import threading
|
||||
import concurrent.futures
|
||||
from threading import Semaphore
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
extract_wiki_nouns_prompt="""
|
||||
我在完善我的专业词库,请从提供的电力行业造价软件相关文本中提取关键词,要求如下:
|
||||
|
||||
一、提取范围
|
||||
1. 核心功能模块
|
||||
(例:多工程批量计价、材机数据反算、变电工程智能组价、架空线路地形系数计算)
|
||||
2、软件功能及界面名称(包括:界面页签、功能按钮、功能名称等)
|
||||
(例:新建工程量清单、导出工程量清单等)
|
||||
3. 业务专用术语
|
||||
(例:装置性材料、甲供材保管费、施工降效补偿、电缆头试验配套费)
|
||||
4. 计价标准体系
|
||||
(例:预规2020版、电网检修定额2015版、配网工程概算定额)
|
||||
|
||||
|
||||
二、提取规则
|
||||
1. 识别核心功能名称(如"多工程批量设置工程量、工程设置密码")
|
||||
2. 提取业务专用名词(如"主材卸车保管费")
|
||||
3. 标注关联术语的对应关系(如"市场价"与"市场价格"互为同义词)
|
||||
4. 包含定额标准相关术语(如"预规2020版")
|
||||
5. 复合型术语需保持完整
|
||||
√ 正确:"地形增加系数批量设置"
|
||||
× 错误:"地形"、"系数"、"设置"
|
||||
6. 总结生成关键词解释
|
||||
关键词:编制依据
|
||||
描述:造价文件编制基准规范
|
||||
|
||||
7. 软件的特定版本号不作为关键词
|
||||
|
||||
三、输出格式:
|
||||
{output_format}
|
||||
|
||||
四、输入内容:
|
||||
{content}
|
||||
"""
|
||||
|
||||
|
||||
class WikijsNounsExtractor:
|
||||
"""从 Wikijs 文档中提取专业名词"""
|
||||
|
||||
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo"):
|
||||
"""
|
||||
初始化专业名词提取器
|
||||
|
||||
Args:
|
||||
api_key: API密钥,如果为None则从环境变量获取
|
||||
base_url: API基础URL,如果为None则使用默认URL
|
||||
model_name: 要使用的模型名称
|
||||
"""
|
||||
# 保存参数
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
|
||||
# 初始化LLM
|
||||
llm_params = {
|
||||
"temperature": 0.6,
|
||||
"model": model_name
|
||||
}
|
||||
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
|
||||
self.llm = OpenAiLLM(**llm_params)
|
||||
|
||||
# 准备术语列表解析器
|
||||
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
||||
|
||||
# 信号量,限制并发请求数量
|
||||
self.semaphore = None
|
||||
|
||||
# 线程锁,用于保护共享资源
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def _convert_html_to_md(self, content, title):
|
||||
"""HTML转Markdown"""
|
||||
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
|
||||
new_content = (content.replace("h6>", "h7>")
|
||||
.replace("h5>", "h6>")
|
||||
.replace("h4>", "h5>")
|
||||
.replace("h3>", "h4>")
|
||||
.replace("h2>", "h3>")
|
||||
.replace("h1>", "h2>"))
|
||||
# 将HTML内容转换为Markdown
|
||||
markdown_content = convert_html_to_md(new_content, "", **options)
|
||||
markdown_content = f"# {title}\n\n{markdown_content}"
|
||||
return markdown_content
|
||||
|
||||
def extract_from_document(self, doc_info: dict) -> List[Term]:
|
||||
"""从单个文档中提取专业名词"""
|
||||
try:
|
||||
# 使用LLM调用处理文档
|
||||
content = doc_info['content']
|
||||
title = doc_info["title"]
|
||||
|
||||
# 转换HTML到Markdown
|
||||
markdown_content = self._convert_html_to_md(content, title)
|
||||
|
||||
# 准备提示词
|
||||
formatted_prompt = extract_wiki_nouns_prompt.replace("{content}", markdown_content)
|
||||
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt)
|
||||
# 使用Pydantic解析器解析结果
|
||||
parsed_output = self.terms_list_parser.parse(response.content)
|
||||
return parsed_output.terms
|
||||
except Exception as e:
|
||||
logging.error(f"解析LLM响应时出错: {str(e)}")
|
||||
logging.error(f"原始响应: {response.content}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logging.error(f"提取专业名词时出错: {str(e)}")
|
||||
return []
|
||||
|
||||
def _process_document(self, doc, path_terms):
|
||||
"""处理单个文档"""
|
||||
try:
|
||||
# 获取信号量
|
||||
with self.semaphore:
|
||||
# 检查文档路径是否在我们要处理的路径中
|
||||
path_prefix = None
|
||||
for prefix in path_terms.keys():
|
||||
if doc['path'].startswith(prefix):
|
||||
path_prefix = prefix
|
||||
break
|
||||
|
||||
# 如果不在要处理的路径中,则跳过
|
||||
if not path_prefix:
|
||||
return None
|
||||
|
||||
# 获取文档详细信息
|
||||
doc_info = WikijsTool.query_doc_info(doc['id'])
|
||||
if not doc_info or not doc_info.get('content'):
|
||||
return None
|
||||
|
||||
# 提取专业名词
|
||||
terms = self.extract_from_document(doc_info)
|
||||
|
||||
# 将提取的术语添加到对应路径的结果列表中
|
||||
terms_dicts = [{"name": term.name, "synonymous": term.synonymous, "description": term.description} for term in terms]
|
||||
|
||||
with self.lock:
|
||||
path_terms[path_prefix].extend(terms_dicts)
|
||||
logging.info(f"文档 {doc['path']} 处理完成,提取了 {len(terms)} 个专业名词")
|
||||
|
||||
# 每处理10个文档保存一次中间结果
|
||||
current_count = len(path_terms[path_prefix])
|
||||
if current_count % 10 == 0:
|
||||
# 使用锁保护文件IO
|
||||
self._save_terms_to_file(path_terms[path_prefix], os.path.join(self.output_dir, f"{path_prefix.split('(')[0]}_nouns.json"))
|
||||
logging.info(f"已处理 {path_prefix} 的文档数达到 {current_count//10*10} 个,已保存中间结果")
|
||||
|
||||
return path_prefix
|
||||
except Exception as e:
|
||||
logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
def process_all_documents(self, output_dir: str = "extracted_nouns", max_concurrency: int = 5):
|
||||
"""使用线程池处理所有文档"""
|
||||
# 保存输出目录
|
||||
self.output_dir = output_dir
|
||||
|
||||
# 创建输出目录
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# 初始化信号量,限制并发请求数
|
||||
self.semaphore = Semaphore(max_concurrency)
|
||||
|
||||
# 获取所有文档
|
||||
all_docs = WikijsTool.get_all_documents()
|
||||
|
||||
# 要处理的路径前缀
|
||||
# path_prefixes = [
|
||||
# "技改检修计价通(2020)",
|
||||
# "西藏造价软件(2023)",
|
||||
# "新型储能电站建设计价通C1(2024)",
|
||||
# "配网造价软件(2022)",
|
||||
# ]
|
||||
path_prefixes = [
|
||||
"主网电力建设计价通(2018)",
|
||||
]
|
||||
# 为每个路径创建单独的结果列表
|
||||
path_terms = {prefix: [] for prefix in path_prefixes}
|
||||
|
||||
# 过滤出符合路径前缀的文档
|
||||
filtered_docs = []
|
||||
for doc in all_docs:
|
||||
for prefix in path_prefixes:
|
||||
if doc['path'].startswith(prefix):
|
||||
filtered_docs.append(doc)
|
||||
break
|
||||
|
||||
logging.info(f"开始使用线程池处理 {len(filtered_docs)} 个文档...")
|
||||
|
||||
# 使用线程池处理所有文档
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
futures = []
|
||||
for doc in filtered_docs:
|
||||
future = executor.submit(self._process_document, doc, path_terms)
|
||||
futures.append(future)
|
||||
|
||||
# 等待所有任务完成
|
||||
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||
try:
|
||||
prefix = future.result()
|
||||
if i % 10 == 0:
|
||||
logging.info(f"已完成 {i+1}/{len(futures)} 个文档的处理")
|
||||
except Exception as e:
|
||||
logging.error(f"处理文档时出错: {str(e)}")
|
||||
|
||||
# 保存最终结果
|
||||
for prefix, terms in path_terms.items():
|
||||
# 为每个路径保存单独的文件
|
||||
output_file = os.path.join(output_dir, f"{prefix.split('(')[0]}_nouns.json")
|
||||
self._save_terms_to_file(terms, output_file)
|
||||
logging.info(f"{prefix} 处理完成,共提取 {len(terms)} 个专业名词,已保存到 {output_file}")
|
||||
|
||||
def _save_terms_to_file(self, terms, output_file):
|
||||
"""保存术语列表到文件"""
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(terms, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
# 从环境变量获取配置
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
|
||||
# os.environ["LLM_MODEL_NAME"] = "Qwen/Qwen2.5-72B-Instruct-128K"
|
||||
|
||||
extractor = WikijsNounsExtractor(api_key=api_key, base_url=base_url, model_name=os.getenv("LLM_MODEL_NAME"))
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(current_dir, "..", "..", "data", "wiki_extracted_nouns")
|
||||
extractor.process_all_documents(output_dir=output_dir, max_concurrency=2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置日志输出到文件,并设置格式
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
date_format = '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
# 创建一个控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(logging.Formatter(log_format, date_format))
|
||||
|
||||
# 获取根日志记录器并添加处理器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.INFO)
|
||||
root_logger.addHandler(console_handler)
|
||||
main()
|
||||
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: intent_recognition_example.py
|
||||
Date: 2025-05-14
|
||||
Description: 意图识别和问题改写示例
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.intent_recognition import IntentRecognizer
|
||||
import pandas as pd
|
||||
import logging
|
||||
import json
|
||||
import concurrent.futures
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# 读取Excel文件中的提问数据
|
||||
def load_questions_from_excel(file_path=None):
|
||||
"""
|
||||
从Excel文件中读取提问数据
|
||||
|
||||
Args:
|
||||
file_path: Excel文件路径,如果为None则使用默认路径
|
||||
|
||||
Returns:
|
||||
提问列表
|
||||
"""
|
||||
|
||||
try:
|
||||
# 读取Excel文件的第一列数据
|
||||
df = pd.read_excel(file_path)
|
||||
questions = df.iloc[:, 0].tolist() # 获取第一列数据
|
||||
logging.info(f"成功从{file_path}读取了{len(questions)}条提问")
|
||||
return questions
|
||||
except Exception as e:
|
||||
logging.error(f"读取Excel文件时出错: {e}")
|
||||
return []
|
||||
|
||||
def process_query(recognizer, query):
|
||||
"""
|
||||
处理单个查询,支持重试机制
|
||||
|
||||
Args:
|
||||
recognizer: 意图识别器实例
|
||||
query: 查询字符串
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 如果是重试,添加重试信息到日志
|
||||
classification, keywords, rewrite, query_keys = recognizer.process_query(query)
|
||||
|
||||
# 将keywords对象转换为字符串
|
||||
keywords_str = ""
|
||||
if keywords and keywords.terms:
|
||||
term_details = []
|
||||
for term in keywords.terms:
|
||||
term_info = {
|
||||
"名称": term.name,
|
||||
"同义词": ";".join(term.synonymous) if term.synonymous else "",
|
||||
"描述": term.description
|
||||
}
|
||||
term_details.append(term_info)
|
||||
|
||||
# 将term_details转换为JSON字符串,确保中文正确显示
|
||||
keywords_str = json.dumps(term_details, ensure_ascii=False, indent=2)
|
||||
|
||||
# 处理成功,返回结果
|
||||
return {
|
||||
"提问": query,
|
||||
"问题拆解": query_keys,
|
||||
"一级分类": classification.vertical_classification,
|
||||
"二级分类": classification.sub_classification,
|
||||
"问题改写": rewrite.rewrite,
|
||||
"检索的关键词": keywords_str
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
|
||||
# 如果已经重试了最大次数,则记录错误并返回错误结果
|
||||
if retry_count > max_retries:
|
||||
logging.error(f"处理问题 '{query}' 时出错: {e.__class__}{e}")
|
||||
return {
|
||||
"提问": query,
|
||||
"一级分类": "处理出错",
|
||||
"二级分类": "处理出错",
|
||||
"问题改写": "处理出错",
|
||||
"检索的关键词": f"重试 {max_retries} 次后失败: {str(e)}"
|
||||
}
|
||||
else:
|
||||
# 可以在这里添加延迟,避免过快重试
|
||||
time.sleep(10 * retry_count)
|
||||
|
||||
examples_query = """下载软件在哪下载?"""
|
||||
|
||||
def main():
|
||||
"""
|
||||
意图识别和问题改写示例
|
||||
"""
|
||||
|
||||
# 从环境变量中获取配置
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
|
||||
# 初始化意图识别器
|
||||
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
|
||||
# 读取提问数据
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条提问数据.xlsx")
|
||||
examples = load_questions_from_excel(data_file)
|
||||
# examples = examples_query.split("\n")
|
||||
max_workers = 20
|
||||
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
|
||||
|
||||
# 创建一个与输入顺序相同的结果列表
|
||||
results = [None] * len(examples)
|
||||
|
||||
# 使用线程池进行并发处理
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务并记录它们的索引
|
||||
future_to_index = {}
|
||||
for idx, query in enumerate(examples):
|
||||
future = executor.submit(process_query, recognizer, query)
|
||||
future_to_index[future] = idx
|
||||
|
||||
# 使用tqdm显示进度条
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(examples), desc="处理进度"):
|
||||
idx = future_to_index[future]
|
||||
result = future.result()
|
||||
# 将结果放在与输入相同的位置
|
||||
results[idx] = result
|
||||
|
||||
# 将结果保存到Excel文件
|
||||
results_df = pd.DataFrame(results)
|
||||
|
||||
output_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条提问数据_重写结果.xlsx")
|
||||
|
||||
# 使用ExcelWriter设置格式
|
||||
with pd.ExcelWriter(output_file, engine='xlsxwriter') as writer:
|
||||
results_df.to_excel(writer, index=False, sheet_name='Sheet1')
|
||||
|
||||
# 获取工作簿和工作表对象
|
||||
workbook = writer.book
|
||||
worksheet = writer.sheets['Sheet1']
|
||||
|
||||
# 设置列宽(单位:像素)
|
||||
# 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位)
|
||||
worksheet.set_column('A:A', 60) # 提问列 60个Excel单位
|
||||
worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位
|
||||
worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位
|
||||
worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位
|
||||
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
|
||||
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
|
||||
|
||||
# 设置所有行高为20磅
|
||||
for i in range(len(results_df) + 1): # +1 是为了包括表头
|
||||
worksheet.set_row(i, 20)
|
||||
|
||||
logging.info(f"处理完成,结果已保存至: {output_file}")
|
||||
|
||||
def setup_logging():
|
||||
# 配置日志输出到控制台
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler() # 添加控制台处理器
|
||||
]
|
||||
)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logging()
|
||||
logging.info("意图识别示例程序开始运行...")
|
||||
main()
|
||||
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
答案正确性评判工具
|
||||
|
||||
此模块用于评判问题的新旧回答是否正确,通过与标准答案(Wiki内容)进行比较,
|
||||
或者在没有标准答案的情况下比较新旧回答的差异。
|
||||
|
||||
用法示例:
|
||||
judge = AnswerCorrectnessJudge()
|
||||
judge.process()
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from urllib.parse import unquote
|
||||
from rag2_0.tool.WikijsTool import WikijsTool
|
||||
from rag2_0.tool.html_to_md import convert_html_to_md
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
load_dotenv()
|
||||
|
||||
class AnswerCorrectnessJudge:
|
||||
"""
|
||||
答案正确性评判工具类
|
||||
|
||||
用于评估问题的新旧回答是否正确,可以通过与标准答案(Wiki内容)进行比较,
|
||||
或者在没有标准答案的情况下比较新旧回答的差异。
|
||||
"""
|
||||
|
||||
def __init__(self, wiki_excel_path="/data/Rag2_0/data/excel/部分提问_软件名称明确.xlsx",
|
||||
answer_excel_path="/data/Rag2_0/data/excel/主网软件提问_对比结果.xlsx",
|
||||
output_path="/data/Rag2_0/data/excel/主网软件提问回答_判断结果.xlsx"):
|
||||
"""
|
||||
初始化答案正确性评判工具
|
||||
|
||||
参数:
|
||||
wiki_excel_path (str): Wiki Excel文件路径
|
||||
answer_excel_path (str): 答案对比Excel文件路径
|
||||
output_path (str): 输出Excel文件路径
|
||||
"""
|
||||
self.wiki_excel_path = wiki_excel_path
|
||||
self.answer_excel_path = answer_excel_path
|
||||
self.output_path = output_path
|
||||
|
||||
# 读取Excel文件
|
||||
self.wiki_excel = pd.read_excel(self.wiki_excel_path)
|
||||
self.answer_excel = pd.read_excel(self.answer_excel_path)
|
||||
|
||||
# 初始化LLM
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = os.getenv("OPENAI_API_BASE")
|
||||
self.model = os.getenv("LLM_MODEL_NAME")
|
||||
|
||||
if not all([self.api_key, self.base_url, self.model]):
|
||||
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
|
||||
|
||||
self.openai_llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model)
|
||||
|
||||
def find_wiki_link(self, query) -> str | None:
|
||||
"""
|
||||
根据查询(对应wiki_excel中的新提问列)找出对应的词条链接
|
||||
|
||||
参数:
|
||||
query (str): 查询内容,对应wiki_excel中的新提问列
|
||||
|
||||
返回:
|
||||
str: 对应的词条链接,如果没有找到则返回None
|
||||
"""
|
||||
# 确保query不为空
|
||||
if not query or pd.isna(query):
|
||||
return None
|
||||
|
||||
# 在"新提问"列中查找匹配的行
|
||||
matched_rows = self.wiki_excel[self.wiki_excel['新提问'] == query]
|
||||
|
||||
# 如果找到了匹配的行,返回对应的词条链接
|
||||
if not matched_rows.empty:
|
||||
return matched_rows.iloc[0]['对应词条链接']
|
||||
|
||||
# 如果没有完全匹配,尝试部分匹配
|
||||
# 去除软件名称部分(如果有)
|
||||
query_parts = query.split(',', 1)
|
||||
if len(query_parts) > 1:
|
||||
clean_query = query_parts[1].strip()
|
||||
|
||||
# 在"提问"列中查找包含清理后查询的行
|
||||
for idx, row in self.wiki_excel.iterrows():
|
||||
if pd.notna(row['提问']) and clean_query in row['提问']:
|
||||
return row['对应词条链接']
|
||||
|
||||
return None
|
||||
|
||||
def get_wiki_content(self, link) -> str:
|
||||
"""
|
||||
获取词条链接的内容
|
||||
|
||||
参数:
|
||||
link (str): 词条链接
|
||||
|
||||
返回:
|
||||
str: 链接内容,如果获取失败则返回错误信息
|
||||
"""
|
||||
try:
|
||||
if not link or pd.isna(link):
|
||||
return "链接为空或无效"
|
||||
# 移除域名部分,只保留路径
|
||||
path = link.split('/', 3)[-1]
|
||||
decoded_path = unquote(path)
|
||||
path_parts = decoded_path.split('/')
|
||||
doc_path = "/".join(path_parts[1:])
|
||||
wiki_doc = WikijsTool.get_all_doc_by_path(path=doc_path, path_is_dir=False)
|
||||
html_content = WikijsTool.query_doc_info(wiki_doc[0]["id"]).get('content')
|
||||
if not html_content:
|
||||
return "获取内容失败"
|
||||
|
||||
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
|
||||
new_content = (html_content.replace("h6>", "h7>")
|
||||
.replace("h5>", "h6>")
|
||||
.replace("h4>", "h5>")
|
||||
.replace("h3>", "h4>")
|
||||
.replace("h2>", "h3>")
|
||||
.replace("h1>", "h2>"))
|
||||
# 将HTML内容转换为Markdown
|
||||
markdown_content = convert_html_to_md(new_content, "", **options)
|
||||
markdown_content = f"# {path_parts[-1]}\n\n{markdown_content}"
|
||||
return markdown_content
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
|
||||
|
||||
def create_prompt(self, standard_answer: str, answer_to_check: str) -> str:
|
||||
"""
|
||||
创建用于评判答案的prompt
|
||||
|
||||
参数:
|
||||
standard_answer (str): 标准答案
|
||||
answer_to_check (str): 需要检查的答案
|
||||
|
||||
返回:
|
||||
str: 格式化的prompt
|
||||
"""
|
||||
return f"""请作为一个专业的答案评判专家,评估以下回答与标准答案的匹配程度。
|
||||
|
||||
标准答案:
|
||||
{standard_answer}
|
||||
|
||||
待评估的回答:
|
||||
{answer_to_check}
|
||||
|
||||
请仔细分析两个答案的内容,并给出你的判断。只需要回答"正确"或"错误",不需要其他解释。
|
||||
如果待评估的回答与标准答案在核心内容和关键信息(步骤)上一致,即使表达方式不同,也应判定为"正确"。
|
||||
如果待评估的回答存在明显的错误信息或重要信息缺失,应判定为"错误"。
|
||||
|
||||
请严格按以下格式输出:【正确】或【错误】:"""
|
||||
|
||||
def judge_old_answer(self, standard_answer: str, old_answer: str) -> bool | None:
|
||||
"""
|
||||
调用LLM判断旧回答是否正确
|
||||
|
||||
参数:
|
||||
standard_answer (str): 标准答案(来自Wiki)
|
||||
old_answer (str): 旧流程的回答
|
||||
|
||||
返回:
|
||||
bool | None: 判断结果,True表示正确,False表示错误,None表示判断失败
|
||||
"""
|
||||
prompt = self.create_prompt(standard_answer, old_answer)
|
||||
try:
|
||||
response = self.openai_llm.invoke(prompt)
|
||||
return "正确" in response.content
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def judge_new_answer(self, standard_answer: str, new_answer: str) -> bool | None:
|
||||
"""
|
||||
调用LLM判断新回答是否正确
|
||||
|
||||
参数:
|
||||
standard_answer (str): 标准答案(来自Wiki)
|
||||
new_answer (str): 新流程的回答
|
||||
|
||||
返回:
|
||||
bool | None: 判断结果,True表示正确,False表示错误,None表示判断失败
|
||||
"""
|
||||
prompt = self.create_prompt(standard_answer, new_answer)
|
||||
try:
|
||||
response = self.openai_llm.invoke(prompt)
|
||||
return "正确" in response.content
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def judge_by_standard_answer(self, standard_answer: str, old_answer: str, new_answer: str) -> str | None:
|
||||
"""
|
||||
综合判断新旧回答的正确性
|
||||
|
||||
参数:
|
||||
standard_answer (str): 标准答案(来自Wiki)
|
||||
old_answer (str): 旧流程的回答
|
||||
new_answer (str): 新流程的回答
|
||||
|
||||
返回:
|
||||
str | None: 包含新旧回答判断结果的字符串,None表示判断失败
|
||||
"""
|
||||
old_result = self.judge_old_answer(standard_answer, old_answer)
|
||||
new_result = self.judge_new_answer(standard_answer, new_answer)
|
||||
if old_result is None or new_result is None:
|
||||
return None
|
||||
if new_result and old_result:
|
||||
return "新旧答案均正确"
|
||||
elif new_result and not old_result:
|
||||
return "新答案正确"
|
||||
elif not new_result and old_result:
|
||||
return "旧答案正确"
|
||||
else:
|
||||
return "新旧答案均错误"
|
||||
|
||||
def judge_answer_diff(self, old_answer: str, new_answer: str) -> str | None:
|
||||
"""
|
||||
判断新旧回答是否存在较大差异
|
||||
|
||||
参数:
|
||||
old_answer (str): 旧流程的回答
|
||||
new_answer (str): 新流程的回答
|
||||
|
||||
返回:
|
||||
str | None: 差异判断结果,None表示判断失败
|
||||
"""
|
||||
|
||||
prompt = f"""请判断以下两个回答是否存在较大差异:
|
||||
|
||||
旧回答: {old_answer}
|
||||
|
||||
新回答: {new_answer}
|
||||
|
||||
主要是关键步骤、关键信息、或者关键主体的差异
|
||||
请仅回答"存在较大差异"或"差异较小"。"""
|
||||
|
||||
try:
|
||||
response = self.openai_llm.invoke(prompt)
|
||||
return "无法判断,新老答案差异较大" if "存在较大差异" in response.content else "无法判断,新老答案基本相同"
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
处理所有问题并评判答案正确性
|
||||
|
||||
读取Excel文件中的问题和答案,进行评判,并将结果保存到输出Excel文件
|
||||
"""
|
||||
# 创建结果列表
|
||||
results = []
|
||||
|
||||
# 读取Excel文件
|
||||
for idx, row in tqdm(self.answer_excel.iterrows(), total=len(self.answer_excel), desc="处理问题"):
|
||||
query = row["问题"]
|
||||
old_answer = row["旧流程答案"]
|
||||
new_answer = row["新流程答案"]
|
||||
standard_answer = ""
|
||||
|
||||
try:
|
||||
wiki_url = self.find_wiki_link(query)
|
||||
if wiki_url and not pd.isna(wiki_url):
|
||||
standard_answer = self.get_wiki_content(wiki_url)
|
||||
except Exception as e:
|
||||
print(f"处理问题 '{query}' 时发生错误: {str(e)}")
|
||||
|
||||
if standard_answer:
|
||||
# 判断答案正确性
|
||||
judge_result = self.judge_by_standard_answer(standard_answer, old_answer, new_answer)
|
||||
else:
|
||||
judge_result = self.judge_answer_diff(old_answer, new_answer)
|
||||
|
||||
if judge_result is None:
|
||||
judge_result = ""
|
||||
|
||||
results.append({
|
||||
"问题": query,
|
||||
"旧流程答案": old_answer,
|
||||
"新流程答案": new_answer,
|
||||
"判断结果": judge_result
|
||||
})
|
||||
|
||||
# 将结果转换为DataFrame并保存
|
||||
results_df = pd.DataFrame(results)
|
||||
results_df.to_excel(self.output_path, index=False)
|
||||
print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}")
|
||||
|
||||
# 测试函数
|
||||
if __name__ == "__main__":
|
||||
# 创建答案正确性评判工具实例
|
||||
judge = AnswerCorrectnessJudge()
|
||||
# 执行处理
|
||||
judge.process()
|
||||
@@ -0,0 +1,615 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
完整性问题判断工具
|
||||
|
||||
此脚本用于读取Excel文件中的问题,调用LLM判断问题是否完整,并将结果保存到Excel文件中。
|
||||
|
||||
用法示例:
|
||||
python judge_query_full.py -i "问题数据.xlsx" -o "完整问题结果.xlsx" -w 50 -c 0
|
||||
|
||||
命令行参数:
|
||||
-i, --input: 输入Excel文件路径
|
||||
-o, --output: 输出Excel文件路径
|
||||
-w, --workers: 并发处理的最大线程数
|
||||
-c, --column: 要处理的问题所在列的索引(从0开始)
|
||||
-t, --test: 测试单个问题,不处理Excel文件
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from rag2_0.tool.APIKeyManager import APIKeyManager
|
||||
from openpyxl.utils import get_column_letter
|
||||
from openpyxl.styles import Alignment, PatternFill, Font, Border, Side
|
||||
from tqdm import tqdm
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
# 默认设置
|
||||
DEFAULT_EXCEL_PATH = r"/data/Rag2_0/data/excel/7000条对话数据.xlsx"
|
||||
DEFAULT_OUTPUT_PATH = r"/data/Rag2_0/data/excel/7000条对话数据_完整问题结果.xlsx"
|
||||
DEFAULT_MAX_WORKERS = 50
|
||||
|
||||
|
||||
class QueryCompletenessJudge:
|
||||
"""
|
||||
问题完整性判断工具类
|
||||
|
||||
用于评估问题是否完整,并将结果保存到Excel文件中。
|
||||
可以批量处理Excel文件中的问题,也可以测试单个问题。
|
||||
"""
|
||||
|
||||
def __init__(self, input_path=DEFAULT_EXCEL_PATH, output_path=DEFAULT_OUTPUT_PATH,
|
||||
max_workers=DEFAULT_MAX_WORKERS, column_index=0):
|
||||
"""
|
||||
初始化问题完整性判断工具
|
||||
|
||||
参数:
|
||||
input_path (str): 输入Excel文件路径
|
||||
output_path (str): 输出Excel文件路径
|
||||
max_workers (int): 并发处理的最大线程数
|
||||
column_index (int): 要处理的问题所在列的索引(从0开始)
|
||||
"""
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
self.max_workers = max_workers
|
||||
self.column_index = column_index
|
||||
self.llm_client = self._create_llm_client()
|
||||
|
||||
def _extract_json_from_response(self, full_answer):
|
||||
"""
|
||||
从LLM响应中提取JSON部分
|
||||
|
||||
参数:
|
||||
full_answer (str): LLM的完整响应文本
|
||||
|
||||
返回:
|
||||
dict: 解析后的JSON对象,如果解析失败则返回None
|
||||
"""
|
||||
# 尝试从回答中提取JSON部分
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', full_answer, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
# 如果没有找到```json```格式,尝试寻找普通的JSON对象
|
||||
json_match = re.search(r'({[\s\S]*"is_complete"[\s\S]*})', full_answer)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
# 如果仍然没有找到,返回None
|
||||
return None
|
||||
|
||||
try:
|
||||
# 解析JSON
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def _create_llm_prompt(self, question):
|
||||
"""
|
||||
创建LLM提示词
|
||||
|
||||
参数:
|
||||
question (str): 需要判断完整性的问题
|
||||
|
||||
返回:
|
||||
str: 格式化后的提示词
|
||||
"""
|
||||
return f"""你是一个电力造价行业专家,用户正在使用电力造价软件,并提出了相关问题。请分析以下问题是否完整。
|
||||
|
||||
问题:{question}
|
||||
|
||||
首先,分析这个问题的结构和内容,思考它是否包含足够的信息来表达清晰的意图。
|
||||
考虑以下几点:
|
||||
1. 问题是否有明确的核心意图,不需要面面俱到
|
||||
2. 问题是否缺少必要的上下文
|
||||
3. **问题如果涉及软件相关,则只需要包含:软件名称、软件功能或软件目的即可**
|
||||
|
||||
|
||||
在你的分析之后,请用JSON格式给出最终结论,格式如下:
|
||||
```json
|
||||
{{
|
||||
"is_complete": true或false,
|
||||
"reason": "判断原因的简要说明",
|
||||
"confidence": 0到100之间的数值,表示你对判断的置信度
|
||||
}}
|
||||
```
|
||||
|
||||
请确保JSON格式正确,以便于程序解析。"""
|
||||
|
||||
def _create_llm_client(self, api_key=None):
|
||||
"""
|
||||
创建LLM客户端
|
||||
|
||||
参数:
|
||||
api_key (str, optional): API密钥,如果为None则从APIKeyManager获取
|
||||
|
||||
返回:
|
||||
OpenAiLLM: LLM客户端实例
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = APIKeyManager.get_api_key()
|
||||
|
||||
return OpenAiLLM(
|
||||
api_key=api_key,
|
||||
base_url="https://api.siliconflow.cn/v1", # 可以根据实际情况修改
|
||||
model="deepseek-ai/DeepSeek-V3", # 可以根据实际情况修改
|
||||
temperature=0.2,
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
def is_question_complete(self, question):
|
||||
"""
|
||||
调用LLM判断问题是否完整
|
||||
|
||||
参数:
|
||||
question (str): 需要判断的问题
|
||||
|
||||
返回:
|
||||
tuple: (bool, str) - 是否完整的布尔值和LLM的详细回复
|
||||
"""
|
||||
# 最大重试次数
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
retry_delay = 2 # 重试延迟,单位:秒
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 创建提示词
|
||||
prompt = self._create_llm_prompt(question)
|
||||
|
||||
# 使用OpenAiLLM调用模型
|
||||
response = self.llm_client.invoke(prompt)
|
||||
|
||||
# 处理可能的响应格式
|
||||
if hasattr(response, 'content'):
|
||||
full_answer = response.content
|
||||
else:
|
||||
# 如果response是字符串
|
||||
full_answer = str(response)
|
||||
|
||||
# 提取JSON部分
|
||||
result = self._extract_json_from_response(full_answer)
|
||||
|
||||
if result:
|
||||
is_complete = result.get("is_complete", False)
|
||||
return is_complete, full_answer
|
||||
else:
|
||||
# 如果没有找到或解析失败,使用简单判断
|
||||
is_complete = "完整" in full_answer[:100]
|
||||
return is_complete, full_answer
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count <= max_retries:
|
||||
# 非最后一次重试,打印错误并继续
|
||||
time.sleep(retry_delay)
|
||||
# 每次重试增加延迟时间,避免频繁失败
|
||||
retry_delay *= 2
|
||||
else:
|
||||
# 已达到最大重试次数,返回错误
|
||||
print(f"错误: 经过 {max_retries} 次重试后仍然失败: {str(e)}")
|
||||
return False, f"错误: 经过 {max_retries} 次重试后仍然失败: {str(e)}"
|
||||
|
||||
# 不应该到达这里,但为了代码完整性添加
|
||||
return False, "未知错误:重试机制逻辑错误"
|
||||
|
||||
def _process_question(self, args, complete_questions, progress_counter, progress_lock, complete_questions_lock, pbar):
|
||||
"""
|
||||
处理单个问题并更新进度
|
||||
|
||||
参数:
|
||||
args (tuple): 包含问题索引、问题内容、LLM客户端和总问题数的元组
|
||||
complete_questions (list): 存储完整问题的列表
|
||||
progress_counter (dict): 进度计数器
|
||||
progress_lock (threading.Lock): 进度锁
|
||||
complete_questions_lock (threading.Lock): 完整问题列表锁
|
||||
pbar (tqdm): 进度条对象
|
||||
"""
|
||||
index, question, llm_client, total_questions = args
|
||||
|
||||
# 跳过空问题
|
||||
if pd.isna(question) or question.strip() == "":
|
||||
with progress_lock:
|
||||
progress_counter["processed"] += 1
|
||||
pbar.update(1)
|
||||
return None
|
||||
|
||||
# 调用LLM判断问题是否完整
|
||||
is_complete, full_answer = self.is_question_complete(question)
|
||||
|
||||
if is_complete:
|
||||
# 从答案中提取JSON
|
||||
parsed_json = self._extract_json_from_response(full_answer)
|
||||
|
||||
if parsed_json:
|
||||
# 构造包含解析出的JSON信息的结果
|
||||
result = {
|
||||
"问题": question,
|
||||
"LLM回复": full_answer,
|
||||
"完整性": "完整" if parsed_json.get("is_complete", False) else "不完整",
|
||||
"原因": parsed_json.get("reason", "未提供"),
|
||||
"置信度": parsed_json.get("confidence", 0)
|
||||
}
|
||||
|
||||
# 更新计数
|
||||
with progress_lock:
|
||||
if result["完整性"] == "完整":
|
||||
progress_counter["complete"] += 1
|
||||
else:
|
||||
progress_counter["incomplete"] += 1
|
||||
else:
|
||||
# JSON解析失败,只保存原始回答
|
||||
result = {
|
||||
"问题": question,
|
||||
"LLM回复": full_answer,
|
||||
"完整性": "完整"
|
||||
}
|
||||
|
||||
# 更新计数
|
||||
with progress_lock:
|
||||
progress_counter["complete"] += 1
|
||||
|
||||
with complete_questions_lock:
|
||||
complete_questions.append(result)
|
||||
else:
|
||||
with progress_lock:
|
||||
progress_counter["incomplete"] += 1
|
||||
# 更新进度条
|
||||
with progress_lock:
|
||||
progress_counter["processed"] += 1
|
||||
# 更新进度条描述
|
||||
pbar.set_postfix(
|
||||
完整=progress_counter["complete"],
|
||||
不完整=progress_counter["incomplete"],
|
||||
完整率=f"{progress_counter['complete']/max(1, progress_counter['processed']):.1%}"
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
def _shorten_response(self, response):
|
||||
"""
|
||||
截断LLM响应,提取重要信息
|
||||
|
||||
参数:
|
||||
response (str): 原始LLM响应
|
||||
|
||||
返回:
|
||||
str: 截断后的响应
|
||||
"""
|
||||
# 保留思考过程的前200个字符和JSON部分
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
|
||||
if json_match:
|
||||
json_part = json_match.group(0)
|
||||
prefix = response[:200] + "..." if len(response) > 200 else response
|
||||
return f"{prefix}\n\n{json_part}"
|
||||
return response[:500] + "..." if len(response) > 500 else response
|
||||
|
||||
def _prepare_excel_dataframe(self, complete_questions):
|
||||
"""
|
||||
将结果处理为DataFrame用于Excel输出
|
||||
|
||||
参数:
|
||||
complete_questions (list): 完整问题列表
|
||||
|
||||
返回:
|
||||
pandas.DataFrame: 处理后的DataFrame
|
||||
"""
|
||||
# 将结果列表转换为DataFrame
|
||||
result_df = pd.DataFrame(complete_questions)
|
||||
|
||||
# 处理LLM回复列,截取一定长度以避免Excel单元格过大
|
||||
if "LLM回复" in result_df.columns:
|
||||
result_df["LLM回复"] = result_df["LLM回复"].apply(self._shorten_response)
|
||||
|
||||
# 调整列的顺序,确保重要列在前面
|
||||
column_order = ["问题", "完整性", "置信度", "原因", "LLM回复"]
|
||||
# 过滤掉不存在的列
|
||||
column_order = [col for col in column_order if col in result_df.columns]
|
||||
# 确保所有剩余的列也被包含
|
||||
for col in result_df.columns:
|
||||
if col not in column_order:
|
||||
column_order.append(col)
|
||||
|
||||
# 重新排序列
|
||||
return result_df[column_order]
|
||||
|
||||
def _set_excel_column_widths(self, worksheet):
|
||||
"""
|
||||
设置Excel列宽
|
||||
|
||||
参数:
|
||||
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
||||
"""
|
||||
for col in range(1, worksheet.max_column + 1):
|
||||
col_letter = get_column_letter(col)
|
||||
column_name = worksheet[f"{col_letter}1"].value
|
||||
|
||||
if column_name == "问题":
|
||||
worksheet.column_dimensions[col_letter].width = 40
|
||||
elif column_name == "LLM回复":
|
||||
worksheet.column_dimensions[col_letter].width = 60
|
||||
elif column_name == "原因":
|
||||
worksheet.column_dimensions[col_letter].width = 30
|
||||
elif column_name == "完整性":
|
||||
worksheet.column_dimensions[col_letter].width = 10
|
||||
elif column_name == "置信度":
|
||||
worksheet.column_dimensions[col_letter].width = 10
|
||||
else:
|
||||
worksheet.column_dimensions[col_letter].width = 15
|
||||
|
||||
def _apply_excel_cell_styles(self, worksheet):
|
||||
"""
|
||||
应用单元格样式
|
||||
|
||||
参数:
|
||||
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
||||
|
||||
返回:
|
||||
openpyxl.styles.Border: 边框样式,用于统计信息
|
||||
"""
|
||||
# 定义样式
|
||||
header_fill = PatternFill(start_color="DDEBF7", end_color="DDEBF7", fill_type="solid")
|
||||
header_font = Font(bold=True)
|
||||
wrap_alignment = Alignment(wrap_text=True, vertical="top")
|
||||
border = Border(
|
||||
left=Side(style='thin'),
|
||||
right=Side(style='thin'),
|
||||
top=Side(style='thin'),
|
||||
bottom=Side(style='thin')
|
||||
)
|
||||
|
||||
# 应用样式到每个单元格
|
||||
for row in worksheet.iter_rows(min_row=1, max_row=worksheet.max_row, min_col=1, max_col=worksheet.max_column):
|
||||
for cell in row:
|
||||
cell.alignment = wrap_alignment
|
||||
cell.border = border
|
||||
|
||||
# 为标题行应用特殊样式
|
||||
if cell.row == 1:
|
||||
cell.fill = header_fill
|
||||
cell.font = header_font
|
||||
|
||||
# 为完整性列应用条件格式
|
||||
if cell.row > 1: # 跳过标题行
|
||||
column_name = worksheet.cell(row=1, column=cell.column).value
|
||||
if column_name == "完整性":
|
||||
if cell.value == "完整":
|
||||
cell.fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
|
||||
else:
|
||||
cell.fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
|
||||
|
||||
return border # 返回边框样式以便在统计信息中重用
|
||||
|
||||
def _add_statistics_to_excel(self, worksheet, complete_questions, total_rows, total_questions, border):
|
||||
"""
|
||||
添加统计信息到Excel表格
|
||||
|
||||
参数:
|
||||
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
||||
complete_questions (list): 完整问题列表
|
||||
total_rows (int): 总行数
|
||||
total_questions (int): 总问题数
|
||||
border (openpyxl.styles.Border): 边框样式
|
||||
|
||||
返回:
|
||||
int: 完整问题数量
|
||||
"""
|
||||
# 计算统计数据
|
||||
complete_count = sum(1 for item in complete_questions if item.get("完整性") == "完整")
|
||||
incomplete_count = total_rows - complete_count
|
||||
|
||||
# 添加统计行
|
||||
worksheet.append([""]) # 空行
|
||||
|
||||
stat_row = worksheet.max_row + 1
|
||||
worksheet.cell(row=stat_row, column=1, value="统计信息")
|
||||
worksheet.cell(row=stat_row, column=1).font = Font(bold=True)
|
||||
|
||||
worksheet.cell(row=stat_row+1, column=1, value="总问题数")
|
||||
worksheet.cell(row=stat_row+1, column=2, value=total_rows)
|
||||
|
||||
worksheet.cell(row=stat_row+2, column=1, value="完整问题数")
|
||||
worksheet.cell(row=stat_row+2, column=2, value=complete_count)
|
||||
worksheet.cell(row=stat_row+2, column=2).fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
|
||||
|
||||
worksheet.cell(row=stat_row+3, column=1, value="不完整问题数")
|
||||
worksheet.cell(row=stat_row+3, column=2, value=incomplete_count)
|
||||
worksheet.cell(row=stat_row+3, column=2).fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
|
||||
|
||||
worksheet.cell(row=stat_row+4, column=1, value="完整问题比例")
|
||||
worksheet.cell(row=stat_row+4, column=2, value=f"{complete_count/total_rows:.2%}" if total_rows > 0 else "0%")
|
||||
|
||||
# 应用边框到统计行
|
||||
for r in range(stat_row, stat_row+5):
|
||||
for c in range(1, 3):
|
||||
worksheet.cell(row=r, column=c).border = border
|
||||
|
||||
return complete_count
|
||||
|
||||
def save_results_to_excel(self, complete_questions, total_questions):
|
||||
"""
|
||||
将结果保存到Excel文件
|
||||
|
||||
参数:
|
||||
complete_questions (list): 完整问题列表
|
||||
total_questions (int): 总问题数
|
||||
"""
|
||||
if not complete_questions:
|
||||
print(f"没有找到完整的问题。")
|
||||
return
|
||||
|
||||
# 准备数据
|
||||
result_df = self._prepare_excel_dataframe(complete_questions)
|
||||
total_rows = len(result_df)
|
||||
|
||||
# 保存到Excel文件
|
||||
result_df.to_excel(self.output_path, index=False, engine='openpyxl')
|
||||
|
||||
# 应用Excel样式
|
||||
from openpyxl import load_workbook
|
||||
wb = load_workbook(self.output_path)
|
||||
ws = wb.active
|
||||
|
||||
# 设置列宽
|
||||
self._set_excel_column_widths(ws)
|
||||
|
||||
# 应用单元格样式
|
||||
border = self._apply_excel_cell_styles(ws)
|
||||
|
||||
# 添加统计信息
|
||||
complete_count = self._add_statistics_to_excel(ws, complete_questions, total_rows, total_questions, border)
|
||||
|
||||
# 保存样式化的工作簿
|
||||
wb.save(self.output_path)
|
||||
|
||||
# 输出结果统计
|
||||
print(f"处理完成。共有{complete_count}/{total_questions}个完整问题被保存到 {self.output_path}")
|
||||
print(f"完整问题比例: {complete_count/total_questions:.2%}" if total_questions > 0 else "完整问题比例: 0%")
|
||||
|
||||
def process_excel_file(self):
|
||||
"""
|
||||
处理Excel文件中的问题
|
||||
|
||||
读取Excel文件,判断问题完整性,并将结果保存到输出Excel文件
|
||||
"""
|
||||
# 确保Excel文件存在
|
||||
if not os.path.exists(self.input_path):
|
||||
print(f"错误: 找不到Excel文件 '{self.input_path}'")
|
||||
return
|
||||
|
||||
# 读取Excel文件
|
||||
print(f"正在读取Excel文件: {self.input_path}")
|
||||
try:
|
||||
df = pd.read_excel(self.input_path)
|
||||
except Exception as e:
|
||||
print(f"读取Excel文件时出错: {e}")
|
||||
return
|
||||
|
||||
# 检查列数据
|
||||
if len(df.columns) <= self.column_index:
|
||||
print(f"错误: Excel文件没有足够的列,请求索引 {self.column_index},但只有 {len(df.columns)} 列")
|
||||
return
|
||||
|
||||
# 获取目标列名称
|
||||
target_col = df.columns[self.column_index]
|
||||
print(f"目标列名称: {target_col}")
|
||||
|
||||
# 准备存储完整问题的列表
|
||||
complete_questions = []
|
||||
total_questions = len(df)
|
||||
|
||||
print(f"总共有{total_questions}个问题需要判断")
|
||||
|
||||
# 用于线程安全的列表操作和进度计数
|
||||
complete_questions_lock = threading.Lock()
|
||||
progress_counter = {"processed": 0, "complete": 0, "incomplete": 0}
|
||||
progress_lock = threading.Lock()
|
||||
|
||||
# 准备问题列表
|
||||
questions = [(i, str(row[target_col]), self.llm_client, total_questions)
|
||||
for i, row in df.iterrows()]
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 使用tqdm创建进度条
|
||||
print(f"开始处理问题,使用 {self.max_workers} 个并发线程...")
|
||||
with tqdm(total=total_questions, desc="处理问题", unit="问题") as pbar:
|
||||
# 使用线程池并发处理
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# 提交所有任务
|
||||
futures = [executor.submit(
|
||||
self._process_question,
|
||||
args,
|
||||
complete_questions,
|
||||
progress_counter,
|
||||
progress_lock,
|
||||
complete_questions_lock,
|
||||
pbar
|
||||
) for args in questions]
|
||||
|
||||
# 等待所有任务完成
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
# 计算总处理时间
|
||||
processing_time = time.time() - start_time
|
||||
print(f"处理完成,耗时: {processing_time:.2f}秒,平均每问题: {processing_time/total_questions:.2f}秒")
|
||||
|
||||
# 将完整问题保存到Excel文件
|
||||
self.save_results_to_excel(complete_questions, total_questions)
|
||||
|
||||
def test_single_question(self, question):
|
||||
"""
|
||||
测试单个问题的完整性
|
||||
|
||||
参数:
|
||||
question (str): 要测试的问题
|
||||
"""
|
||||
print(f"问题: {question}")
|
||||
print("正在调用LLM判断问题是否完整...")
|
||||
|
||||
# 调用LLM判断问题是否完整
|
||||
is_complete, full_answer = self.is_question_complete(question)
|
||||
|
||||
# 从答案中提取JSON
|
||||
parsed_json = self._extract_json_from_response(full_answer)
|
||||
|
||||
print("\n==== LLM回复 ====")
|
||||
print(full_answer)
|
||||
print("================\n")
|
||||
|
||||
if parsed_json:
|
||||
print(f"判断结果: {'完整' if parsed_json.get('is_complete', False) else '不完整'}")
|
||||
print(f"判断原因: {parsed_json.get('reason', '未提供')}")
|
||||
print(f"置信度: {parsed_json.get('confidence', 0)}%")
|
||||
else:
|
||||
print(f"判断结果: {'完整' if is_complete else '不完整'} (简单判断)")
|
||||
print("无法从回复中提取JSON结构化数据")
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""解析命令行参数"""
|
||||
parser = argparse.ArgumentParser(description='判断Excel文件中的问题是否完整')
|
||||
parser.add_argument('-i', '--input', type=str, default=DEFAULT_EXCEL_PATH,
|
||||
help=f'输入Excel文件路径 (默认: {DEFAULT_EXCEL_PATH})')
|
||||
parser.add_argument('-o', '--output', type=str, default=DEFAULT_OUTPUT_PATH,
|
||||
help=f'输出Excel文件路径 (默认: {DEFAULT_OUTPUT_PATH})')
|
||||
parser.add_argument('-w', '--workers', type=int, default=DEFAULT_MAX_WORKERS,
|
||||
help=f'并发处理的最大线程数 (默认: {DEFAULT_MAX_WORKERS})')
|
||||
parser.add_argument('-c', '--column', type=int, default=0,
|
||||
help='要处理的问题所在列的索引 (默认: 0,即第一列)')
|
||||
parser.add_argument('-t', '--test', type=str,
|
||||
help='测试单个问题,不处理Excel文件')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
args = parse_arguments()
|
||||
|
||||
# 创建问题完整性判断工具实例
|
||||
judge = QueryCompletenessJudge(
|
||||
input_path=args.input,
|
||||
output_path=args.output,
|
||||
max_workers=args.workers,
|
||||
column_index=args.column
|
||||
)
|
||||
# 如果是测试单个问题
|
||||
if args.test:
|
||||
judge.test_single_question(args.test)
|
||||
return
|
||||
|
||||
# 处理Excel文件
|
||||
judge.process_excel_file()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -0,0 +1,239 @@
|
||||
import pandas as pd
|
||||
from urllib.parse import unquote
|
||||
from rag2_0.tool.WikijsTool import WikijsTool
|
||||
from rag2_0.tool.html_to_md import convert_html_to_md
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from rag2_0.dify.dify_tool import DifyTool
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
load_dotenv()
|
||||
|
||||
class ContentSource(BaseModel):
|
||||
score:int = Field(description="相关性分数")
|
||||
reason:str = Field(description="评分理由")
|
||||
|
||||
class RetrieveContentScoreJudge:
|
||||
"""
|
||||
检索内容相关性评分工具类
|
||||
|
||||
用于评估检索内容与问题之间的相关性,并计算相关性分数
|
||||
"""
|
||||
|
||||
def __init__(self, wiki_excel_path, answer_excel_path, output_path=None):
|
||||
"""
|
||||
初始化评分工具类
|
||||
|
||||
参数:
|
||||
wiki_excel_path (str): Wiki Excel文件路径
|
||||
answer_excel_path (str): 回答Excel文件路径
|
||||
output_path (str, optional): 输出Excel文件路径,默认为None
|
||||
"""
|
||||
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
|
||||
if os.path.exists(wiki_excel_path):
|
||||
self.wiki_excel = pd.read_excel(wiki_excel_path)
|
||||
else:
|
||||
self.wiki_excel = None
|
||||
self.answer_excel = pd.read_excel(answer_excel_path)
|
||||
self.output_path = output_path or "/data/Rag2_0/data/excel/dify问答_检索内容评分.xlsx"
|
||||
|
||||
# 从环境变量中获取OpenAI的配置
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = os.getenv("OPENAI_API_BASE")
|
||||
self.model_name = os.getenv("LLM_MODEL_NAME")
|
||||
|
||||
if not all([self.api_key, self.base_url, self.model_name]):
|
||||
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
|
||||
|
||||
self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model_name)
|
||||
|
||||
def find_wiki_link(self, query) -> str | None:
|
||||
"""
|
||||
根据查询(对应wiki_excel中的新提问列)找出对应的词条链接
|
||||
|
||||
参数:
|
||||
query (str): 查询内容,对应wiki_excel中的新提问列
|
||||
|
||||
返回:
|
||||
str: 对应的词条链接,如果没有找到则返回None
|
||||
"""
|
||||
# 确保query不为空
|
||||
if not query or pd.isna(query):
|
||||
return None
|
||||
if self.wiki_excel is None:
|
||||
return None
|
||||
# 在"新提问"列中查找匹配的行
|
||||
matched_rows = self.wiki_excel[self.wiki_excel['新提问'] == query]
|
||||
|
||||
# 如果找到了匹配的行,返回对应的词条链接
|
||||
if not matched_rows.empty:
|
||||
return matched_rows.iloc[0]['对应词条链接']
|
||||
|
||||
# 如果没有完全匹配,尝试部分匹配
|
||||
# 去除软件名称部分(如果有)
|
||||
query_parts = query.split(',', 1)
|
||||
if len(query_parts) > 1:
|
||||
clean_query = query_parts[1].strip()
|
||||
|
||||
# 在"提问"列中查找包含清理后查询的行
|
||||
for idx, row in self.wiki_excel.iterrows():
|
||||
if pd.notna(row['提问']) and clean_query in row['提问']:
|
||||
return row['对应词条链接']
|
||||
|
||||
return None
|
||||
|
||||
def get_wiki_title(self, link) -> str | None:
|
||||
"""
|
||||
获取词条标题
|
||||
|
||||
参数:
|
||||
link (str): 词条链接
|
||||
|
||||
返回:
|
||||
str: 词条标题,如果获取失败则返回None
|
||||
"""
|
||||
try:
|
||||
if not link or pd.isna(link):
|
||||
return None
|
||||
# 移除域名部分,只保留路径
|
||||
path = link.split('/', 3)[-1]
|
||||
decoded_path = unquote(path)
|
||||
path_parts = decoded_path.split('/')
|
||||
return path_parts[-1]
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
|
||||
|
||||
def calculate_score(self, answer:str, content:str) -> int:
|
||||
"""
|
||||
使用OpenAiLLM通过LLM判断answer与content之间的相关性分数
|
||||
|
||||
参数:
|
||||
answer (str): 用户问题
|
||||
content (str): 检索内容
|
||||
|
||||
返回:
|
||||
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
|
||||
"""
|
||||
try:
|
||||
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
|
||||
|
||||
【评分标准】
|
||||
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
|
||||
8-9分:高度相关,核心要素匹配但存在少量信息缺失
|
||||
6-7分:部分相关,涉及相同主题但存在重要信息缺失
|
||||
4-5分:弱相关,仅次要信息点匹配
|
||||
1-3分:完全不相关或信息冲突
|
||||
|
||||
【评估维度】
|
||||
1. 主题一致性:核心主题/意图的匹配程度
|
||||
2. 内容覆盖度:是否涵盖query的关键要素
|
||||
3. 信息准确性:是否存在矛盾/错误信息
|
||||
4. 细节丰富度:是否提供query要求的详细信息
|
||||
|
||||
【输出格式】
|
||||
{{
|
||||
"score": 评分,
|
||||
"reason": "简明扼要的评分理由(中文)"
|
||||
}}
|
||||
|
||||
【示例】
|
||||
query: "新冠疫苗的常见副作用"
|
||||
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)"
|
||||
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
|
||||
|
||||
现在评估:
|
||||
query: "{answer}"
|
||||
content: "{content}"
|
||||
"""
|
||||
|
||||
response = self.llm.invoke(user_prompt=prompt, need_retry=True)
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
parsed_output = self.content_source_parser.parse(response.content)
|
||||
return parsed_output.score
|
||||
except Exception as e:
|
||||
return -1
|
||||
except Exception as e:
|
||||
return -1
|
||||
|
||||
def get_retrieve_info(self, query:str, outputs:dict) -> tuple:
|
||||
"""
|
||||
获取检索信息并计算分数
|
||||
|
||||
参数:
|
||||
query (str): 用户问题
|
||||
outputs (dict): 检索输出结果
|
||||
|
||||
返回:
|
||||
tuple: (检索内容列表, 最高分, 最低分, 平均分)
|
||||
"""
|
||||
max_score = 0
|
||||
min_score = 10
|
||||
total_score = 0
|
||||
valid_scores = 0
|
||||
retrieve_content = []
|
||||
for result in outputs["result"]:
|
||||
content = result["content"].strip()
|
||||
score = self.calculate_score(answer=query, content=content)
|
||||
if score != -1:
|
||||
max_score = max(max_score, score)
|
||||
min_score = min(min_score, score)
|
||||
total_score += score
|
||||
valid_scores += 1
|
||||
content_title = content.split("\n")[0]
|
||||
if content_title:
|
||||
retrieve_content.append(content_title + f"--得分({score}分)")
|
||||
avg_score = total_score / valid_scores if valid_scores > 0 else 0
|
||||
return retrieve_content, max_score, min_score, avg_score
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
处理所有问题并评估检索内容相关性
|
||||
|
||||
遍历answer_excel中的所有问题,计算检索内容与问题的相关性分数,
|
||||
并更新Excel文件
|
||||
"""
|
||||
for idx, row in tqdm(self.answer_excel.iterrows(), total=len(self.answer_excel), desc="处理问题评分中"):
|
||||
query = row["问题"]
|
||||
link = self.find_wiki_link(query)
|
||||
answer_title = self.get_wiki_title(link)
|
||||
retrieve_content = []
|
||||
max_score = 0
|
||||
min_score = 0
|
||||
avg_score = 0 # 初始化平均分
|
||||
rewrite_query=""
|
||||
message_info = DifyTool.get_message_debug_info(appid="ccf92b97-2789-4a3f-90e0-135a869a37c5", query=query)
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
retrieve_content, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||
elif workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
|
||||
# 更新 answer_excel 中的词条内容
|
||||
self.answer_excel.at[idx, "答案词条"] = answer_title if answer_title else ""
|
||||
self.answer_excel.at[idx, "问题改写"] = rewrite_query
|
||||
self.answer_excel.at[idx, "检索得到词条"] = "\n".join(retrieve_content) if retrieve_content else "未检索知识库"
|
||||
self.answer_excel.at[idx, "最大得分"] = max_score
|
||||
self.answer_excel.at[idx, "最小得分"] = min_score
|
||||
self.answer_excel.at[idx, "平均得分"] = avg_score
|
||||
|
||||
# 保存结果到Excel文件
|
||||
self.answer_excel.to_excel(self.output_path, index=False)
|
||||
print(f"结果已保存到 {self.output_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建评分工具实例
|
||||
judge = RetrieveContentScoreJudge(
|
||||
wiki_excel_path="/data/Rag2_0/data/excel/400条人工标注-部分提问_软件名称明确.xlsx",
|
||||
answer_excel_path="/data/Rag2_0/data/excel/主网软件提问_回答内容评判.xlsx",
|
||||
output_path="/data/Rag2_0/data/excel/dify问答_检索内容评分.xlsx"
|
||||
)
|
||||
# 执行处理
|
||||
judge.process()
|
||||
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: merge_nouns_with_llm.py
|
||||
Description: 合并多个nouns.json中的同名专业名词,利用LLM生成唯一合并结果
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from rag2_0.intent_recognition.DataModels import Term
|
||||
import logging
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
class TermMerger:
|
||||
"""专业名词合并类,用于合并多个数据源中的同名专业名词"""
|
||||
|
||||
def __init__(self, input_dir=None, output_path=None, max_workers=3):
|
||||
"""初始化名词合并器
|
||||
|
||||
Args:
|
||||
input_dir: 包含nouns.json文件的目录路径
|
||||
output_path: 合并结果的输出文件路径
|
||||
max_workers: 线程池最大工作线程数
|
||||
"""
|
||||
self.EXTRACTED_NOUNS_DIR = input_dir
|
||||
self.OUTPUT_PATH = output_path
|
||||
self.MAX_WORKERS = max_workers
|
||||
self.terms_parser = PydanticOutputParser(pydantic_object=Term)
|
||||
self.MERGE_PROMPT = '''
|
||||
请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请:
|
||||
- 同义词(synonymous)去重合并
|
||||
- 描述(description)合并为更完整、简明的描述
|
||||
- 保持输出格式为:
|
||||
{output_format}
|
||||
原始条目:
|
||||
{items}
|
||||
'''
|
||||
# 配置LLM
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
llm_params = {"temperature": 0.3, "model": model_name}
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
self.llm = OpenAiLLM(**llm_params)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
def load_all_terms(self):
|
||||
"""读取目录下所有nouns.json,返回所有Term列表"""
|
||||
all_terms = []
|
||||
for file in glob.glob(os.path.join(self.EXTRACTED_NOUNS_DIR, '*_nouns.json')):
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
file_terms = json.load(f)
|
||||
new_terms = [{"name": term["name"].upper(), "synonymous": term["synonymous"], "description": term["description"]} for term in file_terms]
|
||||
all_terms.extend(new_terms)
|
||||
logging.info(f"加载{file},共{len(new_terms)}条")
|
||||
except Exception as e:
|
||||
logging.warning(f"读取{file}失败: {e}")
|
||||
|
||||
# 加载suffix_keywords.json文件
|
||||
suffix_keywords_path = os.path.join(os.path.dirname(os.path.dirname(self.EXTRACTED_NOUNS_DIR)), 'data', 'nouns', 'suffix_keywords.json')
|
||||
if os.path.exists(suffix_keywords_path):
|
||||
try:
|
||||
with open(suffix_keywords_path, 'r', encoding='utf-8') as f:
|
||||
suffix_terms = json.load(f)
|
||||
suffix_terms = [{"name": term["name"].upper(), "synonymous": "", "description": ""} for term in suffix_terms]
|
||||
all_terms.extend(suffix_terms)
|
||||
logging.info(f"加载{suffix_keywords_path},共{len(suffix_terms)}条")
|
||||
except Exception as e:
|
||||
logging.warning(f"读取{suffix_keywords_path}失败: {e}")
|
||||
|
||||
return all_terms
|
||||
|
||||
def group_terms_by_name(self, terms):
|
||||
"""按name聚合Term"""
|
||||
name2terms = defaultdict(list)
|
||||
for term in terms:
|
||||
name = term.get('name', '').strip()
|
||||
if name:
|
||||
name2terms[name].append(term)
|
||||
return name2terms
|
||||
|
||||
def merge_terms_with_llm(self, name, term_list):
|
||||
"""调用LLM合并同名Term,失败最多重试三次"""
|
||||
items = json.dumps(term_list, ensure_ascii=False)
|
||||
prompt = self.MERGE_PROMPT.format(name=name, items=items, output_format=self.terms_parser.get_format_instructions())
|
||||
|
||||
max_retries = 3
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
response = self.llm.invoke(prompt, False)
|
||||
parsed_output = self.terms_parser.parse(response.content)
|
||||
return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description}
|
||||
except Exception as e:
|
||||
if attempt == max_retries:
|
||||
logging.warning(f"解析LLM合并结果失败: {e}")
|
||||
return None
|
||||
else:
|
||||
time.sleep(10*attempt)
|
||||
|
||||
def process_term(self, name_terms_tuple):
|
||||
"""处理单个词条,用于线程池并行处理"""
|
||||
name, term_list = name_terms_tuple
|
||||
try:
|
||||
merged = self.merge_terms_with_llm(name, term_list)
|
||||
if merged:
|
||||
return merged
|
||||
else:
|
||||
return term_list[0]
|
||||
except Exception as e:
|
||||
logging.error(f"处理词条 {name} 时出错: {e}")
|
||||
return term_list[0]
|
||||
|
||||
def merge(self):
|
||||
"""合并所有词条的入口方法"""
|
||||
# 1. 读取所有术语
|
||||
all_terms = self.load_all_terms()
|
||||
logging.info(f"共加载{len(all_terms)}条术语")
|
||||
|
||||
# 2. 按名称聚合
|
||||
name2terms = self.group_terms_by_name(all_terms)
|
||||
logging.info(f"共{len(name2terms)}个唯一名词")
|
||||
|
||||
# 3. 使用线程池并行处理
|
||||
merged_terms = []
|
||||
items_to_process = []
|
||||
|
||||
# 先处理只有一个条目的词条(不需要合并)
|
||||
for name, term_list in name2terms.items():
|
||||
if len(term_list) == 1:
|
||||
merged_terms.append(term_list[0])
|
||||
else:
|
||||
items_to_process.append((name, term_list))
|
||||
|
||||
logging.info(f"共{len(merged_terms)}个单一条目,{len(items_to_process)}个需要合并的条目")
|
||||
|
||||
# 只对需要合并的词条使用线程池处理
|
||||
if items_to_process:
|
||||
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
||||
# 使用tqdm显示进度
|
||||
for result in tqdm(executor.map(self.process_term, items_to_process), total=len(items_to_process)):
|
||||
merged_terms.append(result)
|
||||
|
||||
# 4. 保存合并结果
|
||||
os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True)
|
||||
with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f:
|
||||
json.dump(merged_terms, f, ensure_ascii=False, indent=2)
|
||||
logging.info(f"合并后结果已保存到: {self.OUTPUT_PATH}")
|
||||
|
||||
return merged_terms
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,创建TermMerger实例并执行合并"""
|
||||
|
||||
cur_path = os.path.dirname(__file__)
|
||||
input_dir = os.path.abspath(os.path.join(cur_path, '../../data/wiki_extracted_nouns'))
|
||||
output_path = os.path.join(cur_path, "..", "..", "data", "nouns", 'merged_nouns.json')
|
||||
merger = TermMerger(input_dir=input_dir, output_path=output_path, max_workers=2)
|
||||
merger.merge()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
main()
|
||||
@@ -0,0 +1,408 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: validate_excel_data_batch.py
|
||||
Description: 使用LLM批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写是否正确
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import concurrent.futures
|
||||
from tqdm import tqdm
|
||||
from dotenv import load_dotenv
|
||||
from langchain_openai import ChatOpenAI
|
||||
from rag2_0.intent_recognition.PromptTemplates import classification
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
|
||||
class ExcelDataValidator:
|
||||
"""Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写"""
|
||||
|
||||
def __init__(self, input_file=None, output_file=None, workers=4, batch_size=10):
|
||||
"""
|
||||
初始化验证器
|
||||
|
||||
Args:
|
||||
input_file: 输入Excel文件路径
|
||||
output_file: 输出结果Excel文件路径
|
||||
workers: 并行工作线程数
|
||||
batch_size: 每批处理的行数
|
||||
"""
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
self.input_file = input_file
|
||||
self.output_file = output_file
|
||||
self.workers = workers
|
||||
self.batch_size = batch_size
|
||||
self.df = None
|
||||
|
||||
# 设置日志
|
||||
self.setup_logging()
|
||||
|
||||
def setup_logging(self):
|
||||
"""配置日志输出"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
def load_data_from_excel(self, file_path=None):
|
||||
"""
|
||||
从Excel文件中读取数据
|
||||
|
||||
Args:
|
||||
file_path: Excel文件路径,如不提供则使用初始化时的路径
|
||||
|
||||
Returns:
|
||||
DataFrame对象
|
||||
"""
|
||||
file_path = file_path or self.input_file
|
||||
if not file_path:
|
||||
logging.error("未指定输入文件路径")
|
||||
return None
|
||||
|
||||
try:
|
||||
df = pd.read_excel(file_path)
|
||||
required_columns = ["提问", "问题拆解", "一级分类", "二级分类", "问题改写", "检索的关键词"]
|
||||
for col in required_columns:
|
||||
if col not in df.columns:
|
||||
logging.error(f"缺少必要的列: {col}")
|
||||
return None
|
||||
logging.info(f"成功从{file_path}读取了{len(df)}条数据")
|
||||
self.df = df
|
||||
return df
|
||||
except Exception as e:
|
||||
logging.error(f"读取Excel文件时出错: {e}")
|
||||
return None
|
||||
|
||||
def validate_classification(self, llm, query, vertical_class, sub_class):
|
||||
"""
|
||||
验证问题分类是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
vertical_class: 一级分类
|
||||
sub_class: 二级分类
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
"""
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用,也可能涉及电力造价专业知识。我对用户问题进行了分类,请评估以下问题分类是否正确。
|
||||
|
||||
我目前总共有以下分类:
|
||||
{classification}
|
||||
|
||||
问题的分类情况如下:
|
||||
原始问题: {query}
|
||||
一级分类: {vertical_class}
|
||||
二级分类: {sub_class}
|
||||
|
||||
请从专业角度分析这个分类是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
except Exception as e:
|
||||
logging.warning(f"验证问题分类时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
|
||||
def validate_query_keys(self, llm, query, query_keys):
|
||||
"""
|
||||
验证问题拆解是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
query_keys: 问题拆解
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
"""
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了拆解,请评估以下问题拆解是否正确。
|
||||
|
||||
原始问题: {query}
|
||||
问题拆解: {query_keys}
|
||||
|
||||
问题拆解应该准确提取原始问题中的关键词和信息。请分析这个拆解是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
except Exception as e:
|
||||
logging.warning(f"验证问题拆解时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
|
||||
def validate_keywords(self, llm, query, query_keys, keywords_str):
|
||||
"""
|
||||
验证检索关键词是否准确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
query_keys: 问题拆解
|
||||
keywords_str: 检索关键词(JSON字符串)
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
"""
|
||||
# 解析关键词JSON
|
||||
try:
|
||||
if isinstance(keywords_str, str) and keywords_str.strip():
|
||||
keywords = json.loads(keywords_str)
|
||||
else:
|
||||
keywords = []
|
||||
except:
|
||||
keywords = keywords_str
|
||||
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。通过问题检索出了一些关键词,请评估这些关键词是否准确,是否与问题相关
|
||||
原始问题: {query}
|
||||
问题拆解: {query_keys}
|
||||
检索关键词: {keywords}
|
||||
|
||||
检索关键词应该准确反映问题中需要检索的关键概念和术语。请分析这些关键词是否准确、完整。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
except Exception as e:
|
||||
logging.warning(f"验证检索关键词时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
|
||||
def validate_rewrite(self, llm, query, rewrite):
|
||||
"""
|
||||
验证问题改写是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
rewrite: 问题改写
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
"""
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了改写,请评估以下问题改写是否正确。
|
||||
|
||||
原始问题: {query}
|
||||
问题改写: {rewrite}
|
||||
|
||||
问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
except Exception as e:
|
||||
logging.warning(f"验证问题改写时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
|
||||
def validate_row(self, llm, row_data):
|
||||
"""
|
||||
按顺序验证一行数据中的各个环节
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
row_data: (index, row)元组
|
||||
|
||||
Returns:
|
||||
(index, is_all_correct, error_phase, error_reason): 行索引,是否全部正确,错误环节,错误原因
|
||||
"""
|
||||
index, row = row_data
|
||||
query = row["提问"]
|
||||
query_keys = row["问题拆解"]
|
||||
vertical_class = row["一级分类"]
|
||||
sub_class = row["二级分类"]
|
||||
rewrite = row["问题改写"]
|
||||
keywords_str = row["检索的关键词"]
|
||||
|
||||
try:
|
||||
# 1. 验证问题分类
|
||||
is_correct, error_reason = self.validate_classification(llm, query, vertical_class, sub_class)
|
||||
if not is_correct:
|
||||
return index, False, "问题分类", error_reason
|
||||
|
||||
# 2. 验证问题拆解
|
||||
is_correct, error_reason = self.validate_query_keys(llm, query, query_keys)
|
||||
if not is_correct:
|
||||
return index, False, "问题拆解", error_reason
|
||||
|
||||
# 3. 验证检索关键词
|
||||
is_correct, error_reason = self.validate_keywords(llm, query, query_keys, keywords_str)
|
||||
if not is_correct:
|
||||
return index, False, "关键词检索", error_reason
|
||||
|
||||
# 4. 验证问题改写
|
||||
is_correct, error_reason = self.validate_rewrite(llm, query, rewrite)
|
||||
if not is_correct:
|
||||
return index, False, "问题改写", error_reason
|
||||
|
||||
return index, True, "", ""
|
||||
except Exception as e:
|
||||
error_msg = f"处理行 {index} 时发生错误: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
return index, False, "处理错误", error_msg
|
||||
|
||||
def process_batch(self, llm, batch_data):
|
||||
"""处理一批数据"""
|
||||
results = []
|
||||
for row_data in batch_data:
|
||||
results.append(self.validate_row(llm, row_data))
|
||||
return results
|
||||
|
||||
def create_llm_instances(self, count):
|
||||
"""创建多个LLM实例"""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
|
||||
llm_params = {"temperature": 0.7, "model": model_name}
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
|
||||
return [OpenAiLLM(**llm_params) for _ in range(count)]
|
||||
|
||||
def validate(self, input_file=None, output_file=None, workers=None, batch_size=None):
|
||||
"""
|
||||
执行验证过程
|
||||
|
||||
Args:
|
||||
input_file: 输入Excel文件路径
|
||||
output_file: 输出结果Excel文件路径
|
||||
workers: 并行工作线程数
|
||||
batch_size: 每批处理的行数
|
||||
|
||||
Returns:
|
||||
验证后的DataFrame
|
||||
"""
|
||||
input_file = input_file or self.input_file
|
||||
output_file = output_file or self.output_file
|
||||
workers = workers or self.workers
|
||||
batch_size = batch_size or self.batch_size
|
||||
|
||||
# 读取数据
|
||||
df = self.load_data_from_excel(input_file)
|
||||
if df is None:
|
||||
return None
|
||||
|
||||
# 添加验证结果列
|
||||
df["验证结果"] = ""
|
||||
df["错误环节"] = ""
|
||||
df["错误原因"] = ""
|
||||
|
||||
# 准备数据批次
|
||||
all_rows = list(df.iterrows())
|
||||
batches = [all_rows[i:i+batch_size] for i in range(0, len(all_rows), batch_size)]
|
||||
|
||||
# 创建多个LLM实例
|
||||
llm_instances = self.create_llm_instances(min(workers, len(batches)))
|
||||
|
||||
# 使用线程池处理数据
|
||||
all_results = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# 为每个批次分配一个LLM实例
|
||||
future_to_batch = {
|
||||
executor.submit(self.process_batch, llm_instances[i % len(llm_instances)], batch):
|
||||
i for i, batch in enumerate(batches)
|
||||
}
|
||||
|
||||
# 使用tqdm显示进度条
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches), desc="批次处理进度"):
|
||||
batch_results = future.result()
|
||||
all_results.extend(batch_results)
|
||||
|
||||
# 按行索引排序结果,确保与原始数据顺序一致
|
||||
all_results.sort(key=lambda x: x[0])
|
||||
|
||||
# 将结果填充到DataFrame
|
||||
for index, is_correct, error_phase, error_reason in all_results:
|
||||
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
||||
df.at[index, "错误环节"] = error_phase
|
||||
df.at[index, "错误原因"] = error_reason
|
||||
|
||||
# 保存结果
|
||||
if output_file is None:
|
||||
output_file = os.path.join(
|
||||
os.path.dirname(input_file),
|
||||
f"validated_{os.path.basename(input_file)}"
|
||||
)
|
||||
df.to_excel(output_file, index=False)
|
||||
logging.info(f"验证完成,结果已保存至: {output_file}")
|
||||
|
||||
# 输出统计信息
|
||||
self.print_statistics(df)
|
||||
|
||||
return df
|
||||
|
||||
def print_statistics(self, df):
|
||||
"""打印统计信息"""
|
||||
total = len(df)
|
||||
passed = len(df[df["验证结果"] == "通过"])
|
||||
error_stats = df[df["验证结果"] == "不通过"]["错误环节"].value_counts()
|
||||
|
||||
logging.info(f"统计信息: 总计 {total} 条, 通过 {passed} 条, 通过率 {passed/total*100:.2f}%")
|
||||
logging.info("错误环节统计:")
|
||||
for phase, count in error_stats.items():
|
||||
logging.info(f"- {phase}: {count} 条")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 解析命令行参数
|
||||
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "问题分类重写结果")
|
||||
output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx")
|
||||
parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写")
|
||||
parser.add_argument("--input", "-i", type=str, required=True, help="输入Excel文件路径", default=input_excel)
|
||||
parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel)
|
||||
parser.add_argument("--workers", "-w", type=int, default=2, help="并行工作线程数")
|
||||
parser.add_argument("--batch-size", "-b", type=int, default=5, help="每批处理的行数")
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建验证器实例并执行验证
|
||||
validator = ExcelDataValidator(
|
||||
input_file=args.input,
|
||||
output_file=args.output,
|
||||
workers=args.workers,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
validator.validate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: vectorize_save_noun.py
|
||||
Date: 2025-05-15
|
||||
Description: 专业名词向量化和保存的示例程序
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.intent_recognition import ProfessionalNounVectorizer
|
||||
import logging
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数:创建索引并保存
|
||||
"""
|
||||
# 指定文件路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(current_dir, "..", "..", "data", "nouns")
|
||||
|
||||
# 创建向量化器并指定路径
|
||||
noun_vectorizer = ProfessionalNounVectorizer(
|
||||
output_dir=output_dir
|
||||
)
|
||||
file_paths = [
|
||||
os.path.join(current_dir, "..", "..", "data/nouns/merged_nouns.json"),
|
||||
]
|
||||
# 执行向量化和保存(一步完成)
|
||||
success = noun_vectorizer.vectorize_files_and_save(file_paths)
|
||||
if success:
|
||||
logging.info("✓ 索引创建和保存成功")
|
||||
logging.info(f" 索引保存路径: {os.path.join(output_dir, 'professional_nouns_index')}")
|
||||
else:
|
||||
logging.error("✗ 索引创建失败")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置日志输出到控制台
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(message)s'
|
||||
)
|
||||
main()
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
from dify_client.client import ChatClient, CompletionClient, DifyClient
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,459 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class DifyClient:
|
||||
def __init__(self, api_key, base_url: str = "https://api.dify.ai/v1"):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
def _send_request(self, method, endpoint, json=None, params=None, stream=False):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
response = requests.request(
|
||||
method, url, json=json, params=params, headers=headers, stream=stream, verify=False
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _send_request_with_files(self, method, endpoint, data, files):
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
response = requests.request(
|
||||
method, url, data=data, headers=headers, files=files
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def message_feedback(self, message_id, rating, user):
|
||||
data = {"rating": rating, "user": user}
|
||||
return self._send_request("POST", f"/messages/{message_id}/feedbacks", data)
|
||||
|
||||
def get_application_parameters(self, user):
|
||||
params = {"user": user}
|
||||
return self._send_request("GET", "/parameters", params=params)
|
||||
|
||||
def file_upload(self, user, files):
|
||||
data = {"user": user}
|
||||
return self._send_request_with_files(
|
||||
"POST", "/files/upload", data=data, files=files
|
||||
)
|
||||
|
||||
def text_to_audio(self, text: str, user: str, streaming: bool = False):
|
||||
data = {"text": text, "user": user, "streaming": streaming}
|
||||
return self._send_request("POST", "/text-to-audio", data=data)
|
||||
|
||||
def get_meta(self, user):
|
||||
params = {"user": user}
|
||||
return self._send_request("GET", "/meta", params=params)
|
||||
|
||||
|
||||
class CompletionClient(DifyClient):
|
||||
def create_completion_message(self, inputs, response_mode, user, files=None):
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"response_mode": response_mode,
|
||||
"user": user,
|
||||
"files": files,
|
||||
}
|
||||
return self._send_request(
|
||||
"POST",
|
||||
"/completion-messages",
|
||||
data,
|
||||
stream=True if response_mode == "streaming" else False,
|
||||
)
|
||||
|
||||
|
||||
class ChatClient(DifyClient):
|
||||
def create_chat_message(
|
||||
self,
|
||||
inputs,
|
||||
query,
|
||||
user,
|
||||
response_mode="blocking",
|
||||
conversation_id=None,
|
||||
files=None,
|
||||
):
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"user": user,
|
||||
"response_mode": response_mode,
|
||||
"files": files,
|
||||
}
|
||||
if conversation_id:
|
||||
data["conversation_id"] = conversation_id
|
||||
|
||||
return self._send_request(
|
||||
"POST",
|
||||
"/chat-messages",
|
||||
data,
|
||||
stream=True if response_mode == "streaming" else False,
|
||||
)
|
||||
|
||||
def get_suggested(self, message_id, user: str):
|
||||
params = {"user": user}
|
||||
return self._send_request(
|
||||
"GET", f"/messages/{message_id}/suggested", params=params
|
||||
)
|
||||
|
||||
def stop_message(self, task_id, user):
|
||||
data = {"user": user}
|
||||
return self._send_request("POST", f"/chat-messages/{task_id}/stop", data)
|
||||
|
||||
def get_conversations(self, user, last_id=None, limit=None, pinned=None):
|
||||
params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned}
|
||||
return self._send_request("GET", "/conversations", params=params)
|
||||
|
||||
def get_conversation_messages(
|
||||
self, user, conversation_id=None, first_id=None, limit=None
|
||||
):
|
||||
params = {"user": user}
|
||||
|
||||
if conversation_id:
|
||||
params["conversation_id"] = conversation_id
|
||||
if first_id:
|
||||
params["first_id"] = first_id
|
||||
if limit:
|
||||
params["limit"] = limit
|
||||
|
||||
return self._send_request("GET", "/messages", params=params)
|
||||
|
||||
def rename_conversation(
|
||||
self, conversation_id, name, auto_generate: bool, user: str
|
||||
):
|
||||
data = {"name": name, "auto_generate": auto_generate, "user": user}
|
||||
return self._send_request(
|
||||
"POST", f"/conversations/{conversation_id}/name", data
|
||||
)
|
||||
|
||||
def delete_conversation(self, conversation_id, user):
|
||||
data = {"user": user}
|
||||
return self._send_request("DELETE", f"/conversations/{conversation_id}", data)
|
||||
|
||||
def audio_to_text(self, audio_file, user):
|
||||
data = {"user": user}
|
||||
files = {"audio_file": audio_file}
|
||||
return self._send_request_with_files("POST", "/audio-to-text", data, files)
|
||||
|
||||
|
||||
class WorkflowClient(DifyClient):
|
||||
def run(
|
||||
self, inputs: dict, response_mode: str = "streaming", user: str = "abc-123"
|
||||
):
|
||||
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
|
||||
return self._send_request("POST", "/workflows/run", data)
|
||||
|
||||
def stop(self, task_id, user):
|
||||
data = {"user": user}
|
||||
return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data)
|
||||
|
||||
def get_result(self, workflow_run_id):
|
||||
return self._send_request("GET", f"/workflows/run/{workflow_run_id}")
|
||||
|
||||
|
||||
class KnowledgeBaseClient(DifyClient):
|
||||
def __init__(
|
||||
self,
|
||||
api_key,
|
||||
base_url: str = "https://api.dify.ai/v1",
|
||||
dataset_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Construct a KnowledgeBaseClient object.
|
||||
|
||||
Args:
|
||||
api_key (str): API key of Dify.
|
||||
base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'.
|
||||
dataset_id (str, optional): ID of the dataset. Defaults to None. You don't need this if you just want to
|
||||
create a new dataset. or list datasets. otherwise you need to set this.
|
||||
"""
|
||||
super().__init__(api_key=api_key, base_url=base_url)
|
||||
self.dataset_id = dataset_id
|
||||
|
||||
def _get_dataset_id(self):
|
||||
if self.dataset_id is None:
|
||||
raise ValueError("dataset_id is not set")
|
||||
return self.dataset_id
|
||||
|
||||
def create_dataset(self, name: str, **kwargs):
|
||||
return self._send_request("POST", "/datasets", {"name": name}, **kwargs)
|
||||
|
||||
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
|
||||
return self._send_request(
|
||||
"GET", f"/datasets?page={page}&limit={page_size}", **kwargs
|
||||
)
|
||||
|
||||
def create_document_by_text(
|
||||
self, name, text, extra_params: dict | None = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Create a document by text.
|
||||
|
||||
:param name: Name of the document
|
||||
:param text: Text content of the document
|
||||
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
|
||||
e.g.
|
||||
{
|
||||
'indexing_technique': 'high_quality',
|
||||
'process_rule': {
|
||||
'rules': {
|
||||
'pre_processing_rules': [
|
||||
{'id': 'remove_extra_spaces', 'enabled': True},
|
||||
{'id': 'remove_urls_emails', 'enabled': True}
|
||||
],
|
||||
'segmentation': {
|
||||
'separator': '\n',
|
||||
'max_tokens': 500
|
||||
}
|
||||
},
|
||||
'mode': 'custom'
|
||||
}
|
||||
}
|
||||
:return: Response from the API
|
||||
"""
|
||||
data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"process_rule": {"mode": "automatic"},
|
||||
"name": name,
|
||||
"text": text,
|
||||
}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
url = f"/datasets/{self._get_dataset_id()}/document/create_by_text"
|
||||
return self._send_request("POST", url, json=data, **kwargs)
|
||||
|
||||
def update_document_by_text(
|
||||
self, document_id, name, text, extra_params: dict | None = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Update a document by text.
|
||||
|
||||
:param document_id: ID of the document
|
||||
:param name: Name of the document
|
||||
:param text: Text content of the document
|
||||
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
|
||||
e.g.
|
||||
{
|
||||
'indexing_technique': 'high_quality',
|
||||
'process_rule': {
|
||||
'rules': {
|
||||
'pre_processing_rules': [
|
||||
{'id': 'remove_extra_spaces', 'enabled': True},
|
||||
{'id': 'remove_urls_emails', 'enabled': True}
|
||||
],
|
||||
'segmentation': {
|
||||
'separator': '\n',
|
||||
'max_tokens': 500
|
||||
}
|
||||
},
|
||||
'mode': 'custom'
|
||||
}
|
||||
}
|
||||
:return: Response from the API
|
||||
"""
|
||||
data = {"name": name, "text": text}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
url = (
|
||||
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
|
||||
)
|
||||
return self._send_request("POST", url, json=data, **kwargs)
|
||||
|
||||
def create_document_by_file(
|
||||
self, file_path, original_document_id=None, extra_params: dict | None = None
|
||||
):
|
||||
"""
|
||||
Create a document by file.
|
||||
|
||||
:param file_path: Path to the file
|
||||
:param original_document_id: pass this ID if you want to replace the original document (optional)
|
||||
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
|
||||
e.g.
|
||||
{
|
||||
'indexing_technique': 'high_quality',
|
||||
'process_rule': {
|
||||
'rules': {
|
||||
'pre_processing_rules': [
|
||||
{'id': 'remove_extra_spaces', 'enabled': True},
|
||||
{'id': 'remove_urls_emails', 'enabled': True}
|
||||
],
|
||||
'segmentation': {
|
||||
'separator': '\n',
|
||||
'max_tokens': 500
|
||||
}
|
||||
},
|
||||
'mode': 'custom'
|
||||
}
|
||||
}
|
||||
:return: Response from the API
|
||||
"""
|
||||
files = {"file": open(file_path, "rb")}
|
||||
data = {
|
||||
"process_rule": {"mode": "automatic"},
|
||||
"indexing_technique": "high_quality",
|
||||
}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
if original_document_id is not None:
|
||||
data["original_document_id"] = original_document_id
|
||||
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
|
||||
return self._send_request_with_files(
|
||||
"POST", url, {"data": json.dumps(data)}, files
|
||||
)
|
||||
|
||||
def update_document_by_file(
|
||||
self, document_id, file_path, extra_params: dict | None = None
|
||||
):
|
||||
"""
|
||||
Update a document by file.
|
||||
|
||||
:param document_id: ID of the document
|
||||
:param file_path: Path to the file
|
||||
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
|
||||
e.g.
|
||||
{
|
||||
'indexing_technique': 'high_quality',
|
||||
'process_rule': {
|
||||
'rules': {
|
||||
'pre_processing_rules': [
|
||||
{'id': 'remove_extra_spaces', 'enabled': True},
|
||||
{'id': 'remove_urls_emails', 'enabled': True}
|
||||
],
|
||||
'segmentation': {
|
||||
'separator': '\n',
|
||||
'max_tokens': 500
|
||||
}
|
||||
},
|
||||
'mode': 'custom'
|
||||
}
|
||||
}
|
||||
:return:
|
||||
"""
|
||||
files = {"file": open(file_path, "rb")}
|
||||
data = {}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
url = (
|
||||
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
|
||||
)
|
||||
return self._send_request_with_files(
|
||||
"POST", url, {"data": json.dumps(data)}, files
|
||||
)
|
||||
|
||||
def batch_indexing_status(self, batch_id: str, **kwargs):
|
||||
"""
|
||||
Get the status of the batch indexing.
|
||||
|
||||
:param batch_id: ID of the batch uploading
|
||||
:return: Response from the API
|
||||
"""
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status"
|
||||
return self._send_request("GET", url, **kwargs)
|
||||
|
||||
def delete_dataset(self):
|
||||
"""
|
||||
Delete this dataset.
|
||||
|
||||
:return: Response from the API
|
||||
"""
|
||||
url = f"/datasets/{self._get_dataset_id()}"
|
||||
return self._send_request("DELETE", url)
|
||||
|
||||
def delete_document(self, document_id):
|
||||
"""
|
||||
Delete a document.
|
||||
|
||||
:param document_id: ID of the document
|
||||
:return: Response from the API
|
||||
"""
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}"
|
||||
return self._send_request("DELETE", url)
|
||||
|
||||
def list_documents(
|
||||
self,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
keyword: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Get a list of documents in this dataset.
|
||||
|
||||
:return: Response from the API
|
||||
"""
|
||||
params = {}
|
||||
if page is not None:
|
||||
params["page"] = page
|
||||
if page_size is not None:
|
||||
params["limit"] = page_size
|
||||
if keyword is not None:
|
||||
params["keyword"] = keyword
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents"
|
||||
return self._send_request("GET", url, params=params, **kwargs)
|
||||
|
||||
def add_segments(self, document_id, segments, **kwargs):
|
||||
"""
|
||||
Add segments to a document.
|
||||
|
||||
:param document_id: ID of the document
|
||||
:param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}]
|
||||
:return: Response from the API
|
||||
"""
|
||||
data = {"segments": segments}
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
|
||||
return self._send_request("POST", url, json=data, **kwargs)
|
||||
|
||||
def query_segments(
|
||||
self,
|
||||
document_id,
|
||||
keyword: str | None = None,
|
||||
status: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Query segments in this document.
|
||||
|
||||
:param document_id: ID of the document
|
||||
:param keyword: query keyword, optional
|
||||
:param status: status of the segment, optional, e.g. completed
|
||||
"""
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
|
||||
params = {}
|
||||
if keyword is not None:
|
||||
params["keyword"] = keyword
|
||||
if status is not None:
|
||||
params["status"] = status
|
||||
if "params" in kwargs:
|
||||
params.update(kwargs["params"])
|
||||
return self._send_request("GET", url, params=params, **kwargs)
|
||||
|
||||
def delete_document_segment(self, document_id, segment_id):
|
||||
"""
|
||||
Delete a segment from a document.
|
||||
|
||||
:param document_id: ID of the document
|
||||
:param segment_id: ID of the segment
|
||||
:return: Response from the API
|
||||
"""
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
|
||||
return self._send_request("DELETE", url)
|
||||
|
||||
def update_document_segment(self, document_id, segment_id, segment_data, **kwargs):
|
||||
"""
|
||||
Update a segment in a document.
|
||||
|
||||
:param document_id: ID of the document
|
||||
:param segment_id: ID of the segment
|
||||
:param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True}
|
||||
:return: Response from the API
|
||||
"""
|
||||
data = {"segment": segment_data}
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
|
||||
return self._send_request("POST", url, json=data, **kwargs)
|
||||
@@ -0,0 +1,215 @@
|
||||
import psycopg2
|
||||
from psycopg2 import sql
|
||||
import os
|
||||
import json
|
||||
from datetime import timezone, timedelta
|
||||
|
||||
class PgSql:
|
||||
"""
|
||||
用于连接和操作 PostgreSQL 数据库的类。
|
||||
|
||||
该类封装了数据库连接、关闭连接以及执行特定查询的方法,
|
||||
主要用于从 Dify 应用相关的表中获取数据。
|
||||
"""
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化 PgSql 实例并建立数据库连接。
|
||||
"""
|
||||
self.connection = None
|
||||
self.connect_sql()
|
||||
|
||||
def connect_sql(self):
|
||||
"""
|
||||
连接到 PostgreSQL 数据库。
|
||||
|
||||
使用预定义的凭据连接到 'dify' 数据库。
|
||||
如果连接失败,会打印错误信息。
|
||||
"""
|
||||
try:
|
||||
# 连接数据库
|
||||
self.connection = psycopg2.connect(
|
||||
user="postgres",
|
||||
password="difyai123456",
|
||||
host="172.20.0.145",
|
||||
port=5432,
|
||||
database="dify"
|
||||
)
|
||||
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
print("Error while connecting to PostgreSQL", error)
|
||||
|
||||
def close_connection(self):
|
||||
"""
|
||||
关闭当前的 PostgreSQL 数据库连接。
|
||||
|
||||
如果存在活动的连接,则关闭它并打印确认信息。
|
||||
"""
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
print("PostgreSQL connection is closed")
|
||||
|
||||
|
||||
def get_appinfo(self, appid:str)->dict | None:
|
||||
"""
|
||||
根据应用 ID 从 'apps' 表中获取应用信息。
|
||||
|
||||
Args:
|
||||
appid: 目标应用的 ID。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是列名,值是对应的应用数据。
|
||||
如果未找到应用或发生错误,则返回 None。
|
||||
"""
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM apps WHERE id = %s
|
||||
""",
|
||||
(appid,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
print("Error while getting tenant_id by appid", error)
|
||||
|
||||
|
||||
def get_messages_info(self, appid:str, query:str)->dict | None:
|
||||
"""
|
||||
根据应用 ID 和查询内容从 'messages' 表中获取消息信息。
|
||||
|
||||
Args:
|
||||
appid: 目标应用的 ID。
|
||||
query: 用户查询的具体内容。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是列名,值是对应的消息数据。
|
||||
如果未找到消息或发生错误,则返回 None。
|
||||
"""
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC
|
||||
""",
|
||||
(appid, query)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
print("Error while getting messages_info", error)
|
||||
|
||||
def get_messages_info_by_id(self, message_id:str)->dict | None:
|
||||
"""
|
||||
根据消息 ID 从 'messages' 表中获取消息信息。
|
||||
"""
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE id = %s
|
||||
""",
|
||||
(message_id, )
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
print("Error while getting messages_info", error)
|
||||
|
||||
def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None:
|
||||
"""
|
||||
根据工作流运行 ID 从 'workflow_node_executions' 表中获取节点执行信息。
|
||||
|
||||
Args:
|
||||
workflow_run_id: 目标工作流运行的 ID。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是列名,值是对应的节点执行数据。
|
||||
如果未找到执行信息或发生错误,则返回 None。
|
||||
"""
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s
|
||||
""",
|
||||
(workflow_run_id,)
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return [dict(zip(colnames, row)) for row in result]
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
print("Error while getting workflow_node_executions_info", error)
|
||||
|
||||
class DifyTool:
|
||||
"""
|
||||
提供用于获取 Dify 应用调试信息的工具类。
|
||||
|
||||
该类利用 PgSql 类从数据库中检索与特定应用和查询相关的
|
||||
应用信息、消息详情以及工作流节点执行情况。
|
||||
"""
|
||||
@staticmethod
|
||||
def get_message_debug_info_id(message_id:str)->dict | None:
|
||||
"""
|
||||
根据消息 ID 从 'messages' 表中获取消息信息。
|
||||
"""
|
||||
dify_pgsql = PgSql()
|
||||
messages_info = dify_pgsql.get_messages_info_by_id(message_id)
|
||||
if not messages_info:
|
||||
return None
|
||||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||
if not workflow_node_executions_info:
|
||||
return None
|
||||
return {
|
||||
"messages_info": messages_info,
|
||||
"workflow_node_executions_info": workflow_node_executions_info
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_message_debug_info(appid:str, query:str)->dict:
|
||||
"""
|
||||
获取指定应用和查询相关的调试信息。
|
||||
|
||||
此静态方法会创建一个临时的 PgSql 实例来查询数据库,
|
||||
然后聚合应用信息、消息信息和工作流节点执行信息。
|
||||
|
||||
Args:
|
||||
appid: 目标应用的 ID。
|
||||
query: 用户查询的具体内容。
|
||||
|
||||
Returns:
|
||||
一个包含 "appinfo", "messages_info", 和
|
||||
"workflow_node_executions_info"键的字典,分别对应
|
||||
查询到的应用数据、消息数据和节点执行数据。
|
||||
"""
|
||||
dify_pgsql = PgSql()
|
||||
appinfo = dify_pgsql.get_appinfo(appid)
|
||||
if not appinfo:
|
||||
return None
|
||||
messages_info = dify_pgsql.get_messages_info(appid, query)
|
||||
if not messages_info:
|
||||
return None
|
||||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||
if not workflow_node_executions_info:
|
||||
return None
|
||||
return {
|
||||
"appinfo": appinfo,
|
||||
"messages_info": messages_info,
|
||||
"workflow_node_executions_info": workflow_node_executions_info
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(DifyTool.get_message_debug_info("ccf92b97-2789-4a3f-90e0-135a869a37c5", "电力建设计价通软件,导入结算后没有暂列金怎么办?要手动添加么?"))
|
||||
@@ -0,0 +1,54 @@
|
||||
from flask import Flask, request, Response
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.intent_recognition import IntentRecognizer
|
||||
import json
|
||||
import time
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# 初始化意图识别器
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
|
||||
@app.route('/intent_recognize', methods=['POST'])
|
||||
def intent_recognize():
|
||||
try:
|
||||
data = request.get_json(force=True)
|
||||
query = data.get('query')
|
||||
if not query:
|
||||
return Response(json.dumps({"error": "缺少query参数"}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=400)
|
||||
start_time = time.time()
|
||||
classification, keywords, rewrite, query_keys = recognizer.process_query(query)
|
||||
end_time = time.time()
|
||||
print(f"意图识别耗时: {end_time - start_time:.2f}秒")
|
||||
# keywords对象转为字符串
|
||||
keywords_str = ""
|
||||
if keywords and keywords.terms:
|
||||
term_details = []
|
||||
for term in keywords.terms:
|
||||
term_info = {
|
||||
"名称": term.name,
|
||||
"同义词": ";".join(term.synonymous) if term.synonymous else "",
|
||||
"描述": term.description
|
||||
}
|
||||
term_details.append(term_info)
|
||||
keywords_str = term_details
|
||||
result = {
|
||||
"source_query": query,
|
||||
"source_query_keys": query_keys,
|
||||
"vertical_classification": classification.vertical_classification,
|
||||
"sub_classification": classification.sub_classification,
|
||||
"rewrite_query": rewrite.rewrite,
|
||||
"keywords": keywords_str
|
||||
}
|
||||
return Response(json.dumps(result, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
||||
except Exception as e:
|
||||
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=8001)
|
||||
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from rag2_0.dify.dify_client import ChatClient, DifyClient
|
||||
import pandas as pd
|
||||
# 使用线程池并发执行
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from tqdm import tqdm
|
||||
from rag2_0.dify.dify_tool import DifyTool
|
||||
import json
|
||||
|
||||
class DifyComparisonTester:
|
||||
"""
|
||||
Dify新旧流程对比测试类,用于比较两个不同流程的问答效果
|
||||
"""
|
||||
def __init__(self, excel_path:str, baseurl:str, old_workflow_api_key:str, new_workflow_api_key:str):
|
||||
"""
|
||||
初始化对比测试器
|
||||
|
||||
Args:
|
||||
excel_path: 包含问题的Excel文件路径
|
||||
baseurl: Dify API的基础URL
|
||||
old_workflow_api_key: 旧流程的API密钥
|
||||
new_workflow_api_key: 新流程的API密钥
|
||||
"""
|
||||
self.excel_path = excel_path
|
||||
self.baseurl = baseurl
|
||||
self.old_workflow_api_key = old_workflow_api_key
|
||||
self.new_workflow_api_key = new_workflow_api_key
|
||||
self.old_chat = ChatClient(api_key=old_workflow_api_key, base_url=baseurl)
|
||||
self.new_chat = ChatClient(api_key=new_workflow_api_key, base_url=baseurl)
|
||||
|
||||
def process_question(self, q:str):
|
||||
"""
|
||||
处理单个问题,并行获取新旧流程的回答
|
||||
|
||||
Args:
|
||||
q: 问题内容
|
||||
|
||||
Returns:
|
||||
dict: 包含问题和两个流程回答的字典
|
||||
"""
|
||||
q="qwqwwq"
|
||||
def get_old_answer():
|
||||
try:
|
||||
return self.old_chat.create_chat_message(inputs={}, query=q, user="AutoTestDifyChat").json()
|
||||
except Exception as e:
|
||||
return f"error: {str(e)}"
|
||||
|
||||
def get_new_answer():
|
||||
try:
|
||||
return self.new_chat.create_chat_message(inputs={}, query=q, user="AutoTestDifyChat").json()
|
||||
except Exception as e:
|
||||
return f"error: {str(e)}"
|
||||
|
||||
# 并行执行old_chat和new_chat
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
future_old = executor.submit(get_old_answer)
|
||||
future_new = executor.submit(get_new_answer)
|
||||
|
||||
old_result = future_old.result()
|
||||
new_result = future_new.result()
|
||||
old_message_id = old_result["message_id"]
|
||||
new_message_id = new_result["message_id"]
|
||||
old_message_info = DifyTool.get_message_debug_info_id(message_id=old_message_id)
|
||||
new_message_info = DifyTool.get_message_debug_info_id(message_id=new_message_id)
|
||||
for workflow_node in new_message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
old_answer = old_result["answer"]
|
||||
new_answer = new_result["answer"]
|
||||
|
||||
return {"问题": q, "问题改写": rewrite_query, "旧流程答案": old_answer, "新流程答案": new_answer}
|
||||
|
||||
def run_comparison(self):
|
||||
"""
|
||||
运行对比测试,处理所有问题并生成结果Excel
|
||||
|
||||
Returns:
|
||||
str: 输出Excel文件的路径
|
||||
"""
|
||||
# 读取Excel文件中的问题
|
||||
df = pd.read_excel(self.excel_path)
|
||||
questions = df.iloc[:,0].tolist()
|
||||
results = []
|
||||
|
||||
# 按顺序处理问题
|
||||
with tqdm(total=len(questions), desc="处理问题进度") as pbar:
|
||||
for q in questions:
|
||||
result = self.process_question(q)
|
||||
results.append(result)
|
||||
pbar.update(1)
|
||||
|
||||
# 生成输出Excel文件
|
||||
out_path = os.path.join(os.path.dirname(self.excel_path), "dify问答_对比结果.xlsx")
|
||||
df_results = pd.DataFrame(results)
|
||||
|
||||
# 使用ExcelWriter设置格式
|
||||
with pd.ExcelWriter(out_path, engine='xlsxwriter') as writer:
|
||||
df_results.to_excel(writer, index=False, sheet_name='Sheet1')
|
||||
|
||||
# 获取工作簿和工作表对象
|
||||
workbook = writer.book
|
||||
worksheet = writer.sheets['Sheet1']
|
||||
|
||||
# 设置列宽
|
||||
worksheet.set_column('A:A', 50) # 问题列宽 50个Excel单位
|
||||
worksheet.set_column('B:B', 70) # 旧流程答案列宽 70个Excel单位
|
||||
worksheet.set_column('C:C', 70) # 新流程答案列宽 70个Excel单位
|
||||
|
||||
return out_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义Excel路径
|
||||
excel_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", ".." ,"data/excel/历史提问数据(dislike)_1000条_软件明确.xlsx")
|
||||
|
||||
if not os.path.exists(excel_path):
|
||||
print(f"错误:Excel文件不存在: {excel_path}")
|
||||
exit(1)
|
||||
|
||||
# Dify API配置
|
||||
baseurl = "http://172.20.0.145/v1"
|
||||
old_workflow_api_key = "app-wUdkWJx5zeOvmvBUZizMoSw3"
|
||||
new_workflow_api_key = "app-Lf1pQ1NVwdMfCRVNTBCOTPHT"
|
||||
|
||||
# 创建测试器并运行
|
||||
tester = DifyComparisonTester(excel_path, baseurl, old_workflow_api_key, new_workflow_api_key)
|
||||
output_file = tester.run_comparison()
|
||||
print(f"对比结果已保存至: {output_file}")
|
||||
|
||||
# 单个问题测试示例
|
||||
# c = DifyChat(baseurl="http://172.20.0.145/v1", api_key="app-LjJaeLoAfqa6aoGzqU9UvxSf")
|
||||
# c.chat("如何新建配电线路工程")
|
||||
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: DataModels.py
|
||||
Author: oyyz
|
||||
Date: 2025-05-13
|
||||
Description: 提取和分类的数据模型
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# 定义输出模型
|
||||
class Term(BaseModel):
|
||||
name: str = Field(description="专业名词")
|
||||
synonymous: List[str] = Field(description="同义词列表")
|
||||
description: str = Field(description="描述信息", default="")
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Term):
|
||||
return self.name == other.name
|
||||
return False
|
||||
|
||||
class TermList(BaseModel):
|
||||
terms: List[Term] = Field(description="专业名词列表")
|
||||
|
||||
class Classification(BaseModel):
|
||||
vertical_classification:str = Field(description="垂直领域一级分类")
|
||||
sub_classification:str = Field(description="一级分类下的二级分类")
|
||||
|
||||
class QueryRewrite(BaseModel):
|
||||
rewrite:str = Field(description="问题改写")
|
||||
@@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: IntentRecognition.py
|
||||
Author: oyyz
|
||||
Date: 2025-05-13
|
||||
Description: 意图分类、改写核心逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
import json
|
||||
from typing import List, Tuple
|
||||
import re
|
||||
from .PromptTemplates import classification_prompt, query_rewrite_prompt, extract_nouns_prompt, classification_info
|
||||
from .DataModels import Classification, QueryRewrite, Term, TermList
|
||||
from .ProfessionalNounVector import ProfessionalNounRetriever
|
||||
from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM
|
||||
|
||||
|
||||
class IntentRecognizer:
|
||||
"""
|
||||
意图识别和问题改写类
|
||||
"""
|
||||
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo", vector_index_dir: str = None):
|
||||
"""
|
||||
初始化意图识别器
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API密钥,如果为None则从环境变量获取
|
||||
base_url: OpenAI API基础URL,如果为None则使用默认URL
|
||||
model_name: 要使用的模型名称
|
||||
vector_index_dir: 向量索引目录,如果为None则使用默认目录
|
||||
"""
|
||||
# 初始化LLM
|
||||
llm_params = {
|
||||
"temperature": 0.2, # 降低随机性,使结果更确定
|
||||
"model": model_name
|
||||
}
|
||||
|
||||
# 如果提供了API密钥,则使用提供的密钥
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
|
||||
# 如果提供了自定义URL,则使用提供的URL
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
|
||||
self.llm = OpenAiLLM(**llm_params)
|
||||
|
||||
# 准备分类解析器
|
||||
self.classification_parser = PydanticOutputParser(pydantic_object=Classification)
|
||||
|
||||
# 准备问题改写解析器
|
||||
self.query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
|
||||
|
||||
# 准备术语列表解析器
|
||||
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
||||
|
||||
# 加载suffix关键词
|
||||
self.suffix_keywords = self._load_suffix_keywords()
|
||||
|
||||
# 初始化向量检索器
|
||||
self.noun_retriever = ProfessionalNounRetriever(api_key=api_key, index_dir=vector_index_dir)
|
||||
|
||||
def _load_suffix_keywords(self, filepath: str = None) -> List[str]:
|
||||
"""
|
||||
加载后缀关键词列表
|
||||
|
||||
Args:
|
||||
filepath: 后缀关键词文件路径,默认为None使用默认路径
|
||||
|
||||
Returns:
|
||||
后缀关键词列表
|
||||
"""
|
||||
try:
|
||||
# 如果未指定路径,使用默认路径
|
||||
if filepath is None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
filepath = os.path.join(current_dir, "..", "..", "data", "nouns", "suffix_keywords.json")
|
||||
|
||||
# 读取JSON文件
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
suffix_data = json.load(f)
|
||||
|
||||
# 添加额外的固定后缀
|
||||
return suffix_data
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"加载后缀关键词失败: {e}") from e
|
||||
|
||||
def classify_intent(self, query: str, keywords: TermList) -> Classification:
|
||||
"""
|
||||
对用户输入进行意图分类
|
||||
|
||||
Args:
|
||||
content: 用户输入内容
|
||||
keywords: 匹配到的关键词列表
|
||||
rewrite: 重写的问题
|
||||
Returns:
|
||||
分类结果
|
||||
"""
|
||||
formatted_prompt = classification_prompt.replace("{user_input}", query)
|
||||
formatted_prompt = formatted_prompt.replace("{classification_info}", classification_info)
|
||||
formatted_prompt = formatted_prompt.replace("{output_format}", self.classification_parser.get_format_instructions())
|
||||
# 将关键词列表转换为JSON字符串
|
||||
terms_dict = [term.model_dump() for term in keywords.terms]
|
||||
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
||||
formatted_prompt = formatted_prompt.replace("{keywords}", keywords_str)
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt, False)
|
||||
|
||||
# 解析输出
|
||||
try:
|
||||
# 尝试直接解析JSON响应
|
||||
parsed_output = self.classification_parser.parse(response.content.strip())
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"解析分类结果时出错: {e}") from e
|
||||
|
||||
def extract_keywords_with_llm(self, query: str) -> List[Term]:
|
||||
"""
|
||||
使用LLM从用户查询中提取专业关键词
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
|
||||
Returns:
|
||||
提取的术语列表
|
||||
"""
|
||||
# 准备提示词
|
||||
formatted_prompt = extract_nouns_prompt.replace("{content}", query)
|
||||
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt, False)
|
||||
|
||||
try:
|
||||
# 尝试使用Pydantic解析器解析TermList
|
||||
parsed_output = self.terms_list_parser.parse(response.content)
|
||||
return parsed_output.terms
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法解析LLM关键词提取响应: {e}") from e
|
||||
|
||||
def match_keywords(self, query: str) -> Tuple[TermList, List[str]]:
|
||||
"""
|
||||
从用户问题中匹配关键词,结合LLM提取和向量检索
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
|
||||
Returns:
|
||||
匹配到的关键词列表
|
||||
"""
|
||||
matched_terms = set() # 存储匹配到的Term对象
|
||||
query_keys=[]
|
||||
# 步骤2: 使用LLM提取查询中的关键词
|
||||
try:
|
||||
extracted_terms = self.extract_keywords_with_llm(query)
|
||||
for term in extracted_terms:
|
||||
query_keys.append(term.name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"LLM关键词提取失败: {e}") from e
|
||||
|
||||
# 步骤3: 使用向量检索找到相似的专业名词
|
||||
try:
|
||||
# 对matched_terms中的每个关键字进行向量检索
|
||||
for current_key in query_keys:
|
||||
vector_results = self.noun_retriever.query(current_key, top_k=3, use_intersection=True)
|
||||
|
||||
# 添加向量检索结果
|
||||
for result in vector_results:
|
||||
term = Term(
|
||||
name=result.get('name'),
|
||||
synonymous=result.get('synonymous', []),
|
||||
description=result.get('description', '')
|
||||
)
|
||||
matched_terms.add(term)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"向量检索关键词时出错: {e}") from e
|
||||
|
||||
if len(matched_terms) != 0:
|
||||
txts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) + "|" + "描述:" + term.description for term in matched_terms]
|
||||
# txts = [term.name for term in matched_terms]
|
||||
xinference_reranker = XinferenceReRankerModel()
|
||||
rerank_results = xinference_reranker.rerank(query, txts, top_k=5)
|
||||
matched_terms_list = list(matched_terms)
|
||||
matched_terms = [matched_terms_list[result["index"]] for result in rerank_results]
|
||||
# 提取所有Term对象的名称并排序
|
||||
# 将set类型的matched_terms转换为TermList类型
|
||||
term_list = TermList(terms=list(matched_terms))
|
||||
return term_list, query_keys
|
||||
|
||||
def rewrite_query(self, query: str, keywords: TermList) -> QueryRewrite:
|
||||
"""
|
||||
对用户问题进行改写
|
||||
|
||||
Args:
|
||||
query: 用户原始问题
|
||||
keywords: 匹配到的关键词列表
|
||||
|
||||
Returns:
|
||||
改写结果
|
||||
"""
|
||||
# 准备问题改写提示
|
||||
terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
|
||||
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
||||
formatted_prompt = query_rewrite_prompt.format(query=query, output_format=self.query_rewrite_parser.get_format_instructions(),keywords=keywords_str)
|
||||
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt, False)
|
||||
|
||||
# 解析输出
|
||||
try:
|
||||
# 尝试直接解析JSON响应
|
||||
parsed_output = self.query_rewrite_parser.parse(response.content)
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"解析问题改写结果时出错: {e}") from e
|
||||
|
||||
def judge_define_suffix(self, input_str: str) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
判断输入字符串是否包含定义的后缀,并返回所有匹配到的后缀名列表
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[str]]: (是否包含定义的后缀, 匹配到的后缀名列表)
|
||||
"""
|
||||
|
||||
# 构建正则表达式模式,匹配大小写不敏感且前面可能带有.
|
||||
pattern = r'(?:\.?)(' + '|'.join(re.escape(field.get('name')) for field in self.suffix_keywords) + r')'
|
||||
|
||||
# 使用 re.IGNORECASE 标志来忽略大小写,findall找到所有匹配
|
||||
matches = re.finditer(pattern, input_str, re.IGNORECASE)
|
||||
matched_suffixes = [match.group(1) for match in matches]
|
||||
|
||||
return bool(matched_suffixes), matched_suffixes
|
||||
|
||||
def process_query(self, query: str) -> Tuple[Classification, TermList, QueryRewrite, List[str]]:
|
||||
"""
|
||||
处理用户问题的完整流程
|
||||
|
||||
Args:
|
||||
query: 用户原始问题
|
||||
|
||||
Returns:
|
||||
(意图分类结果, 匹配的关键词列表, 问题改写结果)的元组
|
||||
"""
|
||||
# 是否是扩展名
|
||||
# is_suffix, matched_suffixes = self.judge_define_suffix(query)
|
||||
# if is_suffix:
|
||||
# # 将所有匹配到的后缀名作为Term添加到结果中
|
||||
# suffix_terms = []
|
||||
# for suffix in matched_suffixes:
|
||||
# term_dict = next((item for item in self.suffix_keywords if item['name'].lower() == suffix.lower()), None)
|
||||
# if term_dict:
|
||||
# suffix_term = Term(
|
||||
# name=term_dict.get('name'),
|
||||
# synonymous=term_dict.get('synonymous', []),
|
||||
# description=json.dumps(term_dict.get('description', ''), ensure_ascii=False)
|
||||
# )
|
||||
# suffix_terms.append(suffix_term)
|
||||
|
||||
# return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes
|
||||
|
||||
# 步骤1: 匹配关键词
|
||||
keywords_terms, query_keys = self.match_keywords(query)
|
||||
|
||||
# 步骤2: 问题改写
|
||||
rewrite = self.rewrite_query(
|
||||
query=query,
|
||||
keywords=keywords_terms
|
||||
)
|
||||
|
||||
# 步骤3: 进行意图分类
|
||||
classification = self.classify_intent(query, keywords_terms)
|
||||
if classification.vertical_classification == "其他" or classification.sub_classification == "其他":
|
||||
return classification, TermList(terms=[]), QueryRewrite(rewrite=query), []
|
||||
|
||||
if classification.vertical_classification == "闲聊" or classification.sub_classification == "闲聊":
|
||||
return classification, TermList(terms=[]), QueryRewrite(rewrite=query),[]
|
||||
|
||||
# rewrite = QueryRewrite(rewrite=query)
|
||||
return classification, keywords_terms, rewrite, query_keys
|
||||
@@ -0,0 +1,321 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: ProfessionalNounVector.py
|
||||
Date: 2025-05-15
|
||||
Author: oyyz
|
||||
Description: 专业名词向量化和检索的核心逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from rag2_0.tool.ModelTool import SiliconFlowEmbeddings
|
||||
import logging
|
||||
|
||||
def get_embedding_model(api_key: str = None) -> Embeddings:
|
||||
"""
|
||||
获取嵌入模型
|
||||
|
||||
Args:
|
||||
api_key: API密钥,如果为None则从环境变量获取
|
||||
|
||||
Returns:
|
||||
嵌入模型实例
|
||||
"""
|
||||
if not api_key:
|
||||
api_key = os.getenv("SILICONFLOW_API_KEY", "sk-ftnofbucchwnscojohyxwmfzgaykdxihafnlphohsinftkbr")
|
||||
return SiliconFlowEmbeddings(api_key=api_key)
|
||||
|
||||
|
||||
class ProfessionalNounVectorizer:
|
||||
"""专业名词向量化和保存类"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_model: Optional[Embeddings] = None,
|
||||
api_key: str = None,
|
||||
output_dir: str = None):
|
||||
"""
|
||||
初始化向量化器
|
||||
|
||||
Args:
|
||||
embedding_model: 嵌入模型,如果为None则使用默认模型
|
||||
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
|
||||
|
||||
output_dir: 索引输出目录,默认为None使用默认路径
|
||||
"""
|
||||
# 设置嵌入模型
|
||||
if embedding_model:
|
||||
self.embedding_model = embedding_model
|
||||
else:
|
||||
self.embedding_model = get_embedding_model(api_key)
|
||||
|
||||
|
||||
# 设置输出目录
|
||||
self.output_dir = output_dir
|
||||
if self.output_dir is None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.output_dir = os.path.join(current_dir, "..", "..", "data", "nouns")
|
||||
|
||||
# 设置索引路径
|
||||
self.index_path = os.path.join(self.output_dir, "professional_nouns_index")
|
||||
|
||||
def _loadfile(self, file_paths: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
加载多个专业术语JSON文件并合并
|
||||
|
||||
Args:
|
||||
file_paths: JSON文件路径列表
|
||||
|
||||
Returns:
|
||||
合并后的术语列表
|
||||
"""
|
||||
merged_terms = []
|
||||
|
||||
try:
|
||||
for file_path in file_paths:
|
||||
if not os.path.exists(file_path):
|
||||
logging.warning(f"文件不存在: {file_path}")
|
||||
continue
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
terms_data = json.load(f)
|
||||
|
||||
if isinstance(terms_data, list):
|
||||
merged_terms.extend(terms_data)
|
||||
logging.info(f"从 {file_path} 加载了 {len(terms_data)} 条专业名词")
|
||||
else:
|
||||
logging.warning(f"文件格式错误: {file_path},应为JSON数组")
|
||||
|
||||
logging.info(f"总共加载了 {len(merged_terms)} 条专业名词")
|
||||
return merged_terms
|
||||
except Exception as e:
|
||||
logging.error(f"加载多个文件失败: {e}")
|
||||
return []
|
||||
|
||||
def vectorize_files_and_save(self, file_paths: List[str]) -> bool:
|
||||
"""
|
||||
处理多个文件:加载多个术语文件、创建索引并保存
|
||||
|
||||
Args:
|
||||
file_paths: JSON文件路径列表
|
||||
|
||||
Returns:
|
||||
处理成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
# 加载多个文件的术语
|
||||
terms = self._loadfile(file_paths)
|
||||
|
||||
if not terms:
|
||||
logging.warning("未找到术语数据,退出处理")
|
||||
return False
|
||||
|
||||
# 根据名称去重
|
||||
unique_terms = {}
|
||||
for term in terms:
|
||||
name = term.get("name", "")
|
||||
if name and name not in unique_terms:
|
||||
unique_terms[name] = term
|
||||
|
||||
# 转换回列表
|
||||
deduplicated_terms = list(unique_terms.values())
|
||||
logging.info(f"去重后剩余 {len(deduplicated_terms)} 条专业名词")
|
||||
|
||||
# 准备数据
|
||||
texts, metadatas = self._prepare_terms_for_faiss(deduplicated_terms)
|
||||
|
||||
# 创建索引
|
||||
faiss_index = self._create_index(texts, metadatas)
|
||||
|
||||
# 保存索引
|
||||
self._save_index(faiss_index)
|
||||
|
||||
logging.info("完成多文件专业名词向量化和索引创建")
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"多文件向量化处理失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _prepare_terms_for_faiss(self, terms: List[Dict[str, Any]]) -> Tuple[List[str], List[Dict]]:
|
||||
"""
|
||||
将术语准备为FAISS可用的格式 (内部方法)
|
||||
|
||||
Args:
|
||||
terms: 术语列表
|
||||
|
||||
Returns:
|
||||
格式化的术语文本列表和元数据列表
|
||||
"""
|
||||
texts = []
|
||||
metadatas = []
|
||||
|
||||
for term in terms:
|
||||
name = term["name"]
|
||||
texts.append(name.strip())
|
||||
synonyms = term.get("synonymous", [])
|
||||
description = term.get("description", "")
|
||||
# 记录元数据
|
||||
metadatas.append({
|
||||
"name": name,
|
||||
"synonyms": synonyms,
|
||||
"description": description
|
||||
})
|
||||
|
||||
if len(synonyms) > 0:
|
||||
synonyms_str = ', '.join(synonyms)
|
||||
texts.append(synonyms_str.strip())
|
||||
metadatas.append({
|
||||
"name": name,
|
||||
"synonyms": synonyms,
|
||||
"description": description
|
||||
})
|
||||
|
||||
if len(description) > 0:
|
||||
texts.append(description.strip())
|
||||
metadatas.append({
|
||||
"name": name,
|
||||
"synonyms": synonyms,
|
||||
"description": description
|
||||
})
|
||||
|
||||
return texts, metadatas
|
||||
|
||||
def _create_index(self, texts: List[str], metadatas: List[Dict]) -> FAISS:
|
||||
"""
|
||||
创建FAISS索引 (内部方法)
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
metadatas: 元数据列表
|
||||
|
||||
Returns:
|
||||
FAISS索引
|
||||
"""
|
||||
logging.info(f"正在创建FAISS索引,共 {len(texts)} 条数据...")
|
||||
return FAISS.from_texts(texts=texts, embedding=self.embedding_model, metadatas=metadatas)
|
||||
|
||||
def _save_index(self, faiss_index: FAISS) -> None:
|
||||
"""
|
||||
保存FAISS索引到本地 (内部方法)
|
||||
|
||||
Args:
|
||||
faiss_index: 要保存的FAISS索引
|
||||
"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
# 如果索引目录已存在,先删除
|
||||
if os.path.exists(self.index_path):
|
||||
shutil.rmtree(self.index_path)
|
||||
|
||||
# 保存FAISS索引
|
||||
faiss_index.save_local(folder_path=self.index_path)
|
||||
logging.info(f"FAISS索引已保存至 {self.index_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"保存FAISS索引失败: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
class ProfessionalNounRetriever:
|
||||
"""专业名词检索类"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_model: Optional[Embeddings] = None,
|
||||
api_key: str = None,
|
||||
index_dir: str = None):
|
||||
"""
|
||||
初始化检索器并加载索引
|
||||
|
||||
Args:
|
||||
embedding_model: 嵌入模型,如果为None则使用默认模型
|
||||
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
|
||||
index_dir: 索引目录路径,默认为None使用默认路径
|
||||
"""
|
||||
# 设置嵌入模型
|
||||
if embedding_model:
|
||||
self.embedding_model = embedding_model
|
||||
else:
|
||||
self.embedding_model = get_embedding_model(api_key)
|
||||
|
||||
# 设置索引路径
|
||||
self.index_dir = index_dir
|
||||
if self.index_dir is None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.index_dir = os.path.join(current_dir, "..", "..", "data", "nouns", "professional_nouns_index")
|
||||
|
||||
# 在构造函数中加载索引
|
||||
self.faiss_index = None
|
||||
self._load_index()
|
||||
|
||||
def _load_index(self) -> None:
|
||||
"""
|
||||
从本地加载FAISS索引 (内部方法)
|
||||
"""
|
||||
try:
|
||||
# 加载FAISS索引,启用不安全反序列化(仅用于可信数据源)
|
||||
self.faiss_index = FAISS.load_local(
|
||||
folder_path=self.index_dir,
|
||||
embeddings=self.embedding_model,
|
||||
allow_dangerous_deserialization=True
|
||||
)
|
||||
logging.info(f"成功从 {self.index_dir} 加载FAISS索引")
|
||||
except Exception as e:
|
||||
logging.warning(f"加载FAISS索引失败: {e}")
|
||||
self.faiss_index = None
|
||||
|
||||
def query(self, query_text: str, top_k: int = 5, use_intersection: bool = True) -> List[Dict]:
|
||||
"""
|
||||
查询FAISS索引,获取最相似的专业名词 (唯一对外接口)
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
top_k: 返回的结果数量,默认为5
|
||||
use_intersection: 是否使用三种检索方式的交集,默认为True
|
||||
|
||||
Returns:
|
||||
相似度最高的专业名词列表
|
||||
"""
|
||||
try:
|
||||
# 检查索引是否已加载
|
||||
if self.faiss_index is None:
|
||||
logging.warning("FAISS索引未加载,无法执行查询")
|
||||
return []
|
||||
|
||||
# 使用三种检索方式并取交集
|
||||
retriever1 = self.faiss_index.as_retriever(search_kwargs={"k": top_k})
|
||||
retriever2 = self.faiss_index.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={"k": top_k, "fetch_k": 3, "lambda_mult": 0.5}
|
||||
)
|
||||
retriever3 = self.faiss_index.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"score_threshold": 0.5}
|
||||
)
|
||||
|
||||
# 用json.dumps将dict转为字符串,便于取交集
|
||||
set1 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
||||
for i in retriever1.invoke(query_text))
|
||||
set2 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
||||
for i in retriever2.invoke(query_text))
|
||||
set3 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
||||
for i in retriever3.invoke(query_text))
|
||||
|
||||
intersection = set1 | set2 | set3
|
||||
|
||||
# 如果交集为空,使用第一种检索方式的结果
|
||||
if not intersection:
|
||||
logging.warning("三种检索方式无交集,使用普通检索结果")
|
||||
return [json.loads(item) for item in set1]
|
||||
|
||||
# 转回dict
|
||||
return [json.loads(item) for item in intersection]
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"查询FAISS索引失败: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: PromptTemplates.py
|
||||
Author: oyyz
|
||||
Date: 2025-05-13
|
||||
Description: 提示词模板
|
||||
"""
|
||||
|
||||
extract_nouns_prompt="""
|
||||
【智能关键词提取助手】
|
||||
请根据用户问题自动识别核心关键词,并按照以下规则输出:
|
||||
1. 只输出最终关键词列表,不要解释说明
|
||||
2. 关键词提取范围包括但不限于以下内容:
|
||||
- 软件相关:功能模块/操作步骤/报错提示/扩展名后缀名
|
||||
- 造价专业:费用类型/计算标准/行业规范
|
||||
- 电力工程:项目类型/设备型号/工程阶段
|
||||
3. 自动展开缩写(如将'导excel'转为'Excel导入')
|
||||
4. 严格基于用户问题提取关键词,不要输出与用户问题无关的关键词
|
||||
|
||||
三、输出格式:
|
||||
{output_format}
|
||||
|
||||
四、用户问题:
|
||||
{content}
|
||||
|
||||
"""
|
||||
|
||||
classification_info="""【垂直领域分类】:
|
||||
1. 软件问题 -- 指涉及软件使用、功能询问、软件故障排查等方面的提问或请求。
|
||||
2. 业务问题 -- 指涉及电力造价领域专业知识、造价费用计算等电力造价业务知识
|
||||
3. 安装下载注册 -- 指涉及软件(或插件)安装下载、注册、激活等操作类问题。
|
||||
4. 其他 -- 指与软件或电力造价专业无关的日常对话、问候、感慨、情绪表达等。
|
||||
|
||||
【软件问题包括以下两类】:
|
||||
1. 软件功能:询问软件功能的使用、操作、位置等
|
||||
2. 故障排查:软件运行异常、软件报错、软件显示错误等
|
||||
|
||||
【业务问题包括以下两类】:
|
||||
1. 专业咨询:涉及电力造价规范、工程计价规则问题、行业标准解读等
|
||||
2. 数据问题:涉及电力造价费用、造价指标等
|
||||
|
||||
【安装下载注册包括以下三类】:
|
||||
1. 后缀名查询:询问有关软件后缀名、工程文件扩展名等问题,例如:BDY3是什么文件?、用什么软件打开.BDY3文件?
|
||||
2. 软件锁类:询问软件锁信息、锁注册号查询、许可证查询、锁激活问题等软件锁相关问题
|
||||
3. 安装下载类:安装下载咨询、组件(插件)选择、环境配置等
|
||||
4. 问题排查类:软件安装下载失败、报错,系统兼容性问题等
|
||||
|
||||
【其他】:
|
||||
1. 其他"""
|
||||
|
||||
classification_prompt="""
|
||||
用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容,将其归类为以下垂直领域之一:
|
||||
{classification_info}
|
||||
|
||||
【用户输入】:
|
||||
{user_input}
|
||||
|
||||
【输出格式要求】:
|
||||
{output_format}
|
||||
|
||||
【示例】
|
||||
用户输入1: 技改T1怎样新建工程
|
||||
输出1:
|
||||
{
|
||||
"vertical_classification":"软件咨询",
|
||||
"sub_classification":"软件功能"
|
||||
}
|
||||
"""
|
||||
|
||||
query_rewrite_prompt = """
|
||||
|
||||
你是一名电力造价专业问答优化工程师,负责通过多维度信息整合重构用户问题以提升知识库检索准确率。请严格遵循以下流程处理:
|
||||
|
||||
# 任务处理框架
|
||||
## 第一阶段:输入分析
|
||||
1. 解析基础信息
|
||||
- 原始问题(需保留核心语义):{query}
|
||||
- 关键词集合:{keywords}
|
||||
|
||||
## 第二阶段:语义匹配验证
|
||||
2. 执行关键词校验
|
||||
- 建立意图关联矩阵,验证关键词与原始问题的语义一致性
|
||||
- 若存在≥1个有效关联词 → 进入重构流程
|
||||
- 若无有效关联 → 直接输出原始问题
|
||||
|
||||
## 第三阶段:专业重构
|
||||
3. 术语规范化处理
|
||||
a. 实施术语映射:将口语表达替换为知识库标准术语
|
||||
b. 执行结构优化:
|
||||
- 采用【术语标记】规范标注关键概念
|
||||
- 构建主谓宾明确的问题句式
|
||||
- 保持原问题时态与语态特征
|
||||
|
||||
# 输出规范
|
||||
{output_format}
|
||||
|
||||
# 示范案例库
|
||||
▶ 案例1(有效匹配)
|
||||
输入:
|
||||
原始问题:怎么把旧版西藏定额工程转到Z1新版
|
||||
关键词:【'老版本定额升级', '批量设置定额', '西藏造价软件Z1'】
|
||||
输出:
|
||||
{{"rewrite":"【西藏造价软件Z1】如何执行【老版本定额升级】操作?"}}
|
||||
|
||||
▶ 案例2(无效匹配)
|
||||
输入:
|
||||
原始问题:程序界面文字显示过小如何处理?
|
||||
关键词:【'定额升级', '工程批量导入'】
|
||||
输出:
|
||||
{{"rewrite":"程序界面文字显示过小如何处理?"}}
|
||||
|
||||
# 质量约束条款
|
||||
1. 语义内容保真原则
|
||||
- 禁止修改原问题核心诉求(如转换主语/变更操作对象)
|
||||
- 保留原始问题的限定条件
|
||||
|
||||
2. 术语使用规范
|
||||
- 仅使用检索返回的关键词进行术语替换
|
||||
- 新增术语必须来自关键词集合
|
||||
|
||||
3. 结构优化标准
|
||||
- 问题长度控制在20字内
|
||||
- 必须包含≥1个【标注术语】
|
||||
- 禁止添加解释性语句
|
||||
|
||||
4. 异常处理机制
|
||||
- 当关键词与问题无明显关联时,触发直通输出规则
|
||||
- 出现术语冲突时优先保留原始表述
|
||||
"""
|
||||
@@ -0,0 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from .ProfessionalNounVector import ProfessionalNounVectorizer, ProfessionalNounRetriever
|
||||
from .IntentRecognition import IntentRecognizer
|
||||
from .DataModels import Term, TermList, Classification, QueryRewrite
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,256 @@
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional, Dict
|
||||
from threading import Lock
|
||||
|
||||
API_KEY_LIST=[
|
||||
"sk-xxaiabmfhzwwpijuledllkmkzhzwsqeicjxmjwnvriqpwmpk",
|
||||
"sk-lldcprpqjhgdimiwewgbthngfbrazhkiuioubmaatrcpjjum",
|
||||
"sk-bppugibbtvujomvoysnbcdzpcwndxtwrkfvmgbkbzcmobdon",
|
||||
"sk-hnqitgdlfrrnpimcfxigqibstqquintnzpiidsshpajjyxqd",
|
||||
"sk-hrojkkkrrkmsajtnizokbcgexsfggdiqavbtvbayuwqbnmom",
|
||||
"sk-kkdklmnyompoiotzkfqahpayzlkgogfudjkyaebehtsowvid",
|
||||
"sk-sfxzvllifafbyfduupcdtcrjwhdyiyojnksyopnfslurnhsp",
|
||||
"sk-faqirxiszukfswqvzqawxnemqfacrkyurbxxkzwbbujqacdp",
|
||||
"sk-vonaanuueqiczppkntjuphateshrcpqpnvxmwxorkyihjmrb",
|
||||
"sk-qfpeoodgupcukcdstjcxgegwxnuhtxkkrupkogkcvhavxgny",
|
||||
"sk-fsvjnbpfgoadixympaabaukupuhjvbturcbxaqfdzjznemtr",
|
||||
"sk-fltvnbiqntfawjwkfnnhmyfiimzgzxkweqmefcfqkbucwrhi",
|
||||
"sk-oosswdriwyqkglwdigvcxgmcpyplcyowicbaugpizoscevdl",
|
||||
"sk-jswtxhkiralnyiukqimtyuurcaepulxdrfijadtxzrgsajyc",
|
||||
"sk-dcjuhoukdyrbneadtxtnyxzmigkpiqgtqqnreiprxpioftsv",
|
||||
"sk-yrhezyuxjblpaxzzudbowqmvcoxcammupcubghbodolikbdk",
|
||||
"sk-dsgvwpfagmarilmnewwbzhfzlqehburoupjaopucdvybpbdo",
|
||||
"sk-oljjlspuaurtoczyekztiidwtoerugadgepiufclpmrbdfqc",
|
||||
"sk-crgrimubjesthvxuqwedqqdoetljyrgeahxxpctfefgnkpyo",
|
||||
"sk-tubqhwgycxrdhwsqzjopxgeaqpsjdfppckckayvzornaluwq",
|
||||
"sk-amcxlmsdnadptpnehqnkvseolacipztmvovnmxojzohbjjil",
|
||||
"sk-pdyymhshpzmdduwxsezthnrgarnnhgzvmiflbpisfzxkiayt",
|
||||
"sk-qhwoorywmejumyudfxbrkegxtqifsbgcdkmpjckezepgyqnz",
|
||||
"sk-cpoctrgcnstaybeyuieuwjdgeakudhqdnnwdjavjudcbvvem",
|
||||
]
|
||||
|
||||
class APIKeyManager:
|
||||
"""
|
||||
API密钥管理器,用于解析环境变量中的多个API密钥并提供获取接口
|
||||
支持密钥轮转使用
|
||||
"""
|
||||
# 类变量,用于保存单例实例
|
||||
_instance = None
|
||||
_lock = Lock()
|
||||
|
||||
# 密钥使用计数和上次使用时间
|
||||
_key_usage: Dict[str, Dict] = {}
|
||||
# 当前正在使用的密钥索引
|
||||
_current_index = 0
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, env_var_name: str = "OPENAI_API_KEY", separator: str = ";"):
|
||||
"""
|
||||
获取单例实例
|
||||
|
||||
Args:
|
||||
env_var_name: 环境变量名称,默认为'OPENAI_API_KEY'
|
||||
separator: 密钥分隔符,默认为分号
|
||||
|
||||
Returns:
|
||||
APIKeyManager实例
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls(env_var_name, separator)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_api_key(cls) -> Optional[str]:
|
||||
"""
|
||||
静态方法:获取一个API密钥,使用轮转策略
|
||||
|
||||
Returns:
|
||||
API密钥,如果没有可用的密钥则返回None
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
return instance._get_next_api_key()
|
||||
|
||||
@classmethod
|
||||
def get_random_api_key(cls) -> Optional[str]:
|
||||
"""
|
||||
静态方法:随机获取一个API密钥
|
||||
|
||||
Returns:
|
||||
API密钥,如果没有可用的密钥则返回None
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
return instance._get_random_api_key()
|
||||
|
||||
@classmethod
|
||||
def get_valid_api_keys(cls) -> List[str]:
|
||||
"""
|
||||
静态方法:获取有效的API密钥列表
|
||||
|
||||
Returns:
|
||||
"""
|
||||
# 验证每一个apikey是否有效,无效则删除并打印日志。地址https://api.siliconflow.cn/v1/
|
||||
import requests
|
||||
import logging
|
||||
|
||||
valid_api_keys = []
|
||||
url = "https://api.siliconflow.cn/v1/chat/completions"
|
||||
headers_template = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"messages": [
|
||||
{"role": "user", "content": "ping"}
|
||||
],
|
||||
"max_tokens": 1
|
||||
}
|
||||
for key in API_KEY_LIST:
|
||||
headers = headers_template.copy()
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
try:
|
||||
resp = requests.post(url, headers=headers, json=data, timeout=8)
|
||||
if resp.status_code == 200:
|
||||
valid_api_keys.append(key)
|
||||
else:
|
||||
logging.warning(f"API密钥无效(被移除): {key}, 状态码: {resp.status_code}, 响应: {resp.text}")
|
||||
except Exception as e:
|
||||
logging.warning(f"API密钥验证异常(被移除): {key}, 错误: {e}")
|
||||
return valid_api_keys
|
||||
|
||||
@classmethod
|
||||
def count(cls) -> int:
|
||||
"""
|
||||
静态方法:获取API密钥数量
|
||||
|
||||
Returns:
|
||||
API密钥数量
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
return len(instance.api_keys)
|
||||
|
||||
def __init__(self, env_var_name: str = "OPENAI_API_KEY", separator: str = ";"):
|
||||
"""
|
||||
初始化API密钥管理器
|
||||
|
||||
Args:
|
||||
env_var_name: 环境变量名称,默认为'OPENAI_API_KEY'
|
||||
separator: 密钥分隔符,默认为分号
|
||||
"""
|
||||
self.env_var_name = env_var_name
|
||||
self.separator = separator
|
||||
self.api_keys = self._load_api_keys()
|
||||
|
||||
# 初始化密钥使用统计
|
||||
for key in self.api_keys:
|
||||
if key not in self._key_usage:
|
||||
self._key_usage[key] = {
|
||||
"count": 0,
|
||||
"last_used": 0
|
||||
}
|
||||
|
||||
def _load_api_keys(self) -> List[str]:
|
||||
"""
|
||||
从环境变量加载API密钥
|
||||
|
||||
Returns:
|
||||
API密钥列表
|
||||
"""
|
||||
# api_keys = []
|
||||
# env_value = os.environ.get(self.env_var_name)
|
||||
|
||||
# if env_value:
|
||||
# # 分割环境变量并移除空白字符
|
||||
# keys = [key.strip() for key in env_value.split(self.separator)]
|
||||
# # 过滤掉空字符串
|
||||
# api_keys = [key for key in keys if key]
|
||||
|
||||
# return api_keys
|
||||
return API_KEY_LIST
|
||||
|
||||
def _get_next_api_key(self) -> Optional[str]:
|
||||
"""
|
||||
获取下一个API密钥,使用轮转策略
|
||||
|
||||
Returns:
|
||||
API密钥,如果没有可用的密钥则返回None
|
||||
"""
|
||||
if not self.api_keys:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
# 轮转到下一个密钥
|
||||
self._current_index = (self._current_index + 1) % len(self.api_keys)
|
||||
selected_key = self.api_keys[self._current_index]
|
||||
|
||||
# 更新使用统计
|
||||
self._key_usage[selected_key]["count"] += 1
|
||||
self._key_usage[selected_key]["last_used"] = time.time()
|
||||
|
||||
return selected_key
|
||||
|
||||
def _get_random_api_key(self) -> Optional[str]:
|
||||
"""
|
||||
随机获取一个API密钥
|
||||
|
||||
Returns:
|
||||
API密钥,如果没有可用的密钥则返回None
|
||||
"""
|
||||
if not self.api_keys:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
selected_key = random.choice(self.api_keys)
|
||||
|
||||
# 更新使用统计
|
||||
self._key_usage[selected_key]["count"] += 1
|
||||
self._key_usage[selected_key]["last_used"] = time.time()
|
||||
|
||||
return selected_key
|
||||
|
||||
def get_all_api_keys(self) -> List[str]:
|
||||
"""
|
||||
获取所有API密钥
|
||||
|
||||
Returns:
|
||||
API密钥列表
|
||||
"""
|
||||
return self.api_keys.copy()
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""
|
||||
检查是否有可用的API密钥
|
||||
|
||||
Returns:
|
||||
如果有可用的API密钥则返回True,否则返回False
|
||||
"""
|
||||
return len(self.api_keys) > 0
|
||||
|
||||
def get_usage_stats(self) -> Dict:
|
||||
"""
|
||||
获取密钥使用统计信息
|
||||
|
||||
Returns:
|
||||
密钥使用统计信息
|
||||
"""
|
||||
return self._key_usage.copy()
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
|
||||
# 获取有效的API密钥列表
|
||||
valid_keys = APIKeyManager.get_valid_api_keys()
|
||||
print(f"有效的API密钥列表:\n" + "\n".join(valid_keys))
|
||||
|
||||
# 查看总密钥数
|
||||
print(f"总共有 {APIKeyManager.count()} 个API密钥")
|
||||
|
||||
# 获取实例并查看使用统计
|
||||
instance = APIKeyManager.get_instance()
|
||||
stats = instance.get_usage_stats()
|
||||
for key, data in stats.items():
|
||||
print(f"密钥 {key[:5]}... 使用次数: {data['count']}")
|
||||
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: ModelTool.py
|
||||
Date: 2025-05-15
|
||||
Author: oyyz
|
||||
Description: 模型工具类
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
import httpx
|
||||
import time
|
||||
import logging # 导入 logging 模块
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from typing import List, Any
|
||||
import requests
|
||||
import os
|
||||
import logging
|
||||
from .APIKeyManager import APIKeyManager
|
||||
|
||||
class SiliconFlowEmbeddings(Embeddings):
|
||||
"""SiliconFlow嵌入模型封装"""
|
||||
def __init__(self, api_key: str, model: str = "bge-m3"):
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.url = "http://10.1.16.39:9995/v1/embeddings"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
def _embed(self, input: List[str]) -> List[List[float]]:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": input,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
response = requests.post(self.url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [item["embedding"] for item in data["data"]]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._embed(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._embed([text])[0]
|
||||
|
||||
class XinferenceReRankerModel:
|
||||
"""重排模型封装"""
|
||||
|
||||
@staticmethod
|
||||
def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]:
|
||||
"""
|
||||
使用重排序模型对文档进行重新排序
|
||||
|
||||
Args:
|
||||
query: 用户查询文本
|
||||
documents: 需要重新排序的文档列表
|
||||
top_k: 返回排序后的前k个文档
|
||||
|
||||
Returns:
|
||||
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
||||
"""
|
||||
url = "http://10.1.16.39:9995/v1/rerank"
|
||||
|
||||
|
||||
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": os.getenv("RERANKER_MODEL_NAME")}
|
||||
headers = {
|
||||
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=params, headers=headers)
|
||||
response.raise_for_status() # 检查响应状态
|
||||
results = response.json()
|
||||
|
||||
# 返回重排序后的文档列表
|
||||
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"重排序请求失败: {str(e)}")
|
||||
return []
|
||||
|
||||
class OpenAiLLM:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if kwargs.get("api_key") == None or kwargs.get("base_url") == None or kwargs.get("model") == None:
|
||||
raise ValueError("api_key, base_url, model 不能为空")
|
||||
|
||||
self._api_key = kwargs.get("api_key")
|
||||
self._url = kwargs.get("base_url")
|
||||
self._model = kwargs.get("model")
|
||||
|
||||
kwargs.pop("api_key")
|
||||
kwargs.pop("base_url")
|
||||
kwargs.pop("model")
|
||||
self._kwargs = kwargs
|
||||
|
||||
def invoke(self, user_prompt="你是谁?", need_retry=True):
|
||||
# 初始化 OpenAI 客户端
|
||||
api_key = APIKeyManager.get_api_key()
|
||||
client = OpenAI(api_key=api_key, base_url=self._url)
|
||||
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
if need_retry:
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# 创建 Completion 请求. 超时120s
|
||||
completion = client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[{'role': 'user', 'content': user_prompt}],
|
||||
timeout=httpx.Timeout(300.0),
|
||||
**self._kwargs
|
||||
)
|
||||
return completion.choices[0].message
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count == max_retries:
|
||||
logging.error(f"LLM 重试{max_retries}次后仍然失败: {e}")
|
||||
return ""
|
||||
else:
|
||||
time.sleep(5*retry_count) # 重试前等待1秒
|
||||
else:
|
||||
# 创建 Completion 请求. 超时120s
|
||||
completion = client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[{'role': 'user', 'content': user_prompt}],
|
||||
timeout=httpx.Timeout(300.0),
|
||||
**self._kwargs
|
||||
)
|
||||
return completion.choices[0].message
|
||||
|
||||
if __name__ == "__main__":
|
||||
reranker = XinferenceReRankerModel()
|
||||
query = "什么是AI"
|
||||
documents = ["AI是人工智能", "AI是机器学习", "AI是深度学习"]
|
||||
results = reranker.rerank(query, documents)
|
||||
print(results)
|
||||
@@ -0,0 +1,159 @@
|
||||
import os.path
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class WikijsTool:
|
||||
BASE_URL = "http://10.1.16.39:8090/graphql"
|
||||
HEADERS = {
|
||||
"Authorization": "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJhcGkiOjcsImdycCI6MSwiaWF"
|
||||
"0IjoxNzIzMDIwNzg4LCJleHAiOjE4MTc2OTM1ODgsImF1ZCI6InVybjp3aWtpLmpzIiwiaX"
|
||||
"NzIjoidXJuOndpa2kuanMifQ.NSfE4tB7tkN8yapAs0CgkR-Yll6wc3gO3QGKMAv-TlGxx6A-9fJRmkwhRDTVMj_yPVG6"
|
||||
"NXVy_AZpJtLapRXFGn0cvscsRJxq3fY1KgEyt8wO99jvd8DpNHpHhAIgrtyDelmHsBD2Wb5Ib3WJFsWC6d8Yhm9dkpx6tZ"
|
||||
"vMAlFIKOg6UodMoMIry3YWiPGLaqJPQ0gcKmcnB2tC7sPXIIZnvfb5912GVM0n-4wvWobQnb_tXQuYZf99wH_leXjC_7BK8"
|
||||
"8JSaAmB980i3rBxfejmaJ8E6D48zRxwwPFa0veVjjzRkVqHPwAjl1CXb2HE29pGtNmSEE1kLQVqOZD_ibOwKQ"
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def init_url():
|
||||
# 获取当前文件的路径
|
||||
file_path = Path(__file__).resolve()
|
||||
file_path = os.path.join(file_path.parent, 'wikiconfig.json')
|
||||
if not os.path.exists(file_path):
|
||||
return False
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
|
||||
if 'url' in data:
|
||||
WikijsTool.BASE_URL = data['url']
|
||||
|
||||
if 'Authorization' in data:
|
||||
WikijsTool.HEADERS['Authorization'] = data['Authorization']
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_all_documents() -> list[dict]:
|
||||
query = """
|
||||
query Pages {
|
||||
pages {
|
||||
list {
|
||||
path
|
||||
locale
|
||||
title
|
||||
contentType
|
||||
id
|
||||
isPublished
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
# 构建请求数据
|
||||
data = {
|
||||
'query': query,
|
||||
}
|
||||
|
||||
# 发送 POST 请求
|
||||
response = requests.post(WikijsTool.BASE_URL, headers=WikijsTool.HEADERS, json=data)
|
||||
if response.status_code == 200:
|
||||
# 解析数据
|
||||
list_info = json.loads(response.content)['data']['pages']['list']
|
||||
return [item for item in list_info]
|
||||
else:
|
||||
raise ValueError(f"获取文档列表失败,原因:“{response.text}")
|
||||
|
||||
@staticmethod
|
||||
def get_all_doc_by_path(path: str, path_is_dir: bool = True) -> list[dict]:
|
||||
list_document = WikijsTool.get_all_documents()
|
||||
all_document_list = []
|
||||
if path_is_dir:
|
||||
temp_path = path + '/'
|
||||
else:
|
||||
temp_path = path
|
||||
for document_info in list_document:
|
||||
document_path = str(document_info["path"])
|
||||
# 根据路径过滤出对应的所有文档
|
||||
if not document_path.startswith(temp_path):
|
||||
continue
|
||||
|
||||
all_document_list.append(document_info)
|
||||
|
||||
return all_document_list
|
||||
|
||||
@staticmethod
|
||||
def search_document(query_str: str) -> list[dict]:
|
||||
graphql_query = f"""
|
||||
query Pages {{
|
||||
pages {{
|
||||
search(query: "{query_str}") {{
|
||||
results {{
|
||||
id
|
||||
path
|
||||
locale
|
||||
title
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
# 构建请求数据
|
||||
data = {
|
||||
'query': graphql_query,
|
||||
}
|
||||
|
||||
# 发送 POST 请求
|
||||
response = requests.post(WikijsTool.BASE_URL, headers=WikijsTool.HEADERS, json=data)
|
||||
if response.status_code == 200:
|
||||
# 解析数据
|
||||
search_results = json.loads(response.content)['data']['pages']['search']['results']
|
||||
return search_results
|
||||
else:
|
||||
raise ValueError(f"查询文档失败,原因:“{response.text}")
|
||||
|
||||
@staticmethod
|
||||
def query_doc_info(doc_id: int) -> dict:
|
||||
query = """
|
||||
query singlePages($doc_id: Int!) {
|
||||
pages {
|
||||
single(id: $doc_id) {
|
||||
id
|
||||
path
|
||||
title
|
||||
isPublished
|
||||
content
|
||||
contentType
|
||||
isPrivate
|
||||
updatedAt
|
||||
createdAt
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
# 构建请求数据
|
||||
variables = {
|
||||
'doc_id': doc_id,
|
||||
}
|
||||
data = {
|
||||
'query': query,
|
||||
'variables': variables
|
||||
}
|
||||
|
||||
# 发送 POST 请求
|
||||
response = requests.post(WikijsTool.BASE_URL, headers=WikijsTool.HEADERS, json=data)
|
||||
if "errors" in response.text:
|
||||
result = json.loads(response.content)['errors'][0]['message']
|
||||
return {}
|
||||
else:
|
||||
return json.loads(response.content)['data']['pages']['single']
|
||||
|
||||
|
||||
WikijsTool.init_url()
|
||||
if __name__ == "__main__":
|
||||
WikijsTool.query_doc_info(6448)
|
||||
print(WikijsTool.rename_directory("配网知识库/配网造价软件", "配网知识库/配网造价软件1"))
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
from . import custom_markdownify
|
||||
|
||||
convert_html_to_md = custom_markdownify.md
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,491 @@
|
||||
import re
|
||||
from textwrap import fill
|
||||
|
||||
import requests
|
||||
from bs4 import NavigableString
|
||||
from bs4 import BeautifulSoup
|
||||
from markdownify import MarkdownConverter, chomp, UNDERLINED, ATX_CLOSED
|
||||
import copy
|
||||
from . import picture_process
|
||||
|
||||
|
||||
# <br>是否是单元格内部的换行符
|
||||
def judge_br_in_table(el):
|
||||
if el.name in ['td', 'tr']:
|
||||
return True
|
||||
if el.parent is None:
|
||||
return False
|
||||
# 递归父级元素
|
||||
return judge_br_in_table(el.parent)
|
||||
|
||||
|
||||
# 获取div标签中是否为标题,如果是标题则markdown中的返回标题等级
|
||||
def get_markdown_title_level(el):
|
||||
if el.name != 'div' or 'class' not in el.attrs:
|
||||
return ''
|
||||
title_level = ''
|
||||
if 'hdwiki_tmml' in el.attrs['class']:
|
||||
title_level = '## '
|
||||
elif 'hdwiki_tmmll' in el.attrs['class']:
|
||||
title_level = '### '
|
||||
return title_level
|
||||
|
||||
|
||||
def str_is_title(text) -> bool:
|
||||
text = text.strip()
|
||||
pattern = r'^#+'
|
||||
|
||||
# 使用re.search匹配字符串开头的 # 符号
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# 判断el 是否是图片的DIV标签
|
||||
def is_img_div_tag(el) -> bool:
|
||||
if el is None:
|
||||
return False
|
||||
if el.name != "div":
|
||||
return False
|
||||
class_attr = el.get('class')
|
||||
if class_attr is None:
|
||||
return False
|
||||
if "img" in class_attr or "img_l" in class_attr:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# 判断div内部是否是纯文本内容,并且display是否为block
|
||||
def is_only_text_div(el) -> bool:
|
||||
if el is None or el.name != "div" or el.text == "":
|
||||
return False
|
||||
|
||||
if el.get("display", "block") != "block":
|
||||
return False
|
||||
|
||||
# div标签下只包含文本
|
||||
if isinstance(el.string, NavigableString):
|
||||
return True
|
||||
|
||||
# 兼容<div><b>1. 版本概述</b> </div> 判断错误问题
|
||||
# 递归获取所有子标签
|
||||
child_tags = el.find_all(recursive=True)
|
||||
for tag in child_tags:
|
||||
if tag.text == "":
|
||||
continue
|
||||
if tag.name in ["table", "td", "img"]:
|
||||
return False
|
||||
if isinstance(tag.string, NavigableString):
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# a标签是否在图片的div标签内部
|
||||
def a_tag_is_in_img(el) -> bool:
|
||||
if el.parent is None:
|
||||
return False
|
||||
if el.name != "a" or el.parent.name != "div":
|
||||
return False
|
||||
|
||||
return is_img_div_tag(el.parent)
|
||||
|
||||
|
||||
class CustomMarkDownConverter(MarkdownConverter):
|
||||
"""
|
||||
创建自定义的换行装换函数
|
||||
"""
|
||||
|
||||
def __init__(self, img_download_path, **options):
|
||||
super().__init__(**options)
|
||||
self.img_download_path = img_download_path
|
||||
|
||||
# 单元格内的换行依旧保持<br>格式
|
||||
def convert_br(self, el, text, convert_as_inline):
|
||||
if judge_br_in_table(el):
|
||||
return "<br/>"
|
||||
|
||||
# 容错处理(文章4696),因bs4解析html错误 导致将 分类图标签 解析到了br标签下导致图片丢失
|
||||
if text.strip():
|
||||
return text + "\n"
|
||||
|
||||
return super().convert_br(el, text, convert_as_inline)
|
||||
|
||||
# 图片div标签 在图片与图片描述之间添加换行
|
||||
@staticmethod
|
||||
def convert_img_div(text):
|
||||
pattern = r'\*\*(.*?)\*\*'
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
start_index = match.start()
|
||||
text = text[:start_index] + "\n" + text[start_index:]
|
||||
return text
|
||||
|
||||
# 装换标题格式
|
||||
def convert_div(self, el, text, convert_as_inline):
|
||||
title_level = get_markdown_title_level(el)
|
||||
if title_level != '':
|
||||
return "\n\n" + title_level + text + '\n\n'
|
||||
|
||||
if is_img_div_tag(el):
|
||||
# 图片与图片描述文字之间掺入换行符
|
||||
return self.convert_img_div(text)
|
||||
|
||||
if is_only_text_div(el):
|
||||
text = "\n\n" + text + "\n\n"
|
||||
|
||||
return text
|
||||
|
||||
# 检查 URL 是否有效的函数
|
||||
@staticmethod
|
||||
def is_valid_url(url):
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True)
|
||||
return response.status_code == 200
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def try_complete_img_description(img_el):
|
||||
if img_el is None or img_el.name != "img":
|
||||
return
|
||||
|
||||
# 找到父级的div标签
|
||||
img_el_parent_div = None
|
||||
cur_el = img_el
|
||||
while cur_el.parent is not None:
|
||||
if is_img_div_tag(cur_el.parent):
|
||||
img_el_parent_div = cur_el.parent
|
||||
break
|
||||
cur_el = cur_el.parent
|
||||
|
||||
if img_el_parent_div is not None and len(img_el_parent_div.text) != 0:
|
||||
img_el.attrs["alt"] = img_el_parent_div.text
|
||||
return
|
||||
|
||||
# 找到父级的figure标签
|
||||
img_el_parent_div = None
|
||||
cur_el = img_el
|
||||
while cur_el.parent is not None:
|
||||
if cur_el.parent is not None and cur_el.parent.name == 'figure':
|
||||
img_el_parent_div = cur_el.parent
|
||||
break
|
||||
cur_el = cur_el.parent
|
||||
|
||||
if img_el_parent_div is not None and len(img_el_parent_div.text) != 0:
|
||||
img_el.attrs["alt"] = img_el_parent_div.text
|
||||
return
|
||||
|
||||
|
||||
def convert_figcaption(self, el, text, convert_as_inline):
|
||||
return ""
|
||||
|
||||
# 图片后添加空行,图片应该单独在一行后面不接文字(示例文章:6925)
|
||||
def convert_img(self, el, text, convert_as_inline):
|
||||
self.try_complete_img_description(el)
|
||||
img_text = super().convert_img(el, text, convert_as_inline)
|
||||
|
||||
# 5195 出现img标签内出现换行导致 markdown图片显示出现问题
|
||||
img_text = img_text.replace("\r\n", "")
|
||||
img_text = img_text.replace("\n", "")
|
||||
# 空的img标签直接返回空行
|
||||
if img_text == "![]()":
|
||||
return '\n\n'
|
||||
|
||||
# img 标签使用父级超链接标签中的中大图
|
||||
src = el.attrs.get('src', None) or ''
|
||||
if el.parent is not None and el.parent.name == "a":
|
||||
href = el.parent.attrs.get('href', None) or ''
|
||||
href_path = href.rsplit(".", 1)[0]
|
||||
src_path = src.rsplit(".", 1)[0]
|
||||
if href_path + "_s" == src_path:
|
||||
img_text = img_text.replace(src, href)
|
||||
|
||||
if '_s' in img_text:
|
||||
src_path = src.rsplit(".", 1)[0]
|
||||
if src_path.endswith('_s'):
|
||||
original_src_path = src_path[:-2] # 去掉末尾的 '_s'
|
||||
# 构建原始 URL
|
||||
original_url = original_src_path + "." + src.split(".")[-1]
|
||||
if self.is_valid_url(original_url):
|
||||
img_text = img_text.replace(src, original_url)
|
||||
|
||||
# 转换并下载图片
|
||||
return picture_process.process_img_tag(img_text, self.img_download_path)
|
||||
|
||||
@staticmethod
|
||||
def is_img_describe_strong(el) -> bool:
|
||||
if el is None or el.parent is None:
|
||||
return False
|
||||
|
||||
if len(el.contents) == 0:
|
||||
return False
|
||||
|
||||
# if not isinstance(el.contents[0], NavigableString):
|
||||
# return False
|
||||
|
||||
img_list = el.parent.findAll("img")
|
||||
if len(img_list) == 0:
|
||||
return False
|
||||
|
||||
for img_tag in img_list:
|
||||
alt = img_tag.get("alt", None)
|
||||
title = img_tag.get("title", None)
|
||||
if alt is None and title is None:
|
||||
continue
|
||||
|
||||
if alt == el.text or title == el.text:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def convert_b(self, el, text, convert_as_inline):
|
||||
# 如果b 标签下只存在一个标题,则该b不做任何处理,避免对标题进行加粗(示例文章:6925)
|
||||
if len(el.contents) == 1:
|
||||
title_level = get_markdown_title_level(el.contents[0])
|
||||
if title_level != '':
|
||||
return text
|
||||
|
||||
# <b> 标签中存在标题时,不在对内容进行加粗
|
||||
if str_is_title(text):
|
||||
return text
|
||||
|
||||
if self.is_img_describe_strong(el):
|
||||
return ""
|
||||
|
||||
text = text.strip(" \t")
|
||||
suffix = ""
|
||||
if text.endswith("\n"):
|
||||
suffix = " \n"
|
||||
b_text = super().convert_b(el, text, convert_as_inline)
|
||||
|
||||
# 解析完<b> 标签后添加空格。避免出现markdown文档中出现《**1.****版本概述**》(文章2377 4292等)
|
||||
return " " + b_text + suffix + " "
|
||||
|
||||
convert_strong = convert_b
|
||||
|
||||
# 有可能出现<p>之后紧接一个标题hdwiki_tmml 故前后添加换行
|
||||
def convert_p(self, el, text, convert_as_inline):
|
||||
if convert_as_inline:
|
||||
return text
|
||||
if self.options['wrap']:
|
||||
text = fill(text,
|
||||
width=self.options['wrap_width'],
|
||||
break_long_words=False,
|
||||
break_on_hyphens=False)
|
||||
# <p>标签前后换行
|
||||
return '\n\n%s\n\n' % text if text else ''
|
||||
|
||||
def convert_a(self, el, text, convert_as_inline):
|
||||
prefix, suffix, text = chomp(text)
|
||||
if not text:
|
||||
return ''
|
||||
href = el.get('href')
|
||||
if self.is_href_img(href):
|
||||
return text
|
||||
title = el.get('title')
|
||||
# 5195 出现img标签内出现换行导致 markdown图片显示出现问题
|
||||
if title is not None:
|
||||
title = title.replace("\n", "")
|
||||
# For the replacement see #29: text nodes underscores are escaped
|
||||
if (self.options['autolinks']
|
||||
and text.replace(r'\_', '_') == href
|
||||
and not title
|
||||
and not self.options['default_title']):
|
||||
# Shortcut syntax
|
||||
return '<%s>' % href
|
||||
if self.options['default_title'] and not title:
|
||||
title = href
|
||||
title_part = ' "%s"' % title.replace('"', r'\"') if title else ''
|
||||
|
||||
a_tag = '%s[%s](%s%s)%s' % (prefix, text, href, title_part, suffix) if href else text
|
||||
return a_tag
|
||||
|
||||
@staticmethod
|
||||
def is_href_img(href_url) -> bool:
|
||||
if href_url is None:
|
||||
return False
|
||||
file_extension = href_url.split(".")[-1]
|
||||
# 不是图片不处理
|
||||
file_extension = file_extension.lower()
|
||||
if file_extension not in ["jpg", "jpeg", "png", "gif"]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def convert_li(self, el, text, convert_as_inline):
|
||||
# 为空的li标签返回空(文章 4347)
|
||||
if not text.strip():
|
||||
return ""
|
||||
|
||||
li_text = super().convert_li(el, text, convert_as_inline)
|
||||
return li_text
|
||||
|
||||
def convert_td(self, el, text, convert_as_inline):
|
||||
if "\r\n" in text:
|
||||
text = text.replace("\r\n", "<br>")
|
||||
|
||||
if "\n" in text:
|
||||
text = text.replace("\n", "<br>")
|
||||
|
||||
return ' ' + text + ' |'
|
||||
|
||||
def convert_hn(self, n, el, text, convert_as_inline):
|
||||
if convert_as_inline:
|
||||
return text
|
||||
|
||||
style = self.options['heading_style'].lower()
|
||||
text = text.rstrip()
|
||||
if style == UNDERLINED and n <= 2:
|
||||
line = '=' if n == 1 else '-'
|
||||
return self.underline(text, line)
|
||||
hashes = '#' * n
|
||||
hashes = hashes + " "
|
||||
if style == ATX_CLOSED:
|
||||
return '\n\n %s %s %s\n\n' % (hashes, text, hashes)
|
||||
return '\n\n%s %s\n\n' % (hashes, text)
|
||||
|
||||
@staticmethod
|
||||
def convert_thead_table(el, text, cell_name, convert_as_inline):
|
||||
cells = el.find_all(['td', 'th'])
|
||||
is_headrow = all([cell.name == cell_name for cell in cells])
|
||||
overline = ''
|
||||
underline = ''
|
||||
if is_headrow and not el.previous_sibling:
|
||||
# first row and is headline: print headline underline
|
||||
underline += '| ' + ' | '.join(['---'] * len(cells)) + ' |' + '\n'
|
||||
elif (not el.previous_sibling
|
||||
and (el.parent.name == 'table'
|
||||
or (el.parent.name == 'tbody'
|
||||
and not el.parent.previous_sibling))):
|
||||
# first row, not headline, and:
|
||||
# - the parent is table or
|
||||
# - the parent is tbody at the beginning of a table.
|
||||
# print empty headline above this row
|
||||
overline += '| ' + ' | '.join([''] * len(cells)) + ' |' + '\n'
|
||||
overline += '| ' + ' | '.join(['---'] * len(cells)) + ' |' + '\n'
|
||||
return overline + '|' + text + '\n' + underline
|
||||
|
||||
def convert_tr(self, el, text, convert_as_inline):
|
||||
# 解决table标签下存在thead的问题 (文章4061 1976)
|
||||
if el and el.parent and el.parent.name == "thead":
|
||||
return CustomMarkDownConverter.convert_thead_table(el, text, 'td', convert_as_inline)
|
||||
|
||||
# 兼容 table->colgroup、tbody->tr 文章4364
|
||||
if (el and el.parent and el.parent.previousSibling
|
||||
and el.parent.name == "tbody"
|
||||
and el.parent.previousSibling.name == "colgroup"):
|
||||
return CustomMarkDownConverter.convert_thead_table(el, text, 'td', convert_as_inline)
|
||||
|
||||
return super().convert_tr(el, text, convert_as_inline)
|
||||
|
||||
def convert_pre(self, el, text, convert_as_inline):
|
||||
# 文章5192出现pre标签,但内容不是代码。故不额外处理pre标签
|
||||
return text
|
||||
|
||||
def escape(self, text):
|
||||
if not text:
|
||||
return ''
|
||||
if self.options['escape_misc']:
|
||||
# text = re.sub(r'([\\&<`[>~#=+|-])', r'\\\1', text)
|
||||
text = re.sub(r'([\\&<`[>~#%=+|-])', r'\\\1', text)
|
||||
# 以下的转义是不必要的
|
||||
# text = re.sub(r'([0-9])([.)])', r'\1\\\2', text)
|
||||
if self.options['escape_asterisks']:
|
||||
text = text.replace('*', r'\*')
|
||||
if self.options['escape_underscores']:
|
||||
text = text.replace('_', r'\_')
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def convert_span(el, text, convert_as_inline):
|
||||
# 文章3526出现图片后面紧接图片文本的问题。图片文本在span标签内
|
||||
if "style" not in el.attrs:
|
||||
return text
|
||||
|
||||
style_attr = el.attrs['style']
|
||||
|
||||
if style_attr is None:
|
||||
return text
|
||||
style_content = style_attr.split(';')
|
||||
# 遍历style属性内容,找到display的值
|
||||
for item in style_content:
|
||||
if 'display' in item:
|
||||
display_value = item.split(': ')[1] # 获取冒号后的值
|
||||
if display_value == "block" and text != "":
|
||||
return f"\n\n{text}\n\n"
|
||||
return text
|
||||
|
||||
|
||||
def expand_html_table(html) -> tuple[str, bool]:
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
tables = soup.find_all('table')
|
||||
if len(tables) == 0:
|
||||
return html, False
|
||||
for table in tables:
|
||||
# 创建一个二维列表来表示表格
|
||||
table_rows = table.find_all('tr')
|
||||
max_cols = 0
|
||||
for row in table_rows:
|
||||
cols = row.find_all(['td', 'th'])
|
||||
col_count = sum([int(col.get('colspan', 1)) for col in cols])
|
||||
if col_count > max_cols:
|
||||
max_cols = col_count
|
||||
|
||||
# 初始化一个二维列表来存储最终的表格
|
||||
result_table = []
|
||||
for _ in range(len(table_rows)):
|
||||
result_table.append([None] * max_cols)
|
||||
|
||||
# 填充二维列表
|
||||
for r, row in enumerate(table_rows):
|
||||
cols = row.find_all(['td', 'th'])
|
||||
c = 0
|
||||
for col in cols:
|
||||
while result_table[r][c] is not None:
|
||||
c += 1
|
||||
colspan = int(col.get('colspan', 1))
|
||||
rowspan = int(col.get('rowspan', 1))
|
||||
for i in range(rowspan):
|
||||
for j in range(colspan):
|
||||
# 拆分合并单元格时,重复内容
|
||||
result_table[r + i][c + j] = copy.copy(col)
|
||||
# if j == 0 and i == 0:
|
||||
# result_table[r + i][c + j] = copy.copy(col)
|
||||
# else:
|
||||
# result_table[r + i][c + j] = soup.new_tag('td')
|
||||
c += colspan
|
||||
|
||||
# 生成新的表格 HTML
|
||||
new_table = soup.new_tag('table', border="1", cellspacing="0")
|
||||
tbody = soup.new_tag('tbody')
|
||||
new_table.append(tbody)
|
||||
for row in result_table:
|
||||
tr = soup.new_tag('tr')
|
||||
for col in row:
|
||||
if col is not None:
|
||||
td = soup.new_tag(col.name)
|
||||
td.string = col.get_text()
|
||||
tr.append(td)
|
||||
tbody.append(tr)
|
||||
|
||||
# 替换原始HTML中的旧表格
|
||||
table.replace_with(new_table)
|
||||
|
||||
return str(soup), True
|
||||
|
||||
|
||||
# Create shorthand method for conversion
|
||||
def md(html, img_download_path, **options):
|
||||
new_html, result = expand_html_table(html)
|
||||
markdown_content = CustomMarkDownConverter(img_download_path, **options).convert(new_html)
|
||||
# 删除换行符中间的空格
|
||||
temp_txt = re.sub(r'\n\s*\n', '\n\n', markdown_content)
|
||||
# 连续超过3个以上的换行符替换为3个
|
||||
temp_txt = re.sub(r'\n{3,}', '\n\n\n', temp_txt)
|
||||
return temp_txt
|
||||
@@ -0,0 +1,170 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
|
||||
|
||||
def get_img_tag_url(img_tag):
|
||||
|
||||
# 提取图片url的正则表达式模式
|
||||
pattern = r'\!\[.*?\]\((.*?)\)'
|
||||
# 找到第一个匹配的链接
|
||||
match = re.search(pattern, img_tag)
|
||||
if not match:
|
||||
return ""
|
||||
|
||||
# 获取匹配到的链接
|
||||
link = match.group(1)
|
||||
# 第0个为链接
|
||||
link = link.split(" ")[0]
|
||||
return link
|
||||
|
||||
|
||||
# 填充img标签中的图片链接
|
||||
# img_tag ''
|
||||
# img_tag ''
|
||||
def fill_img_url(img_tag):
|
||||
"""
|
||||
填充img标签中的图片链接。
|
||||
|
||||
参数:
|
||||
img_tag (str): 原始的img标签
|
||||
|
||||
返回:
|
||||
tuple: 修改后的img标签和图片的完整链接
|
||||
"""
|
||||
# 一个完整的img标签内删除换行符
|
||||
img_tag = img_tag.replace("\n", "")
|
||||
link = get_img_tag_url(img_tag)
|
||||
if len(link) == 0:
|
||||
return img_tag, ''
|
||||
|
||||
base_url = os.getenv("IMG_URL_PREFIX")
|
||||
if "http:" in link:
|
||||
# 图片为全链接,不替换
|
||||
return img_tag, link
|
||||
elif base_url:
|
||||
# 补全图片链接
|
||||
full_link = urljoin(base_url, link)
|
||||
img_tag = img_tag.replace(link, full_link)
|
||||
return img_tag, full_link
|
||||
else:
|
||||
return img_tag, ''
|
||||
|
||||
|
||||
def download_picture(img_tag, download_path):
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '
|
||||
'Chrome/94.0.4606.71 Safari/537.36 '
|
||||
}
|
||||
img_tag, img_url = fill_img_url(img_tag)
|
||||
if img_url == '':
|
||||
return img_tag
|
||||
# if "_s" in img_tag:
|
||||
# breakpoint()
|
||||
file_name = img_url.split("/")[-1]
|
||||
file_path = os.path.normpath(download_path + "\\" + file_name)
|
||||
file_path = file_path.replace("\\", "/")
|
||||
|
||||
# 文件已经存在时不下载
|
||||
if not os.path.exists(file_path):
|
||||
img_date = requests.get(url=img_url, headers=headers).content
|
||||
logging.info(f"图片下载成功:{img_url}")
|
||||
with open(file_path, 'wb') as fp:
|
||||
fp.write(img_date)
|
||||
|
||||
# img_tag中的url替换为下载的图片路径
|
||||
return img_tag.replace(img_url, file_path)
|
||||
|
||||
|
||||
def download_picture_from_other_url(img_tag, download_path):
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '
|
||||
'Chrome/94.0.4606.71 Safari/537.36 '
|
||||
}
|
||||
img_tag, img_url = fill_img_url(img_tag)
|
||||
# if "_s" in img_tag:
|
||||
# breakpoint()
|
||||
file_name = uuid.uuid4()
|
||||
file_path = os.path.join(download_path, f"{file_name}.png")
|
||||
file_path = os.path.normpath(file_path)
|
||||
# 文件已经存在时不下载
|
||||
if not os.path.exists(file_path):
|
||||
try:
|
||||
img_date = requests.get(url=img_url, headers=headers).content
|
||||
with open(file_path, 'wb') as fp:
|
||||
fp.write(img_date)
|
||||
logging.info(f"图片下载成功:{img_url}")
|
||||
except Exception as e:
|
||||
logging.warning(f"img download error url:{img_url}")
|
||||
return img_tag
|
||||
|
||||
# img_tag中的url替换为下载的图片路径
|
||||
return img_tag.replace(img_url, file_path)
|
||||
|
||||
|
||||
def extract_base64_from_data_uri(data_uri):
|
||||
# 分割字符串以找到 base64 部分
|
||||
parts = data_uri.split(',')
|
||||
if len(parts) == 2 and parts[0].endswith('base64'):
|
||||
# 移除后缀并返回 base64 值
|
||||
return parts[1][:-1]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def picture_base64(img_tag, picture_save_path):
|
||||
# 解码Base64字符串
|
||||
# 
|
||||
base64_str = extract_base64_from_data_uri(img_tag)
|
||||
if picture_save_path is None or picture_save_path == "":
|
||||
return ""
|
||||
# 将图片内容做MD5 用作文件名
|
||||
hash_object = hashlib.md5()
|
||||
hash_object.update(base64_str.encode())
|
||||
img_md5 = hash_object.hexdigest()
|
||||
|
||||
picture_save_path = picture_save_path + "\\%s.png" % img_md5
|
||||
picture_save_path = os.path.normpath(picture_save_path)
|
||||
picture_save_path = picture_save_path.replace("\\", "/")
|
||||
|
||||
# 文件已经存在时不重新保存
|
||||
if not os.path.exists(picture_save_path):
|
||||
decoded_string = base64.b64decode(base64_str)
|
||||
with open(picture_save_path, 'wb') as fp:
|
||||
fp.write(decoded_string)
|
||||
|
||||
# 修改img_tab的图片路径
|
||||
match = re.search("\[(.*?)\]", img_tag)
|
||||
result = ""
|
||||
if match:
|
||||
result = match.group(1)
|
||||
if result == "":
|
||||
return "" % picture_save_path
|
||||
else:
|
||||
return "" % (result, picture_save_path, result)
|
||||
|
||||
|
||||
def process_img_tag(str_img_tag, img_path):
|
||||
# 如果img标签指向的是本地磁盘路径 则忽略该标签返回空
|
||||
if "file:///" in str_img_tag:
|
||||
logging.warning(f"存在非法的链接地址:{str_img_tag}")
|
||||
return ""
|
||||
if img_path is None or img_path == "":
|
||||
return ""
|
||||
|
||||
img_url = get_img_tag_url(str_img_tag)
|
||||
if "data:image/png;base64" in str_img_tag:
|
||||
return picture_base64(str_img_tag, img_path)
|
||||
# (4696等存在指向外部链接的 img标签。 暂时保留不删除)
|
||||
elif "http://" in str_img_tag or "https://" in str_img_tag:
|
||||
return download_picture_from_other_url(str_img_tag, img_path)
|
||||
elif not img_url.startswith("http"):
|
||||
return download_picture(str_img_tag, img_path)
|
||||
else:
|
||||
logging.warning(f"未处理的图片标签:{str_img_tag}")
|
||||
return str_img_tag
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"url":"http://10.1.0.145:8090/graphql",
|
||||
"Authorization":"Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJhcGkiOjEsImdycCI6MSwiaWF0IjoxNzIzNjMxMjcwLCJleHAiOjE4MTgzMDQwNzAsImF1ZCI6InVybjp3aWtpLmpzIiwiaXNzIjoidXJuOndpa2kuanMifQ.g5H1xVMtk7Q3uvrRdtD3aTm49dQkS11cYdDKIwXo7DthOOTGj9DmFO7yILNDU7XFACTZc1Ej6ryguYV_8vGqoc-Rc7LciwvqS_RHDYUKZNKENbv8df9UGDMB-F9DT_airGc1lGJXgVqypxejDL3fY8aRMGXm7GBIlZKY4JTeI2uJZxffgfqKGrOvc3EOtsGgJzKZo4OyQ8UInGtCTiuq6-mLj_Syix_1z52K1tgfnF4E4-rZH_zCD05hUlUMYUV-KWhPkeOEGR5xbRTrulfCvzDD4T0CX4pI-keSKmgVn1HYSSN4o1Tj_l9zsyhUoLRzhzPK29Q3uekIc9obrvCHrg"
|
||||
}
|
||||
Reference in New Issue
Block a user