Compare commits
3 Commits
644b745ab2
...
eaf39e7bae
| Author | SHA1 | Date | |
|---|---|---|---|
| eaf39e7bae | |||
| a1739cb703 | |||
| c28fe97dcc |
+11
-8
@@ -1,11 +1,7 @@
|
||||
sk-kvgfuqeqvpmfsccykyoohheshclcrtvjlnewratvrjpkpbkc
|
||||
sk-mxbapcczwjsyrictwgigxgcvdgptyfrlynrewqioegqwrggv
|
||||
sk-dujjzxrknevesbagqgqmuffxsosjoueviubnmodoormlmlzt
|
||||
sk-lsukfggzghmdhtfhqcbmlfqabbtapwpuxnvtwshqqqlaesie
|
||||
sk-aulumxzhvaladchcwgmsxidtdsvzytbpvzqgfuvcxlwbwcgl
|
||||
sk-otxxemniwhxkdvroszmmkitswwuykosnqoldrkzdoflqpgvw
|
||||
sk-zlruqobfdbjebyyvkmehakpcvfgnlfbdlbfrepusazzckbnv
|
||||
sk-zryimztrlkgvcaiolarhvbcewmhwruhqfcndbylonzlqvdox
|
||||
sk-rczjqufgdisqplkrmvhaxmdgcboluvxympvzljlreuqeeviq
|
||||
sk-xfnvcksdgwufsktvmhpqrwpgovsxxtaeehtxnaqjtxmubqzl
|
||||
sk-wwonvjnowbcxmoyoluynnkjwerghspzdulyidskunkordaft
|
||||
@@ -16,10 +12,8 @@ sk-lvtdgodiaurqyiwdxtdrgxifguychhccqlqkhqctscvqbfgi
|
||||
sk-aedlbtlmqcttxwnvlfmxzaysamocamqxjceoyqjfgpcowybw
|
||||
sk-fahdvndjblyjlizamvwcrxnilsgmbgbvwssxgquhkezgpqne
|
||||
sk-tzludgttzxvpvwayazdbppbauvathdtccafjrhojpemucgyi
|
||||
sk-hrbroidbfusidwnsmxenuzljxgdzzxiimlezygxplavnxjik
|
||||
sk-ylgoiqxmtxeojdnonthxtweungyzldaqarvjxlqyztlvyrff
|
||||
sk-asuqbqwdhjcqnvtjlwufyrkrwkobnrbmukzarvcctsgjipdp
|
||||
sk-dpgpymiydutoexgvkajwgahagnfmcqzafwulccudnzvleifz
|
||||
sk-nbksjgcngsayoumnsdbkcpnqivnvxjenwpzuazzrkhnsgeoo
|
||||
sk-jgybgyayxlwoxeijgrjcneqlyusleohgbliuwpsuhocrjsmk
|
||||
sk-wzjsmwxcbbpcrqivqfzjwufqqjtlwejtncnvbpeicznkwiuh
|
||||
@@ -28,7 +22,6 @@ sk-fcsfmyivfuojsqsditvobfqprdpeunukycpcfnoxkraqevpx
|
||||
sk-szyjgyxrcvyxpvzfwgmbxnflxngxvcplitcctsdvvrqjgftk
|
||||
sk-jzbodthsnvjwbyrnynsxrudtqfnbdbrcxebjwjgajocnzqse
|
||||
sk-fxepossfzpmccibfwqpkluorzqlbtcaplepeugtfzfsctcbl
|
||||
sk-ympnflocrkxjrbubsxqdjqwicuyavvvysctlpfhunkcrzxjx
|
||||
sk-flhqvziknntednkcgjaxlyzzsrfzjhrzrmteqonajpbiinni
|
||||
sk-xfregpbbquqbxpiobjzanydsjivrjrnbokzxcqtnhxhyghhe
|
||||
sk-jrdzerhmvrtvzawkksowbgkggkubwfquplmrxbdhespqgtis
|
||||
@@ -112,4 +105,14 @@ sk-hpgusoznbejkugxugsstyezaihqatwmwiwelwjwudekoxlbq
|
||||
sk-yyqtojpqfrtkyvmhbjvnzhuujzgvitqpuxkytgxspwqscptz
|
||||
sk-vcwenoaiegiwqhdvalxrvttwqmcrudttpfqhlvtdocsxvvob
|
||||
sk-npnyztbrdaqtbnldlvfelkfbfozykuwykkkfsrsisytusymj
|
||||
sk-rjaadeqsgoskclkuqkkyjxdknngrdsjwlgucqutiskroprzk
|
||||
sk-rjaadeqsgoskclkuqkkyjxdknngrdsjwlgucqutiskroprzk
|
||||
sk-uollmeyatyiwfzszvxkpyndmzfrbqjpyixewmrastbmaqbhy
|
||||
sk-xdlsjytiwilvodadkjxvwdgulhhdytkqvfpyrcnllclgzqkb
|
||||
sk-ffkltifkylutornjhwmnmfjsqsywrjibvujhjtjctzgnkvlp
|
||||
sk-vmwocqqjqxnsvzmeyvqskahjaclifpmsbhywvnrvwygkfyuj
|
||||
sk-gzwkmzxeeunaywrdrgirdatqhdtqdgvzqpesvprwbbjhcchn
|
||||
sk-duchutcxmygrnkhzmmlykvtzwaylqtdxfbbuhvfvzuapazii
|
||||
sk-nlddwexmjxqtgdvahwvlotnomrzcgskxeakxkxauicknzfkp
|
||||
sk-lopwluipwvilwpwztvaxfebueeyilefwgncgpeprqvwazxom
|
||||
sk-rgwrklpvhhrluokkbgavzukuhhpfhqzmozpjzoezfhkxyorc
|
||||
sk-cdrpglnfmyeeqyhtvxvkpcpwscsbfouwkagjpphuksfzeipy
|
||||
@@ -9524,7 +9524,7 @@
|
||||
"description": "2020年发布的定额标准"
|
||||
},
|
||||
{
|
||||
"name": "技改检修计价通T1软件",
|
||||
"name": "技改检修工程计价通T1软件",
|
||||
"synonymous": [
|
||||
"技改T1"
|
||||
],
|
||||
|
||||
Binary file not shown.
+1
-1
@@ -1,5 +1,5 @@
|
||||
[project]
|
||||
name = "queryrewrite"
|
||||
name = "rag2_0"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
|
||||
@@ -8,13 +8,17 @@ import concurrent.futures
|
||||
from functools import wraps
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
import traceback
|
||||
import re
|
||||
import logging
|
||||
|
||||
# 将项目根目录添加到Python路径
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# 配置日志
|
||||
@@ -31,7 +35,7 @@ logger = logging.getLogger("dialogue_to_workorder")
|
||||
# ================ 模型定义 ================
|
||||
class UserQuestionAndSolution(BaseModel):
|
||||
user_question: str = Field(description="用户的核心问题")
|
||||
solution: str = Field(description="坐席提供的解决方案")
|
||||
solution: str = Field(description="坐席提供的解决方案,解决方案如果存在多个步骤,使用中文分号隔开")
|
||||
|
||||
class UserQuestionAndSolutionList(BaseModel):
|
||||
user_question_list: list[UserQuestionAndSolution] = Field(description="客户问题列表")
|
||||
@@ -236,7 +240,7 @@ class DialogueToWorkorder:
|
||||
```json
|
||||
{{
|
||||
"user_question": "技改软件打开报错",
|
||||
"solution": "1、告知报错原因 2、通过远程辅助解决"
|
||||
"solution": "1、告知报错原因;2、通过远程辅助解决"
|
||||
}}
|
||||
```
|
||||
=======对话记录如下所示=======
|
||||
@@ -426,14 +430,15 @@ class DialogueToWorkorder:
|
||||
# 分析用户问题和解决方案
|
||||
user_question_list = self.get_user_question_and_solution(conversation_rows)
|
||||
|
||||
# 获取第一个问题和解决方案,用于后续分析
|
||||
if user_question_list and len(user_question_list) > 0:
|
||||
first_question = user_question_list[0]
|
||||
user_question_str = first_question.user_question
|
||||
solution_str = first_question.solution
|
||||
else:
|
||||
user_question_str = ""
|
||||
solution_str = ""
|
||||
user_question_str=""
|
||||
for user_question in user_question_list:
|
||||
user_question_str = user_question_str + user_question.user_question.strip() + "\n"
|
||||
user_question_str = user_question_str.strip()
|
||||
|
||||
solution_str=""
|
||||
for user_question in user_question_list:
|
||||
solution_str = solution_str + user_question.solution.strip() + "\n"
|
||||
solution_str = solution_str.strip()
|
||||
|
||||
# 分析是否抱怨、是否投诉、抱怨级别
|
||||
is_dissatisfaction, dissatisfaction_level, dissatisfaction_reasoning, is_complaint = (
|
||||
@@ -455,15 +460,9 @@ class DialogueToWorkorder:
|
||||
# 创建工单列表
|
||||
workorder_list = []
|
||||
|
||||
for user_question in user_question_list:
|
||||
user_question_str = user_question.user_question
|
||||
solution_str = user_question.solution
|
||||
|
||||
# 创建新的工单字典,复制基本信息
|
||||
workorder_dict = base_workorder_dict.copy()
|
||||
|
||||
# 更新工单字典
|
||||
workorder_dict.update({
|
||||
# 更新工单字典
|
||||
base_workorder_dict.update({
|
||||
"产品线": product_line,
|
||||
"产品名称": product_name,
|
||||
"模块名称": module_name,
|
||||
@@ -474,9 +473,29 @@ class DialogueToWorkorder:
|
||||
"是否投诉": "是" if is_complaint else '否',
|
||||
"解决方案": (solution_str + '\n存在抱怨:' + dissatisfaction_reasoning) if is_dissatisfaction else solution_str
|
||||
})
|
||||
workorder_list.append(base_workorder_dict)
|
||||
# for user_question in user_question_list:
|
||||
# user_question_str = user_question.user_question
|
||||
# solution_str = user_question.solution
|
||||
|
||||
# 将工单添加到列表中
|
||||
workorder_list.append(workorder_dict)
|
||||
# # 创建新的工单字典,复制基本信息
|
||||
# workorder_dict = base_workorder_dict.copy()
|
||||
|
||||
# # 更新工单字典
|
||||
# workorder_dict.update({
|
||||
# "产品线": product_line,
|
||||
# "产品名称": product_name,
|
||||
# "模块名称": module_name,
|
||||
# "客户问题": user_question_str,
|
||||
# "问题类型": problem_type,
|
||||
# "是否抱怨": "是" if is_dissatisfaction else '否',
|
||||
# "抱怨级别": dissatisfaction_level if is_dissatisfaction else '',
|
||||
# "是否投诉": "是" if is_complaint else '否',
|
||||
# "解决方案": (solution_str + '\n存在抱怨:' + dissatisfaction_reasoning) if is_dissatisfaction else solution_str
|
||||
# })
|
||||
|
||||
# # 将工单添加到列表中
|
||||
# workorder_list.append(workorder_dict)
|
||||
|
||||
return workorder_list
|
||||
|
||||
|
||||
@@ -8,8 +8,6 @@ Description: 意图识别和问题改写示例
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from regex import F
|
||||
from rag2_0.intent_recognition import IntentRecognizer
|
||||
import pandas as pd
|
||||
import logging
|
||||
import json
|
||||
@@ -19,6 +17,8 @@ import time
|
||||
import sys
|
||||
import argparse
|
||||
from typing import List, Dict
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.intent_recognition import IntentRecognizer
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
@@ -176,7 +176,7 @@ def save_results_to_excel(results, output_file, is_final=False):
|
||||
logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}")
|
||||
|
||||
# 示例查询
|
||||
examples_query = """.BDD3是哪款软件编制的"""
|
||||
examples_query = """D3软件结算工程怎么解锁清单"""
|
||||
conversation_context=""
|
||||
chat_history=[
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@ Description: 使用LLM批量验证Excel数据中的问题分类、问题拆解
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
import json
|
||||
import argparse
|
||||
@@ -14,13 +15,25 @@ import concurrent.futures
|
||||
from tqdm import tqdm
|
||||
from dotenv import load_dotenv
|
||||
from langchain_openai import ChatOpenAI
|
||||
from rag2_0.intent_recognition.PromptTemplates import classification
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.intent_recognition.PromptTemplates import classification_info
|
||||
from rag2_0.intent_recognition.DataModels import *
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
|
||||
|
||||
# 定义验证结果的Pydantic模型
|
||||
class ValidationResult(BaseModel):
|
||||
is_correct: bool = Field(description="验证是否通过")
|
||||
confidence_score: float = Field(description="置信度得分")
|
||||
reason: str = Field(default="", description="得出结论的原因")
|
||||
|
||||
class ExcelDataValidator:
|
||||
"""Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写"""
|
||||
|
||||
def __init__(self, input_file=None, output_file=None, workers=4, batch_size=10):
|
||||
def __init__(self, input_file=None, output_file=None, workers=4, batch_size=10, debug=False):
|
||||
"""
|
||||
初始化验证器
|
||||
|
||||
@@ -29,6 +42,7 @@ class ExcelDataValidator:
|
||||
output_file: 输出结果Excel文件路径
|
||||
workers: 并行工作线程数
|
||||
batch_size: 每批处理的行数
|
||||
debug: 是否启用调试模式(串行处理)
|
||||
"""
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
@@ -37,6 +51,7 @@ class ExcelDataValidator:
|
||||
self.output_file = output_file
|
||||
self.workers = workers
|
||||
self.batch_size = batch_size
|
||||
self.debug = debug
|
||||
self.df = None
|
||||
|
||||
# 设置日志
|
||||
@@ -71,7 +86,7 @@ class ExcelDataValidator:
|
||||
|
||||
try:
|
||||
df = pd.read_excel(file_path)
|
||||
required_columns = ["提问", "问题拆解", "一级分类", "二级分类", "问题改写", "检索的关键词"]
|
||||
required_columns = ["问题", "问题分类", "问题改写", "槽点信息", "检索的内容"]
|
||||
for col in required_columns:
|
||||
if col not in df.columns:
|
||||
logging.error(f"缺少必要的列: {col}")
|
||||
@@ -94,110 +109,168 @@ class ExcelDataValidator:
|
||||
sub_class: 二级分类
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
||||
"""
|
||||
parser = self.create_validation_parser()
|
||||
format_instructions = parser.get_format_instructions()
|
||||
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用,也可能涉及电力造价专业知识。我对用户问题进行了分类,请评估以下问题分类是否正确。
|
||||
|
||||
我目前总共有以下分类:
|
||||
{classification}
|
||||
{classification_info}
|
||||
|
||||
问题的分类情况如下:
|
||||
原始问题: {query}
|
||||
一级分类: {vertical_class}
|
||||
二级分类: {sub_class}
|
||||
|
||||
请从专业角度分析这个分类是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
请从专业角度分析这个分类是否准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
||||
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
result = parser.parse(response.content)
|
||||
return result.is_correct, result.reason, result.confidence_score
|
||||
except Exception as e:
|
||||
logging.warning(f"验证问题分类时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
return False, f"验证过程出错: {str(e)}", 0.0
|
||||
|
||||
def validate_query_keys(self, llm, query, query_keys):
|
||||
def _get_slot_model(self, classification: Classification) -> Optional[type]:
|
||||
"""
|
||||
验证问题拆解是否正确
|
||||
根据分类结果获取对应的槽位模型类,用于统一提示词处理
|
||||
|
||||
Args:
|
||||
classification: 意图分类结果
|
||||
|
||||
Returns:
|
||||
对应的槽位模型类
|
||||
"""
|
||||
# 软件问题
|
||||
if classification.vertical_classification == "软件问题":
|
||||
if classification.sub_classification == "软件功能":
|
||||
return SoftwareFunctionSlots
|
||||
elif classification.sub_classification == "故障排查":
|
||||
return SoftwareTroubleShootingSlots
|
||||
|
||||
# 业务问题
|
||||
elif classification.vertical_classification == "业务问题":
|
||||
if classification.sub_classification == "专业咨询":
|
||||
return ProfessionalConsultingSlots
|
||||
elif classification.sub_classification == "数据问题":
|
||||
return DataProblemSlots
|
||||
|
||||
# 安装下载注册
|
||||
elif classification.vertical_classification == "安装下载注册":
|
||||
if classification.sub_classification == "后缀名咨询":
|
||||
return FileExtensionConsultingSlots
|
||||
elif classification.sub_classification == "软件锁类":
|
||||
return SoftwareLockSlots
|
||||
elif classification.sub_classification == "安装下载类":
|
||||
return InstallationDownloadSlots
|
||||
elif classification.sub_classification == "问题排查类":
|
||||
return ProblemDiagnosisSlots
|
||||
|
||||
# 其他
|
||||
elif classification.vertical_classification == "其他":
|
||||
return OtherSlots
|
||||
|
||||
return None
|
||||
|
||||
def validate_slot(self, llm, rewrite, slot_info, vertical_class, sub_class):
|
||||
"""
|
||||
验证槽位填充是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
query_keys: 问题拆解
|
||||
rewrite: 问题改写
|
||||
slot_info: 槽位信息(JSON字符串)
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
||||
"""
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了拆解,请评估以下问题拆解是否正确。
|
||||
|
||||
原始问题: {query}
|
||||
问题拆解: {query_keys}
|
||||
|
||||
问题拆解应该准确提取原始问题中的关键词和信息。请分析这个拆解是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
|
||||
# 解析槽位信息JSON
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
if isinstance(slot_info, str) and slot_info.strip():
|
||||
slots = json.loads(slot_info)
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
except Exception as e:
|
||||
logging.warning(f"验证问题拆解时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
|
||||
def validate_keywords(self, llm, query, query_keys, keywords_str):
|
||||
"""
|
||||
验证检索关键词是否准确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
query_keys: 问题拆解
|
||||
keywords_str: 检索关键词(JSON字符串)
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
"""
|
||||
# 解析关键词JSON
|
||||
try:
|
||||
if isinstance(keywords_str, str) and keywords_str.strip():
|
||||
keywords = json.loads(keywords_str)
|
||||
else:
|
||||
keywords = []
|
||||
slots = slot_info
|
||||
except:
|
||||
keywords = keywords_str
|
||||
slots = slot_info
|
||||
|
||||
parser = self.create_validation_parser()
|
||||
format_instructions = parser.get_format_instructions()
|
||||
slot_info_prompt = self._get_slot_model(Classification(vertical_classification=vertical_class, sub_classification=sub_class)).model_json_schema()
|
||||
slot_info_prompt = json.dumps(slot_info_prompt, ensure_ascii=False)
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。通过问题检索出了一些关键词,请评估这些关键词是否准确,是否与问题相关
|
||||
原始问题: {query}
|
||||
问题拆解: {query_keys}
|
||||
检索关键词: {keywords}
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我从用户问题中提取了槽位信息,请评估这些槽位信息是否准确、完整。
|
||||
|
||||
检索关键词应该准确反映问题中需要检索的关键概念和术语。请分析这些关键词是否准确、完整。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
问题改写: {rewrite}
|
||||
槽位模板:{slot_info_prompt}
|
||||
|
||||
填充的槽位信息: {slots}
|
||||
|
||||
槽位信息应该准确提取问题中的关键实体和属性,如软件名称、功能名称、错误信息等。请分析这些槽位是否准确填充,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
||||
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
result = parser.parse(response.content)
|
||||
return result.is_correct, result.reason, result.confidence_score
|
||||
except Exception as e:
|
||||
logging.warning(f"验证检索关键词时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
|
||||
logging.warning(f"验证槽位填充时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}", 0.0
|
||||
|
||||
def validate_retrieve_content(self, llm, rewrite, retrieve_content):
|
||||
"""
|
||||
验证检索内容是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
rewrite: 问题改写
|
||||
retrieve_content: 检索内容(可能是JSON字符串或文本)
|
||||
|
||||
Returns:
|
||||
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
||||
"""
|
||||
# 解析检索内容
|
||||
try:
|
||||
if isinstance(retrieve_content, str) and retrieve_content.strip():
|
||||
if retrieve_content.startswith('{') or retrieve_content.startswith('['):
|
||||
content = json.loads(retrieve_content)
|
||||
else:
|
||||
content = retrieve_content
|
||||
else:
|
||||
content = retrieve_content
|
||||
except:
|
||||
content = retrieve_content
|
||||
|
||||
parser = self.create_validation_parser()
|
||||
format_instructions = parser.get_format_instructions()
|
||||
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我针对用户问题检索了相关内容,请评估这些检索内容是否与问题相关、是否准确。
|
||||
|
||||
问题改写: {rewrite}
|
||||
检索内容: {content}
|
||||
|
||||
检索内容应该与问题主题相关,能够提供有用的信息来回答问题。请分析检索内容是否相关、准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
||||
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = parser.parse(response.content)
|
||||
return result.is_correct, result.reason, result.confidence_score
|
||||
except Exception as e:
|
||||
logging.warning(f"验证检索内容时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}", 0.0
|
||||
|
||||
def validate_rewrite(self, llm, query, rewrite):
|
||||
"""
|
||||
验证问题改写是否正确
|
||||
@@ -208,28 +281,29 @@ class ExcelDataValidator:
|
||||
rewrite: 问题改写
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
||||
"""
|
||||
parser = self.create_validation_parser()
|
||||
format_instructions = parser.get_format_instructions()
|
||||
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了改写,请评估以下问题改写是否正确。
|
||||
|
||||
原始问题: {query}
|
||||
问题改写: {rewrite}
|
||||
|
||||
问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
||||
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
result = parser.parse(response.content)
|
||||
return result.is_correct, result.reason, result.confidence_score
|
||||
except Exception as e:
|
||||
logging.warning(f"验证问题改写时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
return False, f"验证过程出错: {str(e)}", 0.0
|
||||
|
||||
def validate_row(self, llm, row_data):
|
||||
"""
|
||||
@@ -240,42 +314,110 @@ class ExcelDataValidator:
|
||||
row_data: (index, row)元组
|
||||
|
||||
Returns:
|
||||
(index, is_all_correct, error_phase, error_reason): 行索引,是否全部正确,错误环节,错误原因
|
||||
(index, is_all_correct, error_phase, error_reason, confidence_score): 行索引,是否全部正确,错误环节,错误原因,置信度
|
||||
"""
|
||||
index, row = row_data
|
||||
query = row["提问"]
|
||||
query_keys = row["问题拆解"]
|
||||
vertical_class = row["一级分类"]
|
||||
sub_class = row["二级分类"]
|
||||
rewrite = row["问题改写"]
|
||||
keywords_str = row["检索的关键词"]
|
||||
query = row["问题"]
|
||||
query_class = row.get("问题分类", "")
|
||||
rewrite = row.get("问题改写", "")
|
||||
slot_info = row.get("槽点信息", "")
|
||||
retrieve_content = row.get("检索的内容", "")
|
||||
|
||||
if self.debug:
|
||||
logging.info(f"开始验证行 {index}:")
|
||||
logging.info(f" 问题: {query}")
|
||||
logging.info(f" 问题分类: {query_class}")
|
||||
logging.info(f" 问题改写: {rewrite}")
|
||||
|
||||
try:
|
||||
# 1. 验证问题分类
|
||||
is_correct, error_reason = self.validate_classification(llm, query, vertical_class, sub_class)
|
||||
if not is_correct:
|
||||
return index, False, "问题分类", error_reason
|
||||
|
||||
confidence_score = 0.0
|
||||
# 1. 验证问题改写
|
||||
if rewrite:
|
||||
if self.debug:
|
||||
logging.info(f" 验证问题改写...")
|
||||
|
||||
result = self.validate_rewrite(llm, query, rewrite)
|
||||
if isinstance(result, tuple) and len(result) >= 3:
|
||||
is_correct, error_reason, rewrite_confidence = result[:3]
|
||||
confidence_score = max(confidence_score, rewrite_confidence)
|
||||
|
||||
if self.debug:
|
||||
logging.info(f" 问题改写验证结果: {'通过' if is_correct else '不通过'}, 置信度: {rewrite_confidence:.2f}")
|
||||
if not is_correct:
|
||||
logging.info(f" 错误原因: {error_reason}")
|
||||
|
||||
if not is_correct:
|
||||
return index, False, "问题改写", error_reason, rewrite_confidence
|
||||
|
||||
# 2. 验证问题分类
|
||||
if query_class:
|
||||
if self.debug:
|
||||
logging.info(f" 验证问题分类...")
|
||||
|
||||
query_class_list = query_class.split(" - ")
|
||||
if len(query_class_list) >= 2:
|
||||
result = self.validate_classification(llm, rewrite, query_class_list[0], query_class_list[1])
|
||||
if isinstance(result, tuple) and len(result) >= 3:
|
||||
is_correct, error_reason, confidence_score = result[:3]
|
||||
|
||||
if self.debug:
|
||||
logging.info(f" 问题分类验证结果: {'通过' if is_correct else '不通过'}, 置信度: {confidence_score:.2f}")
|
||||
if not is_correct:
|
||||
logging.info(f" 错误原因: {error_reason}")
|
||||
|
||||
if not is_correct:
|
||||
return index, False, "问题分类", error_reason, confidence_score
|
||||
|
||||
# 2. 验证问题拆解
|
||||
is_correct, error_reason = self.validate_query_keys(llm, query, query_keys)
|
||||
if not is_correct:
|
||||
return index, False, "问题拆解", error_reason
|
||||
|
||||
|
||||
# 3. 验证检索关键词
|
||||
is_correct, error_reason = self.validate_keywords(llm, query, query_keys, keywords_str)
|
||||
if not is_correct:
|
||||
return index, False, "关键词检索", error_reason
|
||||
# 3. 验证槽位填充
|
||||
if slot_info:
|
||||
if self.debug:
|
||||
logging.info(f" 验证槽位填充...")
|
||||
|
||||
result = self.validate_slot(llm, rewrite, slot_info, query_class_list[0], query_class_list[1])
|
||||
if isinstance(result, tuple) and len(result) >= 3:
|
||||
is_correct, error_reason, slot_confidence = result[:3]
|
||||
confidence_score = max(confidence_score, slot_confidence)
|
||||
|
||||
if self.debug:
|
||||
logging.info(f" 槽位填充验证结果: {'通过' if is_correct else '不通过'}, 置信度: {slot_confidence:.2f}")
|
||||
if not is_correct:
|
||||
logging.info(f" 错误原因: {error_reason}")
|
||||
|
||||
if not is_correct:
|
||||
return index, False, "槽位填充", error_reason, slot_confidence
|
||||
|
||||
# 4. 验证问题改写
|
||||
is_correct, error_reason = self.validate_rewrite(llm, query, rewrite)
|
||||
if not is_correct:
|
||||
return index, False, "问题改写", error_reason
|
||||
# 4. 验证检索内容
|
||||
if retrieve_content:
|
||||
if self.debug:
|
||||
logging.info(f" 验证检索内容...")
|
||||
|
||||
result = self.validate_retrieve_content(llm, rewrite, retrieve_content)
|
||||
if isinstance(result, tuple) and len(result) >= 3:
|
||||
is_correct, error_reason, retrieve_confidence = result[:3]
|
||||
confidence_score = max(confidence_score, retrieve_confidence)
|
||||
|
||||
if self.debug:
|
||||
logging.info(f" 检索内容验证结果: {'通过' if is_correct else '不通过'}, 置信度: {retrieve_confidence:.2f}")
|
||||
if not is_correct:
|
||||
logging.info(f" 错误原因: {error_reason}")
|
||||
|
||||
if not is_correct:
|
||||
return index, False, "检索内容", error_reason, retrieve_confidence
|
||||
|
||||
if self.debug:
|
||||
logging.info(f" 行 {index} 验证完成: 通过, 总置信度: {confidence_score:.2f}")
|
||||
|
||||
return index, True, "", ""
|
||||
return index, True, "", "", confidence_score
|
||||
except Exception as e:
|
||||
error_msg = f"处理行 {index} 时发生错误: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
return index, False, "处理错误", error_msg
|
||||
if self.debug:
|
||||
import traceback
|
||||
logging.error(traceback.format_exc())
|
||||
return index, False, "处理错误", error_msg, 0.0
|
||||
|
||||
def process_batch(self, llm, batch_data):
|
||||
"""处理一批数据"""
|
||||
@@ -288,17 +430,17 @@ class ExcelDataValidator:
|
||||
"""创建多个LLM实例"""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
model_name = "deepseek-ai/DeepSeek-R1"
|
||||
|
||||
llm_params = {"temperature": 0.7, "model": model_name}
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
|
||||
|
||||
return [OpenAiLLM(**llm_params) for _ in range(count)]
|
||||
|
||||
def validate(self, input_file=None, output_file=None, workers=None, batch_size=None):
|
||||
def validate(self, input_file=None, output_file=None, workers=None, batch_size=None, debug=None):
|
||||
"""
|
||||
执行验证过程
|
||||
|
||||
@@ -307,6 +449,7 @@ class ExcelDataValidator:
|
||||
output_file: 输出结果Excel文件路径
|
||||
workers: 并行工作线程数
|
||||
batch_size: 每批处理的行数
|
||||
debug: 是否启用调试模式(串行处理)
|
||||
|
||||
Returns:
|
||||
验证后的DataFrame
|
||||
@@ -315,6 +458,7 @@ class ExcelDataValidator:
|
||||
output_file = output_file or self.output_file
|
||||
workers = workers or self.workers
|
||||
batch_size = batch_size or self.batch_size
|
||||
debug = debug if debug is not None else self.debug
|
||||
|
||||
# 读取数据
|
||||
df = self.load_data_from_excel(input_file)
|
||||
@@ -325,36 +469,64 @@ class ExcelDataValidator:
|
||||
df["验证结果"] = ""
|
||||
df["错误环节"] = ""
|
||||
df["错误原因"] = ""
|
||||
df["置信度"] = 0.0
|
||||
|
||||
# 准备数据批次
|
||||
# 准备数据
|
||||
all_rows = list(df.iterrows())
|
||||
batches = [all_rows[i:i+batch_size] for i in range(0, len(all_rows), batch_size)]
|
||||
|
||||
# 创建多个LLM实例
|
||||
llm_instances = self.create_llm_instances(min(workers, len(batches)))
|
||||
# 创建LLM实例
|
||||
llm = self.create_llm_instances(1)[0]
|
||||
|
||||
# 使用线程池处理数据
|
||||
# 根据模式选择处理方式
|
||||
all_results = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# 为每个批次分配一个LLM实例
|
||||
future_to_batch = {
|
||||
executor.submit(self.process_batch, llm_instances[i % len(llm_instances)], batch):
|
||||
i for i, batch in enumerate(batches)
|
||||
}
|
||||
if debug:
|
||||
# 调试模式:串行处理
|
||||
logging.info("启用调试模式,使用串行处理...")
|
||||
for i, row_data in enumerate(all_rows):
|
||||
logging.info(f"处理第 {i+1}/{len(all_rows)} 行...")
|
||||
result = self.validate_row(llm, row_data)
|
||||
all_results.append(result)
|
||||
# 实时更新DataFrame
|
||||
index, is_correct, error_phase, error_reason, confidence_score = result
|
||||
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
||||
df.at[index, "错误环节"] = error_phase
|
||||
df.at[index, "错误原因"] = error_reason
|
||||
df.at[index, "置信度"] = confidence_score
|
||||
# 输出当前结果
|
||||
logging.info(f"行 {index} 验证结果: {'通过' if is_correct else '不通过'}, 错误环节: {error_phase}, 错误原因: {error_reason}, 置信度: {confidence_score:.2f}")
|
||||
else:
|
||||
# 正常模式:并行处理
|
||||
batches = [all_rows[i:i+batch_size] for i in range(0, len(all_rows), batch_size)]
|
||||
llm_instances = self.create_llm_instances(min(workers, len(batches)))
|
||||
|
||||
# 使用tqdm显示进度条
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches), desc="批次处理进度"):
|
||||
batch_results = future.result()
|
||||
all_results.extend(batch_results)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# 为每个批次分配一个LLM实例
|
||||
future_to_batch = {
|
||||
executor.submit(self.process_batch, llm_instances[i % len(llm_instances)], batch):
|
||||
i for i, batch in enumerate(batches)
|
||||
}
|
||||
|
||||
# 使用tqdm显示进度条
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches), desc="批次处理进度"):
|
||||
batch_results = future.result()
|
||||
all_results.extend(batch_results)
|
||||
|
||||
# 按行索引排序结果,确保与原始数据顺序一致
|
||||
all_results.sort(key=lambda x: x[0])
|
||||
|
||||
# 将结果填充到DataFrame
|
||||
for index, is_correct, error_phase, error_reason in all_results:
|
||||
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
||||
df.at[index, "错误环节"] = error_phase
|
||||
df.at[index, "错误原因"] = error_reason
|
||||
for result in all_results:
|
||||
if len(result) >= 5:
|
||||
index, is_correct, error_phase, error_reason, confidence_score = result
|
||||
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
||||
df.at[index, "错误环节"] = error_phase
|
||||
df.at[index, "错误原因"] = error_reason
|
||||
df.at[index, "置信度"] = confidence_score
|
||||
else:
|
||||
index, is_correct, error_phase, error_reason = result
|
||||
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
||||
df.at[index, "错误环节"] = error_phase
|
||||
df.at[index, "错误原因"] = error_reason
|
||||
|
||||
# 保存结果
|
||||
if output_file is None:
|
||||
@@ -381,25 +553,34 @@ class ExcelDataValidator:
|
||||
for phase, count in error_stats.items():
|
||||
logging.info(f"- {phase}: {count} 条")
|
||||
|
||||
def create_validation_parser(self):
|
||||
"""创建验证结果解析器"""
|
||||
return PydanticOutputParser(pydantic_object=ValidationResult)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 解析命令行参数
|
||||
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "问题分类重写结果")
|
||||
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "1500条点踩软件问题测试_检索结果.xlsx")
|
||||
output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx")
|
||||
|
||||
parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写")
|
||||
parser.add_argument("--input", "-i", type=str, required=True, help="输入Excel文件路径", default=input_excel)
|
||||
parser.add_argument("--input", "-i", type=str, help="输入Excel文件路径", default=input_excel)
|
||||
parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel)
|
||||
parser.add_argument("--workers", "-w", type=int, default=2, help="并行工作线程数")
|
||||
parser.add_argument("--workers", "-w", type=int, default=20, help="并行工作线程数")
|
||||
parser.add_argument("--batch-size", "-b", type=int, default=5, help="每批处理的行数")
|
||||
parser.add_argument("--debug", "-d", action="store_true", help="启用调试模式(串行处理)")
|
||||
|
||||
args = parser.parse_args()
|
||||
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||
|
||||
# 创建验证器实例并执行验证
|
||||
validator = ExcelDataValidator(
|
||||
input_file=args.input,
|
||||
output_file=args.output,
|
||||
workers=args.workers,
|
||||
batch_size=args.batch_size
|
||||
batch_size=args.batch_size,
|
||||
debug=is_debug
|
||||
)
|
||||
validator.validate()
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ Description: 专业名词向量化和保存的示例程序
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.intent_recognition import ProfessionalNounVectorizer
|
||||
import logging
|
||||
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
from dify_client.client import ChatClient, CompletionClient, DifyClient
|
||||
|
||||
__all__ = ["ChatClient", "CompletionClient", "DifyClient"]
|
||||
|
||||
from .client import ChatClient, CompletionClient, DifyClient
|
||||
|
||||
+107
-82
@@ -7,12 +7,13 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from rag2_0.dify.dify_client import ChatClient
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
|
||||
from threading import Lock
|
||||
|
||||
class ContentSource(BaseModel):
|
||||
score: int = Field(description="相关性分数")
|
||||
reason: str = Field(description="评分理由")
|
||||
|
||||
|
||||
class PgSql:
|
||||
"""
|
||||
用于连接和操作 PostgreSQL 数据库的类。
|
||||
@@ -21,6 +22,7 @@ class PgSql:
|
||||
主要用于从 Dify 应用相关的表中获取数据。
|
||||
"""
|
||||
def __init__(self):
|
||||
self.pg_sql_lock = Lock()
|
||||
"""
|
||||
初始化 PgSql 实例并建立数据库连接。
|
||||
"""
|
||||
@@ -53,8 +55,10 @@ class PgSql:
|
||||
|
||||
如果存在活动的连接,则关闭它。
|
||||
"""
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
with self.pg_sql_lock:
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
self.connection = None
|
||||
|
||||
|
||||
def get_appinfo(self, appid:str)->dict | None:
|
||||
@@ -68,21 +72,22 @@ class PgSql:
|
||||
一个字典,其中键是列名,值是对应的应用数据。
|
||||
如果未找到应用或发生错误,则返回 None。
|
||||
"""
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM apps WHERE id = %s
|
||||
""",
|
||||
(appid,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting tenant_id by appid: {error}")
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM apps WHERE id = %s
|
||||
""",
|
||||
(appid,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting tenant_id by appid: {error}")
|
||||
|
||||
|
||||
def get_messages_info(self, appid:str, query:str)->dict | None:
|
||||
@@ -97,41 +102,43 @@ class PgSql:
|
||||
一个字典,其中键是列名,值是对应的消息数据。
|
||||
如果未找到消息或发生错误,则返回 None。
|
||||
"""
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC
|
||||
""",
|
||||
(appid, query)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting messages_info: {error}")
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC
|
||||
""",
|
||||
(appid, query)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting messages_info: {error}")
|
||||
|
||||
def get_messages_info_by_id(self, message_id:str)->dict | None:
|
||||
"""
|
||||
根据消息 ID 从 'messages' 表中获取消息信息。
|
||||
"""
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE id = %s
|
||||
""",
|
||||
(message_id, )
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting messages_info by id: {error}")
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE id = %s
|
||||
""",
|
||||
(message_id, )
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting messages_info by id: {error}")
|
||||
|
||||
def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None:
|
||||
"""
|
||||
@@ -144,21 +151,22 @@ class PgSql:
|
||||
一个字典,其中键是列名,值是对应的节点执行数据。
|
||||
如果未找到执行信息或发生错误,则返回 None。
|
||||
"""
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s
|
||||
""",
|
||||
(workflow_run_id,)
|
||||
)
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s
|
||||
""",
|
||||
(workflow_run_id,)
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return [dict(zip(colnames, row)) for row in result]
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting workflow_node_executions_info: {error}")
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting workflow_node_executions_info: {error}")
|
||||
|
||||
class DifyTool:
|
||||
"""
|
||||
@@ -167,17 +175,30 @@ class DifyTool:
|
||||
该类利用 PgSql 类从数据库中检索与特定应用和查询相关的
|
||||
应用信息、消息详情以及工作流节点执行情况。
|
||||
"""
|
||||
@staticmethod
|
||||
def get_message_debug_info_by_id(message_id:str)->dict | None:
|
||||
|
||||
def __init__(self):
|
||||
self.dify_pgsql = PgSql()
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
析构函数,在对象被销毁时自动关闭数据库连接。
|
||||
确保在对象生命周期结束时释放数据库资源。
|
||||
"""
|
||||
try:
|
||||
self.dify_pgsql.close_connection()
|
||||
except Exception as e:
|
||||
# 析构函数中的异常不应该传播,所以这里只是简单记录
|
||||
print(f"关闭数据库连接时出错: {e}")
|
||||
|
||||
def get_message_debug_info_by_id(self, message_id:str)->dict | None:
|
||||
"""
|
||||
根据消息 ID 从 'messages' 表中获取消息信息。
|
||||
"""
|
||||
dify_pgsql = PgSql()
|
||||
try:
|
||||
messages_info = dify_pgsql.get_messages_info_by_id(message_id)
|
||||
messages_info = self.dify_pgsql.get_messages_info_by_id(message_id)
|
||||
if not messages_info:
|
||||
return None
|
||||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||
workflow_node_executions_info = self.dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||
if not workflow_node_executions_info:
|
||||
return None
|
||||
return {
|
||||
@@ -186,12 +207,8 @@ class DifyTool:
|
||||
}
|
||||
except Exception as e:
|
||||
raise Exception(f"Error in get_message_debug_info_by_id: {e}")
|
||||
finally:
|
||||
dify_pgsql.close_connection()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_message_debug_info_by_query(appid:str, query:str)->dict:
|
||||
def get_message_debug_info_by_query(self, appid:str, query:str)->dict:
|
||||
"""
|
||||
获取指定应用和查询相关的调试信息。
|
||||
|
||||
@@ -207,15 +224,14 @@ class DifyTool:
|
||||
"workflow_node_executions_info"键的字典,分别对应
|
||||
查询到的应用数据、消息数据和节点执行数据。
|
||||
"""
|
||||
dify_pgsql = PgSql()
|
||||
try:
|
||||
appinfo = dify_pgsql.get_appinfo(appid)
|
||||
appinfo = self.dify_pgsql.get_appinfo(appid)
|
||||
if not appinfo:
|
||||
return None
|
||||
messages_info = dify_pgsql.get_messages_info(appid, query)
|
||||
return None
|
||||
messages_info = self.dify_pgsql.get_messages_info(appid, query)
|
||||
if not messages_info:
|
||||
return None
|
||||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||
workflow_node_executions_info = self.dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||
if not workflow_node_executions_info:
|
||||
return None
|
||||
return {
|
||||
@@ -225,8 +241,6 @@ class DifyTool:
|
||||
}
|
||||
except Exception as e:
|
||||
raise Exception(f"Error in get_message_debug_info_by_query: {e}")
|
||||
finally:
|
||||
dify_pgsql.close_connection()
|
||||
|
||||
class BaseWorkflowChat:
|
||||
"""
|
||||
@@ -242,6 +256,14 @@ class BaseWorkflowChat:
|
||||
"""
|
||||
self.chat_client = ChatClient(api_key=api_key, base_url=base_url)
|
||||
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
|
||||
self.dify_tool = DifyTool()
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
析构函数,在对象被销毁时自动关闭数据库连接。
|
||||
确保在对象生命周期结束时释放数据库资源。
|
||||
"""
|
||||
self.dify_tool.close_connection()
|
||||
|
||||
def create_chat_message(self, query: str):
|
||||
"""
|
||||
@@ -369,6 +391,9 @@ class NewWorkflowChat(BaseWorkflowChat):
|
||||
"""
|
||||
新工作流对话类,用于调用新工作流发送对话并解析获取相关数据
|
||||
"""
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
super().__init__(api_key, base_url)
|
||||
|
||||
def process_question(self, query: str) -> dict:
|
||||
"""
|
||||
处理问题,获取新工作流的回答和相关信息
|
||||
@@ -425,7 +450,7 @@ class NewWorkflowChat(BaseWorkflowChat):
|
||||
reranker_sorce=[]
|
||||
try:
|
||||
# 先取出重排得分
|
||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||||
message_info = self.dify_tool.get_message_debug_info_by_id(message_id=message_id)
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "软件知识检索聚合":
|
||||
retrieve_outputs = json.loads(workflow_node["inputs"])["result"]
|
||||
@@ -470,6 +495,10 @@ class OldWorkFlowChat(BaseWorkflowChat):
|
||||
"""
|
||||
旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
super().__init__(api_key, base_url)
|
||||
|
||||
def process_question(self, query: str) -> dict:
|
||||
"""
|
||||
处理问题,获取旧工作流的回答和相关信息
|
||||
@@ -520,7 +549,7 @@ class OldWorkFlowChat(BaseWorkflowChat):
|
||||
rewrite_query = ""
|
||||
|
||||
try:
|
||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||||
message_info = self.dify_tool.get_message_debug_info_by_id(message_id=message_id)
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
@@ -538,8 +567,4 @@ class OldWorkFlowChat(BaseWorkflowChat):
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
result = DifyTool.get_message_debug_info_by_query("ccf92b97-2789-4a3f-90e0-135a869a37c5", "电力建设计价通软件,导入结算后没有暂列金怎么办?要手动添加么?")
|
||||
print(result)
|
||||
except Exception as e:
|
||||
print(f"执行出错: {e}")
|
||||
pass
|
||||
|
||||
@@ -69,6 +69,14 @@ class DifyComparisonTester:
|
||||
else:
|
||||
self.wiki_excel = None
|
||||
|
||||
self.dify_tool = DifyTool()
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
析构函数,在对象被销毁时自动关闭数据库连接。
|
||||
确保在对象生命周期结束时释放数据库资源。
|
||||
"""
|
||||
self.dify_tool.close_connection()
|
||||
|
||||
def get_llm(self):
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
@@ -381,7 +389,7 @@ content: "{content}"
|
||||
"""
|
||||
try:
|
||||
# 使用DifyTool直接获取消息信息
|
||||
new_message_info = DifyTool.get_message_debug_info_by_id(message_id=new_message_id)
|
||||
new_message_info = self.dify_tool.get_message_debug_info_by_id(message_id=new_message_id)
|
||||
|
||||
# 初始化变量
|
||||
retrieve_title = []
|
||||
@@ -428,7 +436,7 @@ content: "{content}"
|
||||
"""
|
||||
try:
|
||||
# 使用DifyTool直接获取消息信息
|
||||
old_message_info = DifyTool.get_message_debug_info_by_id(message_id=old_message_id)
|
||||
old_message_info = self.dify_tool.get_message_debug_info_by_id(message_id=old_message_id)
|
||||
|
||||
# 初始化变量
|
||||
retrieve_title = []
|
||||
|
||||
@@ -149,7 +149,7 @@ class SlotBase(BaseModel):
|
||||
class SoftwareFunctionSlots(SlotBase):
|
||||
software_name: str = Field(default="", description="软件名称")
|
||||
function_name: str = Field(default="", description="具体功能名称")
|
||||
operation: str = Field(default="", description="用户操作意图(如何使用功能、功能入口、功能使用场景)")
|
||||
operation: str = Field(default="", description="用户操作意图(如何使用功能、功能入口、功能使用场景、是否支持该功能)")
|
||||
project_type: Optional[str] = Field(default="单工程", description="工程类型(单工程、多工程、批次工程), 未明确提及则默认下是(单工程)")
|
||||
software_version: Optional[str] = Field(default="", description="软件版本")
|
||||
operation_steps: Optional[str] = Field(default="", description="操作步骤描述")
|
||||
|
||||
@@ -271,7 +271,7 @@ if __name__ == "__main__":
|
||||
instance = APIKeyManager.get_instance()
|
||||
stats = instance.get_usage_stats()
|
||||
all_balance=0.0
|
||||
buy_balance=16 * 10 * 14 # 购买16次,一次10条api_key,每个api_key有14元
|
||||
buy_balance=17 * 10 * 14 # 购买16次,一次10条api_key,每个api_key有14元
|
||||
invalid_api_keys = []
|
||||
for key, data in stats.items():
|
||||
usage_stats = APIKeyManager.get_key_usage_stats(key)
|
||||
|
||||
Reference in New Issue
Block a user