上传问题改写、意图识别模块代码
This commit is contained in:
@@ -0,0 +1,5 @@
|
|||||||
|
OPENAI_API_KEY=sk-xxaiabmfhzwwpijuledllkmkzhzwsqeicjxmjwnvriqpwmpk
|
||||||
|
OPENAI_API_BASE=https://api.siliconflow.cn/v1/
|
||||||
|
LLM_MODEL_NAME=deepseek-ai/DeepSeek-V3
|
||||||
|
|
||||||
|
RERANKER_MODEL_NAME=bge-reranker-v2-m3
|
||||||
Vendored
+24
@@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
// 使用 IntelliSense 了解相关属性。
|
||||||
|
// 悬停以查看现有属性的描述。
|
||||||
|
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python 调试程序: 当前文件",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "IntentRecognition",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/rag2_0/demo/intent_recognition_example.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
Vendored
+4
@@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"python.analysis.typeCheckingMode": "off",
|
||||||
|
"python.analysis.autoImportCompletions": true
|
||||||
|
}
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Generated
+3035
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,32 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "rag2-0"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = ""
|
||||||
|
authors = ["Your Name <you@example.com>"]
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.11,<3.13"
|
||||||
|
langchain = "^0.3.25"
|
||||||
|
langchain-openai = "^0.3.16"
|
||||||
|
langchain-community = "^0.3.24"
|
||||||
|
python-dotenv = "^1.1.0"
|
||||||
|
pydantic = "^2.11.4"
|
||||||
|
requests = "^2.32.3"
|
||||||
|
faiss-cpu = "^1.11.0"
|
||||||
|
pandas = "^2.2.3"
|
||||||
|
openpyxl = "^3.1.5"
|
||||||
|
bs4 = "^0.0.2"
|
||||||
|
markdownify = "0.13.1"
|
||||||
|
tqdm = "^4.67.1"
|
||||||
|
xlsxwriter = "^3.2.3"
|
||||||
|
flask = "^3.1.1"
|
||||||
|
psycopg2 = "^2.9.10"
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[[tool.poetry.source]]
|
||||||
|
name = "ali-mirrors"
|
||||||
|
url = "http://mirrors.aliyun.com/pypi/simple/"
|
||||||
|
priority = "primary"
|
||||||
@@ -0,0 +1,250 @@
|
|||||||
|
"""
|
||||||
|
提问内容补全工具
|
||||||
|
|
||||||
|
此模块用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||||
|
并将原提问和补全后的提问保存到新的Excel文件中。
|
||||||
|
|
||||||
|
用法示例:
|
||||||
|
completer = QuestionCompleter(input_path="历史提问数据(dislike).xlsx", output_path="补全后的提问数据.xlsx")
|
||||||
|
completer.process()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from langchain.output_parsers import PydanticOutputParser
|
||||||
|
import concurrent.futures
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
class RewriteQuery(BaseModel):
|
||||||
|
rewrite_query:str = Field(description="补全后的提问")
|
||||||
|
software_name:str = Field(description="软件名称")
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
class QuestionCompleter:
|
||||||
|
"""
|
||||||
|
提问内容补全工具类
|
||||||
|
|
||||||
|
用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||||
|
并将原提问和补全后的提问保存到新的Excel文件中。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_path="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
|
||||||
|
output_path="/data/Rag2_0/data/excel/历史提问数据(dislike)_补全后的提问数据.xlsx",
|
||||||
|
question_column="提问", answer_column="回答", max_workers=10):
|
||||||
|
"""
|
||||||
|
初始化提问内容补全工具
|
||||||
|
|
||||||
|
参数:
|
||||||
|
input_path (str): 输入Excel文件路径
|
||||||
|
output_path (str): 输出Excel文件路径
|
||||||
|
question_column (str): 提问列的名称
|
||||||
|
answer_column (str): 回答列的名称
|
||||||
|
max_workers (int): 最大线程数
|
||||||
|
"""
|
||||||
|
self.input_path = input_path
|
||||||
|
self.output_path = output_path
|
||||||
|
self.question_column = question_column
|
||||||
|
self.answer_column = answer_column
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.rewrite_query_parser = PydanticOutputParser(pydantic_object=RewriteQuery)
|
||||||
|
self.lock = Lock() # 添加线程锁
|
||||||
|
|
||||||
|
# 初始化LLM
|
||||||
|
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
self.base_url = os.getenv("OPENAI_API_BASE")
|
||||||
|
self.model = os.getenv("LLM_MODEL_NAME")
|
||||||
|
|
||||||
|
if not all([self.api_key, self.base_url, self.model]):
|
||||||
|
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
|
||||||
|
|
||||||
|
self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model)
|
||||||
|
|
||||||
|
# 读取Excel文件
|
||||||
|
try:
|
||||||
|
self.df = pd.read_excel(self.input_path)
|
||||||
|
print(f"成功读取Excel文件: {self.input_path}")
|
||||||
|
print(f"共有 {len(self.df)} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"读取Excel文件失败: {str(e)}")
|
||||||
|
|
||||||
|
# 检查列是否存在
|
||||||
|
if self.question_column not in self.df.columns:
|
||||||
|
raise ValueError(f"Excel文件中不存在列: {self.question_column}")
|
||||||
|
if self.answer_column not in self.df.columns:
|
||||||
|
raise ValueError(f"Excel文件中不存在列: {self.answer_column}")
|
||||||
|
|
||||||
|
def create_completion_prompt(self, question, answer):
|
||||||
|
"""
|
||||||
|
创建用于补全提问的prompt
|
||||||
|
|
||||||
|
参数:
|
||||||
|
question (str): 原始提问
|
||||||
|
answer (str): 对应的回答
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 格式化的prompt
|
||||||
|
"""
|
||||||
|
prompt = f"""
|
||||||
|
1、判断提问中是否缺少软件名称,如果不缺少,则直接返回原始提问
|
||||||
|
2、如果缺少软件名称,则根据回答中的软件名称,补全提问
|
||||||
|
3、补全后的提问需要保持问题原有意图不变
|
||||||
|
|
||||||
|
4、软件名称包括:
|
||||||
|
配网D3软件(配网工程计价通D3)
|
||||||
|
西藏Z1软件(西藏电力工程计价通Z1)
|
||||||
|
主网计价通软件(电力建设计价通)
|
||||||
|
技改检修工程计价通T1软件(技改检修工程计价通T1)
|
||||||
|
技改检修清单计价通T1软件(技改检修清单计价通T1)
|
||||||
|
储能C1软件(新型储能电站建设计价通C1)
|
||||||
|
如果没有包含上述软件名称,则直接返回原始提问,software_name为空字符串
|
||||||
|
{{
|
||||||
|
"rewrite_query": "xxx",
|
||||||
|
"software_name": ""
|
||||||
|
}}
|
||||||
|
|
||||||
|
原始提问:{question}
|
||||||
|
系统回答:{answer}
|
||||||
|
|
||||||
|
输出格式:
|
||||||
|
{self.rewrite_query_parser.get_format_instructions()}
|
||||||
|
|
||||||
|
示例:
|
||||||
|
例如,如果输入是:
|
||||||
|
提问:这个软件怎么用?
|
||||||
|
回答:Photoshop的使用方法是...
|
||||||
|
|
||||||
|
那么输出会是:
|
||||||
|
{{
|
||||||
|
"rewrite_query": "Photoshop这个软件怎么用?",
|
||||||
|
"software_name": "Photoshop"
|
||||||
|
}}
|
||||||
|
|
||||||
|
或者如果提问已经包含软件名称:
|
||||||
|
提问:Photoshop怎么用?
|
||||||
|
回答:Photoshop的使用方法是...
|
||||||
|
|
||||||
|
那么输出会是:
|
||||||
|
{{
|
||||||
|
"rewrite_query": "Photoshop怎么用?",
|
||||||
|
"software_name": "Photoshop"
|
||||||
|
}}
|
||||||
|
|
||||||
|
"""
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def complete_question(self, question, answer):
|
||||||
|
"""
|
||||||
|
调用LLM补全提问内容
|
||||||
|
|
||||||
|
参数:
|
||||||
|
question (str): 原始提问
|
||||||
|
answer (str): 对应的回答
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 补全后的提问,如果补全失败则返回原始提问
|
||||||
|
"""
|
||||||
|
# 如果提问或回答为空,直接返回原始提问
|
||||||
|
if pd.isna(question) or question.strip() == "" or pd.isna(answer) or answer.strip() == "":
|
||||||
|
return question, ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt = self.create_completion_prompt(question, answer)
|
||||||
|
response = self.llm.invoke(prompt)
|
||||||
|
completed_question = self.rewrite_query_parser.parse(response.content)
|
||||||
|
return completed_question.rewrite_query, completed_question.software_name
|
||||||
|
except Exception as e:
|
||||||
|
print(f"补全提问失败: {str(e)}")
|
||||||
|
return question, ""
|
||||||
|
|
||||||
|
def process_row(self, row):
|
||||||
|
"""
|
||||||
|
处理单行数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
row: DataFrame中的一行
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict: 处理结果
|
||||||
|
"""
|
||||||
|
original_question = row[self.question_column]
|
||||||
|
answer = row[self.answer_column]
|
||||||
|
|
||||||
|
# 调用LLM补全提问
|
||||||
|
completed_question, software_name = self.complete_question(original_question, answer)
|
||||||
|
|
||||||
|
# 创建结果字典
|
||||||
|
result = {
|
||||||
|
"原始提问": original_question,
|
||||||
|
"补全后的提问": completed_question,
|
||||||
|
"软件名称": software_name
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
"""
|
||||||
|
使用多线程处理所有提问并补全内容
|
||||||
|
|
||||||
|
读取Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||||
|
并将原提问和补全后的提问保存到新的Excel文件中
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
total = len(self.df)
|
||||||
|
|
||||||
|
# 使用进度条显示总体进度
|
||||||
|
with tqdm(total=total, desc="补全提问") as pbar:
|
||||||
|
# 创建线程池
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||||
|
# 提交所有任务
|
||||||
|
future_to_idx = {executor.submit(self.process_row, self.df.iloc[idx]): idx for idx in range(total)}
|
||||||
|
|
||||||
|
# 处理完成的任务
|
||||||
|
for future in concurrent.futures.as_completed(future_to_idx):
|
||||||
|
result = future.result()
|
||||||
|
with self.lock:
|
||||||
|
results.append(result)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# 将结果转换为DataFrame并保存
|
||||||
|
results_df = pd.DataFrame(results)
|
||||||
|
results_df.to_excel(self.output_path, index=False)
|
||||||
|
print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='补全Excel文件中的提问内容')
|
||||||
|
parser.add_argument('-i', '--input', type=str, default="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
|
||||||
|
help='输入Excel文件路径')
|
||||||
|
parser.add_argument('-o', '--output', type=str, default="/data/Rag2_0/data/excel/补全后的提问数据.xlsx",
|
||||||
|
help='输出Excel文件路径')
|
||||||
|
parser.add_argument('-q', '--question', type=str, default="提问",
|
||||||
|
help='提问列的名称')
|
||||||
|
parser.add_argument('-a', '--answer', type=str, default="回答",
|
||||||
|
help='回答列的名称')
|
||||||
|
parser.add_argument('-w', '--workers', type=int, default=50,
|
||||||
|
help='最大线程数')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 创建提问补全工具实例
|
||||||
|
completer = QuestionCompleter(
|
||||||
|
input_path=args.input,
|
||||||
|
output_path=args.output,
|
||||||
|
question_column=args.question,
|
||||||
|
answer_column=args.answer,
|
||||||
|
max_workers=args.workers
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行处理
|
||||||
|
completer.process()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,282 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
File: extract_wikijs_nouns.py
|
||||||
|
Author: oyyz
|
||||||
|
Description: 从 Wikijs 文档中提取专业名词
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from langchain.output_parsers import PydanticOutputParser
|
||||||
|
from rag2_0.tool.WikijsTool import WikijsTool
|
||||||
|
from rag2_0.intent_recognition.DataModels import Term, TermList
|
||||||
|
from rag2_0.tool.html_to_md import convert_html_to_md
|
||||||
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import concurrent.futures
|
||||||
|
from threading import Semaphore
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
extract_wiki_nouns_prompt="""
|
||||||
|
我在完善我的专业词库,请从提供的电力行业造价软件相关文本中提取关键词,要求如下:
|
||||||
|
|
||||||
|
一、提取范围
|
||||||
|
1. 核心功能模块
|
||||||
|
(例:多工程批量计价、材机数据反算、变电工程智能组价、架空线路地形系数计算)
|
||||||
|
2、软件功能及界面名称(包括:界面页签、功能按钮、功能名称等)
|
||||||
|
(例:新建工程量清单、导出工程量清单等)
|
||||||
|
3. 业务专用术语
|
||||||
|
(例:装置性材料、甲供材保管费、施工降效补偿、电缆头试验配套费)
|
||||||
|
4. 计价标准体系
|
||||||
|
(例:预规2020版、电网检修定额2015版、配网工程概算定额)
|
||||||
|
|
||||||
|
|
||||||
|
二、提取规则
|
||||||
|
1. 识别核心功能名称(如"多工程批量设置工程量、工程设置密码")
|
||||||
|
2. 提取业务专用名词(如"主材卸车保管费")
|
||||||
|
3. 标注关联术语的对应关系(如"市场价"与"市场价格"互为同义词)
|
||||||
|
4. 包含定额标准相关术语(如"预规2020版")
|
||||||
|
5. 复合型术语需保持完整
|
||||||
|
√ 正确:"地形增加系数批量设置"
|
||||||
|
× 错误:"地形"、"系数"、"设置"
|
||||||
|
6. 总结生成关键词解释
|
||||||
|
关键词:编制依据
|
||||||
|
描述:造价文件编制基准规范
|
||||||
|
|
||||||
|
7. 软件的特定版本号不作为关键词
|
||||||
|
|
||||||
|
三、输出格式:
|
||||||
|
{output_format}
|
||||||
|
|
||||||
|
四、输入内容:
|
||||||
|
{content}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class WikijsNounsExtractor:
|
||||||
|
"""从 Wikijs 文档中提取专业名词"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo"):
|
||||||
|
"""
|
||||||
|
初始化专业名词提取器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API密钥,如果为None则从环境变量获取
|
||||||
|
base_url: API基础URL,如果为None则使用默认URL
|
||||||
|
model_name: 要使用的模型名称
|
||||||
|
"""
|
||||||
|
# 保存参数
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
# 初始化LLM
|
||||||
|
llm_params = {
|
||||||
|
"temperature": 0.6,
|
||||||
|
"model": model_name
|
||||||
|
}
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
llm_params["api_key"] = api_key
|
||||||
|
|
||||||
|
if base_url:
|
||||||
|
llm_params["base_url"] = base_url
|
||||||
|
|
||||||
|
self.llm = OpenAiLLM(**llm_params)
|
||||||
|
|
||||||
|
# 准备术语列表解析器
|
||||||
|
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
||||||
|
|
||||||
|
# 信号量,限制并发请求数量
|
||||||
|
self.semaphore = None
|
||||||
|
|
||||||
|
# 线程锁,用于保护共享资源
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def _convert_html_to_md(self, content, title):
|
||||||
|
"""HTML转Markdown"""
|
||||||
|
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
|
||||||
|
new_content = (content.replace("h6>", "h7>")
|
||||||
|
.replace("h5>", "h6>")
|
||||||
|
.replace("h4>", "h5>")
|
||||||
|
.replace("h3>", "h4>")
|
||||||
|
.replace("h2>", "h3>")
|
||||||
|
.replace("h1>", "h2>"))
|
||||||
|
# 将HTML内容转换为Markdown
|
||||||
|
markdown_content = convert_html_to_md(new_content, "", **options)
|
||||||
|
markdown_content = f"# {title}\n\n{markdown_content}"
|
||||||
|
return markdown_content
|
||||||
|
|
||||||
|
def extract_from_document(self, doc_info: dict) -> List[Term]:
|
||||||
|
"""从单个文档中提取专业名词"""
|
||||||
|
try:
|
||||||
|
# 使用LLM调用处理文档
|
||||||
|
content = doc_info['content']
|
||||||
|
title = doc_info["title"]
|
||||||
|
|
||||||
|
# 转换HTML到Markdown
|
||||||
|
markdown_content = self._convert_html_to_md(content, title)
|
||||||
|
|
||||||
|
# 准备提示词
|
||||||
|
formatted_prompt = extract_wiki_nouns_prompt.replace("{content}", markdown_content)
|
||||||
|
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用LLM
|
||||||
|
response = self.llm.invoke(formatted_prompt)
|
||||||
|
# 使用Pydantic解析器解析结果
|
||||||
|
parsed_output = self.terms_list_parser.parse(response.content)
|
||||||
|
return parsed_output.terms
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"解析LLM响应时出错: {str(e)}")
|
||||||
|
logging.error(f"原始响应: {response.content}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"提取专业名词时出错: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _process_document(self, doc, path_terms):
|
||||||
|
"""处理单个文档"""
|
||||||
|
try:
|
||||||
|
# 获取信号量
|
||||||
|
with self.semaphore:
|
||||||
|
# 检查文档路径是否在我们要处理的路径中
|
||||||
|
path_prefix = None
|
||||||
|
for prefix in path_terms.keys():
|
||||||
|
if doc['path'].startswith(prefix):
|
||||||
|
path_prefix = prefix
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果不在要处理的路径中,则跳过
|
||||||
|
if not path_prefix:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取文档详细信息
|
||||||
|
doc_info = WikijsTool.query_doc_info(doc['id'])
|
||||||
|
if not doc_info or not doc_info.get('content'):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 提取专业名词
|
||||||
|
terms = self.extract_from_document(doc_info)
|
||||||
|
|
||||||
|
# 将提取的术语添加到对应路径的结果列表中
|
||||||
|
terms_dicts = [{"name": term.name, "synonymous": term.synonymous, "description": term.description} for term in terms]
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
path_terms[path_prefix].extend(terms_dicts)
|
||||||
|
logging.info(f"文档 {doc['path']} 处理完成,提取了 {len(terms)} 个专业名词")
|
||||||
|
|
||||||
|
# 每处理10个文档保存一次中间结果
|
||||||
|
current_count = len(path_terms[path_prefix])
|
||||||
|
if current_count % 10 == 0:
|
||||||
|
# 使用锁保护文件IO
|
||||||
|
self._save_terms_to_file(path_terms[path_prefix], os.path.join(self.output_dir, f"{path_prefix.split('(')[0]}_nouns.json"))
|
||||||
|
logging.info(f"已处理 {path_prefix} 的文档数达到 {current_count//10*10} 个,已保存中间结果")
|
||||||
|
|
||||||
|
return path_prefix
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process_all_documents(self, output_dir: str = "extracted_nouns", max_concurrency: int = 5):
|
||||||
|
"""使用线程池处理所有文档"""
|
||||||
|
# 保存输出目录
|
||||||
|
self.output_dir = output_dir
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
|
# 初始化信号量,限制并发请求数
|
||||||
|
self.semaphore = Semaphore(max_concurrency)
|
||||||
|
|
||||||
|
# 获取所有文档
|
||||||
|
all_docs = WikijsTool.get_all_documents()
|
||||||
|
|
||||||
|
# 要处理的路径前缀
|
||||||
|
# path_prefixes = [
|
||||||
|
# "技改检修计价通(2020)",
|
||||||
|
# "西藏造价软件(2023)",
|
||||||
|
# "新型储能电站建设计价通C1(2024)",
|
||||||
|
# "配网造价软件(2022)",
|
||||||
|
# ]
|
||||||
|
path_prefixes = [
|
||||||
|
"主网电力建设计价通(2018)",
|
||||||
|
]
|
||||||
|
# 为每个路径创建单独的结果列表
|
||||||
|
path_terms = {prefix: [] for prefix in path_prefixes}
|
||||||
|
|
||||||
|
# 过滤出符合路径前缀的文档
|
||||||
|
filtered_docs = []
|
||||||
|
for doc in all_docs:
|
||||||
|
for prefix in path_prefixes:
|
||||||
|
if doc['path'].startswith(prefix):
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
break
|
||||||
|
|
||||||
|
logging.info(f"开始使用线程池处理 {len(filtered_docs)} 个文档...")
|
||||||
|
|
||||||
|
# 使用线程池处理所有文档
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||||
|
futures = []
|
||||||
|
for doc in filtered_docs:
|
||||||
|
future = executor.submit(self._process_document, doc, path_terms)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
# 等待所有任务完成
|
||||||
|
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||||
|
try:
|
||||||
|
prefix = future.result()
|
||||||
|
if i % 10 == 0:
|
||||||
|
logging.info(f"已完成 {i+1}/{len(futures)} 个文档的处理")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"处理文档时出错: {str(e)}")
|
||||||
|
|
||||||
|
# 保存最终结果
|
||||||
|
for prefix, terms in path_terms.items():
|
||||||
|
# 为每个路径保存单独的文件
|
||||||
|
output_file = os.path.join(output_dir, f"{prefix.split('(')[0]}_nouns.json")
|
||||||
|
self._save_terms_to_file(terms, output_file)
|
||||||
|
logging.info(f"{prefix} 处理完成,共提取 {len(terms)} 个专业名词,已保存到 {output_file}")
|
||||||
|
|
||||||
|
def _save_terms_to_file(self, terms, output_file):
|
||||||
|
"""保存术语列表到文件"""
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(terms, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 从环境变量获取配置
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
base_url = os.getenv("OPENAI_API_BASE")
|
||||||
|
|
||||||
|
# os.environ["LLM_MODEL_NAME"] = "Qwen/Qwen2.5-72B-Instruct-128K"
|
||||||
|
|
||||||
|
extractor = WikijsNounsExtractor(api_key=api_key, base_url=base_url, model_name=os.getenv("LLM_MODEL_NAME"))
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
output_dir = os.path.join(current_dir, "..", "..", "data", "wiki_extracted_nouns")
|
||||||
|
extractor.process_all_documents(output_dir=output_dir, max_concurrency=2)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 配置日志输出到文件,并设置格式
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
date_format = '%Y-%m-%d %H:%M:%S'
|
||||||
|
|
||||||
|
# 创建一个控制台处理器
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
console_handler.setFormatter(logging.Formatter(log_format, date_format))
|
||||||
|
|
||||||
|
# 获取根日志记录器并添加处理器
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.setLevel(logging.INFO)
|
||||||
|
root_logger.addHandler(console_handler)
|
||||||
|
main()
|
||||||
@@ -0,0 +1,189 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
File: intent_recognition_example.py
|
||||||
|
Date: 2025-05-14
|
||||||
|
Description: 意图识别和问题改写示例
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from rag2_0.intent_recognition import IntentRecognizer
|
||||||
|
import pandas as pd
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import concurrent.futures
|
||||||
|
from tqdm import tqdm
|
||||||
|
import time
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
# 读取Excel文件中的提问数据
|
||||||
|
def load_questions_from_excel(file_path=None):
|
||||||
|
"""
|
||||||
|
从Excel文件中读取提问数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Excel文件路径,如果为None则使用默认路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提问列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 读取Excel文件的第一列数据
|
||||||
|
df = pd.read_excel(file_path)
|
||||||
|
questions = df.iloc[:, 0].tolist() # 获取第一列数据
|
||||||
|
logging.info(f"成功从{file_path}读取了{len(questions)}条提问")
|
||||||
|
return questions
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"读取Excel文件时出错: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def process_query(recognizer, query):
|
||||||
|
"""
|
||||||
|
处理单个查询,支持重试机制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recognizer: 意图识别器实例
|
||||||
|
query: 查询字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果字典
|
||||||
|
"""
|
||||||
|
max_retries = 3
|
||||||
|
retry_count = 0
|
||||||
|
|
||||||
|
while retry_count <= max_retries:
|
||||||
|
try:
|
||||||
|
# 如果是重试,添加重试信息到日志
|
||||||
|
classification, keywords, rewrite, query_keys = recognizer.process_query(query)
|
||||||
|
|
||||||
|
# 将keywords对象转换为字符串
|
||||||
|
keywords_str = ""
|
||||||
|
if keywords and keywords.terms:
|
||||||
|
term_details = []
|
||||||
|
for term in keywords.terms:
|
||||||
|
term_info = {
|
||||||
|
"名称": term.name,
|
||||||
|
"同义词": ";".join(term.synonymous) if term.synonymous else "",
|
||||||
|
"描述": term.description
|
||||||
|
}
|
||||||
|
term_details.append(term_info)
|
||||||
|
|
||||||
|
# 将term_details转换为JSON字符串,确保中文正确显示
|
||||||
|
keywords_str = json.dumps(term_details, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# 处理成功,返回结果
|
||||||
|
return {
|
||||||
|
"提问": query,
|
||||||
|
"问题拆解": query_keys,
|
||||||
|
"一级分类": classification.vertical_classification,
|
||||||
|
"二级分类": classification.sub_classification,
|
||||||
|
"问题改写": rewrite.rewrite,
|
||||||
|
"检索的关键词": keywords_str
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
# 如果已经重试了最大次数,则记录错误并返回错误结果
|
||||||
|
if retry_count > max_retries:
|
||||||
|
logging.error(f"处理问题 '{query}' 时出错: {e.__class__}{e}")
|
||||||
|
return {
|
||||||
|
"提问": query,
|
||||||
|
"一级分类": "处理出错",
|
||||||
|
"二级分类": "处理出错",
|
||||||
|
"问题改写": "处理出错",
|
||||||
|
"检索的关键词": f"重试 {max_retries} 次后失败: {str(e)}"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 可以在这里添加延迟,避免过快重试
|
||||||
|
time.sleep(10 * retry_count)
|
||||||
|
|
||||||
|
examples_query = """下载软件在哪下载?"""
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
意图识别和问题改写示例
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 从环境变量中获取配置
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
base_url = os.getenv("OPENAI_API_BASE")
|
||||||
|
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||||
|
|
||||||
|
# 初始化意图识别器
|
||||||
|
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||||
|
|
||||||
|
# 读取提问数据
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
data_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条提问数据.xlsx")
|
||||||
|
examples = load_questions_from_excel(data_file)
|
||||||
|
# examples = examples_query.split("\n")
|
||||||
|
max_workers = 20
|
||||||
|
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
|
||||||
|
|
||||||
|
# 创建一个与输入顺序相同的结果列表
|
||||||
|
results = [None] * len(examples)
|
||||||
|
|
||||||
|
# 使用线程池进行并发处理
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# 提交所有任务并记录它们的索引
|
||||||
|
future_to_index = {}
|
||||||
|
for idx, query in enumerate(examples):
|
||||||
|
future = executor.submit(process_query, recognizer, query)
|
||||||
|
future_to_index[future] = idx
|
||||||
|
|
||||||
|
# 使用tqdm显示进度条
|
||||||
|
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(examples), desc="处理进度"):
|
||||||
|
idx = future_to_index[future]
|
||||||
|
result = future.result()
|
||||||
|
# 将结果放在与输入相同的位置
|
||||||
|
results[idx] = result
|
||||||
|
|
||||||
|
# 将结果保存到Excel文件
|
||||||
|
results_df = pd.DataFrame(results)
|
||||||
|
|
||||||
|
output_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条提问数据_重写结果.xlsx")
|
||||||
|
|
||||||
|
# 使用ExcelWriter设置格式
|
||||||
|
with pd.ExcelWriter(output_file, engine='xlsxwriter') as writer:
|
||||||
|
results_df.to_excel(writer, index=False, sheet_name='Sheet1')
|
||||||
|
|
||||||
|
# 获取工作簿和工作表对象
|
||||||
|
workbook = writer.book
|
||||||
|
worksheet = writer.sheets['Sheet1']
|
||||||
|
|
||||||
|
# 设置列宽(单位:像素)
|
||||||
|
# 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位)
|
||||||
|
worksheet.set_column('A:A', 60) # 提问列 60个Excel单位
|
||||||
|
worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位
|
||||||
|
worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位
|
||||||
|
worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位
|
||||||
|
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
|
||||||
|
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
|
||||||
|
|
||||||
|
# 设置所有行高为20磅
|
||||||
|
for i in range(len(results_df) + 1): # +1 是为了包括表头
|
||||||
|
worksheet.set_row(i, 20)
|
||||||
|
|
||||||
|
logging.info(f"处理完成,结果已保存至: {output_file}")
|
||||||
|
|
||||||
|
def setup_logging():
|
||||||
|
# 配置日志输出到控制台
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler() # 添加控制台处理器
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||||
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
setup_logging()
|
||||||
|
logging.info("意图识别示例程序开始运行...")
|
||||||
|
main()
|
||||||
@@ -0,0 +1,293 @@
|
|||||||
|
"""
|
||||||
|
答案正确性评判工具
|
||||||
|
|
||||||
|
此模块用于评判问题的新旧回答是否正确,通过与标准答案(Wiki内容)进行比较,
|
||||||
|
或者在没有标准答案的情况下比较新旧回答的差异。
|
||||||
|
|
||||||
|
用法示例:
|
||||||
|
judge = AnswerCorrectnessJudge()
|
||||||
|
judge.process()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from urllib.parse import unquote
|
||||||
|
from rag2_0.tool.WikijsTool import WikijsTool
|
||||||
|
from rag2_0.tool.html_to_md import convert_html_to_md
|
||||||
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
class AnswerCorrectnessJudge:
|
||||||
|
"""
|
||||||
|
答案正确性评判工具类
|
||||||
|
|
||||||
|
用于评估问题的新旧回答是否正确,可以通过与标准答案(Wiki内容)进行比较,
|
||||||
|
或者在没有标准答案的情况下比较新旧回答的差异。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, wiki_excel_path="/data/Rag2_0/data/excel/部分提问_软件名称明确.xlsx",
|
||||||
|
answer_excel_path="/data/Rag2_0/data/excel/主网软件提问_对比结果.xlsx",
|
||||||
|
output_path="/data/Rag2_0/data/excel/主网软件提问回答_判断结果.xlsx"):
|
||||||
|
"""
|
||||||
|
初始化答案正确性评判工具
|
||||||
|
|
||||||
|
参数:
|
||||||
|
wiki_excel_path (str): Wiki Excel文件路径
|
||||||
|
answer_excel_path (str): 答案对比Excel文件路径
|
||||||
|
output_path (str): 输出Excel文件路径
|
||||||
|
"""
|
||||||
|
self.wiki_excel_path = wiki_excel_path
|
||||||
|
self.answer_excel_path = answer_excel_path
|
||||||
|
self.output_path = output_path
|
||||||
|
|
||||||
|
# 读取Excel文件
|
||||||
|
self.wiki_excel = pd.read_excel(self.wiki_excel_path)
|
||||||
|
self.answer_excel = pd.read_excel(self.answer_excel_path)
|
||||||
|
|
||||||
|
# 初始化LLM
|
||||||
|
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
self.base_url = os.getenv("OPENAI_API_BASE")
|
||||||
|
self.model = os.getenv("LLM_MODEL_NAME")
|
||||||
|
|
||||||
|
if not all([self.api_key, self.base_url, self.model]):
|
||||||
|
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
|
||||||
|
|
||||||
|
self.openai_llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model)
|
||||||
|
|
||||||
|
def find_wiki_link(self, query) -> str | None:
|
||||||
|
"""
|
||||||
|
根据查询(对应wiki_excel中的新提问列)找出对应的词条链接
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query (str): 查询内容,对应wiki_excel中的新提问列
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 对应的词条链接,如果没有找到则返回None
|
||||||
|
"""
|
||||||
|
# 确保query不为空
|
||||||
|
if not query or pd.isna(query):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 在"新提问"列中查找匹配的行
|
||||||
|
matched_rows = self.wiki_excel[self.wiki_excel['新提问'] == query]
|
||||||
|
|
||||||
|
# 如果找到了匹配的行,返回对应的词条链接
|
||||||
|
if not matched_rows.empty:
|
||||||
|
return matched_rows.iloc[0]['对应词条链接']
|
||||||
|
|
||||||
|
# 如果没有完全匹配,尝试部分匹配
|
||||||
|
# 去除软件名称部分(如果有)
|
||||||
|
query_parts = query.split(',', 1)
|
||||||
|
if len(query_parts) > 1:
|
||||||
|
clean_query = query_parts[1].strip()
|
||||||
|
|
||||||
|
# 在"提问"列中查找包含清理后查询的行
|
||||||
|
for idx, row in self.wiki_excel.iterrows():
|
||||||
|
if pd.notna(row['提问']) and clean_query in row['提问']:
|
||||||
|
return row['对应词条链接']
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_wiki_content(self, link) -> str:
|
||||||
|
"""
|
||||||
|
获取词条链接的内容
|
||||||
|
|
||||||
|
参数:
|
||||||
|
link (str): 词条链接
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 链接内容,如果获取失败则返回错误信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not link or pd.isna(link):
|
||||||
|
return "链接为空或无效"
|
||||||
|
# 移除域名部分,只保留路径
|
||||||
|
path = link.split('/', 3)[-1]
|
||||||
|
decoded_path = unquote(path)
|
||||||
|
path_parts = decoded_path.split('/')
|
||||||
|
doc_path = "/".join(path_parts[1:])
|
||||||
|
wiki_doc = WikijsTool.get_all_doc_by_path(path=doc_path, path_is_dir=False)
|
||||||
|
html_content = WikijsTool.query_doc_info(wiki_doc[0]["id"]).get('content')
|
||||||
|
if not html_content:
|
||||||
|
return "获取内容失败"
|
||||||
|
|
||||||
|
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
|
||||||
|
new_content = (html_content.replace("h6>", "h7>")
|
||||||
|
.replace("h5>", "h6>")
|
||||||
|
.replace("h4>", "h5>")
|
||||||
|
.replace("h3>", "h4>")
|
||||||
|
.replace("h2>", "h3>")
|
||||||
|
.replace("h1>", "h2>"))
|
||||||
|
# 将HTML内容转换为Markdown
|
||||||
|
markdown_content = convert_html_to_md(new_content, "", **options)
|
||||||
|
markdown_content = f"# {path_parts[-1]}\n\n{markdown_content}"
|
||||||
|
return markdown_content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
|
||||||
|
|
||||||
|
def create_prompt(self, standard_answer: str, answer_to_check: str) -> str:
|
||||||
|
"""
|
||||||
|
创建用于评判答案的prompt
|
||||||
|
|
||||||
|
参数:
|
||||||
|
standard_answer (str): 标准答案
|
||||||
|
answer_to_check (str): 需要检查的答案
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 格式化的prompt
|
||||||
|
"""
|
||||||
|
return f"""请作为一个专业的答案评判专家,评估以下回答与标准答案的匹配程度。
|
||||||
|
|
||||||
|
标准答案:
|
||||||
|
{standard_answer}
|
||||||
|
|
||||||
|
待评估的回答:
|
||||||
|
{answer_to_check}
|
||||||
|
|
||||||
|
请仔细分析两个答案的内容,并给出你的判断。只需要回答"正确"或"错误",不需要其他解释。
|
||||||
|
如果待评估的回答与标准答案在核心内容和关键信息(步骤)上一致,即使表达方式不同,也应判定为"正确"。
|
||||||
|
如果待评估的回答存在明显的错误信息或重要信息缺失,应判定为"错误"。
|
||||||
|
|
||||||
|
请严格按以下格式输出:【正确】或【错误】:"""
|
||||||
|
|
||||||
|
def judge_old_answer(self, standard_answer: str, old_answer: str) -> bool | None:
|
||||||
|
"""
|
||||||
|
调用LLM判断旧回答是否正确
|
||||||
|
|
||||||
|
参数:
|
||||||
|
standard_answer (str): 标准答案(来自Wiki)
|
||||||
|
old_answer (str): 旧流程的回答
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool | None: 判断结果,True表示正确,False表示错误,None表示判断失败
|
||||||
|
"""
|
||||||
|
prompt = self.create_prompt(standard_answer, old_answer)
|
||||||
|
try:
|
||||||
|
response = self.openai_llm.invoke(prompt)
|
||||||
|
return "正确" in response.content
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def judge_new_answer(self, standard_answer: str, new_answer: str) -> bool | None:
|
||||||
|
"""
|
||||||
|
调用LLM判断新回答是否正确
|
||||||
|
|
||||||
|
参数:
|
||||||
|
standard_answer (str): 标准答案(来自Wiki)
|
||||||
|
new_answer (str): 新流程的回答
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool | None: 判断结果,True表示正确,False表示错误,None表示判断失败
|
||||||
|
"""
|
||||||
|
prompt = self.create_prompt(standard_answer, new_answer)
|
||||||
|
try:
|
||||||
|
response = self.openai_llm.invoke(prompt)
|
||||||
|
return "正确" in response.content
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def judge_by_standard_answer(self, standard_answer: str, old_answer: str, new_answer: str) -> str | None:
|
||||||
|
"""
|
||||||
|
综合判断新旧回答的正确性
|
||||||
|
|
||||||
|
参数:
|
||||||
|
standard_answer (str): 标准答案(来自Wiki)
|
||||||
|
old_answer (str): 旧流程的回答
|
||||||
|
new_answer (str): 新流程的回答
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None: 包含新旧回答判断结果的字符串,None表示判断失败
|
||||||
|
"""
|
||||||
|
old_result = self.judge_old_answer(standard_answer, old_answer)
|
||||||
|
new_result = self.judge_new_answer(standard_answer, new_answer)
|
||||||
|
if old_result is None or new_result is None:
|
||||||
|
return None
|
||||||
|
if new_result and old_result:
|
||||||
|
return "新旧答案均正确"
|
||||||
|
elif new_result and not old_result:
|
||||||
|
return "新答案正确"
|
||||||
|
elif not new_result and old_result:
|
||||||
|
return "旧答案正确"
|
||||||
|
else:
|
||||||
|
return "新旧答案均错误"
|
||||||
|
|
||||||
|
def judge_answer_diff(self, old_answer: str, new_answer: str) -> str | None:
|
||||||
|
"""
|
||||||
|
判断新旧回答是否存在较大差异
|
||||||
|
|
||||||
|
参数:
|
||||||
|
old_answer (str): 旧流程的回答
|
||||||
|
new_answer (str): 新流程的回答
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None: 差异判断结果,None表示判断失败
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt = f"""请判断以下两个回答是否存在较大差异:
|
||||||
|
|
||||||
|
旧回答: {old_answer}
|
||||||
|
|
||||||
|
新回答: {new_answer}
|
||||||
|
|
||||||
|
主要是关键步骤、关键信息、或者关键主体的差异
|
||||||
|
请仅回答"存在较大差异"或"差异较小"。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.openai_llm.invoke(prompt)
|
||||||
|
return "无法判断,新老答案差异较大" if "存在较大差异" in response.content else "无法判断,新老答案基本相同"
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
"""
|
||||||
|
处理所有问题并评判答案正确性
|
||||||
|
|
||||||
|
读取Excel文件中的问题和答案,进行评判,并将结果保存到输出Excel文件
|
||||||
|
"""
|
||||||
|
# 创建结果列表
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# 读取Excel文件
|
||||||
|
for idx, row in tqdm(self.answer_excel.iterrows(), total=len(self.answer_excel), desc="处理问题"):
|
||||||
|
query = row["问题"]
|
||||||
|
old_answer = row["旧流程答案"]
|
||||||
|
new_answer = row["新流程答案"]
|
||||||
|
standard_answer = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
wiki_url = self.find_wiki_link(query)
|
||||||
|
if wiki_url and not pd.isna(wiki_url):
|
||||||
|
standard_answer = self.get_wiki_content(wiki_url)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理问题 '{query}' 时发生错误: {str(e)}")
|
||||||
|
|
||||||
|
if standard_answer:
|
||||||
|
# 判断答案正确性
|
||||||
|
judge_result = self.judge_by_standard_answer(standard_answer, old_answer, new_answer)
|
||||||
|
else:
|
||||||
|
judge_result = self.judge_answer_diff(old_answer, new_answer)
|
||||||
|
|
||||||
|
if judge_result is None:
|
||||||
|
judge_result = ""
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"问题": query,
|
||||||
|
"旧流程答案": old_answer,
|
||||||
|
"新流程答案": new_answer,
|
||||||
|
"判断结果": judge_result
|
||||||
|
})
|
||||||
|
|
||||||
|
# 将结果转换为DataFrame并保存
|
||||||
|
results_df = pd.DataFrame(results)
|
||||||
|
results_df.to_excel(self.output_path, index=False)
|
||||||
|
print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}")
|
||||||
|
|
||||||
|
# 测试函数
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 创建答案正确性评判工具实例
|
||||||
|
judge = AnswerCorrectnessJudge()
|
||||||
|
# 执行处理
|
||||||
|
judge.process()
|
||||||
@@ -0,0 +1,615 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
完整性问题判断工具
|
||||||
|
|
||||||
|
此脚本用于读取Excel文件中的问题,调用LLM判断问题是否完整,并将结果保存到Excel文件中。
|
||||||
|
|
||||||
|
用法示例:
|
||||||
|
python judge_query_full.py -i "问题数据.xlsx" -o "完整问题结果.xlsx" -w 50 -c 0
|
||||||
|
|
||||||
|
命令行参数:
|
||||||
|
-i, --input: 输入Excel文件路径
|
||||||
|
-o, --output: 输出Excel文件路径
|
||||||
|
-w, --workers: 并发处理的最大线程数
|
||||||
|
-c, --column: 要处理的问题所在列的索引(从0开始)
|
||||||
|
-t, --test: 测试单个问题,不处理Excel文件
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||||
|
from rag2_0.tool.APIKeyManager import APIKeyManager
|
||||||
|
from openpyxl.utils import get_column_letter
|
||||||
|
from openpyxl.styles import Alignment, PatternFill, Font, Border, Side
|
||||||
|
from tqdm import tqdm
|
||||||
|
import concurrent.futures
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# 默认设置
|
||||||
|
DEFAULT_EXCEL_PATH = r"/data/Rag2_0/data/excel/7000条对话数据.xlsx"
|
||||||
|
DEFAULT_OUTPUT_PATH = r"/data/Rag2_0/data/excel/7000条对话数据_完整问题结果.xlsx"
|
||||||
|
DEFAULT_MAX_WORKERS = 50
|
||||||
|
|
||||||
|
|
||||||
|
class QueryCompletenessJudge:
|
||||||
|
"""
|
||||||
|
问题完整性判断工具类
|
||||||
|
|
||||||
|
用于评估问题是否完整,并将结果保存到Excel文件中。
|
||||||
|
可以批量处理Excel文件中的问题,也可以测试单个问题。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_path=DEFAULT_EXCEL_PATH, output_path=DEFAULT_OUTPUT_PATH,
|
||||||
|
max_workers=DEFAULT_MAX_WORKERS, column_index=0):
|
||||||
|
"""
|
||||||
|
初始化问题完整性判断工具
|
||||||
|
|
||||||
|
参数:
|
||||||
|
input_path (str): 输入Excel文件路径
|
||||||
|
output_path (str): 输出Excel文件路径
|
||||||
|
max_workers (int): 并发处理的最大线程数
|
||||||
|
column_index (int): 要处理的问题所在列的索引(从0开始)
|
||||||
|
"""
|
||||||
|
self.input_path = input_path
|
||||||
|
self.output_path = output_path
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.column_index = column_index
|
||||||
|
self.llm_client = self._create_llm_client()
|
||||||
|
|
||||||
|
def _extract_json_from_response(self, full_answer):
|
||||||
|
"""
|
||||||
|
从LLM响应中提取JSON部分
|
||||||
|
|
||||||
|
参数:
|
||||||
|
full_answer (str): LLM的完整响应文本
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict: 解析后的JSON对象,如果解析失败则返回None
|
||||||
|
"""
|
||||||
|
# 尝试从回答中提取JSON部分
|
||||||
|
json_match = re.search(r'```json\s*(.*?)\s*```', full_answer, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
json_str = json_match.group(1)
|
||||||
|
else:
|
||||||
|
# 如果没有找到```json```格式,尝试寻找普通的JSON对象
|
||||||
|
json_match = re.search(r'({[\s\S]*"is_complete"[\s\S]*})', full_answer)
|
||||||
|
if json_match:
|
||||||
|
json_str = json_match.group(1)
|
||||||
|
else:
|
||||||
|
# 如果仍然没有找到,返回None
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析JSON
|
||||||
|
return json.loads(json_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_llm_prompt(self, question):
|
||||||
|
"""
|
||||||
|
创建LLM提示词
|
||||||
|
|
||||||
|
参数:
|
||||||
|
question (str): 需要判断完整性的问题
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 格式化后的提示词
|
||||||
|
"""
|
||||||
|
return f"""你是一个电力造价行业专家,用户正在使用电力造价软件,并提出了相关问题。请分析以下问题是否完整。
|
||||||
|
|
||||||
|
问题:{question}
|
||||||
|
|
||||||
|
首先,分析这个问题的结构和内容,思考它是否包含足够的信息来表达清晰的意图。
|
||||||
|
考虑以下几点:
|
||||||
|
1. 问题是否有明确的核心意图,不需要面面俱到
|
||||||
|
2. 问题是否缺少必要的上下文
|
||||||
|
3. **问题如果涉及软件相关,则只需要包含:软件名称、软件功能或软件目的即可**
|
||||||
|
|
||||||
|
|
||||||
|
在你的分析之后,请用JSON格式给出最终结论,格式如下:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"is_complete": true或false,
|
||||||
|
"reason": "判断原因的简要说明",
|
||||||
|
"confidence": 0到100之间的数值,表示你对判断的置信度
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
请确保JSON格式正确,以便于程序解析。"""
|
||||||
|
|
||||||
|
def _create_llm_client(self, api_key=None):
|
||||||
|
"""
|
||||||
|
创建LLM客户端
|
||||||
|
|
||||||
|
参数:
|
||||||
|
api_key (str, optional): API密钥,如果为None则从APIKeyManager获取
|
||||||
|
|
||||||
|
返回:
|
||||||
|
OpenAiLLM: LLM客户端实例
|
||||||
|
"""
|
||||||
|
if api_key is None:
|
||||||
|
api_key = APIKeyManager.get_api_key()
|
||||||
|
|
||||||
|
return OpenAiLLM(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url="https://api.siliconflow.cn/v1", # 可以根据实际情况修改
|
||||||
|
model="deepseek-ai/DeepSeek-V3", # 可以根据实际情况修改
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=100
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_question_complete(self, question):
|
||||||
|
"""
|
||||||
|
调用LLM判断问题是否完整
|
||||||
|
|
||||||
|
参数:
|
||||||
|
question (str): 需要判断的问题
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple: (bool, str) - 是否完整的布尔值和LLM的详细回复
|
||||||
|
"""
|
||||||
|
# 最大重试次数
|
||||||
|
max_retries = 3
|
||||||
|
retry_count = 0
|
||||||
|
retry_delay = 2 # 重试延迟,单位:秒
|
||||||
|
|
||||||
|
while retry_count <= max_retries:
|
||||||
|
try:
|
||||||
|
# 创建提示词
|
||||||
|
prompt = self._create_llm_prompt(question)
|
||||||
|
|
||||||
|
# 使用OpenAiLLM调用模型
|
||||||
|
response = self.llm_client.invoke(prompt)
|
||||||
|
|
||||||
|
# 处理可能的响应格式
|
||||||
|
if hasattr(response, 'content'):
|
||||||
|
full_answer = response.content
|
||||||
|
else:
|
||||||
|
# 如果response是字符串
|
||||||
|
full_answer = str(response)
|
||||||
|
|
||||||
|
# 提取JSON部分
|
||||||
|
result = self._extract_json_from_response(full_answer)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
is_complete = result.get("is_complete", False)
|
||||||
|
return is_complete, full_answer
|
||||||
|
else:
|
||||||
|
# 如果没有找到或解析失败,使用简单判断
|
||||||
|
is_complete = "完整" in full_answer[:100]
|
||||||
|
return is_complete, full_answer
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count <= max_retries:
|
||||||
|
# 非最后一次重试,打印错误并继续
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
# 每次重试增加延迟时间,避免频繁失败
|
||||||
|
retry_delay *= 2
|
||||||
|
else:
|
||||||
|
# 已达到最大重试次数,返回错误
|
||||||
|
print(f"错误: 经过 {max_retries} 次重试后仍然失败: {str(e)}")
|
||||||
|
return False, f"错误: 经过 {max_retries} 次重试后仍然失败: {str(e)}"
|
||||||
|
|
||||||
|
# 不应该到达这里,但为了代码完整性添加
|
||||||
|
return False, "未知错误:重试机制逻辑错误"
|
||||||
|
|
||||||
|
def _process_question(self, args, complete_questions, progress_counter, progress_lock, complete_questions_lock, pbar):
|
||||||
|
"""
|
||||||
|
处理单个问题并更新进度
|
||||||
|
|
||||||
|
参数:
|
||||||
|
args (tuple): 包含问题索引、问题内容、LLM客户端和总问题数的元组
|
||||||
|
complete_questions (list): 存储完整问题的列表
|
||||||
|
progress_counter (dict): 进度计数器
|
||||||
|
progress_lock (threading.Lock): 进度锁
|
||||||
|
complete_questions_lock (threading.Lock): 完整问题列表锁
|
||||||
|
pbar (tqdm): 进度条对象
|
||||||
|
"""
|
||||||
|
index, question, llm_client, total_questions = args
|
||||||
|
|
||||||
|
# 跳过空问题
|
||||||
|
if pd.isna(question) or question.strip() == "":
|
||||||
|
with progress_lock:
|
||||||
|
progress_counter["processed"] += 1
|
||||||
|
pbar.update(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 调用LLM判断问题是否完整
|
||||||
|
is_complete, full_answer = self.is_question_complete(question)
|
||||||
|
|
||||||
|
if is_complete:
|
||||||
|
# 从答案中提取JSON
|
||||||
|
parsed_json = self._extract_json_from_response(full_answer)
|
||||||
|
|
||||||
|
if parsed_json:
|
||||||
|
# 构造包含解析出的JSON信息的结果
|
||||||
|
result = {
|
||||||
|
"问题": question,
|
||||||
|
"LLM回复": full_answer,
|
||||||
|
"完整性": "完整" if parsed_json.get("is_complete", False) else "不完整",
|
||||||
|
"原因": parsed_json.get("reason", "未提供"),
|
||||||
|
"置信度": parsed_json.get("confidence", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 更新计数
|
||||||
|
with progress_lock:
|
||||||
|
if result["完整性"] == "完整":
|
||||||
|
progress_counter["complete"] += 1
|
||||||
|
else:
|
||||||
|
progress_counter["incomplete"] += 1
|
||||||
|
else:
|
||||||
|
# JSON解析失败,只保存原始回答
|
||||||
|
result = {
|
||||||
|
"问题": question,
|
||||||
|
"LLM回复": full_answer,
|
||||||
|
"完整性": "完整"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 更新计数
|
||||||
|
with progress_lock:
|
||||||
|
progress_counter["complete"] += 1
|
||||||
|
|
||||||
|
with complete_questions_lock:
|
||||||
|
complete_questions.append(result)
|
||||||
|
else:
|
||||||
|
with progress_lock:
|
||||||
|
progress_counter["incomplete"] += 1
|
||||||
|
# 更新进度条
|
||||||
|
with progress_lock:
|
||||||
|
progress_counter["processed"] += 1
|
||||||
|
# 更新进度条描述
|
||||||
|
pbar.set_postfix(
|
||||||
|
完整=progress_counter["complete"],
|
||||||
|
不完整=progress_counter["incomplete"],
|
||||||
|
完整率=f"{progress_counter['complete']/max(1, progress_counter['processed']):.1%}"
|
||||||
|
)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
def _shorten_response(self, response):
|
||||||
|
"""
|
||||||
|
截断LLM响应,提取重要信息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
response (str): 原始LLM响应
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 截断后的响应
|
||||||
|
"""
|
||||||
|
# 保留思考过程的前200个字符和JSON部分
|
||||||
|
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
json_part = json_match.group(0)
|
||||||
|
prefix = response[:200] + "..." if len(response) > 200 else response
|
||||||
|
return f"{prefix}\n\n{json_part}"
|
||||||
|
return response[:500] + "..." if len(response) > 500 else response
|
||||||
|
|
||||||
|
def _prepare_excel_dataframe(self, complete_questions):
|
||||||
|
"""
|
||||||
|
将结果处理为DataFrame用于Excel输出
|
||||||
|
|
||||||
|
参数:
|
||||||
|
complete_questions (list): 完整问题列表
|
||||||
|
|
||||||
|
返回:
|
||||||
|
pandas.DataFrame: 处理后的DataFrame
|
||||||
|
"""
|
||||||
|
# 将结果列表转换为DataFrame
|
||||||
|
result_df = pd.DataFrame(complete_questions)
|
||||||
|
|
||||||
|
# 处理LLM回复列,截取一定长度以避免Excel单元格过大
|
||||||
|
if "LLM回复" in result_df.columns:
|
||||||
|
result_df["LLM回复"] = result_df["LLM回复"].apply(self._shorten_response)
|
||||||
|
|
||||||
|
# 调整列的顺序,确保重要列在前面
|
||||||
|
column_order = ["问题", "完整性", "置信度", "原因", "LLM回复"]
|
||||||
|
# 过滤掉不存在的列
|
||||||
|
column_order = [col for col in column_order if col in result_df.columns]
|
||||||
|
# 确保所有剩余的列也被包含
|
||||||
|
for col in result_df.columns:
|
||||||
|
if col not in column_order:
|
||||||
|
column_order.append(col)
|
||||||
|
|
||||||
|
# 重新排序列
|
||||||
|
return result_df[column_order]
|
||||||
|
|
||||||
|
def _set_excel_column_widths(self, worksheet):
|
||||||
|
"""
|
||||||
|
设置Excel列宽
|
||||||
|
|
||||||
|
参数:
|
||||||
|
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
||||||
|
"""
|
||||||
|
for col in range(1, worksheet.max_column + 1):
|
||||||
|
col_letter = get_column_letter(col)
|
||||||
|
column_name = worksheet[f"{col_letter}1"].value
|
||||||
|
|
||||||
|
if column_name == "问题":
|
||||||
|
worksheet.column_dimensions[col_letter].width = 40
|
||||||
|
elif column_name == "LLM回复":
|
||||||
|
worksheet.column_dimensions[col_letter].width = 60
|
||||||
|
elif column_name == "原因":
|
||||||
|
worksheet.column_dimensions[col_letter].width = 30
|
||||||
|
elif column_name == "完整性":
|
||||||
|
worksheet.column_dimensions[col_letter].width = 10
|
||||||
|
elif column_name == "置信度":
|
||||||
|
worksheet.column_dimensions[col_letter].width = 10
|
||||||
|
else:
|
||||||
|
worksheet.column_dimensions[col_letter].width = 15
|
||||||
|
|
||||||
|
def _apply_excel_cell_styles(self, worksheet):
|
||||||
|
"""
|
||||||
|
应用单元格样式
|
||||||
|
|
||||||
|
参数:
|
||||||
|
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
||||||
|
|
||||||
|
返回:
|
||||||
|
openpyxl.styles.Border: 边框样式,用于统计信息
|
||||||
|
"""
|
||||||
|
# 定义样式
|
||||||
|
header_fill = PatternFill(start_color="DDEBF7", end_color="DDEBF7", fill_type="solid")
|
||||||
|
header_font = Font(bold=True)
|
||||||
|
wrap_alignment = Alignment(wrap_text=True, vertical="top")
|
||||||
|
border = Border(
|
||||||
|
left=Side(style='thin'),
|
||||||
|
right=Side(style='thin'),
|
||||||
|
top=Side(style='thin'),
|
||||||
|
bottom=Side(style='thin')
|
||||||
|
)
|
||||||
|
|
||||||
|
# 应用样式到每个单元格
|
||||||
|
for row in worksheet.iter_rows(min_row=1, max_row=worksheet.max_row, min_col=1, max_col=worksheet.max_column):
|
||||||
|
for cell in row:
|
||||||
|
cell.alignment = wrap_alignment
|
||||||
|
cell.border = border
|
||||||
|
|
||||||
|
# 为标题行应用特殊样式
|
||||||
|
if cell.row == 1:
|
||||||
|
cell.fill = header_fill
|
||||||
|
cell.font = header_font
|
||||||
|
|
||||||
|
# 为完整性列应用条件格式
|
||||||
|
if cell.row > 1: # 跳过标题行
|
||||||
|
column_name = worksheet.cell(row=1, column=cell.column).value
|
||||||
|
if column_name == "完整性":
|
||||||
|
if cell.value == "完整":
|
||||||
|
cell.fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
|
||||||
|
else:
|
||||||
|
cell.fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
|
||||||
|
|
||||||
|
return border # 返回边框样式以便在统计信息中重用
|
||||||
|
|
||||||
|
def _add_statistics_to_excel(self, worksheet, complete_questions, total_rows, total_questions, border):
|
||||||
|
"""
|
||||||
|
添加统计信息到Excel表格
|
||||||
|
|
||||||
|
参数:
|
||||||
|
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
||||||
|
complete_questions (list): 完整问题列表
|
||||||
|
total_rows (int): 总行数
|
||||||
|
total_questions (int): 总问题数
|
||||||
|
border (openpyxl.styles.Border): 边框样式
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 完整问题数量
|
||||||
|
"""
|
||||||
|
# 计算统计数据
|
||||||
|
complete_count = sum(1 for item in complete_questions if item.get("完整性") == "完整")
|
||||||
|
incomplete_count = total_rows - complete_count
|
||||||
|
|
||||||
|
# 添加统计行
|
||||||
|
worksheet.append([""]) # 空行
|
||||||
|
|
||||||
|
stat_row = worksheet.max_row + 1
|
||||||
|
worksheet.cell(row=stat_row, column=1, value="统计信息")
|
||||||
|
worksheet.cell(row=stat_row, column=1).font = Font(bold=True)
|
||||||
|
|
||||||
|
worksheet.cell(row=stat_row+1, column=1, value="总问题数")
|
||||||
|
worksheet.cell(row=stat_row+1, column=2, value=total_rows)
|
||||||
|
|
||||||
|
worksheet.cell(row=stat_row+2, column=1, value="完整问题数")
|
||||||
|
worksheet.cell(row=stat_row+2, column=2, value=complete_count)
|
||||||
|
worksheet.cell(row=stat_row+2, column=2).fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
|
||||||
|
|
||||||
|
worksheet.cell(row=stat_row+3, column=1, value="不完整问题数")
|
||||||
|
worksheet.cell(row=stat_row+3, column=2, value=incomplete_count)
|
||||||
|
worksheet.cell(row=stat_row+3, column=2).fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
|
||||||
|
|
||||||
|
worksheet.cell(row=stat_row+4, column=1, value="完整问题比例")
|
||||||
|
worksheet.cell(row=stat_row+4, column=2, value=f"{complete_count/total_rows:.2%}" if total_rows > 0 else "0%")
|
||||||
|
|
||||||
|
# 应用边框到统计行
|
||||||
|
for r in range(stat_row, stat_row+5):
|
||||||
|
for c in range(1, 3):
|
||||||
|
worksheet.cell(row=r, column=c).border = border
|
||||||
|
|
||||||
|
return complete_count
|
||||||
|
|
||||||
|
def save_results_to_excel(self, complete_questions, total_questions):
|
||||||
|
"""
|
||||||
|
将结果保存到Excel文件
|
||||||
|
|
||||||
|
参数:
|
||||||
|
complete_questions (list): 完整问题列表
|
||||||
|
total_questions (int): 总问题数
|
||||||
|
"""
|
||||||
|
if not complete_questions:
|
||||||
|
print(f"没有找到完整的问题。")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
result_df = self._prepare_excel_dataframe(complete_questions)
|
||||||
|
total_rows = len(result_df)
|
||||||
|
|
||||||
|
# 保存到Excel文件
|
||||||
|
result_df.to_excel(self.output_path, index=False, engine='openpyxl')
|
||||||
|
|
||||||
|
# 应用Excel样式
|
||||||
|
from openpyxl import load_workbook
|
||||||
|
wb = load_workbook(self.output_path)
|
||||||
|
ws = wb.active
|
||||||
|
|
||||||
|
# 设置列宽
|
||||||
|
self._set_excel_column_widths(ws)
|
||||||
|
|
||||||
|
# 应用单元格样式
|
||||||
|
border = self._apply_excel_cell_styles(ws)
|
||||||
|
|
||||||
|
# 添加统计信息
|
||||||
|
complete_count = self._add_statistics_to_excel(ws, complete_questions, total_rows, total_questions, border)
|
||||||
|
|
||||||
|
# 保存样式化的工作簿
|
||||||
|
wb.save(self.output_path)
|
||||||
|
|
||||||
|
# 输出结果统计
|
||||||
|
print(f"处理完成。共有{complete_count}/{total_questions}个完整问题被保存到 {self.output_path}")
|
||||||
|
print(f"完整问题比例: {complete_count/total_questions:.2%}" if total_questions > 0 else "完整问题比例: 0%")
|
||||||
|
|
||||||
|
def process_excel_file(self):
|
||||||
|
"""
|
||||||
|
处理Excel文件中的问题
|
||||||
|
|
||||||
|
读取Excel文件,判断问题完整性,并将结果保存到输出Excel文件
|
||||||
|
"""
|
||||||
|
# 确保Excel文件存在
|
||||||
|
if not os.path.exists(self.input_path):
|
||||||
|
print(f"错误: 找不到Excel文件 '{self.input_path}'")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 读取Excel文件
|
||||||
|
print(f"正在读取Excel文件: {self.input_path}")
|
||||||
|
try:
|
||||||
|
df = pd.read_excel(self.input_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取Excel文件时出错: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查列数据
|
||||||
|
if len(df.columns) <= self.column_index:
|
||||||
|
print(f"错误: Excel文件没有足够的列,请求索引 {self.column_index},但只有 {len(df.columns)} 列")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取目标列名称
|
||||||
|
target_col = df.columns[self.column_index]
|
||||||
|
print(f"目标列名称: {target_col}")
|
||||||
|
|
||||||
|
# 准备存储完整问题的列表
|
||||||
|
complete_questions = []
|
||||||
|
total_questions = len(df)
|
||||||
|
|
||||||
|
print(f"总共有{total_questions}个问题需要判断")
|
||||||
|
|
||||||
|
# 用于线程安全的列表操作和进度计数
|
||||||
|
complete_questions_lock = threading.Lock()
|
||||||
|
progress_counter = {"processed": 0, "complete": 0, "incomplete": 0}
|
||||||
|
progress_lock = threading.Lock()
|
||||||
|
|
||||||
|
# 准备问题列表
|
||||||
|
questions = [(i, str(row[target_col]), self.llm_client, total_questions)
|
||||||
|
for i, row in df.iterrows()]
|
||||||
|
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 使用tqdm创建进度条
|
||||||
|
print(f"开始处理问题,使用 {self.max_workers} 个并发线程...")
|
||||||
|
with tqdm(total=total_questions, desc="处理问题", unit="问题") as pbar:
|
||||||
|
# 使用线程池并发处理
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||||
|
# 提交所有任务
|
||||||
|
futures = [executor.submit(
|
||||||
|
self._process_question,
|
||||||
|
args,
|
||||||
|
complete_questions,
|
||||||
|
progress_counter,
|
||||||
|
progress_lock,
|
||||||
|
complete_questions_lock,
|
||||||
|
pbar
|
||||||
|
) for args in questions]
|
||||||
|
|
||||||
|
# 等待所有任务完成
|
||||||
|
concurrent.futures.wait(futures)
|
||||||
|
|
||||||
|
# 计算总处理时间
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
print(f"处理完成,耗时: {processing_time:.2f}秒,平均每问题: {processing_time/total_questions:.2f}秒")
|
||||||
|
|
||||||
|
# 将完整问题保存到Excel文件
|
||||||
|
self.save_results_to_excel(complete_questions, total_questions)
|
||||||
|
|
||||||
|
def test_single_question(self, question):
|
||||||
|
"""
|
||||||
|
测试单个问题的完整性
|
||||||
|
|
||||||
|
参数:
|
||||||
|
question (str): 要测试的问题
|
||||||
|
"""
|
||||||
|
print(f"问题: {question}")
|
||||||
|
print("正在调用LLM判断问题是否完整...")
|
||||||
|
|
||||||
|
# 调用LLM判断问题是否完整
|
||||||
|
is_complete, full_answer = self.is_question_complete(question)
|
||||||
|
|
||||||
|
# 从答案中提取JSON
|
||||||
|
parsed_json = self._extract_json_from_response(full_answer)
|
||||||
|
|
||||||
|
print("\n==== LLM回复 ====")
|
||||||
|
print(full_answer)
|
||||||
|
print("================\n")
|
||||||
|
|
||||||
|
if parsed_json:
|
||||||
|
print(f"判断结果: {'完整' if parsed_json.get('is_complete', False) else '不完整'}")
|
||||||
|
print(f"判断原因: {parsed_json.get('reason', '未提供')}")
|
||||||
|
print(f"置信度: {parsed_json.get('confidence', 0)}%")
|
||||||
|
else:
|
||||||
|
print(f"判断结果: {'完整' if is_complete else '不完整'} (简单判断)")
|
||||||
|
print("无法从回复中提取JSON结构化数据")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
"""解析命令行参数"""
|
||||||
|
parser = argparse.ArgumentParser(description='判断Excel文件中的问题是否完整')
|
||||||
|
parser.add_argument('-i', '--input', type=str, default=DEFAULT_EXCEL_PATH,
|
||||||
|
help=f'输入Excel文件路径 (默认: {DEFAULT_EXCEL_PATH})')
|
||||||
|
parser.add_argument('-o', '--output', type=str, default=DEFAULT_OUTPUT_PATH,
|
||||||
|
help=f'输出Excel文件路径 (默认: {DEFAULT_OUTPUT_PATH})')
|
||||||
|
parser.add_argument('-w', '--workers', type=int, default=DEFAULT_MAX_WORKERS,
|
||||||
|
help=f'并发处理的最大线程数 (默认: {DEFAULT_MAX_WORKERS})')
|
||||||
|
parser.add_argument('-c', '--column', type=int, default=0,
|
||||||
|
help='要处理的问题所在列的索引 (默认: 0,即第一列)')
|
||||||
|
parser.add_argument('-t', '--test', type=str,
|
||||||
|
help='测试单个问题,不处理Excel文件')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
# 创建问题完整性判断工具实例
|
||||||
|
judge = QueryCompletenessJudge(
|
||||||
|
input_path=args.input,
|
||||||
|
output_path=args.output,
|
||||||
|
max_workers=args.workers,
|
||||||
|
column_index=args.column
|
||||||
|
)
|
||||||
|
# 如果是测试单个问题
|
||||||
|
if args.test:
|
||||||
|
judge.test_single_question(args.test)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 处理Excel文件
|
||||||
|
judge.process_excel_file()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,239 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from urllib.parse import unquote
|
||||||
|
from rag2_0.tool.WikijsTool import WikijsTool
|
||||||
|
from rag2_0.tool.html_to_md import convert_html_to_md
|
||||||
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
from rag2_0.dify.dify_tool import DifyTool
|
||||||
|
import json
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from langchain.output_parsers import PydanticOutputParser
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
class ContentSource(BaseModel):
|
||||||
|
score:int = Field(description="相关性分数")
|
||||||
|
reason:str = Field(description="评分理由")
|
||||||
|
|
||||||
|
class RetrieveContentScoreJudge:
|
||||||
|
"""
|
||||||
|
检索内容相关性评分工具类
|
||||||
|
|
||||||
|
用于评估检索内容与问题之间的相关性,并计算相关性分数
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, wiki_excel_path, answer_excel_path, output_path=None):
|
||||||
|
"""
|
||||||
|
初始化评分工具类
|
||||||
|
|
||||||
|
参数:
|
||||||
|
wiki_excel_path (str): Wiki Excel文件路径
|
||||||
|
answer_excel_path (str): 回答Excel文件路径
|
||||||
|
output_path (str, optional): 输出Excel文件路径,默认为None
|
||||||
|
"""
|
||||||
|
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
|
||||||
|
if os.path.exists(wiki_excel_path):
|
||||||
|
self.wiki_excel = pd.read_excel(wiki_excel_path)
|
||||||
|
else:
|
||||||
|
self.wiki_excel = None
|
||||||
|
self.answer_excel = pd.read_excel(answer_excel_path)
|
||||||
|
self.output_path = output_path or "/data/Rag2_0/data/excel/dify问答_检索内容评分.xlsx"
|
||||||
|
|
||||||
|
# 从环境变量中获取OpenAI的配置
|
||||||
|
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
self.base_url = os.getenv("OPENAI_API_BASE")
|
||||||
|
self.model_name = os.getenv("LLM_MODEL_NAME")
|
||||||
|
|
||||||
|
if not all([self.api_key, self.base_url, self.model_name]):
|
||||||
|
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
|
||||||
|
|
||||||
|
self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model_name)
|
||||||
|
|
||||||
|
def find_wiki_link(self, query) -> str | None:
|
||||||
|
"""
|
||||||
|
根据查询(对应wiki_excel中的新提问列)找出对应的词条链接
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query (str): 查询内容,对应wiki_excel中的新提问列
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 对应的词条链接,如果没有找到则返回None
|
||||||
|
"""
|
||||||
|
# 确保query不为空
|
||||||
|
if not query or pd.isna(query):
|
||||||
|
return None
|
||||||
|
if self.wiki_excel is None:
|
||||||
|
return None
|
||||||
|
# 在"新提问"列中查找匹配的行
|
||||||
|
matched_rows = self.wiki_excel[self.wiki_excel['新提问'] == query]
|
||||||
|
|
||||||
|
# 如果找到了匹配的行,返回对应的词条链接
|
||||||
|
if not matched_rows.empty:
|
||||||
|
return matched_rows.iloc[0]['对应词条链接']
|
||||||
|
|
||||||
|
# 如果没有完全匹配,尝试部分匹配
|
||||||
|
# 去除软件名称部分(如果有)
|
||||||
|
query_parts = query.split(',', 1)
|
||||||
|
if len(query_parts) > 1:
|
||||||
|
clean_query = query_parts[1].strip()
|
||||||
|
|
||||||
|
# 在"提问"列中查找包含清理后查询的行
|
||||||
|
for idx, row in self.wiki_excel.iterrows():
|
||||||
|
if pd.notna(row['提问']) and clean_query in row['提问']:
|
||||||
|
return row['对应词条链接']
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_wiki_title(self, link) -> str | None:
|
||||||
|
"""
|
||||||
|
获取词条标题
|
||||||
|
|
||||||
|
参数:
|
||||||
|
link (str): 词条链接
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 词条标题,如果获取失败则返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not link or pd.isna(link):
|
||||||
|
return None
|
||||||
|
# 移除域名部分,只保留路径
|
||||||
|
path = link.split('/', 3)[-1]
|
||||||
|
decoded_path = unquote(path)
|
||||||
|
path_parts = decoded_path.split('/')
|
||||||
|
return path_parts[-1]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
|
||||||
|
|
||||||
|
def calculate_score(self, answer:str, content:str) -> int:
|
||||||
|
"""
|
||||||
|
使用OpenAiLLM通过LLM判断answer与content之间的相关性分数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
answer (str): 用户问题
|
||||||
|
content (str): 检索内容
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
|
||||||
|
|
||||||
|
【评分标准】
|
||||||
|
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
|
||||||
|
8-9分:高度相关,核心要素匹配但存在少量信息缺失
|
||||||
|
6-7分:部分相关,涉及相同主题但存在重要信息缺失
|
||||||
|
4-5分:弱相关,仅次要信息点匹配
|
||||||
|
1-3分:完全不相关或信息冲突
|
||||||
|
|
||||||
|
【评估维度】
|
||||||
|
1. 主题一致性:核心主题/意图的匹配程度
|
||||||
|
2. 内容覆盖度:是否涵盖query的关键要素
|
||||||
|
3. 信息准确性:是否存在矛盾/错误信息
|
||||||
|
4. 细节丰富度:是否提供query要求的详细信息
|
||||||
|
|
||||||
|
【输出格式】
|
||||||
|
{{
|
||||||
|
"score": 评分,
|
||||||
|
"reason": "简明扼要的评分理由(中文)"
|
||||||
|
}}
|
||||||
|
|
||||||
|
【示例】
|
||||||
|
query: "新冠疫苗的常见副作用"
|
||||||
|
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)"
|
||||||
|
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
|
||||||
|
|
||||||
|
现在评估:
|
||||||
|
query: "{answer}"
|
||||||
|
content: "{content}"
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = self.llm.invoke(user_prompt=prompt, need_retry=True)
|
||||||
|
|
||||||
|
# 解析JSON响应
|
||||||
|
try:
|
||||||
|
parsed_output = self.content_source_parser.parse(response.content)
|
||||||
|
return parsed_output.score
|
||||||
|
except Exception as e:
|
||||||
|
return -1
|
||||||
|
except Exception as e:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def get_retrieve_info(self, query:str, outputs:dict) -> tuple:
|
||||||
|
"""
|
||||||
|
获取检索信息并计算分数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query (str): 用户问题
|
||||||
|
outputs (dict): 检索输出结果
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple: (检索内容列表, 最高分, 最低分, 平均分)
|
||||||
|
"""
|
||||||
|
max_score = 0
|
||||||
|
min_score = 10
|
||||||
|
total_score = 0
|
||||||
|
valid_scores = 0
|
||||||
|
retrieve_content = []
|
||||||
|
for result in outputs["result"]:
|
||||||
|
content = result["content"].strip()
|
||||||
|
score = self.calculate_score(answer=query, content=content)
|
||||||
|
if score != -1:
|
||||||
|
max_score = max(max_score, score)
|
||||||
|
min_score = min(min_score, score)
|
||||||
|
total_score += score
|
||||||
|
valid_scores += 1
|
||||||
|
content_title = content.split("\n")[0]
|
||||||
|
if content_title:
|
||||||
|
retrieve_content.append(content_title + f"--得分({score}分)")
|
||||||
|
avg_score = total_score / valid_scores if valid_scores > 0 else 0
|
||||||
|
return retrieve_content, max_score, min_score, avg_score
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
"""
|
||||||
|
处理所有问题并评估检索内容相关性
|
||||||
|
|
||||||
|
遍历answer_excel中的所有问题,计算检索内容与问题的相关性分数,
|
||||||
|
并更新Excel文件
|
||||||
|
"""
|
||||||
|
for idx, row in tqdm(self.answer_excel.iterrows(), total=len(self.answer_excel), desc="处理问题评分中"):
|
||||||
|
query = row["问题"]
|
||||||
|
link = self.find_wiki_link(query)
|
||||||
|
answer_title = self.get_wiki_title(link)
|
||||||
|
retrieve_content = []
|
||||||
|
max_score = 0
|
||||||
|
min_score = 0
|
||||||
|
avg_score = 0 # 初始化平均分
|
||||||
|
rewrite_query=""
|
||||||
|
message_info = DifyTool.get_message_debug_info(appid="ccf92b97-2789-4a3f-90e0-135a869a37c5", query=query)
|
||||||
|
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||||
|
if workflow_node["title"] == "知识检索结果后处理":
|
||||||
|
outputs = json.loads(workflow_node["outputs"])
|
||||||
|
retrieve_content, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||||
|
elif workflow_node["title"] == "问题优化结果解析":
|
||||||
|
outputs = json.loads(workflow_node["outputs"])
|
||||||
|
rewrite_query = outputs["optimize_query"]
|
||||||
|
|
||||||
|
# 更新 answer_excel 中的词条内容
|
||||||
|
self.answer_excel.at[idx, "答案词条"] = answer_title if answer_title else ""
|
||||||
|
self.answer_excel.at[idx, "问题改写"] = rewrite_query
|
||||||
|
self.answer_excel.at[idx, "检索得到词条"] = "\n".join(retrieve_content) if retrieve_content else "未检索知识库"
|
||||||
|
self.answer_excel.at[idx, "最大得分"] = max_score
|
||||||
|
self.answer_excel.at[idx, "最小得分"] = min_score
|
||||||
|
self.answer_excel.at[idx, "平均得分"] = avg_score
|
||||||
|
|
||||||
|
# 保存结果到Excel文件
|
||||||
|
self.answer_excel.to_excel(self.output_path, index=False)
|
||||||
|
print(f"结果已保存到 {self.output_path}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 创建评分工具实例
|
||||||
|
judge = RetrieveContentScoreJudge(
|
||||||
|
wiki_excel_path="/data/Rag2_0/data/excel/400条人工标注-部分提问_软件名称明确.xlsx",
|
||||||
|
answer_excel_path="/data/Rag2_0/data/excel/主网软件提问_回答内容评判.xlsx",
|
||||||
|
output_path="/data/Rag2_0/data/excel/dify问答_检索内容评分.xlsx"
|
||||||
|
)
|
||||||
|
# 执行处理
|
||||||
|
judge.process()
|
||||||
@@ -0,0 +1,178 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
File: merge_nouns_with_llm.py
|
||||||
|
Description: 合并多个nouns.json中的同名专业名词,利用LLM生成唯一合并结果
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from collections import defaultdict
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||||
|
from rag2_0.intent_recognition.DataModels import Term
|
||||||
|
import logging
|
||||||
|
from langchain.output_parsers import PydanticOutputParser
|
||||||
|
from tqdm import tqdm
|
||||||
|
import time
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
class TermMerger:
|
||||||
|
"""专业名词合并类,用于合并多个数据源中的同名专业名词"""
|
||||||
|
|
||||||
|
def __init__(self, input_dir=None, output_path=None, max_workers=3):
|
||||||
|
"""初始化名词合并器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: 包含nouns.json文件的目录路径
|
||||||
|
output_path: 合并结果的输出文件路径
|
||||||
|
max_workers: 线程池最大工作线程数
|
||||||
|
"""
|
||||||
|
self.EXTRACTED_NOUNS_DIR = input_dir
|
||||||
|
self.OUTPUT_PATH = output_path
|
||||||
|
self.MAX_WORKERS = max_workers
|
||||||
|
self.terms_parser = PydanticOutputParser(pydantic_object=Term)
|
||||||
|
self.MERGE_PROMPT = '''
|
||||||
|
请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请:
|
||||||
|
- 同义词(synonymous)去重合并
|
||||||
|
- 描述(description)合并为更完整、简明的描述
|
||||||
|
- 保持输出格式为:
|
||||||
|
{output_format}
|
||||||
|
原始条目:
|
||||||
|
{items}
|
||||||
|
'''
|
||||||
|
# 配置LLM
|
||||||
|
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
base_url = os.getenv("OPENAI_API_BASE")
|
||||||
|
llm_params = {"temperature": 0.3, "model": model_name}
|
||||||
|
if api_key:
|
||||||
|
llm_params["api_key"] = api_key
|
||||||
|
if base_url:
|
||||||
|
llm_params["base_url"] = base_url
|
||||||
|
self.llm = OpenAiLLM(**llm_params)
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
def load_all_terms(self):
|
||||||
|
"""读取目录下所有nouns.json,返回所有Term列表"""
|
||||||
|
all_terms = []
|
||||||
|
for file in glob.glob(os.path.join(self.EXTRACTED_NOUNS_DIR, '*_nouns.json')):
|
||||||
|
with open(file, 'r', encoding='utf-8') as f:
|
||||||
|
try:
|
||||||
|
file_terms = json.load(f)
|
||||||
|
new_terms = [{"name": term["name"].upper(), "synonymous": term["synonymous"], "description": term["description"]} for term in file_terms]
|
||||||
|
all_terms.extend(new_terms)
|
||||||
|
logging.info(f"加载{file},共{len(new_terms)}条")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"读取{file}失败: {e}")
|
||||||
|
|
||||||
|
# 加载suffix_keywords.json文件
|
||||||
|
suffix_keywords_path = os.path.join(os.path.dirname(os.path.dirname(self.EXTRACTED_NOUNS_DIR)), 'data', 'nouns', 'suffix_keywords.json')
|
||||||
|
if os.path.exists(suffix_keywords_path):
|
||||||
|
try:
|
||||||
|
with open(suffix_keywords_path, 'r', encoding='utf-8') as f:
|
||||||
|
suffix_terms = json.load(f)
|
||||||
|
suffix_terms = [{"name": term["name"].upper(), "synonymous": "", "description": ""} for term in suffix_terms]
|
||||||
|
all_terms.extend(suffix_terms)
|
||||||
|
logging.info(f"加载{suffix_keywords_path},共{len(suffix_terms)}条")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"读取{suffix_keywords_path}失败: {e}")
|
||||||
|
|
||||||
|
return all_terms
|
||||||
|
|
||||||
|
def group_terms_by_name(self, terms):
|
||||||
|
"""按name聚合Term"""
|
||||||
|
name2terms = defaultdict(list)
|
||||||
|
for term in terms:
|
||||||
|
name = term.get('name', '').strip()
|
||||||
|
if name:
|
||||||
|
name2terms[name].append(term)
|
||||||
|
return name2terms
|
||||||
|
|
||||||
|
def merge_terms_with_llm(self, name, term_list):
|
||||||
|
"""调用LLM合并同名Term,失败最多重试三次"""
|
||||||
|
items = json.dumps(term_list, ensure_ascii=False)
|
||||||
|
prompt = self.MERGE_PROMPT.format(name=name, items=items, output_format=self.terms_parser.get_format_instructions())
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
for attempt in range(1, max_retries + 1):
|
||||||
|
try:
|
||||||
|
response = self.llm.invoke(prompt, False)
|
||||||
|
parsed_output = self.terms_parser.parse(response.content)
|
||||||
|
return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description}
|
||||||
|
except Exception as e:
|
||||||
|
if attempt == max_retries:
|
||||||
|
logging.warning(f"解析LLM合并结果失败: {e}")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
time.sleep(10*attempt)
|
||||||
|
|
||||||
|
def process_term(self, name_terms_tuple):
|
||||||
|
"""处理单个词条,用于线程池并行处理"""
|
||||||
|
name, term_list = name_terms_tuple
|
||||||
|
try:
|
||||||
|
merged = self.merge_terms_with_llm(name, term_list)
|
||||||
|
if merged:
|
||||||
|
return merged
|
||||||
|
else:
|
||||||
|
return term_list[0]
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"处理词条 {name} 时出错: {e}")
|
||||||
|
return term_list[0]
|
||||||
|
|
||||||
|
def merge(self):
|
||||||
|
"""合并所有词条的入口方法"""
|
||||||
|
# 1. 读取所有术语
|
||||||
|
all_terms = self.load_all_terms()
|
||||||
|
logging.info(f"共加载{len(all_terms)}条术语")
|
||||||
|
|
||||||
|
# 2. 按名称聚合
|
||||||
|
name2terms = self.group_terms_by_name(all_terms)
|
||||||
|
logging.info(f"共{len(name2terms)}个唯一名词")
|
||||||
|
|
||||||
|
# 3. 使用线程池并行处理
|
||||||
|
merged_terms = []
|
||||||
|
items_to_process = []
|
||||||
|
|
||||||
|
# 先处理只有一个条目的词条(不需要合并)
|
||||||
|
for name, term_list in name2terms.items():
|
||||||
|
if len(term_list) == 1:
|
||||||
|
merged_terms.append(term_list[0])
|
||||||
|
else:
|
||||||
|
items_to_process.append((name, term_list))
|
||||||
|
|
||||||
|
logging.info(f"共{len(merged_terms)}个单一条目,{len(items_to_process)}个需要合并的条目")
|
||||||
|
|
||||||
|
# 只对需要合并的词条使用线程池处理
|
||||||
|
if items_to_process:
|
||||||
|
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
||||||
|
# 使用tqdm显示进度
|
||||||
|
for result in tqdm(executor.map(self.process_term, items_to_process), total=len(items_to_process)):
|
||||||
|
merged_terms.append(result)
|
||||||
|
|
||||||
|
# 4. 保存合并结果
|
||||||
|
os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True)
|
||||||
|
with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(merged_terms, f, ensure_ascii=False, indent=2)
|
||||||
|
logging.info(f"合并后结果已保存到: {self.OUTPUT_PATH}")
|
||||||
|
|
||||||
|
return merged_terms
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数,创建TermMerger实例并执行合并"""
|
||||||
|
|
||||||
|
cur_path = os.path.dirname(__file__)
|
||||||
|
input_dir = os.path.abspath(os.path.join(cur_path, '../../data/wiki_extracted_nouns'))
|
||||||
|
output_path = os.path.join(cur_path, "..", "..", "data", "nouns", 'merged_nouns.json')
|
||||||
|
merger = TermMerger(input_dir=input_dir, output_path=output_path, max_workers=2)
|
||||||
|
merger.merge()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||||
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||||
|
main()
|
||||||
@@ -0,0 +1,408 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
File: validate_excel_data_batch.py
|
||||||
|
Description: 使用LLM批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写是否正确
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import concurrent.futures
|
||||||
|
from tqdm import tqdm
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from rag2_0.intent_recognition.PromptTemplates import classification
|
||||||
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||||
|
|
||||||
|
class ExcelDataValidator:
|
||||||
|
"""Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写"""
|
||||||
|
|
||||||
|
def __init__(self, input_file=None, output_file=None, workers=4, batch_size=10):
|
||||||
|
"""
|
||||||
|
初始化验证器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_file: 输入Excel文件路径
|
||||||
|
output_file: 输出结果Excel文件路径
|
||||||
|
workers: 并行工作线程数
|
||||||
|
batch_size: 每批处理的行数
|
||||||
|
"""
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
self.input_file = input_file
|
||||||
|
self.output_file = output_file
|
||||||
|
self.workers = workers
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.df = None
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
self.setup_logging()
|
||||||
|
|
||||||
|
def setup_logging(self):
|
||||||
|
"""配置日志输出"""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||||
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
def load_data_from_excel(self, file_path=None):
|
||||||
|
"""
|
||||||
|
从Excel文件中读取数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Excel文件路径,如不提供则使用初始化时的路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame对象
|
||||||
|
"""
|
||||||
|
file_path = file_path or self.input_file
|
||||||
|
if not file_path:
|
||||||
|
logging.error("未指定输入文件路径")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
df = pd.read_excel(file_path)
|
||||||
|
required_columns = ["提问", "问题拆解", "一级分类", "二级分类", "问题改写", "检索的关键词"]
|
||||||
|
for col in required_columns:
|
||||||
|
if col not in df.columns:
|
||||||
|
logging.error(f"缺少必要的列: {col}")
|
||||||
|
return None
|
||||||
|
logging.info(f"成功从{file_path}读取了{len(df)}条数据")
|
||||||
|
self.df = df
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"读取Excel文件时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_classification(self, llm, query, vertical_class, sub_class):
|
||||||
|
"""
|
||||||
|
验证问题分类是否正确
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: LLM模型
|
||||||
|
query: 原始问题
|
||||||
|
vertical_class: 一级分类
|
||||||
|
sub_class: 二级分类
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(bool, str): 是否正确,错误原因(如果有)
|
||||||
|
"""
|
||||||
|
prompt = f"""
|
||||||
|
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用,也可能涉及电力造价专业知识。我对用户问题进行了分类,请评估以下问题分类是否正确。
|
||||||
|
|
||||||
|
我目前总共有以下分类:
|
||||||
|
{classification}
|
||||||
|
|
||||||
|
问题的分类情况如下:
|
||||||
|
原始问题: {query}
|
||||||
|
一级分类: {vertical_class}
|
||||||
|
二级分类: {sub_class}
|
||||||
|
|
||||||
|
请从专业角度分析这个分类是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = llm.invoke(prompt)
|
||||||
|
result = response.content.strip()
|
||||||
|
|
||||||
|
if result.startswith("正确"):
|
||||||
|
return True, ""
|
||||||
|
else:
|
||||||
|
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||||
|
return False, error_reason
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"验证问题分类时出错: {e}")
|
||||||
|
return False, f"验证过程出错: {str(e)}"
|
||||||
|
|
||||||
|
def validate_query_keys(self, llm, query, query_keys):
|
||||||
|
"""
|
||||||
|
验证问题拆解是否正确
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: LLM模型
|
||||||
|
query: 原始问题
|
||||||
|
query_keys: 问题拆解
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(bool, str): 是否正确,错误原因(如果有)
|
||||||
|
"""
|
||||||
|
prompt = f"""
|
||||||
|
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了拆解,请评估以下问题拆解是否正确。
|
||||||
|
|
||||||
|
原始问题: {query}
|
||||||
|
问题拆解: {query_keys}
|
||||||
|
|
||||||
|
问题拆解应该准确提取原始问题中的关键词和信息。请分析这个拆解是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = llm.invoke(prompt)
|
||||||
|
result = response.content.strip()
|
||||||
|
|
||||||
|
if result.startswith("正确"):
|
||||||
|
return True, ""
|
||||||
|
else:
|
||||||
|
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||||
|
return False, error_reason
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"验证问题拆解时出错: {e}")
|
||||||
|
return False, f"验证过程出错: {str(e)}"
|
||||||
|
|
||||||
|
def validate_keywords(self, llm, query, query_keys, keywords_str):
|
||||||
|
"""
|
||||||
|
验证检索关键词是否准确
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: LLM模型
|
||||||
|
query: 原始问题
|
||||||
|
query_keys: 问题拆解
|
||||||
|
keywords_str: 检索关键词(JSON字符串)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(bool, str): 是否正确,错误原因(如果有)
|
||||||
|
"""
|
||||||
|
# 解析关键词JSON
|
||||||
|
try:
|
||||||
|
if isinstance(keywords_str, str) and keywords_str.strip():
|
||||||
|
keywords = json.loads(keywords_str)
|
||||||
|
else:
|
||||||
|
keywords = []
|
||||||
|
except:
|
||||||
|
keywords = keywords_str
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。通过问题检索出了一些关键词,请评估这些关键词是否准确,是否与问题相关
|
||||||
|
原始问题: {query}
|
||||||
|
问题拆解: {query_keys}
|
||||||
|
检索关键词: {keywords}
|
||||||
|
|
||||||
|
检索关键词应该准确反映问题中需要检索的关键概念和术语。请分析这些关键词是否准确、完整。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = llm.invoke(prompt)
|
||||||
|
result = response.content.strip()
|
||||||
|
|
||||||
|
if result.startswith("正确"):
|
||||||
|
return True, ""
|
||||||
|
else:
|
||||||
|
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||||
|
return False, error_reason
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"验证检索关键词时出错: {e}")
|
||||||
|
return False, f"验证过程出错: {str(e)}"
|
||||||
|
|
||||||
|
def validate_rewrite(self, llm, query, rewrite):
|
||||||
|
"""
|
||||||
|
验证问题改写是否正确
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: LLM模型
|
||||||
|
query: 原始问题
|
||||||
|
rewrite: 问题改写
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(bool, str): 是否正确,错误原因(如果有)
|
||||||
|
"""
|
||||||
|
prompt = f"""
|
||||||
|
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了改写,请评估以下问题改写是否正确。
|
||||||
|
|
||||||
|
原始问题: {query}
|
||||||
|
问题改写: {rewrite}
|
||||||
|
|
||||||
|
问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = llm.invoke(prompt)
|
||||||
|
result = response.content.strip()
|
||||||
|
|
||||||
|
if result.startswith("正确"):
|
||||||
|
return True, ""
|
||||||
|
else:
|
||||||
|
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||||
|
return False, error_reason
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"验证问题改写时出错: {e}")
|
||||||
|
return False, f"验证过程出错: {str(e)}"
|
||||||
|
|
||||||
|
def validate_row(self, llm, row_data):
|
||||||
|
"""
|
||||||
|
按顺序验证一行数据中的各个环节
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: LLM模型
|
||||||
|
row_data: (index, row)元组
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(index, is_all_correct, error_phase, error_reason): 行索引,是否全部正确,错误环节,错误原因
|
||||||
|
"""
|
||||||
|
index, row = row_data
|
||||||
|
query = row["提问"]
|
||||||
|
query_keys = row["问题拆解"]
|
||||||
|
vertical_class = row["一级分类"]
|
||||||
|
sub_class = row["二级分类"]
|
||||||
|
rewrite = row["问题改写"]
|
||||||
|
keywords_str = row["检索的关键词"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 验证问题分类
|
||||||
|
is_correct, error_reason = self.validate_classification(llm, query, vertical_class, sub_class)
|
||||||
|
if not is_correct:
|
||||||
|
return index, False, "问题分类", error_reason
|
||||||
|
|
||||||
|
# 2. 验证问题拆解
|
||||||
|
is_correct, error_reason = self.validate_query_keys(llm, query, query_keys)
|
||||||
|
if not is_correct:
|
||||||
|
return index, False, "问题拆解", error_reason
|
||||||
|
|
||||||
|
# 3. 验证检索关键词
|
||||||
|
is_correct, error_reason = self.validate_keywords(llm, query, query_keys, keywords_str)
|
||||||
|
if not is_correct:
|
||||||
|
return index, False, "关键词检索", error_reason
|
||||||
|
|
||||||
|
# 4. 验证问题改写
|
||||||
|
is_correct, error_reason = self.validate_rewrite(llm, query, rewrite)
|
||||||
|
if not is_correct:
|
||||||
|
return index, False, "问题改写", error_reason
|
||||||
|
|
||||||
|
return index, True, "", ""
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"处理行 {index} 时发生错误: {str(e)}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
return index, False, "处理错误", error_msg
|
||||||
|
|
||||||
|
def process_batch(self, llm, batch_data):
|
||||||
|
"""处理一批数据"""
|
||||||
|
results = []
|
||||||
|
for row_data in batch_data:
|
||||||
|
results.append(self.validate_row(llm, row_data))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def create_llm_instances(self, count):
|
||||||
|
"""创建多个LLM实例"""
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
base_url = os.getenv("OPENAI_API_BASE")
|
||||||
|
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||||
|
|
||||||
|
llm_params = {"temperature": 0.7, "model": model_name}
|
||||||
|
if api_key:
|
||||||
|
llm_params["api_key"] = api_key
|
||||||
|
if base_url:
|
||||||
|
llm_params["base_url"] = base_url
|
||||||
|
|
||||||
|
return [OpenAiLLM(**llm_params) for _ in range(count)]
|
||||||
|
|
||||||
|
def validate(self, input_file=None, output_file=None, workers=None, batch_size=None):
|
||||||
|
"""
|
||||||
|
执行验证过程
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_file: 输入Excel文件路径
|
||||||
|
output_file: 输出结果Excel文件路径
|
||||||
|
workers: 并行工作线程数
|
||||||
|
batch_size: 每批处理的行数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
验证后的DataFrame
|
||||||
|
"""
|
||||||
|
input_file = input_file or self.input_file
|
||||||
|
output_file = output_file or self.output_file
|
||||||
|
workers = workers or self.workers
|
||||||
|
batch_size = batch_size or self.batch_size
|
||||||
|
|
||||||
|
# 读取数据
|
||||||
|
df = self.load_data_from_excel(input_file)
|
||||||
|
if df is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 添加验证结果列
|
||||||
|
df["验证结果"] = ""
|
||||||
|
df["错误环节"] = ""
|
||||||
|
df["错误原因"] = ""
|
||||||
|
|
||||||
|
# 准备数据批次
|
||||||
|
all_rows = list(df.iterrows())
|
||||||
|
batches = [all_rows[i:i+batch_size] for i in range(0, len(all_rows), batch_size)]
|
||||||
|
|
||||||
|
# 创建多个LLM实例
|
||||||
|
llm_instances = self.create_llm_instances(min(workers, len(batches)))
|
||||||
|
|
||||||
|
# 使用线程池处理数据
|
||||||
|
all_results = []
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
||||||
|
# 为每个批次分配一个LLM实例
|
||||||
|
future_to_batch = {
|
||||||
|
executor.submit(self.process_batch, llm_instances[i % len(llm_instances)], batch):
|
||||||
|
i for i, batch in enumerate(batches)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 使用tqdm显示进度条
|
||||||
|
for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches), desc="批次处理进度"):
|
||||||
|
batch_results = future.result()
|
||||||
|
all_results.extend(batch_results)
|
||||||
|
|
||||||
|
# 按行索引排序结果,确保与原始数据顺序一致
|
||||||
|
all_results.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
# 将结果填充到DataFrame
|
||||||
|
for index, is_correct, error_phase, error_reason in all_results:
|
||||||
|
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
||||||
|
df.at[index, "错误环节"] = error_phase
|
||||||
|
df.at[index, "错误原因"] = error_reason
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
if output_file is None:
|
||||||
|
output_file = os.path.join(
|
||||||
|
os.path.dirname(input_file),
|
||||||
|
f"validated_{os.path.basename(input_file)}"
|
||||||
|
)
|
||||||
|
df.to_excel(output_file, index=False)
|
||||||
|
logging.info(f"验证完成,结果已保存至: {output_file}")
|
||||||
|
|
||||||
|
# 输出统计信息
|
||||||
|
self.print_statistics(df)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def print_statistics(self, df):
|
||||||
|
"""打印统计信息"""
|
||||||
|
total = len(df)
|
||||||
|
passed = len(df[df["验证结果"] == "通过"])
|
||||||
|
error_stats = df[df["验证结果"] == "不通过"]["错误环节"].value_counts()
|
||||||
|
|
||||||
|
logging.info(f"统计信息: 总计 {total} 条, 通过 {passed} 条, 通过率 {passed/total*100:.2f}%")
|
||||||
|
logging.info("错误环节统计:")
|
||||||
|
for phase, count in error_stats.items():
|
||||||
|
logging.info(f"- {phase}: {count} 条")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
# 解析命令行参数
|
||||||
|
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "问题分类重写结果")
|
||||||
|
output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx")
|
||||||
|
parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写")
|
||||||
|
parser.add_argument("--input", "-i", type=str, required=True, help="输入Excel文件路径", default=input_excel)
|
||||||
|
parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel)
|
||||||
|
parser.add_argument("--workers", "-w", type=int, default=2, help="并行工作线程数")
|
||||||
|
parser.add_argument("--batch-size", "-b", type=int, default=5, help="每批处理的行数")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 创建验证器实例并执行验证
|
||||||
|
validator = ExcelDataValidator(
|
||||||
|
input_file=args.input,
|
||||||
|
output_file=args.output,
|
||||||
|
workers=args.workers,
|
||||||
|
batch_size=args.batch_size
|
||||||
|
)
|
||||||
|
validator.validate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
File: vectorize_save_noun.py
|
||||||
|
Date: 2025-05-15
|
||||||
|
Description: 专业名词向量化和保存的示例程序
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from rag2_0.intent_recognition import ProfessionalNounVectorizer
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
主函数:创建索引并保存
|
||||||
|
"""
|
||||||
|
# 指定文件路径
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
output_dir = os.path.join(current_dir, "..", "..", "data", "nouns")
|
||||||
|
|
||||||
|
# 创建向量化器并指定路径
|
||||||
|
noun_vectorizer = ProfessionalNounVectorizer(
|
||||||
|
output_dir=output_dir
|
||||||
|
)
|
||||||
|
file_paths = [
|
||||||
|
os.path.join(current_dir, "..", "..", "data/nouns/merged_nouns.json"),
|
||||||
|
]
|
||||||
|
# 执行向量化和保存(一步完成)
|
||||||
|
success = noun_vectorizer.vectorize_files_and_save(file_paths)
|
||||||
|
if success:
|
||||||
|
logging.info("✓ 索引创建和保存成功")
|
||||||
|
logging.info(f" 索引保存路径: {os.path.join(output_dir, 'professional_nouns_index')}")
|
||||||
|
else:
|
||||||
|
logging.error("✗ 索引创建失败")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 配置日志输出到控制台
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(message)s'
|
||||||
|
)
|
||||||
|
main()
|
||||||
Binary file not shown.
@@ -0,0 +1 @@
|
|||||||
|
from dify_client.client import ChatClient, CompletionClient, DifyClient
|
||||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,459 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class DifyClient:
|
||||||
|
def __init__(self, api_key, base_url: str = "https://api.dify.ai/v1"):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
def _send_request(self, method, endpoint, json=None, params=None, stream=False):
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.base_url}{endpoint}"
|
||||||
|
response = requests.request(
|
||||||
|
method, url, json=json, params=params, headers=headers, stream=stream, verify=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _send_request_with_files(self, method, endpoint, data, files):
|
||||||
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||||
|
|
||||||
|
url = f"{self.base_url}{endpoint}"
|
||||||
|
response = requests.request(
|
||||||
|
method, url, data=data, headers=headers, files=files
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def message_feedback(self, message_id, rating, user):
|
||||||
|
data = {"rating": rating, "user": user}
|
||||||
|
return self._send_request("POST", f"/messages/{message_id}/feedbacks", data)
|
||||||
|
|
||||||
|
def get_application_parameters(self, user):
|
||||||
|
params = {"user": user}
|
||||||
|
return self._send_request("GET", "/parameters", params=params)
|
||||||
|
|
||||||
|
def file_upload(self, user, files):
|
||||||
|
data = {"user": user}
|
||||||
|
return self._send_request_with_files(
|
||||||
|
"POST", "/files/upload", data=data, files=files
|
||||||
|
)
|
||||||
|
|
||||||
|
def text_to_audio(self, text: str, user: str, streaming: bool = False):
|
||||||
|
data = {"text": text, "user": user, "streaming": streaming}
|
||||||
|
return self._send_request("POST", "/text-to-audio", data=data)
|
||||||
|
|
||||||
|
def get_meta(self, user):
|
||||||
|
params = {"user": user}
|
||||||
|
return self._send_request("GET", "/meta", params=params)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionClient(DifyClient):
|
||||||
|
def create_completion_message(self, inputs, response_mode, user, files=None):
|
||||||
|
data = {
|
||||||
|
"inputs": inputs,
|
||||||
|
"response_mode": response_mode,
|
||||||
|
"user": user,
|
||||||
|
"files": files,
|
||||||
|
}
|
||||||
|
return self._send_request(
|
||||||
|
"POST",
|
||||||
|
"/completion-messages",
|
||||||
|
data,
|
||||||
|
stream=True if response_mode == "streaming" else False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatClient(DifyClient):
|
||||||
|
def create_chat_message(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
query,
|
||||||
|
user,
|
||||||
|
response_mode="blocking",
|
||||||
|
conversation_id=None,
|
||||||
|
files=None,
|
||||||
|
):
|
||||||
|
data = {
|
||||||
|
"inputs": inputs,
|
||||||
|
"query": query,
|
||||||
|
"user": user,
|
||||||
|
"response_mode": response_mode,
|
||||||
|
"files": files,
|
||||||
|
}
|
||||||
|
if conversation_id:
|
||||||
|
data["conversation_id"] = conversation_id
|
||||||
|
|
||||||
|
return self._send_request(
|
||||||
|
"POST",
|
||||||
|
"/chat-messages",
|
||||||
|
data,
|
||||||
|
stream=True if response_mode == "streaming" else False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_suggested(self, message_id, user: str):
|
||||||
|
params = {"user": user}
|
||||||
|
return self._send_request(
|
||||||
|
"GET", f"/messages/{message_id}/suggested", params=params
|
||||||
|
)
|
||||||
|
|
||||||
|
def stop_message(self, task_id, user):
|
||||||
|
data = {"user": user}
|
||||||
|
return self._send_request("POST", f"/chat-messages/{task_id}/stop", data)
|
||||||
|
|
||||||
|
def get_conversations(self, user, last_id=None, limit=None, pinned=None):
|
||||||
|
params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned}
|
||||||
|
return self._send_request("GET", "/conversations", params=params)
|
||||||
|
|
||||||
|
def get_conversation_messages(
|
||||||
|
self, user, conversation_id=None, first_id=None, limit=None
|
||||||
|
):
|
||||||
|
params = {"user": user}
|
||||||
|
|
||||||
|
if conversation_id:
|
||||||
|
params["conversation_id"] = conversation_id
|
||||||
|
if first_id:
|
||||||
|
params["first_id"] = first_id
|
||||||
|
if limit:
|
||||||
|
params["limit"] = limit
|
||||||
|
|
||||||
|
return self._send_request("GET", "/messages", params=params)
|
||||||
|
|
||||||
|
def rename_conversation(
|
||||||
|
self, conversation_id, name, auto_generate: bool, user: str
|
||||||
|
):
|
||||||
|
data = {"name": name, "auto_generate": auto_generate, "user": user}
|
||||||
|
return self._send_request(
|
||||||
|
"POST", f"/conversations/{conversation_id}/name", data
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_conversation(self, conversation_id, user):
|
||||||
|
data = {"user": user}
|
||||||
|
return self._send_request("DELETE", f"/conversations/{conversation_id}", data)
|
||||||
|
|
||||||
|
def audio_to_text(self, audio_file, user):
|
||||||
|
data = {"user": user}
|
||||||
|
files = {"audio_file": audio_file}
|
||||||
|
return self._send_request_with_files("POST", "/audio-to-text", data, files)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowClient(DifyClient):
|
||||||
|
def run(
|
||||||
|
self, inputs: dict, response_mode: str = "streaming", user: str = "abc-123"
|
||||||
|
):
|
||||||
|
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
|
||||||
|
return self._send_request("POST", "/workflows/run", data)
|
||||||
|
|
||||||
|
def stop(self, task_id, user):
|
||||||
|
data = {"user": user}
|
||||||
|
return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data)
|
||||||
|
|
||||||
|
def get_result(self, workflow_run_id):
|
||||||
|
return self._send_request("GET", f"/workflows/run/{workflow_run_id}")
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBaseClient(DifyClient):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key,
|
||||||
|
base_url: str = "https://api.dify.ai/v1",
|
||||||
|
dataset_id: str | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Construct a KnowledgeBaseClient object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): API key of Dify.
|
||||||
|
base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'.
|
||||||
|
dataset_id (str, optional): ID of the dataset. Defaults to None. You don't need this if you just want to
|
||||||
|
create a new dataset. or list datasets. otherwise you need to set this.
|
||||||
|
"""
|
||||||
|
super().__init__(api_key=api_key, base_url=base_url)
|
||||||
|
self.dataset_id = dataset_id
|
||||||
|
|
||||||
|
def _get_dataset_id(self):
|
||||||
|
if self.dataset_id is None:
|
||||||
|
raise ValueError("dataset_id is not set")
|
||||||
|
return self.dataset_id
|
||||||
|
|
||||||
|
def create_dataset(self, name: str, **kwargs):
|
||||||
|
return self._send_request("POST", "/datasets", {"name": name}, **kwargs)
|
||||||
|
|
||||||
|
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
|
||||||
|
return self._send_request(
|
||||||
|
"GET", f"/datasets?page={page}&limit={page_size}", **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_document_by_text(
|
||||||
|
self, name, text, extra_params: dict | None = None, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a document by text.
|
||||||
|
|
||||||
|
:param name: Name of the document
|
||||||
|
:param text: Text content of the document
|
||||||
|
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
|
||||||
|
e.g.
|
||||||
|
{
|
||||||
|
'indexing_technique': 'high_quality',
|
||||||
|
'process_rule': {
|
||||||
|
'rules': {
|
||||||
|
'pre_processing_rules': [
|
||||||
|
{'id': 'remove_extra_spaces', 'enabled': True},
|
||||||
|
{'id': 'remove_urls_emails', 'enabled': True}
|
||||||
|
],
|
||||||
|
'segmentation': {
|
||||||
|
'separator': '\n',
|
||||||
|
'max_tokens': 500
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'mode': 'custom'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
:return: Response from the API
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"process_rule": {"mode": "automatic"},
|
||||||
|
"name": name,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
if extra_params is not None and isinstance(extra_params, dict):
|
||||||
|
data.update(extra_params)
|
||||||
|
url = f"/datasets/{self._get_dataset_id()}/document/create_by_text"
|
||||||
|
return self._send_request("POST", url, json=data, **kwargs)
|
||||||
|
|
||||||
|
def update_document_by_text(
|
||||||
|
self, document_id, name, text, extra_params: dict | None = None, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update a document by text.
|
||||||
|
|
||||||
|
:param document_id: ID of the document
|
||||||
|
:param name: Name of the document
|
||||||
|
:param text: Text content of the document
|
||||||
|
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
|
||||||
|
e.g.
|
||||||
|
{
|
||||||
|
'indexing_technique': 'high_quality',
|
||||||
|
'process_rule': {
|
||||||
|
'rules': {
|
||||||
|
'pre_processing_rules': [
|
||||||
|
{'id': 'remove_extra_spaces', 'enabled': True},
|
||||||
|
{'id': 'remove_urls_emails', 'enabled': True}
|
||||||
|
],
|
||||||
|
'segmentation': {
|
||||||
|
'separator': '\n',
|
||||||
|
'max_tokens': 500
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'mode': 'custom'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
:return: Response from the API
|
||||||
|
"""
|
||||||
|
data = {"name": name, "text": text}
|
||||||
|
if extra_params is not None and isinstance(extra_params, dict):
|
||||||
|
data.update(extra_params)
|
||||||
|
url = (
|
||||||
|
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
|
||||||
|
)
|
||||||
|
return self._send_request("POST", url, json=data, **kwargs)
|
||||||
|
|
||||||
|
def create_document_by_file(
|
||||||
|
self, file_path, original_document_id=None, extra_params: dict | None = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a document by file.
|
||||||
|
|
||||||
|
:param file_path: Path to the file
|
||||||
|
:param original_document_id: pass this ID if you want to replace the original document (optional)
|
||||||
|
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
|
||||||
|
e.g.
|
||||||
|
{
|
||||||
|
'indexing_technique': 'high_quality',
|
||||||
|
'process_rule': {
|
||||||
|
'rules': {
|
||||||
|
'pre_processing_rules': [
|
||||||
|
{'id': 'remove_extra_spaces', 'enabled': True},
|
||||||
|
{'id': 'remove_urls_emails', 'enabled': True}
|
||||||
|
],
|
||||||
|
'segmentation': {
|
||||||
|
'separator': '\n',
|
||||||
|
'max_tokens': 500
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'mode': 'custom'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
:return: Response from the API
|
||||||
|
"""
|
||||||
|
files = {"file": open(file_path, "rb")}
|
||||||
|
data = {
|
||||||
|
"process_rule": {"mode": "automatic"},
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
}
|
||||||
|
if extra_params is not None and isinstance(extra_params, dict):
|
||||||
|
data.update(extra_params)
|
||||||
|
if original_document_id is not None:
|
||||||
|
data["original_document_id"] = original_document_id
|
||||||
|
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
|
||||||
|
return self._send_request_with_files(
|
||||||
|
"POST", url, {"data": json.dumps(data)}, files
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_document_by_file(
|
||||||
|
self, document_id, file_path, extra_params: dict | None = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update a document by file.
|
||||||
|
|
||||||
|
:param document_id: ID of the document
|
||||||
|
:param file_path: Path to the file
|
||||||
|
:param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
|
||||||
|
e.g.
|
||||||
|
{
|
||||||
|
'indexing_technique': 'high_quality',
|
||||||
|
'process_rule': {
|
||||||
|
'rules': {
|
||||||
|
'pre_processing_rules': [
|
||||||
|
{'id': 'remove_extra_spaces', 'enabled': True},
|
||||||
|
{'id': 'remove_urls_emails', 'enabled': True}
|
||||||
|
],
|
||||||
|
'segmentation': {
|
||||||
|
'separator': '\n',
|
||||||
|
'max_tokens': 500
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'mode': 'custom'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
files = {"file": open(file_path, "rb")}
|
||||||
|
data = {}
|
||||||
|
if extra_params is not None and isinstance(extra_params, dict):
|
||||||
|
data.update(extra_params)
|
||||||
|
url = (
|
||||||
|
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
|
||||||
|
)
|
||||||
|
return self._send_request_with_files(
|
||||||
|
"POST", url, {"data": json.dumps(data)}, files
|
||||||
|
)
|
||||||
|
|
||||||
|
def batch_indexing_status(self, batch_id: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Get the status of the batch indexing.
|
||||||
|
|
||||||
|
:param batch_id: ID of the batch uploading
|
||||||
|
:return: Response from the API
|
||||||
|
"""
|
||||||
|
url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status"
|
||||||
|
return self._send_request("GET", url, **kwargs)
|
||||||
|
|
||||||
|
def delete_dataset(self):
|
||||||
|
"""
|
||||||
|
Delete this dataset.
|
||||||
|
|
||||||
|
:return: Response from the API
|
||||||
|
"""
|
||||||
|
url = f"/datasets/{self._get_dataset_id()}"
|
||||||
|
return self._send_request("DELETE", url)
|
||||||
|
|
||||||
|
def delete_document(self, document_id):
|
||||||
|
"""
|
||||||
|
Delete a document.
|
||||||
|
|
||||||
|
:param document_id: ID of the document
|
||||||
|
:return: Response from the API
|
||||||
|
"""
|
||||||
|
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}"
|
||||||
|
return self._send_request("DELETE", url)
|
||||||
|
|
||||||
|
def list_documents(
|
||||||
|
self,
|
||||||
|
page: int | None = None,
|
||||||
|
page_size: int | None = None,
|
||||||
|
keyword: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get a list of documents in this dataset.
|
||||||
|
|
||||||
|
:return: Response from the API
|
||||||
|
"""
|
||||||
|
params = {}
|
||||||
|
if page is not None:
|
||||||
|
params["page"] = page
|
||||||
|
if page_size is not None:
|
||||||
|
params["limit"] = page_size
|
||||||
|
if keyword is not None:
|
||||||
|
params["keyword"] = keyword
|
||||||
|
url = f"/datasets/{self._get_dataset_id()}/documents"
|
||||||
|
return self._send_request("GET", url, params=params, **kwargs)
|
||||||
|
|
||||||
|
def add_segments(self, document_id, segments, **kwargs):
|
||||||
|
"""
|
||||||
|
Add segments to a document.
|
||||||
|
|
||||||
|
:param document_id: ID of the document
|
||||||
|
:param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}]
|
||||||
|
:return: Response from the API
|
||||||
|
"""
|
||||||
|
data = {"segments": segments}
|
||||||
|
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
|
||||||
|
return self._send_request("POST", url, json=data, **kwargs)
|
||||||
|
|
||||||
|
def query_segments(
|
||||||
|
self,
|
||||||
|
document_id,
|
||||||
|
keyword: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Query segments in this document.
|
||||||
|
|
||||||
|
:param document_id: ID of the document
|
||||||
|
:param keyword: query keyword, optional
|
||||||
|
:param status: status of the segment, optional, e.g. completed
|
||||||
|
"""
|
||||||
|
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
|
||||||
|
params = {}
|
||||||
|
if keyword is not None:
|
||||||
|
params["keyword"] = keyword
|
||||||
|
if status is not None:
|
||||||
|
params["status"] = status
|
||||||
|
if "params" in kwargs:
|
||||||
|
params.update(kwargs["params"])
|
||||||
|
return self._send_request("GET", url, params=params, **kwargs)
|
||||||
|
|
||||||
|
def delete_document_segment(self, document_id, segment_id):
|
||||||
|
"""
|
||||||
|
Delete a segment from a document.
|
||||||
|
|
||||||
|
:param document_id: ID of the document
|
||||||
|
:param segment_id: ID of the segment
|
||||||
|
:return: Response from the API
|
||||||
|
"""
|
||||||
|
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
|
||||||
|
return self._send_request("DELETE", url)
|
||||||
|
|
||||||
|
def update_document_segment(self, document_id, segment_id, segment_data, **kwargs):
|
||||||
|
"""
|
||||||
|
Update a segment in a document.
|
||||||
|
|
||||||
|
:param document_id: ID of the document
|
||||||
|
:param segment_id: ID of the segment
|
||||||
|
:param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True}
|
||||||
|
:return: Response from the API
|
||||||
|
"""
|
||||||
|
data = {"segment": segment_data}
|
||||||
|
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
|
||||||
|
return self._send_request("POST", url, json=data, **kwargs)
|
||||||
@@ -0,0 +1,215 @@
|
|||||||
|
import psycopg2
|
||||||
|
from psycopg2 import sql
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from datetime import timezone, timedelta
|
||||||
|
|
||||||
|
class PgSql:
|
||||||
|
"""
|
||||||
|
用于连接和操作 PostgreSQL 数据库的类。
|
||||||
|
|
||||||
|
该类封装了数据库连接、关闭连接以及执行特定查询的方法,
|
||||||
|
主要用于从 Dify 应用相关的表中获取数据。
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
初始化 PgSql 实例并建立数据库连接。
|
||||||
|
"""
|
||||||
|
self.connection = None
|
||||||
|
self.connect_sql()
|
||||||
|
|
||||||
|
def connect_sql(self):
|
||||||
|
"""
|
||||||
|
连接到 PostgreSQL 数据库。
|
||||||
|
|
||||||
|
使用预定义的凭据连接到 'dify' 数据库。
|
||||||
|
如果连接失败,会打印错误信息。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 连接数据库
|
||||||
|
self.connection = psycopg2.connect(
|
||||||
|
user="postgres",
|
||||||
|
password="difyai123456",
|
||||||
|
host="172.20.0.145",
|
||||||
|
port=5432,
|
||||||
|
database="dify"
|
||||||
|
)
|
||||||
|
|
||||||
|
except (Exception, psycopg2.Error) as error:
|
||||||
|
print("Error while connecting to PostgreSQL", error)
|
||||||
|
|
||||||
|
def close_connection(self):
|
||||||
|
"""
|
||||||
|
关闭当前的 PostgreSQL 数据库连接。
|
||||||
|
|
||||||
|
如果存在活动的连接,则关闭它并打印确认信息。
|
||||||
|
"""
|
||||||
|
if self.connection:
|
||||||
|
self.connection.close()
|
||||||
|
print("PostgreSQL connection is closed")
|
||||||
|
|
||||||
|
|
||||||
|
def get_appinfo(self, appid:str)->dict | None:
|
||||||
|
"""
|
||||||
|
根据应用 ID 从 'apps' 表中获取应用信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
appid: 目标应用的 ID。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个字典,其中键是列名,值是对应的应用数据。
|
||||||
|
如果未找到应用或发生错误,则返回 None。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.connection.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM apps WHERE id = %s
|
||||||
|
""",
|
||||||
|
(appid,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
colnames = [desc[0] for desc in cursor.description]
|
||||||
|
return dict(zip(colnames, result))
|
||||||
|
return None
|
||||||
|
except (Exception, psycopg2.Error) as error:
|
||||||
|
print("Error while getting tenant_id by appid", error)
|
||||||
|
|
||||||
|
|
||||||
|
def get_messages_info(self, appid:str, query:str)->dict | None:
|
||||||
|
"""
|
||||||
|
根据应用 ID 和查询内容从 'messages' 表中获取消息信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
appid: 目标应用的 ID。
|
||||||
|
query: 用户查询的具体内容。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个字典,其中键是列名,值是对应的消息数据。
|
||||||
|
如果未找到消息或发生错误,则返回 None。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.connection.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC
|
||||||
|
""",
|
||||||
|
(appid, query)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
colnames = [desc[0] for desc in cursor.description]
|
||||||
|
return dict(zip(colnames, result))
|
||||||
|
return None
|
||||||
|
except (Exception, psycopg2.Error) as error:
|
||||||
|
print("Error while getting messages_info", error)
|
||||||
|
|
||||||
|
def get_messages_info_by_id(self, message_id:str)->dict | None:
|
||||||
|
"""
|
||||||
|
根据消息 ID 从 'messages' 表中获取消息信息。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.connection.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM messages WHERE id = %s
|
||||||
|
""",
|
||||||
|
(message_id, )
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
colnames = [desc[0] for desc in cursor.description]
|
||||||
|
return dict(zip(colnames, result))
|
||||||
|
return None
|
||||||
|
except (Exception, psycopg2.Error) as error:
|
||||||
|
print("Error while getting messages_info", error)
|
||||||
|
|
||||||
|
def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None:
|
||||||
|
"""
|
||||||
|
根据工作流运行 ID 从 'workflow_node_executions' 表中获取节点执行信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: 目标工作流运行的 ID。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个字典,其中键是列名,值是对应的节点执行数据。
|
||||||
|
如果未找到执行信息或发生错误,则返回 None。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.connection.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s
|
||||||
|
""",
|
||||||
|
(workflow_run_id,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchall()
|
||||||
|
if result:
|
||||||
|
colnames = [desc[0] for desc in cursor.description]
|
||||||
|
return [dict(zip(colnames, row)) for row in result]
|
||||||
|
return None
|
||||||
|
except (Exception, psycopg2.Error) as error:
|
||||||
|
print("Error while getting workflow_node_executions_info", error)
|
||||||
|
|
||||||
|
class DifyTool:
|
||||||
|
"""
|
||||||
|
提供用于获取 Dify 应用调试信息的工具类。
|
||||||
|
|
||||||
|
该类利用 PgSql 类从数据库中检索与特定应用和查询相关的
|
||||||
|
应用信息、消息详情以及工作流节点执行情况。
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def get_message_debug_info_id(message_id:str)->dict | None:
|
||||||
|
"""
|
||||||
|
根据消息 ID 从 'messages' 表中获取消息信息。
|
||||||
|
"""
|
||||||
|
dify_pgsql = PgSql()
|
||||||
|
messages_info = dify_pgsql.get_messages_info_by_id(message_id)
|
||||||
|
if not messages_info:
|
||||||
|
return None
|
||||||
|
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||||
|
if not workflow_node_executions_info:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"messages_info": messages_info,
|
||||||
|
"workflow_node_executions_info": workflow_node_executions_info
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_message_debug_info(appid:str, query:str)->dict:
|
||||||
|
"""
|
||||||
|
获取指定应用和查询相关的调试信息。
|
||||||
|
|
||||||
|
此静态方法会创建一个临时的 PgSql 实例来查询数据库,
|
||||||
|
然后聚合应用信息、消息信息和工作流节点执行信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
appid: 目标应用的 ID。
|
||||||
|
query: 用户查询的具体内容。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个包含 "appinfo", "messages_info", 和
|
||||||
|
"workflow_node_executions_info"键的字典,分别对应
|
||||||
|
查询到的应用数据、消息数据和节点执行数据。
|
||||||
|
"""
|
||||||
|
dify_pgsql = PgSql()
|
||||||
|
appinfo = dify_pgsql.get_appinfo(appid)
|
||||||
|
if not appinfo:
|
||||||
|
return None
|
||||||
|
messages_info = dify_pgsql.get_messages_info(appid, query)
|
||||||
|
if not messages_info:
|
||||||
|
return None
|
||||||
|
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||||
|
if not workflow_node_executions_info:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"appinfo": appinfo,
|
||||||
|
"messages_info": messages_info,
|
||||||
|
"workflow_node_executions_info": workflow_node_executions_info
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(DifyTool.get_message_debug_info("ccf92b97-2789-4a3f-90e0-135a869a37c5", "电力建设计价通软件,导入结算后没有暂列金怎么办?要手动添加么?"))
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
from flask import Flask, request, Response
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from rag2_0.intent_recognition import IntentRecognizer
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
# 初始化意图识别器
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
base_url = os.getenv("OPENAI_API_BASE")
|
||||||
|
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||||
|
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||||
|
|
||||||
|
@app.route('/intent_recognize', methods=['POST'])
|
||||||
|
def intent_recognize():
|
||||||
|
try:
|
||||||
|
data = request.get_json(force=True)
|
||||||
|
query = data.get('query')
|
||||||
|
if not query:
|
||||||
|
return Response(json.dumps({"error": "缺少query参数"}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=400)
|
||||||
|
start_time = time.time()
|
||||||
|
classification, keywords, rewrite, query_keys = recognizer.process_query(query)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"意图识别耗时: {end_time - start_time:.2f}秒")
|
||||||
|
# keywords对象转为字符串
|
||||||
|
keywords_str = ""
|
||||||
|
if keywords and keywords.terms:
|
||||||
|
term_details = []
|
||||||
|
for term in keywords.terms:
|
||||||
|
term_info = {
|
||||||
|
"名称": term.name,
|
||||||
|
"同义词": ";".join(term.synonymous) if term.synonymous else "",
|
||||||
|
"描述": term.description
|
||||||
|
}
|
||||||
|
term_details.append(term_info)
|
||||||
|
keywords_str = term_details
|
||||||
|
result = {
|
||||||
|
"source_query": query,
|
||||||
|
"source_query_keys": query_keys,
|
||||||
|
"vertical_classification": classification.vertical_classification,
|
||||||
|
"sub_classification": classification.sub_classification,
|
||||||
|
"rewrite_query": rewrite.rewrite,
|
||||||
|
"keywords": keywords_str
|
||||||
|
}
|
||||||
|
return Response(json.dumps(result, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(host="0.0.0.0", port=8001)
|
||||||
@@ -0,0 +1,136 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
from rag2_0.dify.dify_client import ChatClient, DifyClient
|
||||||
|
import pandas as pd
|
||||||
|
# 使用线程池并发执行
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from tqdm import tqdm
|
||||||
|
from rag2_0.dify.dify_tool import DifyTool
|
||||||
|
import json
|
||||||
|
|
||||||
|
class DifyComparisonTester:
|
||||||
|
"""
|
||||||
|
Dify新旧流程对比测试类,用于比较两个不同流程的问答效果
|
||||||
|
"""
|
||||||
|
def __init__(self, excel_path:str, baseurl:str, old_workflow_api_key:str, new_workflow_api_key:str):
|
||||||
|
"""
|
||||||
|
初始化对比测试器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
excel_path: 包含问题的Excel文件路径
|
||||||
|
baseurl: Dify API的基础URL
|
||||||
|
old_workflow_api_key: 旧流程的API密钥
|
||||||
|
new_workflow_api_key: 新流程的API密钥
|
||||||
|
"""
|
||||||
|
self.excel_path = excel_path
|
||||||
|
self.baseurl = baseurl
|
||||||
|
self.old_workflow_api_key = old_workflow_api_key
|
||||||
|
self.new_workflow_api_key = new_workflow_api_key
|
||||||
|
self.old_chat = ChatClient(api_key=old_workflow_api_key, base_url=baseurl)
|
||||||
|
self.new_chat = ChatClient(api_key=new_workflow_api_key, base_url=baseurl)
|
||||||
|
|
||||||
|
def process_question(self, q:str):
|
||||||
|
"""
|
||||||
|
处理单个问题,并行获取新旧流程的回答
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: 问题内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含问题和两个流程回答的字典
|
||||||
|
"""
|
||||||
|
q="qwqwwq"
|
||||||
|
def get_old_answer():
|
||||||
|
try:
|
||||||
|
return self.old_chat.create_chat_message(inputs={}, query=q, user="AutoTestDifyChat").json()
|
||||||
|
except Exception as e:
|
||||||
|
return f"error: {str(e)}"
|
||||||
|
|
||||||
|
def get_new_answer():
|
||||||
|
try:
|
||||||
|
return self.new_chat.create_chat_message(inputs={}, query=q, user="AutoTestDifyChat").json()
|
||||||
|
except Exception as e:
|
||||||
|
return f"error: {str(e)}"
|
||||||
|
|
||||||
|
# 并行执行old_chat和new_chat
|
||||||
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
future_old = executor.submit(get_old_answer)
|
||||||
|
future_new = executor.submit(get_new_answer)
|
||||||
|
|
||||||
|
old_result = future_old.result()
|
||||||
|
new_result = future_new.result()
|
||||||
|
old_message_id = old_result["message_id"]
|
||||||
|
new_message_id = new_result["message_id"]
|
||||||
|
old_message_info = DifyTool.get_message_debug_info_id(message_id=old_message_id)
|
||||||
|
new_message_info = DifyTool.get_message_debug_info_id(message_id=new_message_id)
|
||||||
|
for workflow_node in new_message_info["workflow_node_executions_info"]:
|
||||||
|
if workflow_node["title"] == "问题优化结果解析":
|
||||||
|
outputs = json.loads(workflow_node["outputs"])
|
||||||
|
rewrite_query = outputs["optimize_query"]
|
||||||
|
old_answer = old_result["answer"]
|
||||||
|
new_answer = new_result["answer"]
|
||||||
|
|
||||||
|
return {"问题": q, "问题改写": rewrite_query, "旧流程答案": old_answer, "新流程答案": new_answer}
|
||||||
|
|
||||||
|
def run_comparison(self):
|
||||||
|
"""
|
||||||
|
运行对比测试,处理所有问题并生成结果Excel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 输出Excel文件的路径
|
||||||
|
"""
|
||||||
|
# 读取Excel文件中的问题
|
||||||
|
df = pd.read_excel(self.excel_path)
|
||||||
|
questions = df.iloc[:,0].tolist()
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# 按顺序处理问题
|
||||||
|
with tqdm(total=len(questions), desc="处理问题进度") as pbar:
|
||||||
|
for q in questions:
|
||||||
|
result = self.process_question(q)
|
||||||
|
results.append(result)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# 生成输出Excel文件
|
||||||
|
out_path = os.path.join(os.path.dirname(self.excel_path), "dify问答_对比结果.xlsx")
|
||||||
|
df_results = pd.DataFrame(results)
|
||||||
|
|
||||||
|
# 使用ExcelWriter设置格式
|
||||||
|
with pd.ExcelWriter(out_path, engine='xlsxwriter') as writer:
|
||||||
|
df_results.to_excel(writer, index=False, sheet_name='Sheet1')
|
||||||
|
|
||||||
|
# 获取工作簿和工作表对象
|
||||||
|
workbook = writer.book
|
||||||
|
worksheet = writer.sheets['Sheet1']
|
||||||
|
|
||||||
|
# 设置列宽
|
||||||
|
worksheet.set_column('A:A', 50) # 问题列宽 50个Excel单位
|
||||||
|
worksheet.set_column('B:B', 70) # 旧流程答案列宽 70个Excel单位
|
||||||
|
worksheet.set_column('C:C', 70) # 新流程答案列宽 70个Excel单位
|
||||||
|
|
||||||
|
return out_path
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 定义Excel路径
|
||||||
|
excel_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", ".." ,"data/excel/历史提问数据(dislike)_1000条_软件明确.xlsx")
|
||||||
|
|
||||||
|
if not os.path.exists(excel_path):
|
||||||
|
print(f"错误:Excel文件不存在: {excel_path}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Dify API配置
|
||||||
|
baseurl = "http://172.20.0.145/v1"
|
||||||
|
old_workflow_api_key = "app-wUdkWJx5zeOvmvBUZizMoSw3"
|
||||||
|
new_workflow_api_key = "app-Lf1pQ1NVwdMfCRVNTBCOTPHT"
|
||||||
|
|
||||||
|
# 创建测试器并运行
|
||||||
|
tester = DifyComparisonTester(excel_path, baseurl, old_workflow_api_key, new_workflow_api_key)
|
||||||
|
output_file = tester.run_comparison()
|
||||||
|
print(f"对比结果已保存至: {output_file}")
|
||||||
|
|
||||||
|
# 单个问题测试示例
|
||||||
|
# c = DifyChat(baseurl="http://172.20.0.145/v1", api_key="app-LjJaeLoAfqa6aoGzqU9UvxSf")
|
||||||
|
# c.chat("如何新建配电线路工程")
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
File: DataModels.py
|
||||||
|
Author: oyyz
|
||||||
|
Date: 2025-05-13
|
||||||
|
Description: 提取和分类的数据模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
# 定义输出模型
|
||||||
|
class Term(BaseModel):
|
||||||
|
name: str = Field(description="专业名词")
|
||||||
|
synonymous: List[str] = Field(description="同义词列表")
|
||||||
|
description: str = Field(description="描述信息", default="")
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.name)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, Term):
|
||||||
|
return self.name == other.name
|
||||||
|
return False
|
||||||
|
|
||||||
|
class TermList(BaseModel):
|
||||||
|
terms: List[Term] = Field(description="专业名词列表")
|
||||||
|
|
||||||
|
class Classification(BaseModel):
|
||||||
|
vertical_classification:str = Field(description="垂直领域一级分类")
|
||||||
|
sub_classification:str = Field(description="一级分类下的二级分类")
|
||||||
|
|
||||||
|
class QueryRewrite(BaseModel):
|
||||||
|
rewrite:str = Field(description="问题改写")
|
||||||
@@ -0,0 +1,289 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
File: IntentRecognition.py
|
||||||
|
Author: oyyz
|
||||||
|
Date: 2025-05-13
|
||||||
|
Description: 意图分类、改写核心逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain.output_parsers import PydanticOutputParser
|
||||||
|
import json
|
||||||
|
from typing import List, Tuple
|
||||||
|
import re
|
||||||
|
from .PromptTemplates import classification_prompt, query_rewrite_prompt, extract_nouns_prompt, classification_info
|
||||||
|
from .DataModels import Classification, QueryRewrite, Term, TermList
|
||||||
|
from .ProfessionalNounVector import ProfessionalNounRetriever
|
||||||
|
from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM
|
||||||
|
|
||||||
|
|
||||||
|
class IntentRecognizer:
|
||||||
|
"""
|
||||||
|
意图识别和问题改写类
|
||||||
|
"""
|
||||||
|
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo", vector_index_dir: str = None):
|
||||||
|
"""
|
||||||
|
初始化意图识别器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenAI API密钥,如果为None则从环境变量获取
|
||||||
|
base_url: OpenAI API基础URL,如果为None则使用默认URL
|
||||||
|
model_name: 要使用的模型名称
|
||||||
|
vector_index_dir: 向量索引目录,如果为None则使用默认目录
|
||||||
|
"""
|
||||||
|
# 初始化LLM
|
||||||
|
llm_params = {
|
||||||
|
"temperature": 0.2, # 降低随机性,使结果更确定
|
||||||
|
"model": model_name
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果提供了API密钥,则使用提供的密钥
|
||||||
|
if api_key:
|
||||||
|
llm_params["api_key"] = api_key
|
||||||
|
|
||||||
|
# 如果提供了自定义URL,则使用提供的URL
|
||||||
|
if base_url:
|
||||||
|
llm_params["base_url"] = base_url
|
||||||
|
|
||||||
|
self.llm = OpenAiLLM(**llm_params)
|
||||||
|
|
||||||
|
# 准备分类解析器
|
||||||
|
self.classification_parser = PydanticOutputParser(pydantic_object=Classification)
|
||||||
|
|
||||||
|
# 准备问题改写解析器
|
||||||
|
self.query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
|
||||||
|
|
||||||
|
# 准备术语列表解析器
|
||||||
|
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
||||||
|
|
||||||
|
# 加载suffix关键词
|
||||||
|
self.suffix_keywords = self._load_suffix_keywords()
|
||||||
|
|
||||||
|
# 初始化向量检索器
|
||||||
|
self.noun_retriever = ProfessionalNounRetriever(api_key=api_key, index_dir=vector_index_dir)
|
||||||
|
|
||||||
|
def _load_suffix_keywords(self, filepath: str = None) -> List[str]:
|
||||||
|
"""
|
||||||
|
加载后缀关键词列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath: 后缀关键词文件路径,默认为None使用默认路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
后缀关键词列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 如果未指定路径,使用默认路径
|
||||||
|
if filepath is None:
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
filepath = os.path.join(current_dir, "..", "..", "data", "nouns", "suffix_keywords.json")
|
||||||
|
|
||||||
|
# 读取JSON文件
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
suffix_data = json.load(f)
|
||||||
|
|
||||||
|
# 添加额外的固定后缀
|
||||||
|
return suffix_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"加载后缀关键词失败: {e}") from e
|
||||||
|
|
||||||
|
def classify_intent(self, query: str, keywords: TermList) -> Classification:
|
||||||
|
"""
|
||||||
|
对用户输入进行意图分类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 用户输入内容
|
||||||
|
keywords: 匹配到的关键词列表
|
||||||
|
rewrite: 重写的问题
|
||||||
|
Returns:
|
||||||
|
分类结果
|
||||||
|
"""
|
||||||
|
formatted_prompt = classification_prompt.replace("{user_input}", query)
|
||||||
|
formatted_prompt = formatted_prompt.replace("{classification_info}", classification_info)
|
||||||
|
formatted_prompt = formatted_prompt.replace("{output_format}", self.classification_parser.get_format_instructions())
|
||||||
|
# 将关键词列表转换为JSON字符串
|
||||||
|
terms_dict = [term.model_dump() for term in keywords.terms]
|
||||||
|
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
||||||
|
formatted_prompt = formatted_prompt.replace("{keywords}", keywords_str)
|
||||||
|
# 调用LLM
|
||||||
|
response = self.llm.invoke(formatted_prompt, False)
|
||||||
|
|
||||||
|
# 解析输出
|
||||||
|
try:
|
||||||
|
# 尝试直接解析JSON响应
|
||||||
|
parsed_output = self.classification_parser.parse(response.content.strip())
|
||||||
|
return parsed_output
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"解析分类结果时出错: {e}") from e
|
||||||
|
|
||||||
|
def extract_keywords_with_llm(self, query: str) -> List[Term]:
|
||||||
|
"""
|
||||||
|
使用LLM从用户查询中提取专业关键词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户查询
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提取的术语列表
|
||||||
|
"""
|
||||||
|
# 准备提示词
|
||||||
|
formatted_prompt = extract_nouns_prompt.replace("{content}", query)
|
||||||
|
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
|
||||||
|
|
||||||
|
# 调用LLM
|
||||||
|
response = self.llm.invoke(formatted_prompt, False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试使用Pydantic解析器解析TermList
|
||||||
|
parsed_output = self.terms_list_parser.parse(response.content)
|
||||||
|
return parsed_output.terms
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"无法解析LLM关键词提取响应: {e}") from e
|
||||||
|
|
||||||
|
def match_keywords(self, query: str) -> Tuple[TermList, List[str]]:
|
||||||
|
"""
|
||||||
|
从用户问题中匹配关键词,结合LLM提取和向量检索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户问题
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
匹配到的关键词列表
|
||||||
|
"""
|
||||||
|
matched_terms = set() # 存储匹配到的Term对象
|
||||||
|
query_keys=[]
|
||||||
|
# 步骤2: 使用LLM提取查询中的关键词
|
||||||
|
try:
|
||||||
|
extracted_terms = self.extract_keywords_with_llm(query)
|
||||||
|
for term in extracted_terms:
|
||||||
|
query_keys.append(term.name)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"LLM关键词提取失败: {e}") from e
|
||||||
|
|
||||||
|
# 步骤3: 使用向量检索找到相似的专业名词
|
||||||
|
try:
|
||||||
|
# 对matched_terms中的每个关键字进行向量检索
|
||||||
|
for current_key in query_keys:
|
||||||
|
vector_results = self.noun_retriever.query(current_key, top_k=3, use_intersection=True)
|
||||||
|
|
||||||
|
# 添加向量检索结果
|
||||||
|
for result in vector_results:
|
||||||
|
term = Term(
|
||||||
|
name=result.get('name'),
|
||||||
|
synonymous=result.get('synonymous', []),
|
||||||
|
description=result.get('description', '')
|
||||||
|
)
|
||||||
|
matched_terms.add(term)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"向量检索关键词时出错: {e}") from e
|
||||||
|
|
||||||
|
if len(matched_terms) != 0:
|
||||||
|
txts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) + "|" + "描述:" + term.description for term in matched_terms]
|
||||||
|
# txts = [term.name for term in matched_terms]
|
||||||
|
xinference_reranker = XinferenceReRankerModel()
|
||||||
|
rerank_results = xinference_reranker.rerank(query, txts, top_k=5)
|
||||||
|
matched_terms_list = list(matched_terms)
|
||||||
|
matched_terms = [matched_terms_list[result["index"]] for result in rerank_results]
|
||||||
|
# 提取所有Term对象的名称并排序
|
||||||
|
# 将set类型的matched_terms转换为TermList类型
|
||||||
|
term_list = TermList(terms=list(matched_terms))
|
||||||
|
return term_list, query_keys
|
||||||
|
|
||||||
|
def rewrite_query(self, query: str, keywords: TermList) -> QueryRewrite:
|
||||||
|
"""
|
||||||
|
对用户问题进行改写
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户原始问题
|
||||||
|
keywords: 匹配到的关键词列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
改写结果
|
||||||
|
"""
|
||||||
|
# 准备问题改写提示
|
||||||
|
terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
|
||||||
|
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
||||||
|
formatted_prompt = query_rewrite_prompt.format(query=query, output_format=self.query_rewrite_parser.get_format_instructions(),keywords=keywords_str)
|
||||||
|
|
||||||
|
|
||||||
|
# 调用LLM
|
||||||
|
response = self.llm.invoke(formatted_prompt, False)
|
||||||
|
|
||||||
|
# 解析输出
|
||||||
|
try:
|
||||||
|
# 尝试直接解析JSON响应
|
||||||
|
parsed_output = self.query_rewrite_parser.parse(response.content)
|
||||||
|
return parsed_output
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"解析问题改写结果时出错: {e}") from e
|
||||||
|
|
||||||
|
def judge_define_suffix(self, input_str: str) -> Tuple[bool, List[str]]:
|
||||||
|
"""
|
||||||
|
判断输入字符串是否包含定义的后缀,并返回所有匹配到的后缀名列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_str: 输入字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, List[str]]: (是否包含定义的后缀, 匹配到的后缀名列表)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 构建正则表达式模式,匹配大小写不敏感且前面可能带有.
|
||||||
|
pattern = r'(?:\.?)(' + '|'.join(re.escape(field.get('name')) for field in self.suffix_keywords) + r')'
|
||||||
|
|
||||||
|
# 使用 re.IGNORECASE 标志来忽略大小写,findall找到所有匹配
|
||||||
|
matches = re.finditer(pattern, input_str, re.IGNORECASE)
|
||||||
|
matched_suffixes = [match.group(1) for match in matches]
|
||||||
|
|
||||||
|
return bool(matched_suffixes), matched_suffixes
|
||||||
|
|
||||||
|
def process_query(self, query: str) -> Tuple[Classification, TermList, QueryRewrite, List[str]]:
|
||||||
|
"""
|
||||||
|
处理用户问题的完整流程
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户原始问题
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(意图分类结果, 匹配的关键词列表, 问题改写结果)的元组
|
||||||
|
"""
|
||||||
|
# 是否是扩展名
|
||||||
|
# is_suffix, matched_suffixes = self.judge_define_suffix(query)
|
||||||
|
# if is_suffix:
|
||||||
|
# # 将所有匹配到的后缀名作为Term添加到结果中
|
||||||
|
# suffix_terms = []
|
||||||
|
# for suffix in matched_suffixes:
|
||||||
|
# term_dict = next((item for item in self.suffix_keywords if item['name'].lower() == suffix.lower()), None)
|
||||||
|
# if term_dict:
|
||||||
|
# suffix_term = Term(
|
||||||
|
# name=term_dict.get('name'),
|
||||||
|
# synonymous=term_dict.get('synonymous', []),
|
||||||
|
# description=json.dumps(term_dict.get('description', ''), ensure_ascii=False)
|
||||||
|
# )
|
||||||
|
# suffix_terms.append(suffix_term)
|
||||||
|
|
||||||
|
# return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes
|
||||||
|
|
||||||
|
# 步骤1: 匹配关键词
|
||||||
|
keywords_terms, query_keys = self.match_keywords(query)
|
||||||
|
|
||||||
|
# 步骤2: 问题改写
|
||||||
|
rewrite = self.rewrite_query(
|
||||||
|
query=query,
|
||||||
|
keywords=keywords_terms
|
||||||
|
)
|
||||||
|
|
||||||
|
# 步骤3: 进行意图分类
|
||||||
|
classification = self.classify_intent(query, keywords_terms)
|
||||||
|
if classification.vertical_classification == "其他" or classification.sub_classification == "其他":
|
||||||
|
return classification, TermList(terms=[]), QueryRewrite(rewrite=query), []
|
||||||
|
|
||||||
|
if classification.vertical_classification == "闲聊" or classification.sub_classification == "闲聊":
|
||||||
|
return classification, TermList(terms=[]), QueryRewrite(rewrite=query),[]
|
||||||
|
|
||||||
|
# rewrite = QueryRewrite(rewrite=query)
|
||||||
|
return classification, keywords_terms, rewrite, query_keys
|
||||||
@@ -0,0 +1,321 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
File: ProfessionalNounVector.py
|
||||||
|
Date: 2025-05-15
|
||||||
|
Author: oyyz
|
||||||
|
Description: 专业名词向量化和检索的核心逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
from typing import List, Dict, Any, Tuple, Optional
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain_community.vectorstores import FAISS
|
||||||
|
from rag2_0.tool.ModelTool import SiliconFlowEmbeddings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
def get_embedding_model(api_key: str = None) -> Embeddings:
|
||||||
|
"""
|
||||||
|
获取嵌入模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API密钥,如果为None则从环境变量获取
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
嵌入模型实例
|
||||||
|
"""
|
||||||
|
if not api_key:
|
||||||
|
api_key = os.getenv("SILICONFLOW_API_KEY", "sk-ftnofbucchwnscojohyxwmfzgaykdxihafnlphohsinftkbr")
|
||||||
|
return SiliconFlowEmbeddings(api_key=api_key)
|
||||||
|
|
||||||
|
|
||||||
|
class ProfessionalNounVectorizer:
|
||||||
|
"""专业名词向量化和保存类"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
embedding_model: Optional[Embeddings] = None,
|
||||||
|
api_key: str = None,
|
||||||
|
output_dir: str = None):
|
||||||
|
"""
|
||||||
|
初始化向量化器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_model: 嵌入模型,如果为None则使用默认模型
|
||||||
|
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
|
||||||
|
|
||||||
|
output_dir: 索引输出目录,默认为None使用默认路径
|
||||||
|
"""
|
||||||
|
# 设置嵌入模型
|
||||||
|
if embedding_model:
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
else:
|
||||||
|
self.embedding_model = get_embedding_model(api_key)
|
||||||
|
|
||||||
|
|
||||||
|
# 设置输出目录
|
||||||
|
self.output_dir = output_dir
|
||||||
|
if self.output_dir is None:
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
self.output_dir = os.path.join(current_dir, "..", "..", "data", "nouns")
|
||||||
|
|
||||||
|
# 设置索引路径
|
||||||
|
self.index_path = os.path.join(self.output_dir, "professional_nouns_index")
|
||||||
|
|
||||||
|
def _loadfile(self, file_paths: List[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
加载多个专业术语JSON文件并合并
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_paths: JSON文件路径列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
合并后的术语列表
|
||||||
|
"""
|
||||||
|
merged_terms = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for file_path in file_paths:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
logging.warning(f"文件不存在: {file_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
terms_data = json.load(f)
|
||||||
|
|
||||||
|
if isinstance(terms_data, list):
|
||||||
|
merged_terms.extend(terms_data)
|
||||||
|
logging.info(f"从 {file_path} 加载了 {len(terms_data)} 条专业名词")
|
||||||
|
else:
|
||||||
|
logging.warning(f"文件格式错误: {file_path},应为JSON数组")
|
||||||
|
|
||||||
|
logging.info(f"总共加载了 {len(merged_terms)} 条专业名词")
|
||||||
|
return merged_terms
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"加载多个文件失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def vectorize_files_and_save(self, file_paths: List[str]) -> bool:
|
||||||
|
"""
|
||||||
|
处理多个文件:加载多个术语文件、创建索引并保存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_paths: JSON文件路径列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理成功返回True,否则返回False
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 加载多个文件的术语
|
||||||
|
terms = self._loadfile(file_paths)
|
||||||
|
|
||||||
|
if not terms:
|
||||||
|
logging.warning("未找到术语数据,退出处理")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 根据名称去重
|
||||||
|
unique_terms = {}
|
||||||
|
for term in terms:
|
||||||
|
name = term.get("name", "")
|
||||||
|
if name and name not in unique_terms:
|
||||||
|
unique_terms[name] = term
|
||||||
|
|
||||||
|
# 转换回列表
|
||||||
|
deduplicated_terms = list(unique_terms.values())
|
||||||
|
logging.info(f"去重后剩余 {len(deduplicated_terms)} 条专业名词")
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
texts, metadatas = self._prepare_terms_for_faiss(deduplicated_terms)
|
||||||
|
|
||||||
|
# 创建索引
|
||||||
|
faiss_index = self._create_index(texts, metadatas)
|
||||||
|
|
||||||
|
# 保存索引
|
||||||
|
self._save_index(faiss_index)
|
||||||
|
|
||||||
|
logging.info("完成多文件专业名词向量化和索引创建")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"多文件向量化处理失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_terms_for_faiss(self, terms: List[Dict[str, Any]]) -> Tuple[List[str], List[Dict]]:
|
||||||
|
"""
|
||||||
|
将术语准备为FAISS可用的格式 (内部方法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
terms: 术语列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化的术语文本列表和元数据列表
|
||||||
|
"""
|
||||||
|
texts = []
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
|
for term in terms:
|
||||||
|
name = term["name"]
|
||||||
|
texts.append(name.strip())
|
||||||
|
synonyms = term.get("synonymous", [])
|
||||||
|
description = term.get("description", "")
|
||||||
|
# 记录元数据
|
||||||
|
metadatas.append({
|
||||||
|
"name": name,
|
||||||
|
"synonyms": synonyms,
|
||||||
|
"description": description
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(synonyms) > 0:
|
||||||
|
synonyms_str = ', '.join(synonyms)
|
||||||
|
texts.append(synonyms_str.strip())
|
||||||
|
metadatas.append({
|
||||||
|
"name": name,
|
||||||
|
"synonyms": synonyms,
|
||||||
|
"description": description
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(description) > 0:
|
||||||
|
texts.append(description.strip())
|
||||||
|
metadatas.append({
|
||||||
|
"name": name,
|
||||||
|
"synonyms": synonyms,
|
||||||
|
"description": description
|
||||||
|
})
|
||||||
|
|
||||||
|
return texts, metadatas
|
||||||
|
|
||||||
|
def _create_index(self, texts: List[str], metadatas: List[Dict]) -> FAISS:
|
||||||
|
"""
|
||||||
|
创建FAISS索引 (内部方法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
metadatas: 元数据列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FAISS索引
|
||||||
|
"""
|
||||||
|
logging.info(f"正在创建FAISS索引,共 {len(texts)} 条数据...")
|
||||||
|
return FAISS.from_texts(texts=texts, embedding=self.embedding_model, metadatas=metadatas)
|
||||||
|
|
||||||
|
def _save_index(self, faiss_index: FAISS) -> None:
|
||||||
|
"""
|
||||||
|
保存FAISS索引到本地 (内部方法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
faiss_index: 要保存的FAISS索引
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 确保输出目录存在
|
||||||
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 如果索引目录已存在,先删除
|
||||||
|
if os.path.exists(self.index_path):
|
||||||
|
shutil.rmtree(self.index_path)
|
||||||
|
|
||||||
|
# 保存FAISS索引
|
||||||
|
faiss_index.save_local(folder_path=self.index_path)
|
||||||
|
logging.info(f"FAISS索引已保存至 {self.index_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"保存FAISS索引失败: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
class ProfessionalNounRetriever:
|
||||||
|
"""专业名词检索类"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
embedding_model: Optional[Embeddings] = None,
|
||||||
|
api_key: str = None,
|
||||||
|
index_dir: str = None):
|
||||||
|
"""
|
||||||
|
初始化检索器并加载索引
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_model: 嵌入模型,如果为None则使用默认模型
|
||||||
|
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
|
||||||
|
index_dir: 索引目录路径,默认为None使用默认路径
|
||||||
|
"""
|
||||||
|
# 设置嵌入模型
|
||||||
|
if embedding_model:
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
else:
|
||||||
|
self.embedding_model = get_embedding_model(api_key)
|
||||||
|
|
||||||
|
# 设置索引路径
|
||||||
|
self.index_dir = index_dir
|
||||||
|
if self.index_dir is None:
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
self.index_dir = os.path.join(current_dir, "..", "..", "data", "nouns", "professional_nouns_index")
|
||||||
|
|
||||||
|
# 在构造函数中加载索引
|
||||||
|
self.faiss_index = None
|
||||||
|
self._load_index()
|
||||||
|
|
||||||
|
def _load_index(self) -> None:
|
||||||
|
"""
|
||||||
|
从本地加载FAISS索引 (内部方法)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 加载FAISS索引,启用不安全反序列化(仅用于可信数据源)
|
||||||
|
self.faiss_index = FAISS.load_local(
|
||||||
|
folder_path=self.index_dir,
|
||||||
|
embeddings=self.embedding_model,
|
||||||
|
allow_dangerous_deserialization=True
|
||||||
|
)
|
||||||
|
logging.info(f"成功从 {self.index_dir} 加载FAISS索引")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"加载FAISS索引失败: {e}")
|
||||||
|
self.faiss_index = None
|
||||||
|
|
||||||
|
def query(self, query_text: str, top_k: int = 5, use_intersection: bool = True) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
查询FAISS索引,获取最相似的专业名词 (唯一对外接口)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_text: 查询文本
|
||||||
|
top_k: 返回的结果数量,默认为5
|
||||||
|
use_intersection: 是否使用三种检索方式的交集,默认为True
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相似度最高的专业名词列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查索引是否已加载
|
||||||
|
if self.faiss_index is None:
|
||||||
|
logging.warning("FAISS索引未加载,无法执行查询")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 使用三种检索方式并取交集
|
||||||
|
retriever1 = self.faiss_index.as_retriever(search_kwargs={"k": top_k})
|
||||||
|
retriever2 = self.faiss_index.as_retriever(
|
||||||
|
search_type="mmr",
|
||||||
|
search_kwargs={"k": top_k, "fetch_k": 3, "lambda_mult": 0.5}
|
||||||
|
)
|
||||||
|
retriever3 = self.faiss_index.as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={"score_threshold": 0.5}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 用json.dumps将dict转为字符串,便于取交集
|
||||||
|
set1 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
||||||
|
for i in retriever1.invoke(query_text))
|
||||||
|
set2 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
||||||
|
for i in retriever2.invoke(query_text))
|
||||||
|
set3 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
||||||
|
for i in retriever3.invoke(query_text))
|
||||||
|
|
||||||
|
intersection = set1 | set2 | set3
|
||||||
|
|
||||||
|
# 如果交集为空,使用第一种检索方式的结果
|
||||||
|
if not intersection:
|
||||||
|
logging.warning("三种检索方式无交集,使用普通检索结果")
|
||||||
|
return [json.loads(item) for item in set1]
|
||||||
|
|
||||||
|
# 转回dict
|
||||||
|
return [json.loads(item) for item in intersection]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"查询FAISS索引失败: {e}")
|
||||||
|
return []
|
||||||
@@ -0,0 +1,130 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
File: PromptTemplates.py
|
||||||
|
Author: oyyz
|
||||||
|
Date: 2025-05-13
|
||||||
|
Description: 提示词模板
|
||||||
|
"""
|
||||||
|
|
||||||
|
extract_nouns_prompt="""
|
||||||
|
【智能关键词提取助手】
|
||||||
|
请根据用户问题自动识别核心关键词,并按照以下规则输出:
|
||||||
|
1. 只输出最终关键词列表,不要解释说明
|
||||||
|
2. 关键词提取范围包括但不限于以下内容:
|
||||||
|
- 软件相关:功能模块/操作步骤/报错提示/扩展名后缀名
|
||||||
|
- 造价专业:费用类型/计算标准/行业规范
|
||||||
|
- 电力工程:项目类型/设备型号/工程阶段
|
||||||
|
3. 自动展开缩写(如将'导excel'转为'Excel导入')
|
||||||
|
4. 严格基于用户问题提取关键词,不要输出与用户问题无关的关键词
|
||||||
|
|
||||||
|
三、输出格式:
|
||||||
|
{output_format}
|
||||||
|
|
||||||
|
四、用户问题:
|
||||||
|
{content}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
classification_info="""【垂直领域分类】:
|
||||||
|
1. 软件问题 -- 指涉及软件使用、功能询问、软件故障排查等方面的提问或请求。
|
||||||
|
2. 业务问题 -- 指涉及电力造价领域专业知识、造价费用计算等电力造价业务知识
|
||||||
|
3. 安装下载注册 -- 指涉及软件(或插件)安装下载、注册、激活等操作类问题。
|
||||||
|
4. 其他 -- 指与软件或电力造价专业无关的日常对话、问候、感慨、情绪表达等。
|
||||||
|
|
||||||
|
【软件问题包括以下两类】:
|
||||||
|
1. 软件功能:询问软件功能的使用、操作、位置等
|
||||||
|
2. 故障排查:软件运行异常、软件报错、软件显示错误等
|
||||||
|
|
||||||
|
【业务问题包括以下两类】:
|
||||||
|
1. 专业咨询:涉及电力造价规范、工程计价规则问题、行业标准解读等
|
||||||
|
2. 数据问题:涉及电力造价费用、造价指标等
|
||||||
|
|
||||||
|
【安装下载注册包括以下三类】:
|
||||||
|
1. 后缀名查询:询问有关软件后缀名、工程文件扩展名等问题,例如:BDY3是什么文件?、用什么软件打开.BDY3文件?
|
||||||
|
2. 软件锁类:询问软件锁信息、锁注册号查询、许可证查询、锁激活问题等软件锁相关问题
|
||||||
|
3. 安装下载类:安装下载咨询、组件(插件)选择、环境配置等
|
||||||
|
4. 问题排查类:软件安装下载失败、报错,系统兼容性问题等
|
||||||
|
|
||||||
|
【其他】:
|
||||||
|
1. 其他"""
|
||||||
|
|
||||||
|
classification_prompt="""
|
||||||
|
用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容,将其归类为以下垂直领域之一:
|
||||||
|
{classification_info}
|
||||||
|
|
||||||
|
【用户输入】:
|
||||||
|
{user_input}
|
||||||
|
|
||||||
|
【输出格式要求】:
|
||||||
|
{output_format}
|
||||||
|
|
||||||
|
【示例】
|
||||||
|
用户输入1: 技改T1怎样新建工程
|
||||||
|
输出1:
|
||||||
|
{
|
||||||
|
"vertical_classification":"软件咨询",
|
||||||
|
"sub_classification":"软件功能"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
query_rewrite_prompt = """
|
||||||
|
|
||||||
|
你是一名电力造价专业问答优化工程师,负责通过多维度信息整合重构用户问题以提升知识库检索准确率。请严格遵循以下流程处理:
|
||||||
|
|
||||||
|
# 任务处理框架
|
||||||
|
## 第一阶段:输入分析
|
||||||
|
1. 解析基础信息
|
||||||
|
- 原始问题(需保留核心语义):{query}
|
||||||
|
- 关键词集合:{keywords}
|
||||||
|
|
||||||
|
## 第二阶段:语义匹配验证
|
||||||
|
2. 执行关键词校验
|
||||||
|
- 建立意图关联矩阵,验证关键词与原始问题的语义一致性
|
||||||
|
- 若存在≥1个有效关联词 → 进入重构流程
|
||||||
|
- 若无有效关联 → 直接输出原始问题
|
||||||
|
|
||||||
|
## 第三阶段:专业重构
|
||||||
|
3. 术语规范化处理
|
||||||
|
a. 实施术语映射:将口语表达替换为知识库标准术语
|
||||||
|
b. 执行结构优化:
|
||||||
|
- 采用【术语标记】规范标注关键概念
|
||||||
|
- 构建主谓宾明确的问题句式
|
||||||
|
- 保持原问题时态与语态特征
|
||||||
|
|
||||||
|
# 输出规范
|
||||||
|
{output_format}
|
||||||
|
|
||||||
|
# 示范案例库
|
||||||
|
▶ 案例1(有效匹配)
|
||||||
|
输入:
|
||||||
|
原始问题:怎么把旧版西藏定额工程转到Z1新版
|
||||||
|
关键词:【'老版本定额升级', '批量设置定额', '西藏造价软件Z1'】
|
||||||
|
输出:
|
||||||
|
{{"rewrite":"【西藏造价软件Z1】如何执行【老版本定额升级】操作?"}}
|
||||||
|
|
||||||
|
▶ 案例2(无效匹配)
|
||||||
|
输入:
|
||||||
|
原始问题:程序界面文字显示过小如何处理?
|
||||||
|
关键词:【'定额升级', '工程批量导入'】
|
||||||
|
输出:
|
||||||
|
{{"rewrite":"程序界面文字显示过小如何处理?"}}
|
||||||
|
|
||||||
|
# 质量约束条款
|
||||||
|
1. 语义内容保真原则
|
||||||
|
- 禁止修改原问题核心诉求(如转换主语/变更操作对象)
|
||||||
|
- 保留原始问题的限定条件
|
||||||
|
|
||||||
|
2. 术语使用规范
|
||||||
|
- 仅使用检索返回的关键词进行术语替换
|
||||||
|
- 新增术语必须来自关键词集合
|
||||||
|
|
||||||
|
3. 结构优化标准
|
||||||
|
- 问题长度控制在20字内
|
||||||
|
- 必须包含≥1个【标注术语】
|
||||||
|
- 禁止添加解释性语句
|
||||||
|
|
||||||
|
4. 异常处理机制
|
||||||
|
- 当关键词与问题无明显关联时,触发直通输出规则
|
||||||
|
- 出现术语冲突时优先保留原始表述
|
||||||
|
"""
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from .ProfessionalNounVector import ProfessionalNounVectorizer, ProfessionalNounRetriever
|
||||||
|
from .IntentRecognition import IntentRecognizer
|
||||||
|
from .DataModels import Term, TermList, Classification, QueryRewrite
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,256 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Dict
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
API_KEY_LIST=[
|
||||||
|
"sk-xxaiabmfhzwwpijuledllkmkzhzwsqeicjxmjwnvriqpwmpk",
|
||||||
|
"sk-lldcprpqjhgdimiwewgbthngfbrazhkiuioubmaatrcpjjum",
|
||||||
|
"sk-bppugibbtvujomvoysnbcdzpcwndxtwrkfvmgbkbzcmobdon",
|
||||||
|
"sk-hnqitgdlfrrnpimcfxigqibstqquintnzpiidsshpajjyxqd",
|
||||||
|
"sk-hrojkkkrrkmsajtnizokbcgexsfggdiqavbtvbayuwqbnmom",
|
||||||
|
"sk-kkdklmnyompoiotzkfqahpayzlkgogfudjkyaebehtsowvid",
|
||||||
|
"sk-sfxzvllifafbyfduupcdtcrjwhdyiyojnksyopnfslurnhsp",
|
||||||
|
"sk-faqirxiszukfswqvzqawxnemqfacrkyurbxxkzwbbujqacdp",
|
||||||
|
"sk-vonaanuueqiczppkntjuphateshrcpqpnvxmwxorkyihjmrb",
|
||||||
|
"sk-qfpeoodgupcukcdstjcxgegwxnuhtxkkrupkogkcvhavxgny",
|
||||||
|
"sk-fsvjnbpfgoadixympaabaukupuhjvbturcbxaqfdzjznemtr",
|
||||||
|
"sk-fltvnbiqntfawjwkfnnhmyfiimzgzxkweqmefcfqkbucwrhi",
|
||||||
|
"sk-oosswdriwyqkglwdigvcxgmcpyplcyowicbaugpizoscevdl",
|
||||||
|
"sk-jswtxhkiralnyiukqimtyuurcaepulxdrfijadtxzrgsajyc",
|
||||||
|
"sk-dcjuhoukdyrbneadtxtnyxzmigkpiqgtqqnreiprxpioftsv",
|
||||||
|
"sk-yrhezyuxjblpaxzzudbowqmvcoxcammupcubghbodolikbdk",
|
||||||
|
"sk-dsgvwpfagmarilmnewwbzhfzlqehburoupjaopucdvybpbdo",
|
||||||
|
"sk-oljjlspuaurtoczyekztiidwtoerugadgepiufclpmrbdfqc",
|
||||||
|
"sk-crgrimubjesthvxuqwedqqdoetljyrgeahxxpctfefgnkpyo",
|
||||||
|
"sk-tubqhwgycxrdhwsqzjopxgeaqpsjdfppckckayvzornaluwq",
|
||||||
|
"sk-amcxlmsdnadptpnehqnkvseolacipztmvovnmxojzohbjjil",
|
||||||
|
"sk-pdyymhshpzmdduwxsezthnrgarnnhgzvmiflbpisfzxkiayt",
|
||||||
|
"sk-qhwoorywmejumyudfxbrkegxtqifsbgcdkmpjckezepgyqnz",
|
||||||
|
"sk-cpoctrgcnstaybeyuieuwjdgeakudhqdnnwdjavjudcbvvem",
|
||||||
|
]
|
||||||
|
|
||||||
|
class APIKeyManager:
|
||||||
|
"""
|
||||||
|
API密钥管理器,用于解析环境变量中的多个API密钥并提供获取接口
|
||||||
|
支持密钥轮转使用
|
||||||
|
"""
|
||||||
|
# 类变量,用于保存单例实例
|
||||||
|
_instance = None
|
||||||
|
_lock = Lock()
|
||||||
|
|
||||||
|
# 密钥使用计数和上次使用时间
|
||||||
|
_key_usage: Dict[str, Dict] = {}
|
||||||
|
# 当前正在使用的密钥索引
|
||||||
|
_current_index = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls, env_var_name: str = "OPENAI_API_KEY", separator: str = ";"):
|
||||||
|
"""
|
||||||
|
获取单例实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_var_name: 环境变量名称,默认为'OPENAI_API_KEY'
|
||||||
|
separator: 密钥分隔符,默认为分号
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
APIKeyManager实例
|
||||||
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls(env_var_name, separator)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_api_key(cls) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
静态方法:获取一个API密钥,使用轮转策略
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API密钥,如果没有可用的密钥则返回None
|
||||||
|
"""
|
||||||
|
instance = cls.get_instance()
|
||||||
|
return instance._get_next_api_key()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_random_api_key(cls) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
静态方法:随机获取一个API密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API密钥,如果没有可用的密钥则返回None
|
||||||
|
"""
|
||||||
|
instance = cls.get_instance()
|
||||||
|
return instance._get_random_api_key()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_valid_api_keys(cls) -> List[str]:
|
||||||
|
"""
|
||||||
|
静态方法:获取有效的API密钥列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
# 验证每一个apikey是否有效,无效则删除并打印日志。地址https://api.siliconflow.cn/v1/
|
||||||
|
import requests
|
||||||
|
import logging
|
||||||
|
|
||||||
|
valid_api_keys = []
|
||||||
|
url = "https://api.siliconflow.cn/v1/chat/completions"
|
||||||
|
headers_template = {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"model": "deepseek-ai/DeepSeek-V3",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "ping"}
|
||||||
|
],
|
||||||
|
"max_tokens": 1
|
||||||
|
}
|
||||||
|
for key in API_KEY_LIST:
|
||||||
|
headers = headers_template.copy()
|
||||||
|
headers["Authorization"] = f"Bearer {key}"
|
||||||
|
try:
|
||||||
|
resp = requests.post(url, headers=headers, json=data, timeout=8)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
valid_api_keys.append(key)
|
||||||
|
else:
|
||||||
|
logging.warning(f"API密钥无效(被移除): {key}, 状态码: {resp.status_code}, 响应: {resp.text}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"API密钥验证异常(被移除): {key}, 错误: {e}")
|
||||||
|
return valid_api_keys
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def count(cls) -> int:
|
||||||
|
"""
|
||||||
|
静态方法:获取API密钥数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API密钥数量
|
||||||
|
"""
|
||||||
|
instance = cls.get_instance()
|
||||||
|
return len(instance.api_keys)
|
||||||
|
|
||||||
|
def __init__(self, env_var_name: str = "OPENAI_API_KEY", separator: str = ";"):
|
||||||
|
"""
|
||||||
|
初始化API密钥管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_var_name: 环境变量名称,默认为'OPENAI_API_KEY'
|
||||||
|
separator: 密钥分隔符,默认为分号
|
||||||
|
"""
|
||||||
|
self.env_var_name = env_var_name
|
||||||
|
self.separator = separator
|
||||||
|
self.api_keys = self._load_api_keys()
|
||||||
|
|
||||||
|
# 初始化密钥使用统计
|
||||||
|
for key in self.api_keys:
|
||||||
|
if key not in self._key_usage:
|
||||||
|
self._key_usage[key] = {
|
||||||
|
"count": 0,
|
||||||
|
"last_used": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def _load_api_keys(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
从环境变量加载API密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API密钥列表
|
||||||
|
"""
|
||||||
|
# api_keys = []
|
||||||
|
# env_value = os.environ.get(self.env_var_name)
|
||||||
|
|
||||||
|
# if env_value:
|
||||||
|
# # 分割环境变量并移除空白字符
|
||||||
|
# keys = [key.strip() for key in env_value.split(self.separator)]
|
||||||
|
# # 过滤掉空字符串
|
||||||
|
# api_keys = [key for key in keys if key]
|
||||||
|
|
||||||
|
# return api_keys
|
||||||
|
return API_KEY_LIST
|
||||||
|
|
||||||
|
def _get_next_api_key(self) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
获取下一个API密钥,使用轮转策略
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API密钥,如果没有可用的密钥则返回None
|
||||||
|
"""
|
||||||
|
if not self.api_keys:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# 轮转到下一个密钥
|
||||||
|
self._current_index = (self._current_index + 1) % len(self.api_keys)
|
||||||
|
selected_key = self.api_keys[self._current_index]
|
||||||
|
|
||||||
|
# 更新使用统计
|
||||||
|
self._key_usage[selected_key]["count"] += 1
|
||||||
|
self._key_usage[selected_key]["last_used"] = time.time()
|
||||||
|
|
||||||
|
return selected_key
|
||||||
|
|
||||||
|
def _get_random_api_key(self) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
随机获取一个API密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API密钥,如果没有可用的密钥则返回None
|
||||||
|
"""
|
||||||
|
if not self.api_keys:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
selected_key = random.choice(self.api_keys)
|
||||||
|
|
||||||
|
# 更新使用统计
|
||||||
|
self._key_usage[selected_key]["count"] += 1
|
||||||
|
self._key_usage[selected_key]["last_used"] = time.time()
|
||||||
|
|
||||||
|
return selected_key
|
||||||
|
|
||||||
|
def get_all_api_keys(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
获取所有API密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API密钥列表
|
||||||
|
"""
|
||||||
|
return self.api_keys.copy()
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
"""
|
||||||
|
检查是否有可用的API密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果有可用的API密钥则返回True,否则返回False
|
||||||
|
"""
|
||||||
|
return len(self.api_keys) > 0
|
||||||
|
|
||||||
|
def get_usage_stats(self) -> Dict:
|
||||||
|
"""
|
||||||
|
获取密钥使用统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
密钥使用统计信息
|
||||||
|
"""
|
||||||
|
return self._key_usage.copy()
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# 获取有效的API密钥列表
|
||||||
|
valid_keys = APIKeyManager.get_valid_api_keys()
|
||||||
|
print(f"有效的API密钥列表:\n" + "\n".join(valid_keys))
|
||||||
|
|
||||||
|
# 查看总密钥数
|
||||||
|
print(f"总共有 {APIKeyManager.count()} 个API密钥")
|
||||||
|
|
||||||
|
# 获取实例并查看使用统计
|
||||||
|
instance = APIKeyManager.get_instance()
|
||||||
|
stats = instance.get_usage_stats()
|
||||||
|
for key, data in stats.items():
|
||||||
|
print(f"密钥 {key[:5]}... 使用次数: {data['count']}")
|
||||||
@@ -0,0 +1,143 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
File: ModelTool.py
|
||||||
|
Date: 2025-05-15
|
||||||
|
Author: oyyz
|
||||||
|
Description: 模型工具类
|
||||||
|
"""
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
import httpx
|
||||||
|
import time
|
||||||
|
import logging # 导入 logging 模块
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from typing import List, Any
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from .APIKeyManager import APIKeyManager
|
||||||
|
|
||||||
|
class SiliconFlowEmbeddings(Embeddings):
|
||||||
|
"""SiliconFlow嵌入模型封装"""
|
||||||
|
def __init__(self, api_key: str, model: str = "bge-m3"):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.model = model
|
||||||
|
self.url = "http://10.1.16.39:9995/v1/embeddings"
|
||||||
|
self.headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
def _embed(self, input: List[str]) -> List[List[float]]:
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"input": input,
|
||||||
|
"encoding_format": "float"
|
||||||
|
}
|
||||||
|
response = requests.post(self.url, json=payload, headers=self.headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return [item["embedding"] for item in data["data"]]
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
return self._embed(texts)
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
return self._embed([text])[0]
|
||||||
|
|
||||||
|
class XinferenceReRankerModel:
|
||||||
|
"""重排模型封装"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]:
|
||||||
|
"""
|
||||||
|
使用重排序模型对文档进行重新排序
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户查询文本
|
||||||
|
documents: 需要重新排序的文档列表
|
||||||
|
top_k: 返回排序后的前k个文档
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
||||||
|
"""
|
||||||
|
url = "http://10.1.16.39:9995/v1/rerank"
|
||||||
|
|
||||||
|
|
||||||
|
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": os.getenv("RERANKER_MODEL_NAME")}
|
||||||
|
headers = {
|
||||||
|
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, json=params, headers=headers)
|
||||||
|
response.raise_for_status() # 检查响应状态
|
||||||
|
results = response.json()
|
||||||
|
|
||||||
|
# 返回重排序后的文档列表
|
||||||
|
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logging.error(f"重排序请求失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
class OpenAiLLM:
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
if kwargs.get("api_key") == None or kwargs.get("base_url") == None or kwargs.get("model") == None:
|
||||||
|
raise ValueError("api_key, base_url, model 不能为空")
|
||||||
|
|
||||||
|
self._api_key = kwargs.get("api_key")
|
||||||
|
self._url = kwargs.get("base_url")
|
||||||
|
self._model = kwargs.get("model")
|
||||||
|
|
||||||
|
kwargs.pop("api_key")
|
||||||
|
kwargs.pop("base_url")
|
||||||
|
kwargs.pop("model")
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
def invoke(self, user_prompt="你是谁?", need_retry=True):
|
||||||
|
# 初始化 OpenAI 客户端
|
||||||
|
api_key = APIKeyManager.get_api_key()
|
||||||
|
client = OpenAI(api_key=api_key, base_url=self._url)
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
retry_count = 0
|
||||||
|
|
||||||
|
if need_retry:
|
||||||
|
while retry_count < max_retries:
|
||||||
|
try:
|
||||||
|
# 创建 Completion 请求. 超时120s
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model=self._model,
|
||||||
|
messages=[{'role': 'user', 'content': user_prompt}],
|
||||||
|
timeout=httpx.Timeout(300.0),
|
||||||
|
**self._kwargs
|
||||||
|
)
|
||||||
|
return completion.choices[0].message
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count == max_retries:
|
||||||
|
logging.error(f"LLM 重试{max_retries}次后仍然失败: {e}")
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
time.sleep(5*retry_count) # 重试前等待1秒
|
||||||
|
else:
|
||||||
|
# 创建 Completion 请求. 超时120s
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model=self._model,
|
||||||
|
messages=[{'role': 'user', 'content': user_prompt}],
|
||||||
|
timeout=httpx.Timeout(300.0),
|
||||||
|
**self._kwargs
|
||||||
|
)
|
||||||
|
return completion.choices[0].message
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
reranker = XinferenceReRankerModel()
|
||||||
|
query = "什么是AI"
|
||||||
|
documents = ["AI是人工智能", "AI是机器学习", "AI是深度学习"]
|
||||||
|
results = reranker.rerank(query, documents)
|
||||||
|
print(results)
|
||||||
@@ -0,0 +1,159 @@
|
|||||||
|
import os.path
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class WikijsTool:
|
||||||
|
BASE_URL = "http://10.1.16.39:8090/graphql"
|
||||||
|
HEADERS = {
|
||||||
|
"Authorization": "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJhcGkiOjcsImdycCI6MSwiaWF"
|
||||||
|
"0IjoxNzIzMDIwNzg4LCJleHAiOjE4MTc2OTM1ODgsImF1ZCI6InVybjp3aWtpLmpzIiwiaX"
|
||||||
|
"NzIjoidXJuOndpa2kuanMifQ.NSfE4tB7tkN8yapAs0CgkR-Yll6wc3gO3QGKMAv-TlGxx6A-9fJRmkwhRDTVMj_yPVG6"
|
||||||
|
"NXVy_AZpJtLapRXFGn0cvscsRJxq3fY1KgEyt8wO99jvd8DpNHpHhAIgrtyDelmHsBD2Wb5Ib3WJFsWC6d8Yhm9dkpx6tZ"
|
||||||
|
"vMAlFIKOg6UodMoMIry3YWiPGLaqJPQ0gcKmcnB2tC7sPXIIZnvfb5912GVM0n-4wvWobQnb_tXQuYZf99wH_leXjC_7BK8"
|
||||||
|
"8JSaAmB980i3rBxfejmaJ8E6D48zRxwwPFa0veVjjzRkVqHPwAjl1CXb2HE29pGtNmSEE1kLQVqOZD_ibOwKQ"
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_url():
|
||||||
|
# 获取当前文件的路径
|
||||||
|
file_path = Path(__file__).resolve()
|
||||||
|
file_path = os.path.join(file_path.parent, 'wikiconfig.json')
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
return False
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||||||
|
data = json.load(file)
|
||||||
|
|
||||||
|
if 'url' in data:
|
||||||
|
WikijsTool.BASE_URL = data['url']
|
||||||
|
|
||||||
|
if 'Authorization' in data:
|
||||||
|
WikijsTool.HEADERS['Authorization'] = data['Authorization']
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_documents() -> list[dict]:
|
||||||
|
query = """
|
||||||
|
query Pages {
|
||||||
|
pages {
|
||||||
|
list {
|
||||||
|
path
|
||||||
|
locale
|
||||||
|
title
|
||||||
|
contentType
|
||||||
|
id
|
||||||
|
isPublished
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# 构建请求数据
|
||||||
|
data = {
|
||||||
|
'query': query,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送 POST 请求
|
||||||
|
response = requests.post(WikijsTool.BASE_URL, headers=WikijsTool.HEADERS, json=data)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# 解析数据
|
||||||
|
list_info = json.loads(response.content)['data']['pages']['list']
|
||||||
|
return [item for item in list_info]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"获取文档列表失败,原因:“{response.text}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_doc_by_path(path: str, path_is_dir: bool = True) -> list[dict]:
|
||||||
|
list_document = WikijsTool.get_all_documents()
|
||||||
|
all_document_list = []
|
||||||
|
if path_is_dir:
|
||||||
|
temp_path = path + '/'
|
||||||
|
else:
|
||||||
|
temp_path = path
|
||||||
|
for document_info in list_document:
|
||||||
|
document_path = str(document_info["path"])
|
||||||
|
# 根据路径过滤出对应的所有文档
|
||||||
|
if not document_path.startswith(temp_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_document_list.append(document_info)
|
||||||
|
|
||||||
|
return all_document_list
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def search_document(query_str: str) -> list[dict]:
|
||||||
|
graphql_query = f"""
|
||||||
|
query Pages {{
|
||||||
|
pages {{
|
||||||
|
search(query: "{query_str}") {{
|
||||||
|
results {{
|
||||||
|
id
|
||||||
|
path
|
||||||
|
locale
|
||||||
|
title
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
# 构建请求数据
|
||||||
|
data = {
|
||||||
|
'query': graphql_query,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送 POST 请求
|
||||||
|
response = requests.post(WikijsTool.BASE_URL, headers=WikijsTool.HEADERS, json=data)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# 解析数据
|
||||||
|
search_results = json.loads(response.content)['data']['pages']['search']['results']
|
||||||
|
return search_results
|
||||||
|
else:
|
||||||
|
raise ValueError(f"查询文档失败,原因:“{response.text}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def query_doc_info(doc_id: int) -> dict:
|
||||||
|
query = """
|
||||||
|
query singlePages($doc_id: Int!) {
|
||||||
|
pages {
|
||||||
|
single(id: $doc_id) {
|
||||||
|
id
|
||||||
|
path
|
||||||
|
title
|
||||||
|
isPublished
|
||||||
|
content
|
||||||
|
contentType
|
||||||
|
isPrivate
|
||||||
|
updatedAt
|
||||||
|
createdAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# 构建请求数据
|
||||||
|
variables = {
|
||||||
|
'doc_id': doc_id,
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
'query': query,
|
||||||
|
'variables': variables
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送 POST 请求
|
||||||
|
response = requests.post(WikijsTool.BASE_URL, headers=WikijsTool.HEADERS, json=data)
|
||||||
|
if "errors" in response.text:
|
||||||
|
result = json.loads(response.content)['errors'][0]['message']
|
||||||
|
return {}
|
||||||
|
else:
|
||||||
|
return json.loads(response.content)['data']['pages']['single']
|
||||||
|
|
||||||
|
|
||||||
|
WikijsTool.init_url()
|
||||||
|
if __name__ == "__main__":
|
||||||
|
WikijsTool.query_doc_info(6448)
|
||||||
|
print(WikijsTool.rename_directory("配网知识库/配网造价软件", "配网知识库/配网造价软件1"))
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,3 @@
|
|||||||
|
from . import custom_markdownify
|
||||||
|
|
||||||
|
convert_html_to_md = custom_markdownify.md
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,491 @@
|
|||||||
|
import re
|
||||||
|
from textwrap import fill
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from bs4 import NavigableString
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from markdownify import MarkdownConverter, chomp, UNDERLINED, ATX_CLOSED
|
||||||
|
import copy
|
||||||
|
from . import picture_process
|
||||||
|
|
||||||
|
|
||||||
|
# <br>是否是单元格内部的换行符
|
||||||
|
def judge_br_in_table(el):
|
||||||
|
if el.name in ['td', 'tr']:
|
||||||
|
return True
|
||||||
|
if el.parent is None:
|
||||||
|
return False
|
||||||
|
# 递归父级元素
|
||||||
|
return judge_br_in_table(el.parent)
|
||||||
|
|
||||||
|
|
||||||
|
# 获取div标签中是否为标题,如果是标题则markdown中的返回标题等级
|
||||||
|
def get_markdown_title_level(el):
|
||||||
|
if el.name != 'div' or 'class' not in el.attrs:
|
||||||
|
return ''
|
||||||
|
title_level = ''
|
||||||
|
if 'hdwiki_tmml' in el.attrs['class']:
|
||||||
|
title_level = '## '
|
||||||
|
elif 'hdwiki_tmmll' in el.attrs['class']:
|
||||||
|
title_level = '### '
|
||||||
|
return title_level
|
||||||
|
|
||||||
|
|
||||||
|
def str_is_title(text) -> bool:
|
||||||
|
text = text.strip()
|
||||||
|
pattern = r'^#+'
|
||||||
|
|
||||||
|
# 使用re.search匹配字符串开头的 # 符号
|
||||||
|
match = re.search(pattern, text)
|
||||||
|
if match:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# 判断el 是否是图片的DIV标签
|
||||||
|
def is_img_div_tag(el) -> bool:
|
||||||
|
if el is None:
|
||||||
|
return False
|
||||||
|
if el.name != "div":
|
||||||
|
return False
|
||||||
|
class_attr = el.get('class')
|
||||||
|
if class_attr is None:
|
||||||
|
return False
|
||||||
|
if "img" in class_attr or "img_l" in class_attr:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# 判断div内部是否是纯文本内容,并且display是否为block
|
||||||
|
def is_only_text_div(el) -> bool:
|
||||||
|
if el is None or el.name != "div" or el.text == "":
|
||||||
|
return False
|
||||||
|
|
||||||
|
if el.get("display", "block") != "block":
|
||||||
|
return False
|
||||||
|
|
||||||
|
# div标签下只包含文本
|
||||||
|
if isinstance(el.string, NavigableString):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 兼容<div><b>1. 版本概述</b> </div> 判断错误问题
|
||||||
|
# 递归获取所有子标签
|
||||||
|
child_tags = el.find_all(recursive=True)
|
||||||
|
for tag in child_tags:
|
||||||
|
if tag.text == "":
|
||||||
|
continue
|
||||||
|
if tag.name in ["table", "td", "img"]:
|
||||||
|
return False
|
||||||
|
if isinstance(tag.string, NavigableString):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# a标签是否在图片的div标签内部
|
||||||
|
def a_tag_is_in_img(el) -> bool:
|
||||||
|
if el.parent is None:
|
||||||
|
return False
|
||||||
|
if el.name != "a" or el.parent.name != "div":
|
||||||
|
return False
|
||||||
|
|
||||||
|
return is_img_div_tag(el.parent)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomMarkDownConverter(MarkdownConverter):
|
||||||
|
"""
|
||||||
|
创建自定义的换行装换函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, img_download_path, **options):
|
||||||
|
super().__init__(**options)
|
||||||
|
self.img_download_path = img_download_path
|
||||||
|
|
||||||
|
# 单元格内的换行依旧保持<br>格式
|
||||||
|
def convert_br(self, el, text, convert_as_inline):
|
||||||
|
if judge_br_in_table(el):
|
||||||
|
return "<br/>"
|
||||||
|
|
||||||
|
# 容错处理(文章4696),因bs4解析html错误 导致将 分类图标签 解析到了br标签下导致图片丢失
|
||||||
|
if text.strip():
|
||||||
|
return text + "\n"
|
||||||
|
|
||||||
|
return super().convert_br(el, text, convert_as_inline)
|
||||||
|
|
||||||
|
# 图片div标签 在图片与图片描述之间添加换行
|
||||||
|
@staticmethod
|
||||||
|
def convert_img_div(text):
|
||||||
|
pattern = r'\*\*(.*?)\*\*'
|
||||||
|
match = re.search(pattern, text)
|
||||||
|
if match:
|
||||||
|
start_index = match.start()
|
||||||
|
text = text[:start_index] + "\n" + text[start_index:]
|
||||||
|
return text
|
||||||
|
|
||||||
|
# 装换标题格式
|
||||||
|
def convert_div(self, el, text, convert_as_inline):
|
||||||
|
title_level = get_markdown_title_level(el)
|
||||||
|
if title_level != '':
|
||||||
|
return "\n\n" + title_level + text + '\n\n'
|
||||||
|
|
||||||
|
if is_img_div_tag(el):
|
||||||
|
# 图片与图片描述文字之间掺入换行符
|
||||||
|
return self.convert_img_div(text)
|
||||||
|
|
||||||
|
if is_only_text_div(el):
|
||||||
|
text = "\n\n" + text + "\n\n"
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
# 检查 URL 是否有效的函数
|
||||||
|
@staticmethod
|
||||||
|
def is_valid_url(url):
|
||||||
|
try:
|
||||||
|
response = requests.head(url, allow_redirects=True)
|
||||||
|
return response.status_code == 200
|
||||||
|
except requests.RequestException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def try_complete_img_description(img_el):
|
||||||
|
if img_el is None or img_el.name != "img":
|
||||||
|
return
|
||||||
|
|
||||||
|
# 找到父级的div标签
|
||||||
|
img_el_parent_div = None
|
||||||
|
cur_el = img_el
|
||||||
|
while cur_el.parent is not None:
|
||||||
|
if is_img_div_tag(cur_el.parent):
|
||||||
|
img_el_parent_div = cur_el.parent
|
||||||
|
break
|
||||||
|
cur_el = cur_el.parent
|
||||||
|
|
||||||
|
if img_el_parent_div is not None and len(img_el_parent_div.text) != 0:
|
||||||
|
img_el.attrs["alt"] = img_el_parent_div.text
|
||||||
|
return
|
||||||
|
|
||||||
|
# 找到父级的figure标签
|
||||||
|
img_el_parent_div = None
|
||||||
|
cur_el = img_el
|
||||||
|
while cur_el.parent is not None:
|
||||||
|
if cur_el.parent is not None and cur_el.parent.name == 'figure':
|
||||||
|
img_el_parent_div = cur_el.parent
|
||||||
|
break
|
||||||
|
cur_el = cur_el.parent
|
||||||
|
|
||||||
|
if img_el_parent_div is not None and len(img_el_parent_div.text) != 0:
|
||||||
|
img_el.attrs["alt"] = img_el_parent_div.text
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def convert_figcaption(self, el, text, convert_as_inline):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 图片后添加空行,图片应该单独在一行后面不接文字(示例文章:6925)
|
||||||
|
def convert_img(self, el, text, convert_as_inline):
|
||||||
|
self.try_complete_img_description(el)
|
||||||
|
img_text = super().convert_img(el, text, convert_as_inline)
|
||||||
|
|
||||||
|
# 5195 出现img标签内出现换行导致 markdown图片显示出现问题
|
||||||
|
img_text = img_text.replace("\r\n", "")
|
||||||
|
img_text = img_text.replace("\n", "")
|
||||||
|
# 空的img标签直接返回空行
|
||||||
|
if img_text == "![]()":
|
||||||
|
return '\n\n'
|
||||||
|
|
||||||
|
# img 标签使用父级超链接标签中的中大图
|
||||||
|
src = el.attrs.get('src', None) or ''
|
||||||
|
if el.parent is not None and el.parent.name == "a":
|
||||||
|
href = el.parent.attrs.get('href', None) or ''
|
||||||
|
href_path = href.rsplit(".", 1)[0]
|
||||||
|
src_path = src.rsplit(".", 1)[0]
|
||||||
|
if href_path + "_s" == src_path:
|
||||||
|
img_text = img_text.replace(src, href)
|
||||||
|
|
||||||
|
if '_s' in img_text:
|
||||||
|
src_path = src.rsplit(".", 1)[0]
|
||||||
|
if src_path.endswith('_s'):
|
||||||
|
original_src_path = src_path[:-2] # 去掉末尾的 '_s'
|
||||||
|
# 构建原始 URL
|
||||||
|
original_url = original_src_path + "." + src.split(".")[-1]
|
||||||
|
if self.is_valid_url(original_url):
|
||||||
|
img_text = img_text.replace(src, original_url)
|
||||||
|
|
||||||
|
# 转换并下载图片
|
||||||
|
return picture_process.process_img_tag(img_text, self.img_download_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_img_describe_strong(el) -> bool:
|
||||||
|
if el is None or el.parent is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if len(el.contents) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# if not isinstance(el.contents[0], NavigableString):
|
||||||
|
# return False
|
||||||
|
|
||||||
|
img_list = el.parent.findAll("img")
|
||||||
|
if len(img_list) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for img_tag in img_list:
|
||||||
|
alt = img_tag.get("alt", None)
|
||||||
|
title = img_tag.get("title", None)
|
||||||
|
if alt is None and title is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if alt == el.text or title == el.text:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def convert_b(self, el, text, convert_as_inline):
|
||||||
|
# 如果b 标签下只存在一个标题,则该b不做任何处理,避免对标题进行加粗(示例文章:6925)
|
||||||
|
if len(el.contents) == 1:
|
||||||
|
title_level = get_markdown_title_level(el.contents[0])
|
||||||
|
if title_level != '':
|
||||||
|
return text
|
||||||
|
|
||||||
|
# <b> 标签中存在标题时,不在对内容进行加粗
|
||||||
|
if str_is_title(text):
|
||||||
|
return text
|
||||||
|
|
||||||
|
if self.is_img_describe_strong(el):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text = text.strip(" \t")
|
||||||
|
suffix = ""
|
||||||
|
if text.endswith("\n"):
|
||||||
|
suffix = " \n"
|
||||||
|
b_text = super().convert_b(el, text, convert_as_inline)
|
||||||
|
|
||||||
|
# 解析完<b> 标签后添加空格。避免出现markdown文档中出现《**1.****版本概述**》(文章2377 4292等)
|
||||||
|
return " " + b_text + suffix + " "
|
||||||
|
|
||||||
|
convert_strong = convert_b
|
||||||
|
|
||||||
|
# 有可能出现<p>之后紧接一个标题hdwiki_tmml 故前后添加换行
|
||||||
|
def convert_p(self, el, text, convert_as_inline):
|
||||||
|
if convert_as_inline:
|
||||||
|
return text
|
||||||
|
if self.options['wrap']:
|
||||||
|
text = fill(text,
|
||||||
|
width=self.options['wrap_width'],
|
||||||
|
break_long_words=False,
|
||||||
|
break_on_hyphens=False)
|
||||||
|
# <p>标签前后换行
|
||||||
|
return '\n\n%s\n\n' % text if text else ''
|
||||||
|
|
||||||
|
def convert_a(self, el, text, convert_as_inline):
|
||||||
|
prefix, suffix, text = chomp(text)
|
||||||
|
if not text:
|
||||||
|
return ''
|
||||||
|
href = el.get('href')
|
||||||
|
if self.is_href_img(href):
|
||||||
|
return text
|
||||||
|
title = el.get('title')
|
||||||
|
# 5195 出现img标签内出现换行导致 markdown图片显示出现问题
|
||||||
|
if title is not None:
|
||||||
|
title = title.replace("\n", "")
|
||||||
|
# For the replacement see #29: text nodes underscores are escaped
|
||||||
|
if (self.options['autolinks']
|
||||||
|
and text.replace(r'\_', '_') == href
|
||||||
|
and not title
|
||||||
|
and not self.options['default_title']):
|
||||||
|
# Shortcut syntax
|
||||||
|
return '<%s>' % href
|
||||||
|
if self.options['default_title'] and not title:
|
||||||
|
title = href
|
||||||
|
title_part = ' "%s"' % title.replace('"', r'\"') if title else ''
|
||||||
|
|
||||||
|
a_tag = '%s[%s](%s%s)%s' % (prefix, text, href, title_part, suffix) if href else text
|
||||||
|
return a_tag
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_href_img(href_url) -> bool:
|
||||||
|
if href_url is None:
|
||||||
|
return False
|
||||||
|
file_extension = href_url.split(".")[-1]
|
||||||
|
# 不是图片不处理
|
||||||
|
file_extension = file_extension.lower()
|
||||||
|
if file_extension not in ["jpg", "jpeg", "png", "gif"]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def convert_li(self, el, text, convert_as_inline):
|
||||||
|
# 为空的li标签返回空(文章 4347)
|
||||||
|
if not text.strip():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
li_text = super().convert_li(el, text, convert_as_inline)
|
||||||
|
return li_text
|
||||||
|
|
||||||
|
def convert_td(self, el, text, convert_as_inline):
|
||||||
|
if "\r\n" in text:
|
||||||
|
text = text.replace("\r\n", "<br>")
|
||||||
|
|
||||||
|
if "\n" in text:
|
||||||
|
text = text.replace("\n", "<br>")
|
||||||
|
|
||||||
|
return ' ' + text + ' |'
|
||||||
|
|
||||||
|
def convert_hn(self, n, el, text, convert_as_inline):
|
||||||
|
if convert_as_inline:
|
||||||
|
return text
|
||||||
|
|
||||||
|
style = self.options['heading_style'].lower()
|
||||||
|
text = text.rstrip()
|
||||||
|
if style == UNDERLINED and n <= 2:
|
||||||
|
line = '=' if n == 1 else '-'
|
||||||
|
return self.underline(text, line)
|
||||||
|
hashes = '#' * n
|
||||||
|
hashes = hashes + " "
|
||||||
|
if style == ATX_CLOSED:
|
||||||
|
return '\n\n %s %s %s\n\n' % (hashes, text, hashes)
|
||||||
|
return '\n\n%s %s\n\n' % (hashes, text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_thead_table(el, text, cell_name, convert_as_inline):
|
||||||
|
cells = el.find_all(['td', 'th'])
|
||||||
|
is_headrow = all([cell.name == cell_name for cell in cells])
|
||||||
|
overline = ''
|
||||||
|
underline = ''
|
||||||
|
if is_headrow and not el.previous_sibling:
|
||||||
|
# first row and is headline: print headline underline
|
||||||
|
underline += '| ' + ' | '.join(['---'] * len(cells)) + ' |' + '\n'
|
||||||
|
elif (not el.previous_sibling
|
||||||
|
and (el.parent.name == 'table'
|
||||||
|
or (el.parent.name == 'tbody'
|
||||||
|
and not el.parent.previous_sibling))):
|
||||||
|
# first row, not headline, and:
|
||||||
|
# - the parent is table or
|
||||||
|
# - the parent is tbody at the beginning of a table.
|
||||||
|
# print empty headline above this row
|
||||||
|
overline += '| ' + ' | '.join([''] * len(cells)) + ' |' + '\n'
|
||||||
|
overline += '| ' + ' | '.join(['---'] * len(cells)) + ' |' + '\n'
|
||||||
|
return overline + '|' + text + '\n' + underline
|
||||||
|
|
||||||
|
def convert_tr(self, el, text, convert_as_inline):
|
||||||
|
# 解决table标签下存在thead的问题 (文章4061 1976)
|
||||||
|
if el and el.parent and el.parent.name == "thead":
|
||||||
|
return CustomMarkDownConverter.convert_thead_table(el, text, 'td', convert_as_inline)
|
||||||
|
|
||||||
|
# 兼容 table->colgroup、tbody->tr 文章4364
|
||||||
|
if (el and el.parent and el.parent.previousSibling
|
||||||
|
and el.parent.name == "tbody"
|
||||||
|
and el.parent.previousSibling.name == "colgroup"):
|
||||||
|
return CustomMarkDownConverter.convert_thead_table(el, text, 'td', convert_as_inline)
|
||||||
|
|
||||||
|
return super().convert_tr(el, text, convert_as_inline)
|
||||||
|
|
||||||
|
def convert_pre(self, el, text, convert_as_inline):
|
||||||
|
# 文章5192出现pre标签,但内容不是代码。故不额外处理pre标签
|
||||||
|
return text
|
||||||
|
|
||||||
|
def escape(self, text):
|
||||||
|
if not text:
|
||||||
|
return ''
|
||||||
|
if self.options['escape_misc']:
|
||||||
|
# text = re.sub(r'([\\&<`[>~#=+|-])', r'\\\1', text)
|
||||||
|
text = re.sub(r'([\\&<`[>~#%=+|-])', r'\\\1', text)
|
||||||
|
# 以下的转义是不必要的
|
||||||
|
# text = re.sub(r'([0-9])([.)])', r'\1\\\2', text)
|
||||||
|
if self.options['escape_asterisks']:
|
||||||
|
text = text.replace('*', r'\*')
|
||||||
|
if self.options['escape_underscores']:
|
||||||
|
text = text.replace('_', r'\_')
|
||||||
|
return text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_span(el, text, convert_as_inline):
|
||||||
|
# 文章3526出现图片后面紧接图片文本的问题。图片文本在span标签内
|
||||||
|
if "style" not in el.attrs:
|
||||||
|
return text
|
||||||
|
|
||||||
|
style_attr = el.attrs['style']
|
||||||
|
|
||||||
|
if style_attr is None:
|
||||||
|
return text
|
||||||
|
style_content = style_attr.split(';')
|
||||||
|
# 遍历style属性内容,找到display的值
|
||||||
|
for item in style_content:
|
||||||
|
if 'display' in item:
|
||||||
|
display_value = item.split(': ')[1] # 获取冒号后的值
|
||||||
|
if display_value == "block" and text != "":
|
||||||
|
return f"\n\n{text}\n\n"
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def expand_html_table(html) -> tuple[str, bool]:
|
||||||
|
soup = BeautifulSoup(html, 'html.parser')
|
||||||
|
tables = soup.find_all('table')
|
||||||
|
if len(tables) == 0:
|
||||||
|
return html, False
|
||||||
|
for table in tables:
|
||||||
|
# 创建一个二维列表来表示表格
|
||||||
|
table_rows = table.find_all('tr')
|
||||||
|
max_cols = 0
|
||||||
|
for row in table_rows:
|
||||||
|
cols = row.find_all(['td', 'th'])
|
||||||
|
col_count = sum([int(col.get('colspan', 1)) for col in cols])
|
||||||
|
if col_count > max_cols:
|
||||||
|
max_cols = col_count
|
||||||
|
|
||||||
|
# 初始化一个二维列表来存储最终的表格
|
||||||
|
result_table = []
|
||||||
|
for _ in range(len(table_rows)):
|
||||||
|
result_table.append([None] * max_cols)
|
||||||
|
|
||||||
|
# 填充二维列表
|
||||||
|
for r, row in enumerate(table_rows):
|
||||||
|
cols = row.find_all(['td', 'th'])
|
||||||
|
c = 0
|
||||||
|
for col in cols:
|
||||||
|
while result_table[r][c] is not None:
|
||||||
|
c += 1
|
||||||
|
colspan = int(col.get('colspan', 1))
|
||||||
|
rowspan = int(col.get('rowspan', 1))
|
||||||
|
for i in range(rowspan):
|
||||||
|
for j in range(colspan):
|
||||||
|
# 拆分合并单元格时,重复内容
|
||||||
|
result_table[r + i][c + j] = copy.copy(col)
|
||||||
|
# if j == 0 and i == 0:
|
||||||
|
# result_table[r + i][c + j] = copy.copy(col)
|
||||||
|
# else:
|
||||||
|
# result_table[r + i][c + j] = soup.new_tag('td')
|
||||||
|
c += colspan
|
||||||
|
|
||||||
|
# 生成新的表格 HTML
|
||||||
|
new_table = soup.new_tag('table', border="1", cellspacing="0")
|
||||||
|
tbody = soup.new_tag('tbody')
|
||||||
|
new_table.append(tbody)
|
||||||
|
for row in result_table:
|
||||||
|
tr = soup.new_tag('tr')
|
||||||
|
for col in row:
|
||||||
|
if col is not None:
|
||||||
|
td = soup.new_tag(col.name)
|
||||||
|
td.string = col.get_text()
|
||||||
|
tr.append(td)
|
||||||
|
tbody.append(tr)
|
||||||
|
|
||||||
|
# 替换原始HTML中的旧表格
|
||||||
|
table.replace_with(new_table)
|
||||||
|
|
||||||
|
return str(soup), True
|
||||||
|
|
||||||
|
|
||||||
|
# Create shorthand method for conversion
|
||||||
|
def md(html, img_download_path, **options):
|
||||||
|
new_html, result = expand_html_table(html)
|
||||||
|
markdown_content = CustomMarkDownConverter(img_download_path, **options).convert(new_html)
|
||||||
|
# 删除换行符中间的空格
|
||||||
|
temp_txt = re.sub(r'\n\s*\n', '\n\n', markdown_content)
|
||||||
|
# 连续超过3个以上的换行符替换为3个
|
||||||
|
temp_txt = re.sub(r'\n{3,}', '\n\n\n', temp_txt)
|
||||||
|
return temp_txt
|
||||||
@@ -0,0 +1,170 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_tag_url(img_tag):
|
||||||
|
|
||||||
|
# 提取图片url的正则表达式模式
|
||||||
|
pattern = r'\!\[.*?\]\((.*?)\)'
|
||||||
|
# 找到第一个匹配的链接
|
||||||
|
match = re.search(pattern, img_tag)
|
||||||
|
if not match:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 获取匹配到的链接
|
||||||
|
link = match.group(1)
|
||||||
|
# 第0个为链接
|
||||||
|
link = link.split(" ")[0]
|
||||||
|
return link
|
||||||
|
|
||||||
|
|
||||||
|
# 填充img标签中的图片链接
|
||||||
|
# img_tag ''
|
||||||
|
# img_tag ''
|
||||||
|
def fill_img_url(img_tag):
|
||||||
|
"""
|
||||||
|
填充img标签中的图片链接。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
img_tag (str): 原始的img标签
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple: 修改后的img标签和图片的完整链接
|
||||||
|
"""
|
||||||
|
# 一个完整的img标签内删除换行符
|
||||||
|
img_tag = img_tag.replace("\n", "")
|
||||||
|
link = get_img_tag_url(img_tag)
|
||||||
|
if len(link) == 0:
|
||||||
|
return img_tag, ''
|
||||||
|
|
||||||
|
base_url = os.getenv("IMG_URL_PREFIX")
|
||||||
|
if "http:" in link:
|
||||||
|
# 图片为全链接,不替换
|
||||||
|
return img_tag, link
|
||||||
|
elif base_url:
|
||||||
|
# 补全图片链接
|
||||||
|
full_link = urljoin(base_url, link)
|
||||||
|
img_tag = img_tag.replace(link, full_link)
|
||||||
|
return img_tag, full_link
|
||||||
|
else:
|
||||||
|
return img_tag, ''
|
||||||
|
|
||||||
|
|
||||||
|
def download_picture(img_tag, download_path):
|
||||||
|
headers = {
|
||||||
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '
|
||||||
|
'Chrome/94.0.4606.71 Safari/537.36 '
|
||||||
|
}
|
||||||
|
img_tag, img_url = fill_img_url(img_tag)
|
||||||
|
if img_url == '':
|
||||||
|
return img_tag
|
||||||
|
# if "_s" in img_tag:
|
||||||
|
# breakpoint()
|
||||||
|
file_name = img_url.split("/")[-1]
|
||||||
|
file_path = os.path.normpath(download_path + "\\" + file_name)
|
||||||
|
file_path = file_path.replace("\\", "/")
|
||||||
|
|
||||||
|
# 文件已经存在时不下载
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
img_date = requests.get(url=img_url, headers=headers).content
|
||||||
|
logging.info(f"图片下载成功:{img_url}")
|
||||||
|
with open(file_path, 'wb') as fp:
|
||||||
|
fp.write(img_date)
|
||||||
|
|
||||||
|
# img_tag中的url替换为下载的图片路径
|
||||||
|
return img_tag.replace(img_url, file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def download_picture_from_other_url(img_tag, download_path):
|
||||||
|
headers = {
|
||||||
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '
|
||||||
|
'Chrome/94.0.4606.71 Safari/537.36 '
|
||||||
|
}
|
||||||
|
img_tag, img_url = fill_img_url(img_tag)
|
||||||
|
# if "_s" in img_tag:
|
||||||
|
# breakpoint()
|
||||||
|
file_name = uuid.uuid4()
|
||||||
|
file_path = os.path.join(download_path, f"{file_name}.png")
|
||||||
|
file_path = os.path.normpath(file_path)
|
||||||
|
# 文件已经存在时不下载
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
try:
|
||||||
|
img_date = requests.get(url=img_url, headers=headers).content
|
||||||
|
with open(file_path, 'wb') as fp:
|
||||||
|
fp.write(img_date)
|
||||||
|
logging.info(f"图片下载成功:{img_url}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"img download error url:{img_url}")
|
||||||
|
return img_tag
|
||||||
|
|
||||||
|
# img_tag中的url替换为下载的图片路径
|
||||||
|
return img_tag.replace(img_url, file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_base64_from_data_uri(data_uri):
|
||||||
|
# 分割字符串以找到 base64 部分
|
||||||
|
parts = data_uri.split(',')
|
||||||
|
if len(parts) == 2 and parts[0].endswith('base64'):
|
||||||
|
# 移除后缀并返回 base64 值
|
||||||
|
return parts[1][:-1]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def picture_base64(img_tag, picture_save_path):
|
||||||
|
# 解码Base64字符串
|
||||||
|
# 
|
||||||
|
base64_str = extract_base64_from_data_uri(img_tag)
|
||||||
|
if picture_save_path is None or picture_save_path == "":
|
||||||
|
return ""
|
||||||
|
# 将图片内容做MD5 用作文件名
|
||||||
|
hash_object = hashlib.md5()
|
||||||
|
hash_object.update(base64_str.encode())
|
||||||
|
img_md5 = hash_object.hexdigest()
|
||||||
|
|
||||||
|
picture_save_path = picture_save_path + "\\%s.png" % img_md5
|
||||||
|
picture_save_path = os.path.normpath(picture_save_path)
|
||||||
|
picture_save_path = picture_save_path.replace("\\", "/")
|
||||||
|
|
||||||
|
# 文件已经存在时不重新保存
|
||||||
|
if not os.path.exists(picture_save_path):
|
||||||
|
decoded_string = base64.b64decode(base64_str)
|
||||||
|
with open(picture_save_path, 'wb') as fp:
|
||||||
|
fp.write(decoded_string)
|
||||||
|
|
||||||
|
# 修改img_tab的图片路径
|
||||||
|
match = re.search("\[(.*?)\]", img_tag)
|
||||||
|
result = ""
|
||||||
|
if match:
|
||||||
|
result = match.group(1)
|
||||||
|
if result == "":
|
||||||
|
return "" % picture_save_path
|
||||||
|
else:
|
||||||
|
return "" % (result, picture_save_path, result)
|
||||||
|
|
||||||
|
|
||||||
|
def process_img_tag(str_img_tag, img_path):
|
||||||
|
# 如果img标签指向的是本地磁盘路径 则忽略该标签返回空
|
||||||
|
if "file:///" in str_img_tag:
|
||||||
|
logging.warning(f"存在非法的链接地址:{str_img_tag}")
|
||||||
|
return ""
|
||||||
|
if img_path is None or img_path == "":
|
||||||
|
return ""
|
||||||
|
|
||||||
|
img_url = get_img_tag_url(str_img_tag)
|
||||||
|
if "data:image/png;base64" in str_img_tag:
|
||||||
|
return picture_base64(str_img_tag, img_path)
|
||||||
|
# (4696等存在指向外部链接的 img标签。 暂时保留不删除)
|
||||||
|
elif "http://" in str_img_tag or "https://" in str_img_tag:
|
||||||
|
return download_picture_from_other_url(str_img_tag, img_path)
|
||||||
|
elif not img_url.startswith("http"):
|
||||||
|
return download_picture(str_img_tag, img_path)
|
||||||
|
else:
|
||||||
|
logging.warning(f"未处理的图片标签:{str_img_tag}")
|
||||||
|
return str_img_tag
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"url":"http://10.1.0.145:8090/graphql",
|
||||||
|
"Authorization":"Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJhcGkiOjEsImdycCI6MSwiaWF0IjoxNzIzNjMxMjcwLCJleHAiOjE4MTgzMDQwNzAsImF1ZCI6InVybjp3aWtpLmpzIiwiaXNzIjoidXJuOndpa2kuanMifQ.g5H1xVMtk7Q3uvrRdtD3aTm49dQkS11cYdDKIwXo7DthOOTGj9DmFO7yILNDU7XFACTZc1Ej6ryguYV_8vGqoc-Rc7LciwvqS_RHDYUKZNKENbv8df9UGDMB-F9DT_airGc1lGJXgVqypxejDL3fY8aRMGXm7GBIlZKY4JTeI2uJZxffgfqKGrOvc3EOtsGgJzKZo4OyQ8UInGtCTiuq6-mLj_Syix_1z52K1tgfnF4E4-rZH_zCD05hUlUMYUV-KWhPkeOEGR5xbRTrulfCvzDD4T0CX4pI-keSKmgVn1HYSSN4o1Tj_l9zsyhUoLRzhzPK29Q3uekIc9obrvCHrg"
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user