更新API密钥管理,优化意图识别和Excel数据验证功能,增强日志记录,改进错误处理机制,支持文档检索功能,提升代码可读性和灵活性。
This commit is contained in:
@@ -69,7 +69,7 @@ class JsonDeduplicator:
|
||||
logging.info(f"从{self.INPUT_PATH}加载了{len(data)}条记录")
|
||||
return data
|
||||
except Exception as e:
|
||||
logging.error(f"读取{self.INPUT_PATH}失败: {e}")
|
||||
logging.error(f"读取{self.INPUT_PATH}失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def group_items_by_key(self, items):
|
||||
@@ -118,7 +118,7 @@ class JsonDeduplicator:
|
||||
# 如果合并失败,返回第一个项目
|
||||
return item_list[0]
|
||||
except Exception as e:
|
||||
logging.error(f"处理键 {key} 时出错: {e}")
|
||||
logging.error(f"处理键 {key} 时出错: {e}", exc_info=True)
|
||||
return item_list[0]
|
||||
|
||||
def deduplicate(self):
|
||||
|
||||
@@ -135,11 +135,10 @@ class WikijsNounsExtractor:
|
||||
parsed_output = self.terms_list_parser.parse(response.content)
|
||||
return parsed_output.terms
|
||||
except Exception as e:
|
||||
logging.error(f"解析LLM响应时出错: {str(e)}")
|
||||
logging.error(f"原始响应: {response.content}")
|
||||
logging.error(f"解析LLM响应时出错: {str(e)}", exc_info=True)
|
||||
return []
|
||||
except Exception as e:
|
||||
logging.error(f"提取专业名词时出错: {str(e)}")
|
||||
logging.error(f"提取专业名词时出错: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
def _process_document(self, doc, path_terms):
|
||||
@@ -182,7 +181,7 @@ class WikijsNounsExtractor:
|
||||
|
||||
return path_prefix
|
||||
except Exception as e:
|
||||
logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}")
|
||||
logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def process_all_documents(self, output_dir: str = "extracted_nouns", max_concurrency: int = 5):
|
||||
@@ -237,7 +236,7 @@ class WikijsNounsExtractor:
|
||||
if i % 10 == 0:
|
||||
logging.info(f"已完成 {i+1}/{len(futures)} 个文档的处理")
|
||||
except Exception as e:
|
||||
logging.error(f"处理文档时出错: {str(e)}")
|
||||
logging.error(f"处理文档时出错: {str(e)}", exc_info=True)
|
||||
|
||||
# 保存最终结果
|
||||
for prefix, terms in path_terms.items():
|
||||
|
||||
@@ -27,249 +27,6 @@ from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
dify_query_retrieval = DifyQueryRetrieval(api_key="dataset-skLjmPVonjHo119OWNf3kAmY", base_url="http://172.20.0.145/v1")
|
||||
|
||||
def is_retrieved_doc_relevant(query: str, retrieved_doc: List[Dict[str, Any]], api_key: str = None, base_url: str = None, model_name: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
使用LLM判断检索出的文档是否与用户提问相关
|
||||
|
||||
Args:
|
||||
query: 用户提问
|
||||
retrieved_doc: 检索出的文档列表
|
||||
api_key: API密钥,默认使用环境变量
|
||||
base_url: API基础URL,默认使用环境变量
|
||||
model_name: 模型名称,默认使用环境变量或默认模型
|
||||
|
||||
Returns:
|
||||
包含相关性判断结果的字典,包括is_relevant(布尔值)和explanation(解释)
|
||||
"""
|
||||
# 使用环境变量或参数值
|
||||
api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
base_url = base_url or os.getenv("OPENAI_API_BASE")
|
||||
model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
|
||||
# 如果没有检索到文档,直接返回不相关
|
||||
if not retrieved_doc or len(retrieved_doc) == 0:
|
||||
return {
|
||||
"is_relevant": False,
|
||||
"explanation": "没有检索到任何文档",
|
||||
"relevance_score": 0.0
|
||||
}
|
||||
|
||||
# 构建文档内容
|
||||
doc_contents = []
|
||||
for i, doc in enumerate(retrieved_doc[:3]): # 只取前3个文档进行判断
|
||||
content = doc.get("content", "")
|
||||
title = doc.get("title", "")
|
||||
doc_contents.append(f"文档{i+1}标题: {title}\n文档{i+1}内容: {content}")
|
||||
|
||||
doc_text = "\n\n".join(doc_contents)
|
||||
class TempModel(BaseModel):
|
||||
is_relevant: bool = Field(description="是否与用户提问相关")
|
||||
relevance_score: int = Field(description="相关性评分,0-100分")
|
||||
explanation: str = Field(description="解释各个文档与提问的相关性或不相关性")
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=TempModel)
|
||||
# 构建提示词
|
||||
prompt = f"""请判断以下检索文档是否与用户提问相关,并给出相关性评分(0-100分)。
|
||||
|
||||
用户提问: {query}
|
||||
|
||||
检索文档:
|
||||
{doc_text}
|
||||
|
||||
请按照以下JSON格式返回结果:
|
||||
{parser.get_format_instructions()}
|
||||
"""
|
||||
|
||||
try:
|
||||
# 初始化LLM并调用
|
||||
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"})
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
result = parser.parse(response.content)
|
||||
|
||||
return {
|
||||
"is_relevant": result.is_relevant,
|
||||
"relevance_score": result.relevance_score,
|
||||
"explanation": result.explanation
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"判断文档相关性时出错: {str(e)}")
|
||||
return {
|
||||
"is_relevant": False,
|
||||
"explanation": f"判断过程出错: {str(e)}",
|
||||
"relevance_score": 0.0
|
||||
}
|
||||
|
||||
# 读取Excel文件中的提问数据
|
||||
def load_questions_from_excel(file_path=None):
|
||||
"""
|
||||
从Excel文件中读取提问数据
|
||||
|
||||
Args:
|
||||
file_path: Excel文件路径,如果为None则使用默认路径
|
||||
|
||||
Returns:
|
||||
提问列表
|
||||
"""
|
||||
|
||||
try:
|
||||
# 读取Excel文件的第一列数据
|
||||
df = pd.read_excel(file_path)
|
||||
questions = df.iloc[:, 0].tolist() # 获取第一列数据
|
||||
logging.info(f"成功从{file_path}读取了{len(questions)}条提问")
|
||||
return questions
|
||||
except Exception as e:
|
||||
logging.error(f"读取Excel文件时出错: {e}")
|
||||
return []
|
||||
|
||||
def process_query(recognizer: IntentRecognizer, query: str, conversation_context: str = "", chat_history: List[Dict[str, str]] = None, previous_slots: Dict[str, str] = None):
|
||||
"""
|
||||
处理单个查询,支持重试机制,并包含槽位填充
|
||||
|
||||
Args:
|
||||
recognizer: 意图识别器实例
|
||||
query: 查询字符串
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 使用新的process_query_with_slots方法处理查询
|
||||
# result = recognizer.process_query_with_slots(query)
|
||||
result = recognizer.process_query(query,
|
||||
conversation_context=conversation_context,
|
||||
chat_history=chat_history,
|
||||
previous_slots=previous_slots,
|
||||
enable_query_expansion=True)
|
||||
# 提取分类信息
|
||||
classification = result["classification"]
|
||||
original_query = result["rewrite"]["rewrite"]
|
||||
query_list = result["query_expand"]["all"]
|
||||
soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","")
|
||||
# 将字典转换为Classification对象
|
||||
classification_obj = Classification(**classification)
|
||||
retrieved_doc=dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name)
|
||||
|
||||
# 判断检索文档是否相关
|
||||
relevance_result = {}
|
||||
if retrieved_doc:
|
||||
# 获取API密钥和基础URL
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
# 判断文档相关性
|
||||
relevance_result = is_retrieved_doc_relevant(query, retrieved_doc, api_key, base_url, model_name)
|
||||
else:
|
||||
retrieved_doc_str = []
|
||||
relevance_result = {
|
||||
"is_relevant": False,
|
||||
"explanation": "没有检索到文档",
|
||||
"relevance_score": 0.0
|
||||
}
|
||||
|
||||
retrieved_doc_titles=[]
|
||||
if retrieved_doc:
|
||||
retrieved_doc_titles=[doc["title"].split("/")[-1] for doc in retrieved_doc]
|
||||
# 提取槽位填充信息
|
||||
slot_filling = result.get("slot_filling", {})
|
||||
slot_filling_str = ""
|
||||
if slot_filling and "filled_data" in slot_filling:
|
||||
# 格式化槽位填充结果
|
||||
slot_filling_str = json.dumps({
|
||||
"是否完整": slot_filling.get("is_complete", False),
|
||||
"缺失槽位": slot_filling.get("missing_slots", {}),
|
||||
"填充数据": slot_filling.get("filled_data", {})
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
# 处理成功,返回结果
|
||||
return {
|
||||
"提问": query,
|
||||
"问题分类": f"{classification['vertical_classification']} - {classification['sub_classification']}",
|
||||
"问题改写": result["rewrite"]["rewrite"],
|
||||
"槽位填充": slot_filling_str,
|
||||
"检索的文档": "\n".join(retrieved_doc_titles),
|
||||
"文档是否相关": "相关" if relevance_result["is_relevant"] else "不相关",
|
||||
"文档相关性解释": relevance_result["explanation"]
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"处理问题 '{query}' 时出错: ",exc_info=True)
|
||||
retry_count += 1
|
||||
|
||||
# 如果已经重试了最大次数,则记录错误并返回错误结果
|
||||
if retry_count > max_retries:
|
||||
return {
|
||||
"提问": query,
|
||||
"问题分类": "处理出错",
|
||||
"问题改写": "处理出错",
|
||||
"槽位填充": "处理出错",
|
||||
"检索的文档": f"重试 {max_retries} 次后失败: {str(e)}",
|
||||
"文档是否相关": "处理出错",
|
||||
"文档相关性解释": "处理出错"
|
||||
}
|
||||
else:
|
||||
# 可以在这里添加延迟,避免过快重试
|
||||
time.sleep(10)
|
||||
|
||||
def save_results_to_excel(results, output_file, is_final=False):
|
||||
"""
|
||||
将结果保存到Excel文件
|
||||
|
||||
Args:
|
||||
results: 结果列表
|
||||
output_file: 输出文件路径
|
||||
is_final: 是否为最终保存,如果是则使用完整文件名,否则添加临时标记
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# 过滤掉None值
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
if not valid_results:
|
||||
logging.warning("没有有效结果可保存")
|
||||
return
|
||||
|
||||
# 创建DataFrame
|
||||
results_df = pd.DataFrame(valid_results)
|
||||
|
||||
# 根据是否为最终保存确定文件名
|
||||
if not is_final:
|
||||
file_name, file_ext = os.path.splitext(output_file)
|
||||
temp_output_file = f"{file_name}_temp{file_ext}"
|
||||
else:
|
||||
temp_output_file = output_file
|
||||
|
||||
# 使用ExcelWriter设置格式
|
||||
with pd.ExcelWriter(temp_output_file, engine='xlsxwriter') as writer:
|
||||
results_df.to_excel(writer, index=False, sheet_name='Sheet1')
|
||||
|
||||
# 获取工作簿和工作表对象
|
||||
workbook = writer.book
|
||||
worksheet = writer.sheets['Sheet1']
|
||||
|
||||
# 设置列宽(单位:像素)
|
||||
# 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位)
|
||||
worksheet.set_column('A:A', 60) # 提问列 60个Excel单位
|
||||
worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位
|
||||
worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位
|
||||
worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位
|
||||
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
|
||||
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
|
||||
worksheet.set_column('G:G', 80) # 槽位填充 80个Excel单位
|
||||
worksheet.set_column('H:H', 60) # 文档相关性 60个Excel单位
|
||||
|
||||
# 设置所有行高为20磅
|
||||
for i in range(len(results_df) + 1): # +1 是为了包括表头
|
||||
worksheet.set_row(i, 20)
|
||||
|
||||
logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}")
|
||||
|
||||
# 示例查询
|
||||
examples_query = """主网电力建设计价通软件, 35kV的软件 土质比例不能一起设置吗"""
|
||||
conversation_context=""
|
||||
@@ -296,29 +53,327 @@ previous_slots={
|
||||
"operation_steps": None
|
||||
}
|
||||
|
||||
class QueryRewriteProcessor:
|
||||
"""
|
||||
查询改写处理器,用于意图识别、问题改写和文档检索
|
||||
"""
|
||||
def __init__(self,
|
||||
api_key: str = None,
|
||||
base_url: str = None,
|
||||
model_name: str = None,
|
||||
dify_api_key: str = "dataset-skLjmPVonjHo119OWNf3kAmY",
|
||||
dify_base_url: str = "http://172.20.0.145/v1"):
|
||||
"""
|
||||
初始化查询改写处理器
|
||||
|
||||
Args:
|
||||
api_key: API密钥,默认使用环境变量
|
||||
base_url: API基础URL,默认使用环境变量
|
||||
model_name: 模型名称,默认使用环境变量或默认模型
|
||||
dify_api_key: Dify API密钥
|
||||
dify_base_url: Dify API基础URL
|
||||
"""
|
||||
# 初始化意图识别器
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = base_url or os.getenv("OPENAI_API_BASE")
|
||||
self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
|
||||
self.recognizer = IntentRecognizer(api_key=self.api_key, base_url=self.base_url, model_name=self.model_name)
|
||||
self.dify_query_retrieval = DifyQueryRetrieval(api_key=dify_api_key, base_url=dify_base_url)
|
||||
|
||||
def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
使用LLM判断检索出的文档是否与用户提问相关
|
||||
|
||||
Args:
|
||||
query: 用户提问
|
||||
retrieved_doc: 检索出的文档列表
|
||||
|
||||
Returns:
|
||||
包含相关性判断结果的字典,包括is_relevant(布尔值)和explanation(解释)
|
||||
"""
|
||||
# 如果没有检索到文档,直接返回不相关
|
||||
if not retrieved_doc or len(retrieved_doc) == 0:
|
||||
return {
|
||||
"is_relevant": False,
|
||||
"explanation": "没有检索到任何文档",
|
||||
"relevance_score": 0.0
|
||||
}
|
||||
|
||||
# 构建文档内容
|
||||
doc_contents = []
|
||||
for i, doc in enumerate(retrieved_doc[:3]): # 只取前3个文档进行判断
|
||||
content = doc.get("content", "")
|
||||
title = doc.get("title", "")
|
||||
doc_contents.append(f"文档{i+1}标题: {title}\n文档{i+1}内容: {content}")
|
||||
|
||||
doc_text = "\n\n".join(doc_contents)
|
||||
class TempModel(BaseModel):
|
||||
is_relevant: bool = Field(description="是否与用户提问相关")
|
||||
relevance_score: int = Field(description="相关性评分,0-100分")
|
||||
explanation: str = Field(description="解释各个文档与提问的相关性或不相关性")
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=TempModel)
|
||||
# 构建提示词
|
||||
prompt = f"""请判断以下检索文档是否与用户提问相关,并给出相关性评分(0-100分)。
|
||||
|
||||
用户提问: {query}
|
||||
|
||||
检索文档:
|
||||
{doc_text}
|
||||
|
||||
请按照以下JSON格式返回结果:
|
||||
{parser.get_format_instructions()}
|
||||
"""
|
||||
|
||||
try:
|
||||
# 初始化LLM并调用
|
||||
llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"})
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
result = parser.parse(response.content)
|
||||
|
||||
return {
|
||||
"is_relevant": result.is_relevant,
|
||||
"relevance_score": result.relevance_score,
|
||||
"explanation": result.explanation
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"判断文档相关性时出错: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"is_relevant": False,
|
||||
"explanation": f"判断过程出错: {str(e)}",
|
||||
"relevance_score": 0.0
|
||||
}
|
||||
|
||||
def load_questions_from_excel(self, file_path=None):
|
||||
"""
|
||||
从Excel文件中读取提问数据
|
||||
|
||||
Args:
|
||||
file_path: Excel文件路径,如果为None则使用默认路径
|
||||
|
||||
Returns:
|
||||
提问列表
|
||||
"""
|
||||
try:
|
||||
# 读取Excel文件的第一列数据
|
||||
df = pd.read_excel(file_path)
|
||||
questions = df.iloc[:, 0].tolist() # 获取第一列数据
|
||||
logging.info(f"成功从{file_path}读取了{len(questions)}条提问")
|
||||
return questions
|
||||
except Exception as e:
|
||||
logging.error(f"读取Excel文件时出错: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def process_query(self, query: str, conversation_context: str = "", chat_history: List[Dict[str, str]] = None, previous_slots: Dict[str, str] = None, enable_retrieval: bool = False):
|
||||
"""
|
||||
处理单个查询,支持重试机制,并包含槽位填充
|
||||
|
||||
Args:
|
||||
query: 查询字符串
|
||||
conversation_context: 对话上下文
|
||||
chat_history: 聊天历史记录
|
||||
previous_slots: 之前识别的槽位信息
|
||||
enable_retrieval: 是否启用文档检索功能,默认为False
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 使用process_query方法处理查询
|
||||
result = self.recognizer.process_query(query,
|
||||
conversation_context=conversation_context,
|
||||
chat_history=chat_history,
|
||||
previous_slots=previous_slots,
|
||||
enable_query_expansion=True)
|
||||
# 提取分类信息
|
||||
classification = result["classification"]
|
||||
original_query = result["rewrite"]["rewrite"]
|
||||
query_list = result["query_expand"]["all"]
|
||||
soft_name = result.get("slot_filling", {}).get("filled_data", {}).get("software_name","")
|
||||
# 将字典转换为Classification对象
|
||||
classification_obj = Classification(**classification)
|
||||
|
||||
# 根据enable_retrieval参数决定是否进行文档检索
|
||||
retrieved_doc = None
|
||||
if enable_retrieval:
|
||||
retrieved_doc = self.dify_query_retrieval.retrieve(original_query, query_list, classification_obj, soft_name)
|
||||
|
||||
# 判断检索文档是否相关
|
||||
relevance_result = {}
|
||||
if retrieved_doc:
|
||||
# 判断文档相关性
|
||||
relevance_result = self.is_retrieved_doc_relevant(query, retrieved_doc)
|
||||
else:
|
||||
relevance_result = {
|
||||
"is_relevant": False,
|
||||
"explanation": "没有检索到文档" if enable_retrieval else "文档检索功能未启用",
|
||||
"relevance_score": 0.0
|
||||
}
|
||||
|
||||
retrieved_doc_titles=[]
|
||||
if retrieved_doc:
|
||||
retrieved_doc_titles=[doc["title"].split("/")[-1] for doc in retrieved_doc]
|
||||
# 提取槽位填充信息
|
||||
slot_filling = result.get("slot_filling", {})
|
||||
slot_filling_str = ""
|
||||
if slot_filling and "filled_data" in slot_filling:
|
||||
# 格式化槽位填充结果
|
||||
slot_filling_str = json.dumps({
|
||||
"是否完整": slot_filling.get("is_complete", False),
|
||||
"缺失槽位": slot_filling.get("missing_slots", {}),
|
||||
"填充数据": slot_filling.get("filled_data", {})
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
# 处理成功,返回结果
|
||||
return {
|
||||
"问题": query,
|
||||
"问题分类": f"{classification['vertical_classification']} - {classification['sub_classification']}",
|
||||
"问题改写": result["rewrite"]["rewrite"],
|
||||
"槽位信息": slot_filling_str,
|
||||
"检索的文档": "\n".join(retrieved_doc_titles),
|
||||
"检索的内容": json.dumps(retrieved_doc, ensure_ascii=False, indent=2) if retrieved_doc else "",
|
||||
"文档是否相关": "相关" if relevance_result["is_relevant"] else "不相关",
|
||||
"文档相关性解释": relevance_result["explanation"]
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"处理问题 '{query}' 时出错: ",exc_info=True)
|
||||
retry_count += 1
|
||||
|
||||
# 如果已经重试了最大次数,则记录错误并返回错误结果
|
||||
if retry_count > max_retries:
|
||||
return {
|
||||
"问题": query,
|
||||
"问题分类": "处理出错",
|
||||
"问题改写": "处理出错",
|
||||
"槽位信息": "处理出错",
|
||||
"检索的文档": f"重试 {max_retries} 次后失败: {str(e)}",
|
||||
"检索的内容":"",
|
||||
"文档是否相关": "处理出错",
|
||||
"文档相关性解释": "处理出错"
|
||||
}
|
||||
else:
|
||||
# 可以在这里添加延迟,避免过快重试
|
||||
time.sleep(10)
|
||||
|
||||
def save_results_to_excel(self, results, output_file, is_final=False):
|
||||
"""
|
||||
将结果保存到Excel文件
|
||||
|
||||
Args:
|
||||
results: 结果列表
|
||||
output_file: 输出文件路径
|
||||
is_final: 是否为最终保存,如果是则使用完整文件名,否则添加临时标记
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# 过滤掉None值
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
if not valid_results:
|
||||
logging.warning("没有有效结果可保存")
|
||||
return
|
||||
|
||||
# 创建DataFrame
|
||||
results_df = pd.DataFrame(valid_results)
|
||||
|
||||
# 根据是否为最终保存确定文件名
|
||||
if not is_final:
|
||||
file_name, file_ext = os.path.splitext(output_file)
|
||||
temp_output_file = f"{file_name}_temp{file_ext}"
|
||||
else:
|
||||
temp_output_file = output_file
|
||||
|
||||
# 使用ExcelWriter设置格式
|
||||
with pd.ExcelWriter(temp_output_file, engine='xlsxwriter') as writer:
|
||||
results_df.to_excel(writer, index=False, sheet_name='Sheet1')
|
||||
|
||||
# 获取工作簿和工作表对象
|
||||
workbook = writer.book
|
||||
worksheet = writer.sheets['Sheet1']
|
||||
|
||||
# 设置列宽(单位:像素)
|
||||
# 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位)
|
||||
worksheet.set_column('A:A', 60) # 提问列 60个Excel单位
|
||||
worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位
|
||||
worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位
|
||||
worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位
|
||||
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
|
||||
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
|
||||
worksheet.set_column('G:G', 80) # 槽位填充 80个Excel单位
|
||||
worksheet.set_column('H:H', 60) # 文档相关性 60个Excel单位
|
||||
|
||||
# 设置所有行高为20磅
|
||||
for i in range(len(results_df) + 1): # +1 是为了包括表头
|
||||
worksheet.set_row(i, 20)
|
||||
|
||||
logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}")
|
||||
|
||||
def process_batch(self, questions: List[str], max_workers: int = 2, enable_retrieval: bool = False, output_file: str = None):
|
||||
"""
|
||||
批量处理多个问题
|
||||
|
||||
Args:
|
||||
questions: 问题列表
|
||||
max_workers: 并发处理的最大线程数,默认为2
|
||||
enable_retrieval: 是否启用文档检索功能,默认为False
|
||||
output_file: 输出文件路径,如果为None则不保存结果
|
||||
|
||||
Returns:
|
||||
处理结果列表
|
||||
"""
|
||||
logging.info(f"共有 {len(questions)} 个问题需要处理,使用 {max_workers} 个并发线程")
|
||||
logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}")
|
||||
|
||||
# 创建一个与输入顺序相同的结果列表
|
||||
results = [None] * len(questions)
|
||||
|
||||
# 使用线程池进行并发处理
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务并记录它们的索引
|
||||
future_to_index = {}
|
||||
for idx, query in enumerate(questions):
|
||||
if not query or query.strip() == "":
|
||||
continue
|
||||
future = executor.submit(self.process_query, query, enable_retrieval=enable_retrieval)
|
||||
future_to_index[future] = idx
|
||||
|
||||
# 使用tqdm显示进度条
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(future_to_index), desc="处理进度"):
|
||||
idx = future_to_index[future]
|
||||
result = future.result()
|
||||
# 将结果放在与输入相同的位置
|
||||
results[idx] = result
|
||||
|
||||
# 如果提供了输出文件路径,则保存结果
|
||||
if output_file:
|
||||
self.save_results_to_excel(results, output_file, is_final=True)
|
||||
|
||||
return results
|
||||
|
||||
def parse_arguments():
|
||||
"""解析命令行参数"""
|
||||
parser = argparse.ArgumentParser(description='意图识别和问题改写工具')
|
||||
|
||||
input_file="data/excel/1500条点踩软件问题测试.xlsx"
|
||||
ouput_file="data/excel/1500条点踩软件问题测试_意图分类.xlsx"
|
||||
# 添加数据文件路径参数
|
||||
parser.add_argument('--input', '-i', type=str,
|
||||
parser.add_argument('--input', '-i', type=str, default=input_file,
|
||||
help='输入Excel文件路径,包含待处理的提问数据(第一列)')
|
||||
parser.add_argument('--output', '-o', type=str,
|
||||
parser.add_argument('--output', '-o', type=str,default=ouput_file,
|
||||
help='输出Excel文件路径,用于保存处理结果')
|
||||
|
||||
# 添加LLM相关参数
|
||||
parser.add_argument('--model', '-m', type=str,
|
||||
help='LLM模型名称,默认使用环境变量中的配置')
|
||||
parser.add_argument('--api_base', '-a', type=str,
|
||||
help='API基础URL,默认使用环境变量中的配置')
|
||||
|
||||
# 添加处理相关参数
|
||||
parser.add_argument('--max_workers', '-w', type=int, default=2,
|
||||
help='并发处理的最大线程数,默认为20')
|
||||
parser.add_argument('--debug', '-d', action='store_true',
|
||||
help='启用调试模式,使用示例查询而非从文件读取')
|
||||
parser.add_argument('--query', '-q', type=str,
|
||||
help='在调试模式下使用的查询字符串')
|
||||
|
||||
parser.add_argument('--enable_retrieval', '-r', action='store_true',
|
||||
help='是否启用文档检索功能,默认不启用')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -332,11 +387,12 @@ def main():
|
||||
|
||||
# 从环境变量中获取配置,命令行参数优先
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = args.api_base if args.api_base else os.getenv("OPENAI_API_BASE")
|
||||
model_name = args.model if args.model else os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
enable_retrieval = args.enable_retrieval
|
||||
|
||||
# 初始化意图识别器
|
||||
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
# 初始化查询改写处理器
|
||||
processor = QueryRewriteProcessor(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
|
||||
# 读取提问数据
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -344,51 +400,29 @@ def main():
|
||||
output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "1500条点踩软件问题_槽位(分类)填充结果.xlsx")
|
||||
|
||||
# 检测是否为调试模式
|
||||
is_debug = args.debug or (hasattr(sys, 'gettrace') and sys.gettrace() is not None)
|
||||
is_debug = False
|
||||
is_debug =hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||
if is_debug:
|
||||
# 如果提供了查询参数,使用它;否则使用默认示例
|
||||
if args.query:
|
||||
examples = [args.query]
|
||||
else:
|
||||
examples = examples_query.strip().split("\n")
|
||||
examples = examples_query.strip().split("\n")
|
||||
else:
|
||||
examples = load_questions_from_excel(data_file)
|
||||
examples = processor.load_questions_from_excel(data_file)
|
||||
|
||||
if not is_debug:
|
||||
max_workers = args.max_workers
|
||||
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
|
||||
|
||||
# 创建一个与输入顺序相同的结果列表
|
||||
results = [None] * len(examples)
|
||||
batch_size = 100 # 每100条保存一次
|
||||
|
||||
# 使用线程池进行并发处理
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务并记录它们的索引
|
||||
future_to_index = {}
|
||||
for idx, query in enumerate(examples):
|
||||
future = executor.submit(process_query, recognizer, query)
|
||||
future_to_index[future] = idx
|
||||
|
||||
# 使用tqdm显示进度条
|
||||
completed = 0
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(examples), desc="处理进度"):
|
||||
idx = future_to_index[future]
|
||||
result = future.result()
|
||||
# 将结果放在与输入相同的位置
|
||||
results[idx] = result
|
||||
completed += 1
|
||||
|
||||
# 处理完所有数据后,保存最终结果
|
||||
save_results_to_excel(results, output_file, is_final=True)
|
||||
# 批量处理问题
|
||||
results = processor.process_batch(questions=examples, max_workers=args.max_workers, enable_retrieval=enable_retrieval, output_file=output_file)
|
||||
logging.info(f"所有处理完成,最终结果已保存至: {output_file}")
|
||||
else:
|
||||
logging.info(f"文档检索功能状态: {'已启用' if enable_retrieval else '未启用'}")
|
||||
for idx, query in enumerate(examples):
|
||||
if query.strip() == "":
|
||||
continue
|
||||
# process_query(recognizer, query, conversation_context, chat_history, previous_slots)
|
||||
print(json.dumps(process_query(recognizer, query), ensure_ascii=False, indent=2))
|
||||
# 在调试模式下使用完整的参数
|
||||
print(json.dumps(processor.process_query(
|
||||
query,
|
||||
conversation_context=conversation_context,
|
||||
chat_history=chat_history,
|
||||
previous_slots=previous_slots,
|
||||
enable_retrieval=enable_retrieval
|
||||
), ensure_ascii=False, indent=2))
|
||||
|
||||
def setup_logging():
|
||||
# 配置日志输出到控制台
|
||||
|
||||
@@ -81,21 +81,21 @@ class ExcelDataValidator:
|
||||
"""
|
||||
file_path = file_path or self.input_file
|
||||
if not file_path:
|
||||
logging.error("未指定输入文件路径")
|
||||
logging.error("未指定输入文件路径", exc_info=True)
|
||||
return None
|
||||
|
||||
try:
|
||||
df = pd.read_excel(file_path)
|
||||
required_columns = ["问题", "问题分类", "问题改写", "槽点信息", "检索的内容"]
|
||||
required_columns = ["问题", "问题分类", "问题改写", "槽点信息"]
|
||||
for col in required_columns:
|
||||
if col not in df.columns:
|
||||
logging.error(f"缺少必要的列: {col}")
|
||||
logging.error(f"缺少必要的列: {col}", exc_info=True)
|
||||
return None
|
||||
logging.info(f"成功从{file_path}读取了{len(df)}条数据")
|
||||
self.df = df
|
||||
return df
|
||||
except Exception as e:
|
||||
logging.error(f"读取Excel文件时出错: {e}")
|
||||
logging.error(f"读取Excel文件时出错: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def validate_classification(self, llm:OpenAiLLM , query:str, vertical_class:str, sub_class:str):
|
||||
@@ -413,10 +413,7 @@ class ExcelDataValidator:
|
||||
return index, True, "", "", confidence_score
|
||||
except Exception as e:
|
||||
error_msg = f"处理行 {index} 时发生错误: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
if self.debug:
|
||||
import traceback
|
||||
logging.error(traceback.format_exc())
|
||||
logging.error(error_msg, exc_info=True)
|
||||
return index, False, "处理错误", error_msg, 0.0
|
||||
|
||||
def process_batch(self, llm, batch_data):
|
||||
|
||||
@@ -635,7 +635,7 @@ class IntentRecognizer:
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
# 如果解析失败,返回原始查询作为后退提示
|
||||
logging.error(f"后退提示生成失败: {e}")
|
||||
logging.error(f"后退提示生成失败: {e}", exc_info=True)
|
||||
return StepBackPrompt(original_query=query, step_back_query=query)
|
||||
|
||||
def _generate_follow_up_questions(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> FollowUpQuestions:
|
||||
@@ -672,7 +672,7 @@ class IntentRecognizer:
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
# 如果解析失败,返回原始查询作为后续问题
|
||||
logging.error(f"后续问题生成失败: {e}")
|
||||
logging.error(f"后续问题生成失败: {e}", exc_info=True)
|
||||
return FollowUpQuestions(original_query=query, follow_up_query=query)
|
||||
|
||||
def _generate_hypothetical_document(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> HypotheticalDocument:
|
||||
@@ -709,7 +709,7 @@ class IntentRecognizer:
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
# 如果解析失败,返回空的假设性回答
|
||||
logging.error(f"假设性文档生成失败: {e}")
|
||||
logging.error(f"假设性文档生成失败: {e}", exc_info=True)
|
||||
return HypotheticalDocument(original_query=query, hypothetical_answer="")
|
||||
|
||||
def _generate_multi_questions(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> MultiQuestions:
|
||||
@@ -746,7 +746,7 @@ class IntentRecognizer:
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
# 如果解析失败,返回原始查询作为唯一子问题
|
||||
logging.error(f"多角度问题生成失败: {e},LLM返回内容:{response.content}")
|
||||
logging.error(f"多角度问题生成失败: {e}",exc_info=True)
|
||||
return MultiQuestions(original_query=query, sub_questions=[query])
|
||||
|
||||
def _run_in_thread(self, func, args=(), kwargs={}):
|
||||
@@ -768,7 +768,7 @@ class IntentRecognizer:
|
||||
result = func(*args, **kwargs)
|
||||
result_container.append(result)
|
||||
except Exception as e:
|
||||
logging.error(f"线程执行函数 {func.__name__} 时出错: {e}")
|
||||
logging.error(f"线程执行函数 {func.__name__} 时出错: {e}", exc_info=True)
|
||||
result_container.append(None)
|
||||
|
||||
thread = threading.Thread(target=thread_target)
|
||||
@@ -866,5 +866,4 @@ class IntentRecognizer:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"process_intent_and_slot error:{e}")
|
||||
raise RuntimeError(f"process_intent_and_slot error:{e}") from e
|
||||
@@ -93,7 +93,7 @@ class ProfessionalNounVectorizer:
|
||||
logging.info(f"总共加载了 {len(merged_terms)} 条专业名词")
|
||||
return merged_terms
|
||||
except Exception as e:
|
||||
logging.error(f"加载多个文件失败: {e}")
|
||||
logging.error(f"加载多个文件失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def vectorize_files_and_save(self, file_paths: List[str]) -> bool:
|
||||
@@ -139,7 +139,7 @@ class ProfessionalNounVectorizer:
|
||||
logging.info("完成多文件专业名词向量化和索引创建")
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"多文件向量化处理失败: {e}")
|
||||
logging.error(f"多文件向量化处理失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _updata_suffix_item(self)->Tuple[List[str], List[Dict]] :
|
||||
@@ -246,7 +246,7 @@ class ProfessionalNounVectorizer:
|
||||
faiss_index.save_local(folder_path=self.index_path)
|
||||
logging.info(f"FAISS索引已保存至 {self.index_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"保存FAISS索引失败: {e}")
|
||||
logging.error(f"保存FAISS索引失败: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
|
||||
@@ -349,5 +349,5 @@ class ProfessionalNounRetriever:
|
||||
return [json.loads(item) for item in intersection]
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"查询FAISS索引失败: {e}")
|
||||
logging.error(f"查询FAISS索引失败: {e}", exc_info=True)
|
||||
return []
|
||||
Reference in New Issue
Block a user