更新环境变量配置,调整模型名称获取方式,新增Dify API相关配置,删除无用的脚本文件,优化意图识别逻辑,添加LLM提取词条逻辑
This commit is contained in:
@@ -48,7 +48,7 @@ class JsonDeduplicator:
|
||||
{items}
|
||||
'''
|
||||
# 配置LLM
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
llm_params = {"temperature": 0.3, "model": model_name}
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: extract_wikijs_nouns.py
|
||||
Author: oyyz
|
||||
Description: 从 Wikijs 文档中提取专业名词
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
from dotenv import load_dotenv
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from rag2_0.tool.WikijsTool import WikijsTool
|
||||
from rag2_0.intent_recognition.DataModels import Term, TermList
|
||||
from rag2_0.tool.html_to_md import convert_html_to_md
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
import json
|
||||
import datetime
|
||||
import logging
|
||||
import threading
|
||||
import concurrent.futures
|
||||
from threading import Semaphore
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
extract_wiki_nouns_prompt="""
|
||||
我在完善我的专业词库,请从提供的电力行业造价软件相关文本中提取关键词,要求如下:
|
||||
|
||||
一、提取范围
|
||||
1. 核心功能模块
|
||||
(例:多工程批量计价、材机数据反算、变电工程智能组价、架空线路地形系数计算)
|
||||
2、软件功能及界面名称(包括:界面页签、功能按钮、功能名称等)
|
||||
(例:新建工程量清单、导出工程量清单等)
|
||||
3. 业务专用术语
|
||||
(例:装置性材料、甲供材保管费、施工降效补偿、电缆头试验配套费)
|
||||
4. 计价标准体系
|
||||
(例:预规2020版、电网检修定额2015版、配网工程概算定额)
|
||||
|
||||
|
||||
二、提取规则
|
||||
1. 识别核心功能名称(如"多工程批量设置工程量、工程设置密码")
|
||||
2. 提取业务专用名词(如"主材卸车保管费")
|
||||
3. 标注关联术语的对应关系(如"市场价"与"市场价格"互为同义词)
|
||||
4. 包含定额标准相关术语(如"预规2020版")
|
||||
5. 复合型术语需保持完整
|
||||
√ 正确:"地形增加系数批量设置"
|
||||
× 错误:"地形"、"系数"、"设置"
|
||||
6. 总结生成关键词解释
|
||||
关键词:编制依据
|
||||
描述:造价文件编制基准规范
|
||||
|
||||
7. 软件的特定版本号不作为关键词
|
||||
|
||||
三、输出格式:
|
||||
{output_format}
|
||||
|
||||
四、输入内容:
|
||||
{content}
|
||||
"""
|
||||
|
||||
|
||||
class WikijsNounsExtractor:
|
||||
"""从 Wikijs 文档中提取专业名词"""
|
||||
|
||||
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo"):
|
||||
"""
|
||||
初始化专业名词提取器
|
||||
|
||||
Args:
|
||||
api_key: API密钥,如果为None则从环境变量获取
|
||||
base_url: API基础URL,如果为None则使用默认URL
|
||||
model_name: 要使用的模型名称
|
||||
"""
|
||||
# 保存参数
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
|
||||
# 初始化LLM
|
||||
llm_params = {
|
||||
"temperature": 0.6,
|
||||
"model": model_name
|
||||
}
|
||||
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
|
||||
self.llm = OpenAiLLM(**llm_params)
|
||||
|
||||
# 准备术语列表解析器
|
||||
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
||||
|
||||
# 信号量,限制并发请求数量
|
||||
self.semaphore = None
|
||||
|
||||
# 线程锁,用于保护共享资源
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def _convert_html_to_md(self, content, title):
|
||||
"""HTML转Markdown"""
|
||||
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
|
||||
new_content = (content.replace("h6>", "h7>")
|
||||
.replace("h5>", "h6>")
|
||||
.replace("h4>", "h5>")
|
||||
.replace("h3>", "h4>")
|
||||
.replace("h2>", "h3>")
|
||||
.replace("h1>", "h2>"))
|
||||
# 将HTML内容转换为Markdown
|
||||
markdown_content = convert_html_to_md(new_content, "", **options)
|
||||
markdown_content = f"# {title}\n\n{markdown_content}"
|
||||
return markdown_content
|
||||
|
||||
def extract_from_document(self, doc_info: dict) -> List[Term]:
|
||||
"""从单个文档中提取专业名词"""
|
||||
try:
|
||||
# 使用LLM调用处理文档
|
||||
content = doc_info['content']
|
||||
title = doc_info["title"]
|
||||
|
||||
# 转换HTML到Markdown
|
||||
markdown_content = self._convert_html_to_md(content, title)
|
||||
|
||||
# 准备提示词
|
||||
formatted_prompt = extract_wiki_nouns_prompt.replace("{content}", markdown_content)
|
||||
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt)
|
||||
# 使用Pydantic解析器解析结果
|
||||
parsed_output = self.terms_list_parser.parse(response.content)
|
||||
return parsed_output.terms
|
||||
except Exception as e:
|
||||
logging.error(f"解析LLM响应时出错: {str(e)}", exc_info=True)
|
||||
return []
|
||||
except Exception as e:
|
||||
logging.error(f"提取专业名词时出错: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
def _process_document(self, doc, path_terms):
|
||||
"""处理单个文档"""
|
||||
try:
|
||||
# 获取信号量
|
||||
with self.semaphore:
|
||||
# 检查文档路径是否在我们要处理的路径中
|
||||
path_prefix = None
|
||||
for prefix in path_terms.keys():
|
||||
if doc['path'].startswith(prefix):
|
||||
path_prefix = prefix
|
||||
break
|
||||
|
||||
# 如果不在要处理的路径中,则跳过
|
||||
if not path_prefix:
|
||||
return None
|
||||
|
||||
# 获取文档详细信息
|
||||
doc_info = WikijsTool.query_doc_info(doc['id'])
|
||||
if not doc_info or not doc_info.get('content'):
|
||||
return None
|
||||
|
||||
# 提取专业名词
|
||||
terms = self.extract_from_document(doc_info)
|
||||
|
||||
# 将提取的术语添加到对应路径的结果列表中
|
||||
terms_dicts = [{"name": term.name, "synonymous": term.synonymous, "description": term.description} for term in terms]
|
||||
|
||||
with self.lock:
|
||||
path_terms[path_prefix].extend(terms_dicts)
|
||||
logging.info(f"文档 {doc['path']} 处理完成,提取了 {len(terms)} 个专业名词")
|
||||
|
||||
# 每处理10个文档保存一次中间结果
|
||||
current_count = len(path_terms[path_prefix])
|
||||
if current_count % 10 == 0:
|
||||
# 使用锁保护文件IO
|
||||
self._save_terms_to_file(path_terms[path_prefix], os.path.join(self.output_dir, f"{path_prefix.split('(')[0]}_nouns.json"))
|
||||
logging.info(f"已处理 {path_prefix} 的文档数达到 {current_count//10*10} 个,已保存中间结果")
|
||||
|
||||
return path_prefix
|
||||
except Exception as e:
|
||||
logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def process_all_documents(self, output_dir: str = "extracted_nouns", max_concurrency: int = 5):
|
||||
"""使用线程池处理所有文档"""
|
||||
# 保存输出目录
|
||||
self.output_dir = output_dir
|
||||
|
||||
# 创建输出目录
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# 初始化信号量,限制并发请求数
|
||||
self.semaphore = Semaphore(max_concurrency)
|
||||
|
||||
# 获取所有文档
|
||||
all_docs = WikijsTool.get_all_documents()
|
||||
|
||||
# 要处理的路径前缀
|
||||
# path_prefixes = [
|
||||
# "技改检修计价通(2020)",
|
||||
# "西藏造价软件(2023)",
|
||||
# "新型储能电站建设计价通C1(2024)",
|
||||
# "配网造价软件(2022)",
|
||||
# ]
|
||||
path_prefixes = [
|
||||
"主网电力建设计价通(2018)",
|
||||
]
|
||||
# 为每个路径创建单独的结果列表
|
||||
path_terms = {prefix: [] for prefix in path_prefixes}
|
||||
|
||||
# 过滤出符合路径前缀的文档
|
||||
filtered_docs = []
|
||||
for doc in all_docs:
|
||||
for prefix in path_prefixes:
|
||||
if doc['path'].startswith(prefix):
|
||||
filtered_docs.append(doc)
|
||||
break
|
||||
|
||||
logging.info(f"开始使用线程池处理 {len(filtered_docs)} 个文档...")
|
||||
|
||||
# 使用线程池处理所有文档
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
futures = []
|
||||
for doc in filtered_docs:
|
||||
future = executor.submit(self._process_document, doc, path_terms)
|
||||
futures.append(future)
|
||||
|
||||
# 等待所有任务完成
|
||||
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||
try:
|
||||
prefix = future.result()
|
||||
if i % 10 == 0:
|
||||
logging.info(f"已完成 {i+1}/{len(futures)} 个文档的处理")
|
||||
except Exception as e:
|
||||
logging.error(f"处理文档时出错: {str(e)}", exc_info=True)
|
||||
|
||||
# 保存最终结果
|
||||
for prefix, terms in path_terms.items():
|
||||
# 为每个路径保存单独的文件
|
||||
output_file = os.path.join(output_dir, f"{prefix.split('(')[0]}_nouns.json")
|
||||
self._save_terms_to_file(terms, output_file)
|
||||
logging.info(f"{prefix} 处理完成,共提取 {len(terms)} 个专业名词,已保存到 {output_file}")
|
||||
|
||||
def _save_terms_to_file(self, terms, output_file):
|
||||
"""保存术语列表到文件"""
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(terms, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
# 从环境变量获取配置
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
|
||||
# os.environ["LLM_MODEL_NAME"] = "Qwen/Qwen2.5-72B-Instruct-128K"
|
||||
|
||||
extractor = WikijsNounsExtractor(api_key=api_key, base_url=base_url, model_name=os.getenv("LLM_MODEL_NAME"))
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(current_dir, "..", "..", "data", "wiki_extracted_nouns")
|
||||
extractor.process_all_documents(output_dir=output_dir, max_concurrency=2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置日志输出到文件,并设置格式
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
date_format = '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
# 创建一个控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(logging.Formatter(log_format, date_format))
|
||||
|
||||
# 获取根日志记录器并添加处理器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.INFO)
|
||||
root_logger.addHandler(console_handler)
|
||||
main()
|
||||
@@ -75,15 +75,8 @@ class QueryRewriteProcessor:
|
||||
dify_base_url: Dify API基础URL
|
||||
"""
|
||||
# 初始化意图识别器
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = base_url or os.getenv("OPENAI_API_BASE")
|
||||
self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
# 使用asyncio.run()运行异步create方法
|
||||
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
model_name=self.model_name
|
||||
))
|
||||
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create())
|
||||
self.dify_query_retrieval = DifyQueryRetrieval(api_key=dify_api_key, base_url=dify_base_url)
|
||||
|
||||
def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
@@ -174,7 +167,7 @@ class QueryRewriteProcessor:
|
||||
return []
|
||||
|
||||
def process_query(self, query: str,
|
||||
conversation_context: str = "",
|
||||
conversation_context: Dict = None,
|
||||
chat_history: List[Dict[str, str]] = None,
|
||||
previous_slots: Dict[str, str] = None,
|
||||
enable_retrieval: bool = False):
|
||||
@@ -196,12 +189,17 @@ class QueryRewriteProcessor:
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
if conversation_context is None:
|
||||
conversation_context = {}
|
||||
|
||||
current_softname = conversation_context.get("current_softname", "")
|
||||
result = asyncio.run(self.recognizer_async.process_query_async(query,
|
||||
conversation_context=conversation_context,
|
||||
chat_history=chat_history,
|
||||
previous_slots=previous_slots,
|
||||
enable_query_expansion=True,
|
||||
use_jieba=True))
|
||||
use_jieba=True,
|
||||
cur_soft_name=current_softname))
|
||||
|
||||
# 提取分类信息
|
||||
classification = result["classification"]
|
||||
@@ -414,7 +412,7 @@ def main():
|
||||
# 从环境变量中获取配置,命令行参数优先
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
||||
enable_retrieval = args.enable_retrieval
|
||||
|
||||
# 初始化查询改写处理器
|
||||
@@ -441,8 +439,10 @@ def main():
|
||||
for idx, query in enumerate(examples):
|
||||
if query.strip() == "":
|
||||
continue
|
||||
query="811619150828能看一下这个锁是16的马"
|
||||
conversation_context="当前使用软件:配网计价通D3软件"
|
||||
query="怎么把一个批次拆分成多个批次工程"
|
||||
conversation_context={
|
||||
"current_softname": "配网计价通D3软件"
|
||||
}
|
||||
# 在调试模式下使用完整的参数
|
||||
print(json.dumps(processor.process_query(
|
||||
query,
|
||||
|
||||
@@ -44,7 +44,7 @@ class TermMerger:
|
||||
{items}
|
||||
'''
|
||||
# 配置LLM
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
llm_params = {"temperature": 0.3, "model": model_name}
|
||||
|
||||
@@ -1,573 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: validate_excel_data_batch.py
|
||||
Description: 使用LLM批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写是否正确
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import concurrent.futures
|
||||
from tqdm import tqdm
|
||||
from dotenv import load_dotenv
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.intent_recognition.PromptTemplates import classification_info
|
||||
from rag2_0.intent_recognition.DataModels import *
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
|
||||
|
||||
# 定义验证结果的Pydantic模型
|
||||
class ValidationResult(BaseModel):
|
||||
is_correct: bool = Field(description="验证是否通过")
|
||||
confidence_score: float = Field(description="置信度得分")
|
||||
reason: str = Field(default="", description="得出结论的原因")
|
||||
|
||||
class ExcelDataValidator:
|
||||
"""Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写"""
|
||||
|
||||
def __init__(self, input_file=None, output_file=None, workers=4, debug=False):
|
||||
"""
|
||||
初始化验证器
|
||||
|
||||
Args:
|
||||
input_file: 输入Excel文件路径
|
||||
output_file: 输出结果Excel文件路径
|
||||
workers: 并行工作线程数
|
||||
debug: 是否启用调试模式(串行处理)
|
||||
"""
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
self.input_file = input_file
|
||||
self.output_file = output_file
|
||||
self.workers = workers
|
||||
self.debug = debug
|
||||
self.df = None
|
||||
|
||||
# 设置日志
|
||||
self.setup_logging()
|
||||
|
||||
def setup_logging(self):
|
||||
"""配置日志输出"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
def load_data_from_excel(self, file_path=None):
|
||||
"""
|
||||
从Excel文件中读取数据
|
||||
|
||||
Args:
|
||||
file_path: Excel文件路径,如不提供则使用初始化时的路径
|
||||
|
||||
Returns:
|
||||
DataFrame对象
|
||||
"""
|
||||
file_path = file_path or self.input_file
|
||||
if not file_path:
|
||||
logging.error("未指定输入文件路径", exc_info=True)
|
||||
return None
|
||||
|
||||
try:
|
||||
df = pd.read_excel(file_path)
|
||||
required_columns = ["问题", "问题分类", "问题改写", "槽位信息", "检索的内容"]
|
||||
for col in required_columns:
|
||||
if col not in df.columns:
|
||||
logging.error(f"缺少必要的列: {col}", exc_info=True)
|
||||
return None
|
||||
logging.info(f"成功从{file_path}读取了{len(df)}条数据")
|
||||
self.df = df
|
||||
return df
|
||||
except Exception as e:
|
||||
logging.error(f"读取Excel文件时出错: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def validate_classification(self, llm:OpenAiLLM , query:str, vertical_class:str, sub_class:str):
|
||||
"""
|
||||
验证问题分类是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
vertical_class: 一级分类
|
||||
sub_class: 二级分类
|
||||
|
||||
Returns:
|
||||
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
||||
"""
|
||||
parser = self.create_validation_parser()
|
||||
format_instructions = parser.get_format_instructions()
|
||||
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用,也可能涉及电力造价专业知识。我对用户问题进行了分类,请评估以下问题分类是否正确。
|
||||
|
||||
我目前总共有以下分类:
|
||||
{classification_info}
|
||||
|
||||
问题的分类情况如下:
|
||||
原始问题: {query}
|
||||
一级分类: {vertical_class}
|
||||
二级分类: {sub_class}
|
||||
|
||||
请从专业角度分析这个分类是否准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
||||
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = parser.parse(response.content)
|
||||
return result.is_correct, result.reason, result.confidence_score
|
||||
except Exception as e:
|
||||
logging.warning(f"验证问题分类时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}", 0.0
|
||||
|
||||
def _get_slot_model(self, classification: Classification) -> Optional[type]:
|
||||
"""
|
||||
根据分类结果获取对应的槽位模型类,用于统一提示词处理
|
||||
|
||||
Args:
|
||||
classification: 意图分类结果
|
||||
|
||||
Returns:
|
||||
对应的槽位模型类
|
||||
"""
|
||||
# 软件问题
|
||||
if classification.vertical_classification == "软件问题":
|
||||
if classification.sub_classification == "软件功能":
|
||||
return SoftwareFunctionSlots
|
||||
elif classification.sub_classification == "故障排查":
|
||||
return SoftwareTroubleShootingSlots
|
||||
|
||||
# 业务问题
|
||||
elif classification.vertical_classification == "业务问题":
|
||||
if classification.sub_classification == "专业咨询":
|
||||
return ProfessionalConsultingSlots
|
||||
elif classification.sub_classification == "数据问题":
|
||||
return DataProblemSlots
|
||||
|
||||
# 安装下载注册
|
||||
elif classification.vertical_classification == "安装下载注册":
|
||||
if classification.sub_classification == "后缀名咨询":
|
||||
return FileExtensionConsultingSlots
|
||||
elif classification.sub_classification == "软件锁类":
|
||||
return SoftwareLockSlots
|
||||
elif classification.sub_classification == "安装下载类":
|
||||
return InstallationDownloadSlots
|
||||
elif classification.sub_classification == "问题排查类":
|
||||
return ProblemDiagnosisSlots
|
||||
|
||||
# 其他
|
||||
elif classification.vertical_classification == "其他":
|
||||
return OtherSlots
|
||||
|
||||
return None
|
||||
|
||||
def validate_slot(self, llm, rewrite, slot_info, vertical_class, sub_class):
|
||||
"""
|
||||
验证槽位填充是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
rewrite: 问题改写
|
||||
slot_info: 槽位信息(JSON字符串)
|
||||
|
||||
Returns:
|
||||
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
||||
"""
|
||||
# 解析槽位信息JSON
|
||||
try:
|
||||
if isinstance(slot_info, str) and slot_info.strip():
|
||||
slots = json.loads(slot_info)
|
||||
else:
|
||||
slots = slot_info
|
||||
except:
|
||||
slots = slot_info
|
||||
|
||||
parser = self.create_validation_parser()
|
||||
format_instructions = parser.get_format_instructions()
|
||||
slot_info_prompt = self._get_slot_model(Classification(vertical_classification=vertical_class, sub_classification=sub_class)).model_json_schema()
|
||||
slot_info_prompt = json.dumps(slot_info_prompt, ensure_ascii=False)
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我从用户问题中提取了槽位信息,请评估这些槽位信息是否准确、完整。
|
||||
|
||||
问题改写: {rewrite}
|
||||
槽位模板:{slot_info_prompt}
|
||||
|
||||
填充的槽位信息: {slots}
|
||||
|
||||
槽位信息应该准确提取问题中的关键实体和属性,如软件名称、功能名称、错误信息等。请分析这些槽位是否准确填充,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
||||
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = parser.parse(response.content)
|
||||
return result.is_correct, result.reason, result.confidence_score
|
||||
except Exception as e:
|
||||
logging.warning(f"验证槽位填充时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}", 0.0
|
||||
|
||||
def validate_retrieve_content(self, llm, rewrite, retrieve_content):
|
||||
"""
|
||||
验证检索内容是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
rewrite: 问题改写
|
||||
retrieve_content: 检索内容(可能是JSON字符串或文本)
|
||||
|
||||
Returns:
|
||||
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
||||
"""
|
||||
# 解析检索内容
|
||||
try:
|
||||
if isinstance(retrieve_content, str) and retrieve_content.strip():
|
||||
if retrieve_content.startswith('{') or retrieve_content.startswith('['):
|
||||
content = json.loads(retrieve_content)
|
||||
else:
|
||||
content = retrieve_content
|
||||
else:
|
||||
content = retrieve_content
|
||||
except:
|
||||
content = retrieve_content
|
||||
|
||||
parser = self.create_validation_parser()
|
||||
format_instructions = parser.get_format_instructions()
|
||||
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我针对用户问题检索了相关内容,请评估这些检索内容是否能解答提问。
|
||||
|
||||
问题改写: {rewrite}
|
||||
检索内容: {content}
|
||||
|
||||
检索内容应该与问题主题相关,能够提供有用的信息来回答问题。请分析检索内容是否能解答提问、准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
||||
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = parser.parse(response.content)
|
||||
return result.is_correct, result.reason, result.confidence_score
|
||||
except Exception as e:
|
||||
logging.warning(f"验证检索内容时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}", 0.0
|
||||
|
||||
def validate_rewrite(self, llm, query, rewrite):
|
||||
"""
|
||||
验证问题改写是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
rewrite: 问题改写
|
||||
|
||||
Returns:
|
||||
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
||||
"""
|
||||
parser = self.create_validation_parser()
|
||||
format_instructions = parser.get_format_instructions()
|
||||
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了改写,请评估以下问题改写是否正确。
|
||||
|
||||
原始问题: {query}
|
||||
问题改写: {rewrite}
|
||||
|
||||
问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
||||
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = parser.parse(response.content)
|
||||
return result.is_correct, result.reason, result.confidence_score
|
||||
except Exception as e:
|
||||
logging.warning(f"验证问题改写时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}", 0.0
|
||||
|
||||
def validate_row(self, llm, row_data):
|
||||
"""
|
||||
按顺序验证一行数据中的各个环节
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
row_data: (index, row)元组
|
||||
|
||||
Returns:
|
||||
(index, is_all_correct, error_phase, error_reason, confidence_score): 行索引,是否全部正确,错误环节,错误原因,置信度
|
||||
"""
|
||||
index, row = row_data
|
||||
query = row["问题"]
|
||||
query_class = row.get("问题分类", "")
|
||||
rewrite = row.get("问题改写", "")
|
||||
slot_info = row.get("槽位信息", "")
|
||||
retrieve_content = row.get("检索的内容", "")
|
||||
|
||||
if self.debug:
|
||||
logging.info(f"开始验证行 {index}:")
|
||||
logging.info(f" 问题: {query}")
|
||||
logging.info(f" 问题分类: {query_class}")
|
||||
logging.info(f" 问题改写: {rewrite}")
|
||||
|
||||
try:
|
||||
|
||||
confidence_score = 0.0
|
||||
# 1. 验证问题改写
|
||||
if rewrite:
|
||||
if self.debug:
|
||||
logging.info(f" 验证问题改写...")
|
||||
|
||||
result = self.validate_rewrite(llm, query, rewrite)
|
||||
if isinstance(result, tuple) and len(result) >= 3:
|
||||
is_correct, error_reason, rewrite_confidence = result[:3]
|
||||
confidence_score = max(confidence_score, rewrite_confidence)
|
||||
|
||||
if self.debug:
|
||||
logging.info(f" 问题改写验证结果: {'通过' if is_correct else '不通过'}, 置信度: {rewrite_confidence:.2f}")
|
||||
if not is_correct:
|
||||
logging.info(f" 错误原因: {error_reason}")
|
||||
|
||||
if not is_correct:
|
||||
return index, False, "问题改写", error_reason, rewrite_confidence
|
||||
|
||||
# 2. 验证问题分类
|
||||
if query_class:
|
||||
if self.debug:
|
||||
logging.info(f" 验证问题分类...")
|
||||
|
||||
query_class_list = query_class.split(" - ")
|
||||
if len(query_class_list) >= 2:
|
||||
result = self.validate_classification(llm, rewrite, query_class_list[0], query_class_list[1])
|
||||
if isinstance(result, tuple) and len(result) >= 3:
|
||||
is_correct, error_reason, classification_confidence = result[:3]
|
||||
confidence_score = max(confidence_score, classification_confidence)
|
||||
|
||||
if self.debug:
|
||||
logging.info(f" 问题分类验证结果: {'通过' if is_correct else '不通过'}, 置信度: {classification_confidence:.2f}")
|
||||
if not is_correct:
|
||||
logging.info(f" 错误原因: {error_reason}")
|
||||
|
||||
if not is_correct:
|
||||
return index, False, "问题分类", error_reason, classification_confidence
|
||||
|
||||
|
||||
|
||||
# 3. 验证槽位填充
|
||||
if slot_info:
|
||||
if self.debug:
|
||||
logging.info(f" 验证槽位填充...")
|
||||
|
||||
result = self.validate_slot(llm, rewrite, slot_info, query_class_list[0], query_class_list[1])
|
||||
if isinstance(result, tuple) and len(result) >= 3:
|
||||
is_correct, error_reason, slot_confidence = result[:3]
|
||||
confidence_score = max(confidence_score, slot_confidence)
|
||||
|
||||
if self.debug:
|
||||
logging.info(f" 槽位填充验证结果: {'通过' if is_correct else '不通过'}, 置信度: {slot_confidence:.2f}")
|
||||
if not is_correct:
|
||||
logging.info(f" 错误原因: {error_reason}")
|
||||
|
||||
if not is_correct:
|
||||
return index, False, "槽位填充", error_reason, slot_confidence
|
||||
|
||||
# 4. 验证检索内容
|
||||
if retrieve_content and retrieve_content != "" and pd.notna(retrieve_content):
|
||||
if self.debug:
|
||||
logging.info(f" 验证检索内容...")
|
||||
|
||||
result = self.validate_retrieve_content(llm, query, retrieve_content)
|
||||
if isinstance(result, tuple) and len(result) >= 3:
|
||||
is_correct, error_reason, retrieve_confidence = result[:3]
|
||||
confidence_score = max(confidence_score, retrieve_confidence)
|
||||
|
||||
if self.debug:
|
||||
logging.info(f" 检索内容验证结果: {'通过' if is_correct else '不通过'}, 置信度: {retrieve_confidence:.2f}")
|
||||
if not is_correct:
|
||||
logging.info(f" 错误原因: {error_reason}")
|
||||
|
||||
if not is_correct:
|
||||
return index, False, "检索内容", error_reason, retrieve_confidence
|
||||
|
||||
if self.debug:
|
||||
logging.info(f" 行 {index} 验证完成: 通过, 总置信度: {confidence_score:.2f}")
|
||||
|
||||
return index, True, "", "", confidence_score
|
||||
except Exception as e:
|
||||
error_msg = f"处理行 {index} 时发生错误: {str(e)}"
|
||||
logging.error(error_msg, exc_info=True)
|
||||
return index, False, "处理错误", error_msg, 0.0
|
||||
|
||||
def create_llm_instances(self, count):
|
||||
"""创建多个LLM实例"""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = "deepseek-ai/DeepSeek-R1"
|
||||
|
||||
llm_params = {"temperature": 0.7, "model": model_name}
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
|
||||
return [OpenAiLLM(**llm_params) for _ in range(count)]
|
||||
|
||||
def validate(self, input_file=None, output_file=None, workers=None, debug=None):
|
||||
"""
|
||||
执行验证过程
|
||||
|
||||
Args:
|
||||
input_file: 输入Excel文件路径
|
||||
output_file: 输出结果Excel文件路径
|
||||
workers: 并行工作线程数
|
||||
batch_size: 每批处理的行数(已弃用,保留参数保持兼容)
|
||||
debug: 是否启用调试模式(串行处理)
|
||||
|
||||
Returns:
|
||||
验证后的DataFrame
|
||||
"""
|
||||
input_file = input_file or self.input_file
|
||||
output_file = output_file or self.output_file
|
||||
workers = workers or self.workers
|
||||
debug = debug if debug is not None else self.debug
|
||||
|
||||
# 读取数据
|
||||
df = self.load_data_from_excel(input_file)
|
||||
if df is None:
|
||||
return None
|
||||
|
||||
# 添加验证结果列
|
||||
df["验证结果"] = ""
|
||||
df["错误环节"] = ""
|
||||
df["错误原因"] = ""
|
||||
df["置信度"] = 0.0
|
||||
|
||||
# 准备数据
|
||||
all_rows = list(df.iterrows())
|
||||
|
||||
# 创建LLM实例
|
||||
llm = self.create_llm_instances(1)[0]
|
||||
|
||||
# 根据模式选择处理方式
|
||||
all_results = []
|
||||
if debug:
|
||||
# 调试模式:串行处理
|
||||
logging.info("启用调试模式,使用串行处理...")
|
||||
for i, row_data in enumerate(all_rows):
|
||||
logging.info(f"处理第 {i+1}/{len(all_rows)} 行...")
|
||||
result = self.validate_row(llm, row_data)
|
||||
all_results.append(result)
|
||||
# 实时更新DataFrame
|
||||
index, is_correct, error_phase, error_reason, confidence_score = result
|
||||
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
||||
df.at[index, "错误环节"] = error_phase
|
||||
df.at[index, "错误原因"] = error_reason
|
||||
df.at[index, "置信度"] = confidence_score
|
||||
# 输出当前结果
|
||||
logging.info(f"行 {index} 验证结果: {'通过' if is_correct else '不通过'}, 错误环节: {error_phase}, 错误原因: {error_reason}, 置信度: {confidence_score:.2f}")
|
||||
else:
|
||||
# 正常模式:并行处理,每行单独处理
|
||||
llm_instances = self.create_llm_instances(min(workers, len(all_rows)))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# 为每行分配一个LLM实例
|
||||
future_to_row = {
|
||||
executor.submit(self.validate_row, llm_instances[i % len(llm_instances)], row_data):
|
||||
i for i, row_data in enumerate(all_rows)
|
||||
}
|
||||
|
||||
# 使用tqdm显示进度条
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_row), total=len(all_rows), desc="处理进度"):
|
||||
result = future.result()
|
||||
all_results.append(result)
|
||||
|
||||
# 按行索引排序结果,确保与原始数据顺序一致
|
||||
all_results.sort(key=lambda x: x[0])
|
||||
|
||||
# 将结果填充到DataFrame
|
||||
for result in all_results:
|
||||
if len(result) >= 5:
|
||||
index, is_correct, error_phase, error_reason, confidence_score = result
|
||||
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
||||
df.at[index, "错误环节"] = error_phase
|
||||
df.at[index, "错误原因"] = error_reason
|
||||
df.at[index, "置信度"] = confidence_score
|
||||
else:
|
||||
index, is_correct, error_phase, error_reason = result
|
||||
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
||||
df.at[index, "错误环节"] = error_phase
|
||||
df.at[index, "错误原因"] = error_reason
|
||||
|
||||
# 保存结果
|
||||
if output_file is None:
|
||||
output_file = os.path.join(
|
||||
os.path.dirname(input_file),
|
||||
f"validated_{os.path.basename(input_file)}"
|
||||
)
|
||||
df.to_excel(output_file, index=False)
|
||||
logging.info(f"验证完成,结果已保存至: {output_file}")
|
||||
|
||||
# 输出统计信息
|
||||
self.print_statistics(df)
|
||||
|
||||
return df
|
||||
|
||||
def print_statistics(self, df):
|
||||
"""打印统计信息"""
|
||||
total = len(df)
|
||||
passed = len(df[df["验证结果"] == "通过"])
|
||||
error_stats = df[df["验证结果"] == "不通过"]["错误环节"].value_counts()
|
||||
|
||||
logging.info(f"统计信息: 总计 {total} 条, 通过 {passed} 条, 通过率 {passed/total*100:.2f}%")
|
||||
logging.info("错误环节统计:")
|
||||
for phase, count in error_stats.items():
|
||||
logging.info(f"- {phase}: {count} 条")
|
||||
|
||||
def create_validation_parser(self):
|
||||
"""创建验证结果解析器"""
|
||||
return PydanticOutputParser(pydantic_object=ValidationResult)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 解析命令行参数
|
||||
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "1500条点踩软件问题测试_意图分类.xlsx")
|
||||
output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx")
|
||||
|
||||
parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写")
|
||||
parser.add_argument("--input", "-i", type=str, help="输入Excel文件路径", default=input_excel)
|
||||
parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel)
|
||||
parser.add_argument("--workers", "-w", type=int, default=20, help="并行工作线程数")
|
||||
args = parser.parse_args()
|
||||
logging.info(f"输入文件路径: {args.input}, 输出文件路径: {args.output}, 并行工作线程数: {args.workers}")
|
||||
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||
|
||||
# 创建验证器实例并执行验证
|
||||
validator = ExcelDataValidator(
|
||||
input_file=args.input,
|
||||
output_file=args.output,
|
||||
workers=args.workers,
|
||||
debug=is_debug
|
||||
)
|
||||
validator.validate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,19 +3,15 @@ import json
|
||||
|
||||
from regex import search
|
||||
|
||||
import ijson
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.dify.dify_tool import DifyTool
|
||||
|
||||
df = pd.read_excel("data/excel/已分析数据汇总(第一轮).xlsx")
|
||||
df=df[df["评价"]=="dislike"]
|
||||
dify_tool = DifyTool()
|
||||
|
||||
df = pd.read_excel("data/excel/0714提问数据汇总(已分析)_软件.xlsx")
|
||||
|
||||
msg_id_list = df["msg_id"].tolist()
|
||||
msg_debug_list = {}
|
||||
# 流式解析 JSON 数组
|
||||
with open("data/excel/msg_debug_list.json", "r", encoding="utf-8") as f:
|
||||
# 使用ijson.items直接获取顶层键值对
|
||||
for msg_id, data in ijson.kvitems(f, ''):
|
||||
if msg_id in msg_id_list:
|
||||
msg_debug_list[msg_id] = data
|
||||
|
||||
def get_rewrite_query(intent_node_execution_info)->str:
|
||||
outputs_result =json.loads(intent_node_execution_info['outputs'])
|
||||
@@ -28,7 +24,7 @@ def judge_error_node_and_reason(intent_node_execution_info, knowledge_filter_nod
|
||||
|
||||
outputs_result =json.loads(intent_node_execution_info['outputs'])
|
||||
result["问题改写结果"] = outputs_result['optimize_query']
|
||||
if outputs_result['is_complete'] == False:
|
||||
if outputs_result['is_complete'] == False and outputs_result["has_slot_filling"] == True:
|
||||
result["错误环节"] = "槽点填充"
|
||||
result["错误原因"] = f"槽点缺失"
|
||||
result["具体描述"] = f"缺失内容:{outputs_result['missing_slots']}"
|
||||
@@ -80,6 +76,8 @@ for index, row in df.iterrows():
|
||||
answer = row["回答"]
|
||||
query = row["提问"]
|
||||
rating = row["评价"]
|
||||
if rating != "dislike":
|
||||
continue
|
||||
class_type = row["问题分类"]
|
||||
dislike_reason = row["点踩原因"]
|
||||
if dislike_reason is None or pd.isna(dislike_reason):
|
||||
@@ -87,7 +85,8 @@ for index, row in df.iterrows():
|
||||
|
||||
answer_wiki_name = row["关联词条"]
|
||||
search_wiki = row["检索到的词条"]
|
||||
node_executions_info = msg_debug_list[msg_id]
|
||||
msg_debug_info = dify_tool.get_message_debug_info_by_id(msg_id)
|
||||
node_executions_info = msg_debug_info["workflow_node_executions_info"]
|
||||
intent_node_execution_info = [node_execution_info for node_execution_info in node_executions_info
|
||||
if node_execution_info["title"] == "意图识别结果解析"]
|
||||
|
||||
@@ -109,7 +108,7 @@ for index, row in df.iterrows():
|
||||
print(f"msg_id: {msg_id} 处理失败: {e}")
|
||||
continue
|
||||
|
||||
df.to_excel("data/excel/已分析数据汇总(第一轮)_分析.xlsx", index=False)
|
||||
df.to_excel("data/excel/0714提问数据汇总(已分析)_软件_分析.xlsx", index=False)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -84,15 +84,14 @@ async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/query_type", summary="异步检索API")
|
||||
async def query_type(query: str, query_type: str, workflow_run_id:str):
|
||||
async def query_type(query_type: str, workflow_run_id:str):
|
||||
try:
|
||||
# 记录请求
|
||||
logger.info(f"接收到请求: {query}, 类型: {query_type}, workflow_run_id: {workflow_run_id}")
|
||||
logger.info(f"接收到请求: 类型: {query_type}, workflow_run_id: {workflow_run_id}")
|
||||
|
||||
# 保存 提问、问题类型、当前时间戳到json
|
||||
timestamp = datetime.datetime.now().isoformat()
|
||||
query_data = {
|
||||
"query": query,
|
||||
"query_type": query_type,
|
||||
"timestamp": timestamp,
|
||||
"workflow_run_id": workflow_run_id
|
||||
@@ -127,7 +126,7 @@ async def query_type(query: str, query_type: str, workflow_run_id:str):
|
||||
logger.error(f"保存查询数据时出错: {str(e)}", exc_info=True)
|
||||
|
||||
# 返回响应
|
||||
content = f"<strong>当前提问</strong>: {query}<br><strong>问题类型</strong>: {query_type}<br><strong>操作是否成功</strong>: {'成功' if success else '失败'}"
|
||||
content = f"<strong>问题类型</strong>: {query_type}<br><strong>操作是否成功</strong>: {'成功' if success else '失败'}"
|
||||
return HTMLResponse(content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"处理请求时出错: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -84,7 +84,7 @@ class DifyComparisonTester:
|
||||
def get_llm(self, **kwargs):
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model = os.getenv("LLM_MODEL_NAME")
|
||||
model = os.getenv("MODEL_NAME")
|
||||
return OpenAiLLM(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
|
||||
def find_wiki_link(self, row) -> str | None:
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from rag2_0.dify.dify_client import DifyApi
|
||||
|
||||
soft_name_map = {
|
||||
"配网造价软件知识(new)": "配网计价通D3软件",
|
||||
"西藏造价软件知识(new)": "西藏计价通Z1软件",
|
||||
"储能C1计价通软件知识(new)": "储能计价通C1软件",
|
||||
"技改检修工程计价通T1软件知识(new)": "技改检修工程计价通T1软件",
|
||||
"技改检修清单计价通T1软件知识(new)": "技改检修清单计价通T1软件",
|
||||
"电力建设计价通(2018)软件知识(new)": "电力建设计价通软件",
|
||||
"下载安装注册(new)": "下载安装注册",
|
||||
}
|
||||
|
||||
soft_wiki_file_name = {
|
||||
"配网计价通D3软件": ["配网计价通D3软件.txt", []],
|
||||
"西藏计价通Z1软件": ["西藏计价通Z1软件.txt", []],
|
||||
"储能计价通C1软件": ["储能计价通C1软件.txt", []],
|
||||
"技改检修工程计价通T1软件": ["技改检修工程计价通T1软件.txt", []],
|
||||
"技改检修清单计价通T1软件": ["技改检修清单计价通T1软件.txt", []],
|
||||
"电力建设计价通软件": ["电力建设计价通软件.txt", []],
|
||||
"下载安装注册": ["下载安装注册.txt", []],
|
||||
}
|
||||
|
||||
def get_soft_wiki_titles(dify_api, soft_name_map, soft_wiki_file_name):
|
||||
"""获取每个软件的wiki标题列表"""
|
||||
dataset_list = dify_api.get_all_dataset_list()
|
||||
soft_name_map_keys = list(soft_name_map.keys())
|
||||
for dataset in dataset_list:
|
||||
if dataset["name"] not in soft_name_map_keys:
|
||||
continue
|
||||
dataset_name = dataset["name"]
|
||||
dataset_id = dataset["id"]
|
||||
documents = dify_api.get_documents(dataset_id=dataset_id)
|
||||
for document_id, doc_info in documents.items():
|
||||
document_name = doc_info["name"]
|
||||
wiki_name = document_name.split("/")[-1]
|
||||
wiki_title = re.sub(r'^(.*?)|^\(.*?\)', '', wiki_name)
|
||||
if wiki_title not in soft_wiki_file_name[soft_name_map[dataset_name]][1]:
|
||||
soft_wiki_file_name[soft_name_map[dataset_name]][1].append(wiki_title)
|
||||
return soft_wiki_file_name
|
||||
|
||||
def save_wiki_titles(soft_wiki_file_name, output_dir="data/wiki_data"):
|
||||
"""将wiki标题列表保存到对应txt文件"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
for soft_name, (txt_file_name, wiki_titles) in soft_wiki_file_name.items():
|
||||
output_path = os.path.join(output_dir, txt_file_name)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for title in wiki_titles:
|
||||
f.write(title + "\n")
|
||||
print(f"已保存 {soft_name} 的wiki标题列表到 {output_path},共 {len(wiki_titles)} 条")
|
||||
|
||||
def main():
|
||||
dify_api = DifyApi()
|
||||
wiki_titles = get_soft_wiki_titles(dify_api, soft_name_map, soft_wiki_file_name)
|
||||
save_wiki_titles(wiki_titles)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,151 +0,0 @@
|
||||
from rag2_0.dify.dify_tool import NewWorkflowChat
|
||||
import pandas as pd
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
class ChatDifyByWorkorder:
|
||||
|
||||
def __init__(self, api_key=None, base_url="https://api.dify.ai/v1") -> None:
|
||||
"""
|
||||
初始化ChatDifyByWorkorder类
|
||||
|
||||
Args:
|
||||
api_key: Dify API密钥,默认为None
|
||||
base_url: Dify API的基础URL,默认为"https://api.dify.ai/v1"
|
||||
"""
|
||||
baseurl = "http://172.20.0.145/v1"
|
||||
new_workflow_api_key = "app-qxsSybCs7ABiKlC1JabTYVn6"
|
||||
self.new_chat = NewWorkflowChat(api_key=new_workflow_api_key, base_url=baseurl)
|
||||
self.new_chat_answer = NewWorkflowChat(api_key=new_workflow_api_key, base_url=baseurl)
|
||||
|
||||
|
||||
def get_soft_name(self, row) -> str:
|
||||
if "博微配网计价通D3" in row["产品线"]:
|
||||
return "博微配网计价通D3"
|
||||
elif "博微电力建设计价通软件" in row["产品线"]:
|
||||
return "电力建设计价通软件"
|
||||
elif "新能源系列" in row["产品线"] and "博微新型储能电站建设计价通C1软件" in row["产品名称"]:
|
||||
return "储能C1软件"
|
||||
elif "博微西藏计价通Z1" in row["产品线"]:
|
||||
return "西藏计价通Z1"
|
||||
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-概预算" in row["产品名称"]:
|
||||
return "技改检修工程计价通T1"
|
||||
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-清单" in row["产品名称"]:
|
||||
return "检修清单计价通T1"
|
||||
return ""
|
||||
|
||||
def process_query(self, q:str) -> dict:
|
||||
"""
|
||||
发送问题并获取回答及相关工作流信息
|
||||
|
||||
Args:
|
||||
q: 用户问题
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和工作流信息的字典
|
||||
"""
|
||||
retry_count = 0
|
||||
max_retries = 2
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 发送问题获取回答和消息ID
|
||||
result = self.new_chat.process_question(q)
|
||||
return result
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count <= max_retries:
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
def process_answer(self, q:str) -> dict:
|
||||
"""
|
||||
发送问题并获取回答及相关工作流信息
|
||||
|
||||
Args:
|
||||
q: 用户问题
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和工作流信息的字典
|
||||
"""
|
||||
retry_count = 0
|
||||
max_retries = 2
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 发送问题获取回答和消息ID
|
||||
result = self.new_chat_answer.process_question(q)
|
||||
return result
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count <= max_retries:
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
def process_row(self, row):
|
||||
"""处理单行数据"""
|
||||
soft_name = self.get_soft_name(row=row)
|
||||
if soft_name == "":
|
||||
return None
|
||||
|
||||
# 使用线程池并发执行查询
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
# 提交两个任务并获取Future对象
|
||||
query_future = executor.submit(self.process_query, q=f"{soft_name},{row['客户问题']}")
|
||||
answer_future = executor.submit(self.process_answer, q=f"{soft_name},{row['解决方案']}")
|
||||
|
||||
# 获取结果
|
||||
query_result = query_future.result()
|
||||
answer_result = answer_future.result()
|
||||
except Exception as e:
|
||||
print(f"处理工单 {row.get('工单编号', '未知')} 时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
worker_id = str(row["工单编号"])
|
||||
if query_result is None or answer_result is None:
|
||||
print("处理对话出现错误")
|
||||
return None
|
||||
|
||||
worker_order_info = {
|
||||
"工单编号": worker_id,
|
||||
"用户问题": row['客户问题'],
|
||||
"解决方案": row['解决方案'],
|
||||
"AI回答": query_result["新流程答案"],
|
||||
"用户问题检索到的词条": query_result["新检索词条"],
|
||||
"解决方案检索到的词条": answer_result["新检索词条"],
|
||||
}
|
||||
return worker_order_info
|
||||
|
||||
def run(self, excel_path:str):
|
||||
df_data = pd.read_excel(excel_path)
|
||||
list_worker_order_info = []
|
||||
|
||||
# 创建进度条
|
||||
with tqdm(total=len(df_data), desc="处理工单") as pbar:
|
||||
# 创建线程池,最大并发数可以根据需要调整
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
# 提交所有任务
|
||||
future_to_row = {executor.submit(self.process_row, row): idx for idx, row in df_data.iterrows()}
|
||||
|
||||
# 处理完成的任务
|
||||
for future in concurrent.futures.as_completed(future_to_row):
|
||||
result = future.result()
|
||||
if result is not None:
|
||||
list_worker_order_info.append(result)
|
||||
pbar.update(1)
|
||||
|
||||
return list_worker_order_info
|
||||
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
worker_chat = ChatDifyByWorkorder()
|
||||
result = worker_chat.run(excel_path="data/excel/工单记录_均衡提取2000条.xlsx")
|
||||
# 可以选择保存结果到Excel
|
||||
if result:
|
||||
pd.DataFrame(result).to_excel("data/excel/工单处理结果.xlsx", index=False)
|
||||
@@ -1,4 +1,5 @@
|
||||
|
||||
__all__ = ["ChatClient", "CompletionClient", "DifyClient"]
|
||||
__all__ = ["ChatClient", "CompletionClient", "DifyClient", "DifyApi"]
|
||||
|
||||
from .client import ChatClient, CompletionClient, DifyClient
|
||||
from .dify_api import DifyApi
|
||||
|
||||
@@ -14,12 +14,12 @@ class DifyApi:
|
||||
用于与Dify API进行交互的类。
|
||||
"""
|
||||
|
||||
def __init__(self, dify_url: str="http://10.1.16.39/v1",
|
||||
dify_dataset_api_key: str="dataset-skLjmPVonjHo119OWNf3kAmY",
|
||||
dify_app_api_key: str="app-wUdkWJx5zeOvmvBUZizMoSw3"):
|
||||
self.dify_url = dify_url
|
||||
self.dify_dataset_api_key = dify_dataset_api_key
|
||||
self.dify_app_api_key = dify_app_api_key
|
||||
def __init__(self, dify_url: str=None,
|
||||
dify_dataset_api_key: str=None,
|
||||
dify_app_api_key: str=None):
|
||||
self.dify_url = dify_url if dify_url else os.environ.get('DIFY_BSAE_URL')
|
||||
self.dify_dataset_api_key = dify_dataset_api_key if dify_dataset_api_key else os.environ.get('DIFY_DATASET_KEY')
|
||||
self.dify_app_api_key = dify_app_api_key if dify_app_api_key else os.environ.get('DIFY_APP_KEY')
|
||||
|
||||
def get_document_indexing_status(self, datasets_id: str, batch: str) -> bool:
|
||||
"""
|
||||
|
||||
@@ -449,7 +449,7 @@ content: "{content}"
|
||||
"""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model = os.getenv("LLM_MODEL_NAME")
|
||||
model = os.getenv("MODEL_NAME")
|
||||
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model=model)
|
||||
response = llm.invoke(user_prompt=prompt, need_retry=True)
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
||||
# 定义请求模型
|
||||
class IntentRecognizeRequest(BaseModel):
|
||||
query: str
|
||||
conversation_context: str = ""
|
||||
conversation_context: Dict = None
|
||||
chat_history: Optional[List] = None
|
||||
previous_slots: str | Dict = None
|
||||
|
||||
@@ -89,13 +89,15 @@ _instance = None
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global _instance
|
||||
# 初始化AsyncIntentRecognizer实例
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
_instance = await AsyncIntentRecognizer.create(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
_instance = await AsyncIntentRecognizer.create()
|
||||
logger.info("AsyncIntentRecognizer初始化完成")
|
||||
|
||||
@app.post("/intent_recognize1")
|
||||
async def intent_recognize(request: Request):
|
||||
data = await request.json()
|
||||
print(data)
|
||||
return {"message": "success"}
|
||||
|
||||
@app.post("/intent_recognize", response_model=IntentRecognizeResponse, summary="意图识别", description="识别用户查询的意图并进行问题改写")
|
||||
async def intent_recognize(request: IntentRecognizeRequest):
|
||||
try:
|
||||
@@ -103,14 +105,15 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
||||
raise HTTPException(status_code=400, detail="缺少query参数")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
current_softname = request.conversation_context.get("current_softname", "")
|
||||
result = await _instance.process_query_async(
|
||||
query=request.query,
|
||||
conversation_context=request.conversation_context,
|
||||
chat_history=request.chat_history,
|
||||
previous_slots=request.previous_slots,
|
||||
use_jieba=True,
|
||||
enable_query_expansion=True
|
||||
enable_query_expansion=True,
|
||||
cur_soft_name=current_softname
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
import pandas as pd
|
||||
import random
|
||||
import math
|
||||
|
||||
work_order_excel="data/excel/6万工单记录.xlsx"
|
||||
|
||||
soft_row_data={
|
||||
"博微配网计价通D3":{"基本功能":[], "高级功能":[]},
|
||||
"储能C1软件":{"基本功能":[], "高级功能":[]},
|
||||
"西藏计价通Z1":{"基本功能":[], "高级功能":[]},
|
||||
"技改检修工程计价通T1":{"基本功能":[], "高级功能":[]},
|
||||
"检修清单计价通T1":{"基本功能":[], "高级功能":[]},
|
||||
"电力建设计价通软件":{"基本功能":[], "高级功能":[]},
|
||||
}
|
||||
|
||||
df = pd.read_excel(work_order_excel)
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
if pd.isna(row["产品线"]):
|
||||
continue
|
||||
|
||||
if "博微配网计价通D3" in row["产品线"]:
|
||||
soft_row_data["博微配网计价通D3"][row["问题类型"]].append((idx, row))
|
||||
elif "博微电力建设计价通软件" in row["产品线"]:
|
||||
soft_row_data["电力建设计价通软件"][row["问题类型"]].append((idx, row))
|
||||
elif "新能源系列" in row["产品线"] and "博微新型储能电站建设计价通C1软件" in row["产品名称"]:
|
||||
soft_row_data["储能C1软件"][row["问题类型"]].append((idx, row))
|
||||
elif "博微西藏计价通Z1" in row["产品线"]:
|
||||
soft_row_data["西藏计价通Z1"][row["问题类型"]].append((idx, row))
|
||||
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-概预算" in row["产品名称"]:
|
||||
soft_row_data["技改检修工程计价通T1"][row["问题类型"]].append((idx, row))
|
||||
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-清单" in row["产品名称"]:
|
||||
soft_row_data["检修清单计价通T1"][row["问题类型"]].append((idx, row))
|
||||
|
||||
# 计算每个软件和功能类型的数据量
|
||||
total_count = 0
|
||||
counts = {}
|
||||
for software, types in soft_row_data.items():
|
||||
counts[software] = {}
|
||||
for type_name, rows in types.items():
|
||||
counts[software][type_name] = len(rows)
|
||||
total_count += len(rows)
|
||||
|
||||
print(f"原始数据总量: {total_count}条")
|
||||
for software, types in counts.items():
|
||||
print(f"{software}: 基本功能 {types['基本功能']}条, 高级功能 {types['高级功能']}条")
|
||||
|
||||
# 计算均衡提取的数量
|
||||
total_target = 2000
|
||||
categories_count = sum(len(types) for types in soft_row_data.values())
|
||||
per_category_target = math.ceil(total_target / categories_count)
|
||||
|
||||
# 均衡提取数据
|
||||
balanced_data = []
|
||||
extracted_counts = {}
|
||||
extracted_indices = set() # 使用集合存储已提取数据的索引
|
||||
|
||||
for software, types in soft_row_data.items():
|
||||
extracted_counts[software] = {}
|
||||
|
||||
for type_name, rows in types.items():
|
||||
# 如果数据量不足,全部提取;否则随机抽取目标数量
|
||||
if len(rows) <= per_category_target:
|
||||
extracted = rows
|
||||
else:
|
||||
extracted = random.sample(rows, per_category_target)
|
||||
|
||||
extracted_counts[software][type_name] = len(extracted)
|
||||
for idx, row in extracted:
|
||||
extracted_indices.add(idx) # 记录已提取数据的索引
|
||||
balanced_data.append(row)
|
||||
|
||||
# 数据量不足2000时,从剩余数据中补充
|
||||
remaining_target = total_target - len(balanced_data)
|
||||
if remaining_target > 0:
|
||||
# 收集所有未被选中的数据
|
||||
remaining_data = []
|
||||
for software, types in soft_row_data.items():
|
||||
for type_name, rows in types.items():
|
||||
# 添加未被选中的数据
|
||||
for idx, row in rows:
|
||||
if idx not in extracted_indices:
|
||||
remaining_data.append(row)
|
||||
|
||||
# 如果剩余数据足够,随机抽取补充
|
||||
if len(remaining_data) >= remaining_target:
|
||||
additional_data = random.sample(remaining_data, remaining_target)
|
||||
else:
|
||||
additional_data = remaining_data
|
||||
|
||||
balanced_data.extend(additional_data)
|
||||
|
||||
# 输出结果
|
||||
print(f"\n均衡提取后数据总量: {len(balanced_data)}条")
|
||||
for software, types in extracted_counts.items():
|
||||
print(f"{software}: 基本功能 {types['基本功能']}条, 高级功能 {types['高级功能']}条")
|
||||
|
||||
# 将均衡提取的数据转换为DataFrame并保存
|
||||
balanced_df = pd.DataFrame(balanced_data)
|
||||
balanced_df.to_excel("data/excel/均衡提取2000条工单.xlsx", index=False)
|
||||
print(f"\n已将均衡提取的{len(balanced_data)}条数据保存至'data/excel/均衡提取2000条工单.xlsx'")
|
||||
@@ -39,10 +39,21 @@ from .ProfessionalNounVector import ProfessionalNounRetriever, AsyncProfessional
|
||||
from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM, SiliconFlowReRankerModel
|
||||
|
||||
class AsyncIntentRecognizer:
|
||||
SOFT_WIKI_PATH = "data/wiki_data"
|
||||
SOFT_NAMETOWIKI_MAP = {
|
||||
"配网计价通D3软件": "配网计价通D3软件.txt",
|
||||
"西藏计价通Z1软件": "西藏计价通Z1软件.txt",
|
||||
"储能计价通C1软件": "储能计价通C1软件.txt",
|
||||
"技改检修工程计价通T1软件": "技改检修工程计价通T1软件.txt",
|
||||
"技改检修清单计价通T1软件": "技改检修清单计价通T1软件.txt",
|
||||
"电力建设计价通软件": "电力建设计价通软件.txt",
|
||||
"下载安装注册": "下载安装注册.txt",
|
||||
}
|
||||
|
||||
"""
|
||||
异步意图识别和问题改写类
|
||||
"""
|
||||
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo", vector_index_dir: str = None):
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化异步意图识别器
|
||||
|
||||
@@ -52,51 +63,53 @@ class AsyncIntentRecognizer:
|
||||
model_name: 要使用的模型名称
|
||||
vector_index_dir: 向量索引目录,如果为None则使用默认目录
|
||||
"""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
||||
# 初始化LLM
|
||||
llm_params = {
|
||||
"temperature": 0.2, # 降低随机性,使结果更确定
|
||||
"top_p": 0.7,
|
||||
"model": model_name
|
||||
"model": model_name,
|
||||
"api_key": api_key,
|
||||
"base_url": base_url
|
||||
}
|
||||
|
||||
# 如果提供了API密钥,则使用提供的密钥
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
|
||||
# 如果提供了自定义URL,则使用提供的URL
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
|
||||
self._llm = OpenAiLLM(**llm_params)
|
||||
llm_params["model"] = os.getenv("MINI_MODEL_NAME", "gpt-3.5-turbo")
|
||||
self._llm_mini = OpenAiLLM(**llm_params)
|
||||
|
||||
# 加载suffix关键词
|
||||
self._suffix_keywords = self._load_suffix_keywords()
|
||||
|
||||
# 加载软件词条名称库
|
||||
self._soft_wiki_library = self._load_soft_wiki_library()
|
||||
# 异步检索器将在create方法中初始化
|
||||
self._noun_retriever = None
|
||||
self._api_key = api_key
|
||||
self._vector_index_dir = vector_index_dir
|
||||
|
||||
def _load_soft_wiki_library(self):
|
||||
"""
|
||||
加载软件wiki库
|
||||
"""
|
||||
SOFT_WIKI_LIBRARY = {}
|
||||
for soft_name, wiki_file_name in self.SOFT_NAMETOWIKI_MAP.items():
|
||||
with open(f"{self.SOFT_WIKI_PATH}/{wiki_file_name}", "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
# 去除空行
|
||||
lines = [line.strip() for line in lines if line.strip()]
|
||||
SOFT_WIKI_LIBRARY[soft_name] = lines
|
||||
return SOFT_WIKI_LIBRARY
|
||||
|
||||
@classmethod
|
||||
async def create(cls, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo", vector_index_dir: str = None):
|
||||
async def create(cls):
|
||||
"""
|
||||
异步工厂方法:创建并初始化异步意图识别器实例
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API密钥,如果为None则从环境变量获取
|
||||
base_url: OpenAI API基础URL,如果为None则使用默认URL
|
||||
model_name: 要使用的模型名称
|
||||
vector_index_dir: 向量索引目录,如果为None则使用默认目录
|
||||
|
||||
Returns:
|
||||
初始化完成的AsyncIntentRecognizer实例
|
||||
"""
|
||||
instance = cls(api_key, base_url, model_name, vector_index_dir)
|
||||
instance = cls()
|
||||
# 异步初始化名词检索器
|
||||
instance._noun_retriever = await AsyncProfessionalNounRetriever.create(
|
||||
api_key=api_key,
|
||||
index_dir=vector_index_dir
|
||||
)
|
||||
instance._noun_retriever = await AsyncProfessionalNounRetriever.create()
|
||||
return instance
|
||||
|
||||
def _load_suffix_keywords(self, filepath: str = None) -> List[str]:
|
||||
@@ -402,11 +415,12 @@ class AsyncIntentRecognizer:
|
||||
return f"通过博微软件助手查询软件锁信息,锁注册号为{lock_number}"
|
||||
|
||||
|
||||
async def process_query_async(self, query: str, conversation_context: str = "",
|
||||
async def process_query_async(self, query: str, conversation_context: Dict = None,
|
||||
chat_history: List[Dict[str, str]] = None,
|
||||
previous_slots: Dict[str, Any] = None,
|
||||
use_jieba: bool = False,
|
||||
enable_query_expansion: bool = False) -> Dict[str, Any]:
|
||||
enable_query_expansion: bool = False,
|
||||
cur_soft_name: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
异步处理用户问题的完整流程
|
||||
|
||||
@@ -417,7 +431,7 @@ class AsyncIntentRecognizer:
|
||||
previous_slots: 历史槽位信息
|
||||
use_jieba: 是否使用jieba分词辅助提取关键词
|
||||
enable_query_expansion: 是否启用查询扩展
|
||||
|
||||
cur_soft_name: 当前查询的软件名称
|
||||
Returns:
|
||||
包含分类、关键词、改写和槽位填充结果的字典
|
||||
"""
|
||||
@@ -425,7 +439,8 @@ class AsyncIntentRecognizer:
|
||||
chat_history = []
|
||||
if previous_slots is None:
|
||||
previous_slots = {}
|
||||
|
||||
if conversation_context is None:
|
||||
conversation_context = {}
|
||||
# 步骤: 并行执行提问扩展
|
||||
query_expand_tasks = []
|
||||
if enable_query_expansion:
|
||||
@@ -437,9 +452,9 @@ class AsyncIntentRecognizer:
|
||||
# 5.2: Follow Up Questions
|
||||
asyncio.create_task(self._generate_follow_up_questions_async(query, chat_history, conversation_context)),
|
||||
|
||||
# 5.3: HyDE
|
||||
# asyncio.create_task(self._generate_hypothetical_document_async(query, chat_history, conversation_context)),
|
||||
|
||||
# 5.3: 文档查询
|
||||
asyncio.create_task(self._find_matching_software_docs_async(query, cur_soft_name, chat_history)),
|
||||
|
||||
# 5.4: 多问题查询
|
||||
asyncio.create_task(self._generate_multi_questions_async(query, chat_history, conversation_context))
|
||||
]
|
||||
@@ -497,23 +512,22 @@ class AsyncIntentRecognizer:
|
||||
# 收集结果
|
||||
step_back_result = query_expand_results[0] if query_expand_results[0] else StepBackPrompt(original_query=query, can_use_back_prompt=False, step_back_query=[query])
|
||||
follow_up_result = query_expand_results[1] if query_expand_results[1] else FollowUpQuestions(original_query=query, follow_up_query=query)
|
||||
# hyde_result = query_expand_results[2] if query_expand_results[2] else HypotheticalDocument(original_query=query, hypothetical_answer="")
|
||||
multi_questions_result = query_expand_results[2] if query_expand_results[2] else MultiQuestions(original_query=query, sub_questions=[query])
|
||||
wiki_result = query_expand_results[2] if query_expand_results[2] else []
|
||||
multi_questions_result = query_expand_results[3] if query_expand_results[3] else MultiQuestions(original_query=query, sub_questions=[query])
|
||||
|
||||
all_questions = multi_questions_result.sub_questions
|
||||
all_questions.append(query)
|
||||
all_questions.append(rewrite.rewrite)
|
||||
all_questions.extend(step_back_result.step_back_query)
|
||||
all_questions.append(follow_up_result.follow_up_query)
|
||||
# all_questions.append(hyde_result.hypothetical_answer)
|
||||
all_questions.extend(wiki_result)
|
||||
all_questions = list(set(all_questions))
|
||||
|
||||
query_expand = {
|
||||
"all": all_questions,
|
||||
"step_back": step_back_result.model_dump(),
|
||||
"follow_up": follow_up_result.model_dump(),
|
||||
# "hyde": hyde_result.model_dump(),
|
||||
"multi_questions": multi_questions_result.model_dump()
|
||||
"multi_questions": multi_questions_result.model_dump(),
|
||||
}
|
||||
|
||||
# 返回所有结果
|
||||
@@ -721,45 +735,72 @@ class AsyncIntentRecognizer:
|
||||
logging.error(f"异步后续问题生成失败: {e}", exc_info=True)
|
||||
return FollowUpQuestions(original_query=query, follow_up_query=query)
|
||||
|
||||
async def _generate_hypothetical_document_async(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> HypotheticalDocument:
|
||||
async def _find_matching_software_docs_async(self, query: str, soft_name: str,
|
||||
chat_history: List[Dict[str, str]] = None,
|
||||
top_k: int = 3) -> List[str]:
|
||||
"""
|
||||
异步生成假设性文档
|
||||
异步查找软件文档中与用户问题最匹配的几行内容
|
||||
|
||||
Args:
|
||||
query: 用户原始问题
|
||||
query: 用户问题
|
||||
soft_name: 软件名称
|
||||
chat_history: 历史对话记录
|
||||
conversation_context: 会话背景信息
|
||||
top_k: 返回的匹配行数,默认为3
|
||||
|
||||
Returns:
|
||||
假设性文档结果
|
||||
匹配的文档行列表
|
||||
"""
|
||||
if chat_history is None:
|
||||
chat_history = []
|
||||
|
||||
# 检查软件名称是否在支持的列表中
|
||||
if soft_name not in self.SOFT_NAMETOWIKI_MAP:
|
||||
return []
|
||||
|
||||
# 获取软件文档内容
|
||||
soft_docs = self._soft_wiki_library.get(soft_name, [])
|
||||
if not soft_docs:
|
||||
return []
|
||||
soft_docs.extend(self._soft_wiki_library.get("下载安装注册", []))
|
||||
# soft_docs=soft_docs[:50]
|
||||
# 构建文档字符串,只包含行内容
|
||||
soft_docs_str = "\n".join(f"{doc.strip()}" for i, doc in enumerate(soft_docs))
|
||||
|
||||
# 构建提示词,让LLM选择最匹配的行
|
||||
prompt = f"""
|
||||
{soft_docs_str}
|
||||
================================
|
||||
以上为软件功能操作、常见问题排查等功能,结合历史对话,请输出与当前提问最相关的1-3个功能名称,
|
||||
使用Json格式输出,如下:
|
||||
[{{"content": "行内容"}},...]
|
||||
当前问题: {query}
|
||||
历史对话: {json.dumps(chat_history, ensure_ascii=False)}
|
||||
"""
|
||||
hyde_start_time = time.time()
|
||||
# 准备提示词
|
||||
hyde_parser = PydanticOutputParser(pydantic_object=HypotheticalDocument)
|
||||
formatted_prompt = hyde_prompt.format(
|
||||
query=query,
|
||||
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
|
||||
# conversation_context=conversation_context,
|
||||
output_format=hyde_parser.get_format_instructions()
|
||||
)
|
||||
|
||||
try:
|
||||
# 异步调用LLM
|
||||
response = await self._llm.invoke_async(formatted_prompt, False)
|
||||
start_time = time.time()
|
||||
response = await self._llm.invoke_async(prompt, False, response_format={"type": "json_object"})
|
||||
end_time = time.time()
|
||||
|
||||
# 解析输出
|
||||
response.content = response.content.strip()
|
||||
clean_output = re.sub(r'<think>.*?</think>', '', response.content, flags=re.DOTALL)
|
||||
parsed_output = hyde_parser.parse(clean_output)
|
||||
hyde_end_time = time.time()
|
||||
hyde_time = hyde_end_time - hyde_start_time
|
||||
logging.debug(f"异步假设性文档生成耗时统计 - 总耗时: {hyde_time:.2f}秒")
|
||||
return parsed_output
|
||||
# 解析JSON响应
|
||||
try:
|
||||
wiki_names = []
|
||||
json_response = json.loads(response.content)
|
||||
for match in json_response:
|
||||
wiki_names.append(match["content"])
|
||||
logging.debug(f"软件文档匹配耗时: {end_time - start_time:.2f}秒")
|
||||
return wiki_names
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"解析JSON响应时出错: {e}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
# 如果解析失败,返回空的假设性回答
|
||||
logging.error(f"异步假设性文档生成失败: {e}", exc_info=True)
|
||||
return HypotheticalDocument(original_query=query, hypothetical_answer="")
|
||||
|
||||
logging.error(f"查找匹配软件文档时出错: {e}", exc_info=True)
|
||||
# 出错时返回空列表
|
||||
return []
|
||||
|
||||
async def _generate_multi_questions_async(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> MultiQuestions:
|
||||
"""
|
||||
异步生成多角度问题
|
||||
|
||||
@@ -28,8 +28,6 @@ def get_embedding_model(api_key: str = None) -> Embeddings:
|
||||
Returns:
|
||||
嵌入模型实例
|
||||
"""
|
||||
if not api_key:
|
||||
api_key = os.getenv("SILICONFLOW_API_KEY", "sk-ftnofbucchwnscojohyxwmfzgaykdxihafnlphohsinftkbr")
|
||||
return SiliconFlowEmbeddings(api_key=api_key)
|
||||
|
||||
|
||||
|
||||
@@ -413,7 +413,7 @@ multi_questions_prompt = """
|
||||
## 任务说明
|
||||
1. 分析用户的原始问题,理解其核心意图和需求
|
||||
2. 考虑历史对话和会话背景,理解用户当前问题的上下文
|
||||
3. 从不同角度生成2-4个子问题,这些子问题应该:
|
||||
3. 从不同角度生成1-3个子问题,这些子问题应该:
|
||||
- 分别关注原始问题的不同方面或组成部分
|
||||
- 更加具体和直接
|
||||
- 共同覆盖原始问题的完整意图
|
||||
|
||||
@@ -236,7 +236,7 @@ class OpenAiLLM:
|
||||
self._model = kwargs.get("model")
|
||||
kwargs.pop("model")
|
||||
else:
|
||||
self._model = os.getenv("LLM_MODEL_NAME")
|
||||
self._model = os.getenv("MODEL_NAME")
|
||||
|
||||
self._kwargs = kwargs
|
||||
|
||||
@@ -284,13 +284,19 @@ class OpenAiLLM:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"OpenAiLLM:invoke:error:{str(e)}.api_key:{api_key}") from e
|
||||
|
||||
async def invoke_async(self, user_prompt="你是谁?", need_retry=True):
|
||||
async def invoke_async(self, user_prompt="你是谁?", need_retry=True, **extra_kwargs):
|
||||
"""异步调用OpenAI API"""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
if "timeout" not in self._kwargs:
|
||||
|
||||
# 合并额外的kwargs与self._kwargs
|
||||
kwargs = {**self._kwargs}
|
||||
if extra_kwargs:
|
||||
kwargs.update(extra_kwargs)
|
||||
|
||||
if "timeout" not in kwargs:
|
||||
timeout = httpx.Timeout(300.0)
|
||||
self._kwargs["timeout"] = timeout
|
||||
kwargs["timeout"] = timeout
|
||||
|
||||
if need_retry:
|
||||
while retry_count < max_retries:
|
||||
@@ -302,7 +308,7 @@ class OpenAiLLM:
|
||||
completion = await client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[{'role': 'user', 'content': user_prompt}],
|
||||
**self._kwargs
|
||||
**kwargs
|
||||
)
|
||||
return completion.choices[0].message
|
||||
|
||||
@@ -319,7 +325,7 @@ class OpenAiLLM:
|
||||
completion = await client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[{'role': 'user', 'content': user_prompt}],
|
||||
**self._kwargs
|
||||
**kwargs
|
||||
)
|
||||
return completion.choices[0].message
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user