优化意图识别示例,更新文档相关性判断逻辑,增强Excel数据验证功能,改进日志记录,调整参数以提升代码可读性和灵活性。
This commit is contained in:
@@ -28,7 +28,7 @@ from rag2_0.tool.ModelTool import OpenAiLLM
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# 示例查询
|
# 示例查询
|
||||||
examples_query = """主网电力建设计价通软件, 35kV的软件 土质比例不能一起设置吗"""
|
examples_query = """ PE2211PK0801是什么软件"""
|
||||||
conversation_context=""
|
conversation_context=""
|
||||||
chat_history=[
|
chat_history=[
|
||||||
{
|
{
|
||||||
@@ -100,27 +100,23 @@ class QueryRewriteProcessor:
|
|||||||
"relevance_score": 0.0
|
"relevance_score": 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
# 构建文档内容
|
doc_text_list = json.dumps(retrieved_doc, ensure_ascii=False, indent=2)
|
||||||
doc_contents = []
|
|
||||||
for i, doc in enumerate(retrieved_doc[:3]): # 只取前3个文档进行判断
|
|
||||||
content = doc.get("content", "")
|
|
||||||
title = doc.get("title", "")
|
|
||||||
doc_contents.append(f"文档{i+1}标题: {title}\n文档{i+1}内容: {content}")
|
|
||||||
|
|
||||||
doc_text = "\n\n".join(doc_contents)
|
|
||||||
class TempModel(BaseModel):
|
class TempModel(BaseModel):
|
||||||
is_relevant: bool = Field(description="是否与用户提问相关")
|
can_solve_problem: bool = Field(description="是否能解决用户问题")
|
||||||
relevance_score: int = Field(description="相关性评分,0-100分")
|
relevance_score: int = Field(description="相关性评分,0-100分")
|
||||||
explanation: str = Field(description="解释各个文档与提问的相关性或不相关性")
|
explanation: str = Field(description="解释文档是否能解决(回答)提问")
|
||||||
|
|
||||||
parser = PydanticOutputParser(pydantic_object=TempModel)
|
class most_relevant_document(BaseModel):
|
||||||
|
most_relevant_document: TempModel = Field(description="最相关的文档的判断结果")
|
||||||
|
|
||||||
|
parser = PydanticOutputParser(pydantic_object=most_relevant_document)
|
||||||
# 构建提示词
|
# 构建提示词
|
||||||
prompt = f"""请判断以下检索文档是否与用户提问相关,并给出相关性评分(0-100分)。
|
prompt = f"""请判断以下检索文档列表中是否与用户提问相关,能够解决用户的问题,并给出相关性评分(0-100分)。输出最相关的文档的判断结果。
|
||||||
|
|
||||||
用户提问: {query}
|
用户提问: {query}
|
||||||
|
|
||||||
检索文档:
|
检索文档列表:
|
||||||
{doc_text}
|
{doc_text_list}
|
||||||
|
|
||||||
请按照以下JSON格式返回结果:
|
请按照以下JSON格式返回结果:
|
||||||
{parser.get_format_instructions()}
|
{parser.get_format_instructions()}
|
||||||
@@ -131,10 +127,10 @@ class QueryRewriteProcessor:
|
|||||||
llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"})
|
llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model="deepseek-ai/DeepSeek-R1", response_format={"type": "json_object"})
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
result = parser.parse(response.content)
|
result = parser.parse(response.content).most_relevant_document
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"is_relevant": result.is_relevant,
|
"is_relevant": result.can_solve_problem,
|
||||||
"relevance_score": result.relevance_score,
|
"relevance_score": result.relevance_score,
|
||||||
"explanation": result.explanation
|
"explanation": result.explanation
|
||||||
}
|
}
|
||||||
@@ -418,9 +414,6 @@ def main():
|
|||||||
# 在调试模式下使用完整的参数
|
# 在调试模式下使用完整的参数
|
||||||
print(json.dumps(processor.process_query(
|
print(json.dumps(processor.process_query(
|
||||||
query,
|
query,
|
||||||
conversation_context=conversation_context,
|
|
||||||
chat_history=chat_history,
|
|
||||||
previous_slots=previous_slots,
|
|
||||||
enable_retrieval=enable_retrieval
|
enable_retrieval=enable_retrieval
|
||||||
), ensure_ascii=False, indent=2))
|
), ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class TermMerger:
|
|||||||
else:
|
else:
|
||||||
return term_list[0]
|
return term_list[0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"处理词条 {name} 时出错: {e}")
|
logging.error(f"处理词条 {name} 时出错: {e}", exc_info=True)
|
||||||
return term_list[0]
|
return term_list[0]
|
||||||
|
|
||||||
def merge(self):
|
def merge(self):
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class ValidationResult(BaseModel):
|
|||||||
class ExcelDataValidator:
|
class ExcelDataValidator:
|
||||||
"""Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写"""
|
"""Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写"""
|
||||||
|
|
||||||
def __init__(self, input_file=None, output_file=None, workers=4, batch_size=10, debug=False):
|
def __init__(self, input_file=None, output_file=None, workers=4, debug=False):
|
||||||
"""
|
"""
|
||||||
初始化验证器
|
初始化验证器
|
||||||
|
|
||||||
@@ -41,7 +41,6 @@ class ExcelDataValidator:
|
|||||||
input_file: 输入Excel文件路径
|
input_file: 输入Excel文件路径
|
||||||
output_file: 输出结果Excel文件路径
|
output_file: 输出结果Excel文件路径
|
||||||
workers: 并行工作线程数
|
workers: 并行工作线程数
|
||||||
batch_size: 每批处理的行数
|
|
||||||
debug: 是否启用调试模式(串行处理)
|
debug: 是否启用调试模式(串行处理)
|
||||||
"""
|
"""
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
@@ -50,7 +49,6 @@ class ExcelDataValidator:
|
|||||||
self.input_file = input_file
|
self.input_file = input_file
|
||||||
self.output_file = output_file
|
self.output_file = output_file
|
||||||
self.workers = workers
|
self.workers = workers
|
||||||
self.batch_size = batch_size
|
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.df = None
|
self.df = None
|
||||||
|
|
||||||
@@ -86,7 +84,7 @@ class ExcelDataValidator:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
df = pd.read_excel(file_path)
|
df = pd.read_excel(file_path)
|
||||||
required_columns = ["问题", "问题分类", "问题改写", "槽点信息"]
|
required_columns = ["问题", "问题分类", "问题改写", "槽位信息", "检索的内容"]
|
||||||
for col in required_columns:
|
for col in required_columns:
|
||||||
if col not in df.columns:
|
if col not in df.columns:
|
||||||
logging.error(f"缺少必要的列: {col}", exc_info=True)
|
logging.error(f"缺少必要的列: {col}", exc_info=True)
|
||||||
@@ -320,7 +318,7 @@ class ExcelDataValidator:
|
|||||||
query = row["问题"]
|
query = row["问题"]
|
||||||
query_class = row.get("问题分类", "")
|
query_class = row.get("问题分类", "")
|
||||||
rewrite = row.get("问题改写", "")
|
rewrite = row.get("问题改写", "")
|
||||||
slot_info = row.get("槽点信息", "")
|
slot_info = row.get("槽位信息", "")
|
||||||
retrieve_content = row.get("检索的内容", "")
|
retrieve_content = row.get("检索的内容", "")
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
@@ -359,15 +357,16 @@ class ExcelDataValidator:
|
|||||||
if len(query_class_list) >= 2:
|
if len(query_class_list) >= 2:
|
||||||
result = self.validate_classification(llm, rewrite, query_class_list[0], query_class_list[1])
|
result = self.validate_classification(llm, rewrite, query_class_list[0], query_class_list[1])
|
||||||
if isinstance(result, tuple) and len(result) >= 3:
|
if isinstance(result, tuple) and len(result) >= 3:
|
||||||
is_correct, error_reason, confidence_score = result[:3]
|
is_correct, error_reason, classification_confidence = result[:3]
|
||||||
|
confidence_score = max(confidence_score, classification_confidence)
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
logging.info(f" 问题分类验证结果: {'通过' if is_correct else '不通过'}, 置信度: {confidence_score:.2f}")
|
logging.info(f" 问题分类验证结果: {'通过' if is_correct else '不通过'}, 置信度: {classification_confidence:.2f}")
|
||||||
if not is_correct:
|
if not is_correct:
|
||||||
logging.info(f" 错误原因: {error_reason}")
|
logging.info(f" 错误原因: {error_reason}")
|
||||||
|
|
||||||
if not is_correct:
|
if not is_correct:
|
||||||
return index, False, "问题分类", error_reason, confidence_score
|
return index, False, "问题分类", error_reason, classification_confidence
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -416,13 +415,6 @@ class ExcelDataValidator:
|
|||||||
logging.error(error_msg, exc_info=True)
|
logging.error(error_msg, exc_info=True)
|
||||||
return index, False, "处理错误", error_msg, 0.0
|
return index, False, "处理错误", error_msg, 0.0
|
||||||
|
|
||||||
def process_batch(self, llm, batch_data):
|
|
||||||
"""处理一批数据"""
|
|
||||||
results = []
|
|
||||||
for row_data in batch_data:
|
|
||||||
results.append(self.validate_row(llm, row_data))
|
|
||||||
return results
|
|
||||||
|
|
||||||
def create_llm_instances(self, count):
|
def create_llm_instances(self, count):
|
||||||
"""创建多个LLM实例"""
|
"""创建多个LLM实例"""
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
@@ -437,7 +429,7 @@ class ExcelDataValidator:
|
|||||||
|
|
||||||
return [OpenAiLLM(**llm_params) for _ in range(count)]
|
return [OpenAiLLM(**llm_params) for _ in range(count)]
|
||||||
|
|
||||||
def validate(self, input_file=None, output_file=None, workers=None, batch_size=None, debug=None):
|
def validate(self, input_file=None, output_file=None, workers=None, debug=None):
|
||||||
"""
|
"""
|
||||||
执行验证过程
|
执行验证过程
|
||||||
|
|
||||||
@@ -445,7 +437,7 @@ class ExcelDataValidator:
|
|||||||
input_file: 输入Excel文件路径
|
input_file: 输入Excel文件路径
|
||||||
output_file: 输出结果Excel文件路径
|
output_file: 输出结果Excel文件路径
|
||||||
workers: 并行工作线程数
|
workers: 并行工作线程数
|
||||||
batch_size: 每批处理的行数
|
batch_size: 每批处理的行数(已弃用,保留参数保持兼容)
|
||||||
debug: 是否启用调试模式(串行处理)
|
debug: 是否启用调试模式(串行处理)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -454,7 +446,6 @@ class ExcelDataValidator:
|
|||||||
input_file = input_file or self.input_file
|
input_file = input_file or self.input_file
|
||||||
output_file = output_file or self.output_file
|
output_file = output_file or self.output_file
|
||||||
workers = workers or self.workers
|
workers = workers or self.workers
|
||||||
batch_size = batch_size or self.batch_size
|
|
||||||
debug = debug if debug is not None else self.debug
|
debug = debug if debug is not None else self.debug
|
||||||
|
|
||||||
# 读取数据
|
# 读取数据
|
||||||
@@ -492,21 +483,20 @@ class ExcelDataValidator:
|
|||||||
# 输出当前结果
|
# 输出当前结果
|
||||||
logging.info(f"行 {index} 验证结果: {'通过' if is_correct else '不通过'}, 错误环节: {error_phase}, 错误原因: {error_reason}, 置信度: {confidence_score:.2f}")
|
logging.info(f"行 {index} 验证结果: {'通过' if is_correct else '不通过'}, 错误环节: {error_phase}, 错误原因: {error_reason}, 置信度: {confidence_score:.2f}")
|
||||||
else:
|
else:
|
||||||
# 正常模式:并行处理
|
# 正常模式:并行处理,每行单独处理
|
||||||
batches = [all_rows[i:i+batch_size] for i in range(0, len(all_rows), batch_size)]
|
llm_instances = self.create_llm_instances(min(workers, len(all_rows)))
|
||||||
llm_instances = self.create_llm_instances(min(workers, len(batches)))
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
||||||
# 为每个批次分配一个LLM实例
|
# 为每行分配一个LLM实例
|
||||||
future_to_batch = {
|
future_to_row = {
|
||||||
executor.submit(self.process_batch, llm_instances[i % len(llm_instances)], batch):
|
executor.submit(self.validate_row, llm_instances[i % len(llm_instances)], row_data):
|
||||||
i for i, batch in enumerate(batches)
|
i for i, row_data in enumerate(all_rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
# 使用tqdm显示进度条
|
# 使用tqdm显示进度条
|
||||||
for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches), desc="批次处理进度"):
|
for future in tqdm(concurrent.futures.as_completed(future_to_row), total=len(all_rows), desc="处理进度"):
|
||||||
batch_results = future.result()
|
result = future.result()
|
||||||
all_results.extend(batch_results)
|
all_results.append(result)
|
||||||
|
|
||||||
# 按行索引排序结果,确保与原始数据顺序一致
|
# 按行索引排序结果,确保与原始数据顺序一致
|
||||||
all_results.sort(key=lambda x: x[0])
|
all_results.sort(key=lambda x: x[0])
|
||||||
@@ -558,16 +548,14 @@ class ExcelDataValidator:
|
|||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
"""主函数"""
|
||||||
# 解析命令行参数
|
# 解析命令行参数
|
||||||
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "1500条点踩软件问题测试_检索结果.xlsx")
|
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "1500条点踩软件问题测试_意图分类.xlsx")
|
||||||
output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx")
|
output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写")
|
parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写")
|
||||||
parser.add_argument("--input", "-i", type=str, help="输入Excel文件路径", default=input_excel)
|
parser.add_argument("--input", "-i", type=str, help="输入Excel文件路径", default=input_excel)
|
||||||
parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel)
|
parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel)
|
||||||
parser.add_argument("--workers", "-w", type=int, default=20, help="并行工作线程数")
|
parser.add_argument("--workers", "-w", type=int, default=20, help="并行工作线程数")
|
||||||
parser.add_argument("--batch-size", "-b", type=int, default=5, help="每批处理的行数")
|
logging.info(f"输入文件路径: {args.input}, 输出文件路径: {args.output}, 并行工作线程数: {args.workers}")
|
||||||
parser.add_argument("--debug", "-d", action="store_true", help="启用调试模式(串行处理)")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||||
|
|
||||||
@@ -576,7 +564,6 @@ def main():
|
|||||||
input_file=args.input,
|
input_file=args.input,
|
||||||
output_file=args.output,
|
output_file=args.output,
|
||||||
workers=args.workers,
|
workers=args.workers,
|
||||||
batch_size=args.batch_size,
|
|
||||||
debug=is_debug
|
debug=is_debug
|
||||||
)
|
)
|
||||||
validator.validate()
|
validator.validate()
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class SiliconFlowReRankerModel:
|
|||||||
results = response.json()
|
results = response.json()
|
||||||
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logging.error(f"重排序请求失败: {str(e)}")
|
logging.error(f"重排序请求失败: {str(e)}", exc_info=True)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
class XinferenceReRankerModel:
|
class XinferenceReRankerModel:
|
||||||
|
|||||||
Reference in New Issue
Block a user