优化意图识别示例,新增命令行参数解析功能,支持输入输出文件路径和调试模式,增强代码可读性和灵活性。同时更新Dify工具,调整检索信息获取逻辑,确保重排得分信息的正确传递。
This commit is contained in:
+34
-19
@@ -318,7 +318,7 @@ content: "{content}"
|
||||
except Exception as e:
|
||||
return -1
|
||||
|
||||
def get_retrieve_info(self, query: str, outputs: dict) -> tuple:
|
||||
def get_retrieve_info(self, query: str, outputs: dict, reranker_sorce_info:list) -> tuple:
|
||||
"""
|
||||
获取检索信息并计算分数
|
||||
|
||||
@@ -333,20 +333,21 @@ content: "{content}"
|
||||
min_score = 10
|
||||
total_score = 0
|
||||
valid_scores = 0
|
||||
retrieve_content = []
|
||||
retrieve_title = []
|
||||
|
||||
# 使用线程池并发计算分数
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# 创建任务列表
|
||||
future_to_content = {}
|
||||
for result in outputs["result"]:
|
||||
content = result["content"].strip()
|
||||
for result in outputs:
|
||||
content = result["segment_content"].strip()
|
||||
segment_id = result["segment_id"].strip()
|
||||
future = executor.submit(self.calculate_score, query=query, content=content)
|
||||
future_to_content[future] = content
|
||||
future_to_content[future] = (content, segment_id)
|
||||
|
||||
# 收集结果
|
||||
for future in as_completed(future_to_content):
|
||||
content = future_to_content[future]
|
||||
content, segment_id = future_to_content[future]
|
||||
score = future.result()
|
||||
content_title = content.split("\n")[0]
|
||||
|
||||
@@ -357,10 +358,11 @@ content: "{content}"
|
||||
valid_scores += 1
|
||||
|
||||
if content_title:
|
||||
retrieve_content.append(content_title + f"--相关性得分({score}分)")
|
||||
current_score = next((cur_source_info["score"] for cur_source_info in reranker_sorce_info if cur_source_info["segment_id"] == segment_id), None)
|
||||
retrieve_title.append(content_title + f"--LLM得分({score}分)--重排得分({current_score:.2f}分)")
|
||||
|
||||
avg_score = total_score / valid_scores if valid_scores > 0 else 0
|
||||
return retrieve_content, max_score, min_score, avg_score
|
||||
return retrieve_title, max_score, min_score, avg_score
|
||||
|
||||
|
||||
class NewWorkflowChat(BaseWorkflowChat):
|
||||
@@ -395,7 +397,6 @@ class NewWorkflowChat(BaseWorkflowChat):
|
||||
"新问题分类": workflow_info["问题分类"],
|
||||
"槽点信息": workflow_info["槽点信息"],
|
||||
"新检索词条": workflow_info["检索词条"],
|
||||
"检索内容": workflow_info["检索内容"],
|
||||
"message_id":message_id
|
||||
}
|
||||
|
||||
@@ -421,14 +422,23 @@ class NewWorkflowChat(BaseWorkflowChat):
|
||||
vertical_classification = ""
|
||||
sub_classification = ""
|
||||
slot_info = ""
|
||||
|
||||
reranker_sorce=[]
|
||||
try:
|
||||
# 先取出重排得分
|
||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||
retrieve_content = outputs["result"]
|
||||
if workflow_node["title"] == "软件知识检索聚合":
|
||||
retrieve_outputs = json.loads(workflow_node["inputs"])["result"]
|
||||
reranker_sorce = [{"score":result["metadata"]["score"], "segment_id":result["metadata"]["segment_id"]} for result in retrieve_outputs]
|
||||
|
||||
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "软件知识检索聚合":
|
||||
retrieve_outputs = json.loads(workflow_node["inputs"])["result"]
|
||||
reranker_sorce = [{"score":result["metadata"]["score"], "segment_id":result["metadata"]["segment_id"]} for result in retrieve_outputs]
|
||||
elif workflow_node["title"] == "提取处理后的知识":
|
||||
outputs = json.loads(workflow_node["outputs"])["knowledge_list"]
|
||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs, reranker_sorce_info=reranker_sorce)
|
||||
elif workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
@@ -439,11 +449,18 @@ class NewWorkflowChat(BaseWorkflowChat):
|
||||
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
retrieve_content = ""
|
||||
if len(reranker_sorce)==0:
|
||||
retrieve_content="未检索知识库"
|
||||
elif len(reranker_sorce) > 0 and len(retrieve_title)==0:
|
||||
retrieve_content = "知识与提问不相关,被丢弃"
|
||||
else:
|
||||
retrieve_content = "\n".join(retrieve_title)
|
||||
|
||||
return {
|
||||
"问题改写": rewrite_query,
|
||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||
"检索内容": retrieve_content,
|
||||
"检索词条": retrieve_content,
|
||||
"问题分类": f"{vertical_classification} - {sub_classification}",
|
||||
"槽点信息": slot_info,
|
||||
|
||||
@@ -479,7 +496,6 @@ class OldWorkFlowChat(BaseWorkflowChat):
|
||||
"旧流程答案": answer,
|
||||
"旧问题改写": workflow_info["问题改写"],
|
||||
"旧检索词条": workflow_info["检索词条"],
|
||||
"检索内容": workflow_info["检索内容"],
|
||||
"message_id":message_id
|
||||
}
|
||||
|
||||
@@ -519,7 +535,6 @@ class OldWorkFlowChat(BaseWorkflowChat):
|
||||
return {
|
||||
"问题改写": rewrite_query,
|
||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||
"检索内容": retrieve_content,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user