更新环境变量配置,调整模型名称获取方式,新增Dify API相关配置,删除无用的脚本文件,优化意图识别逻辑,添加LLM提取词条逻辑
This commit is contained in:
@@ -3,19 +3,15 @@ import json
|
||||
|
||||
from regex import search
|
||||
|
||||
import ijson
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.dify.dify_tool import DifyTool
|
||||
|
||||
df = pd.read_excel("data/excel/已分析数据汇总(第一轮).xlsx")
|
||||
df=df[df["评价"]=="dislike"]
|
||||
dify_tool = DifyTool()
|
||||
|
||||
df = pd.read_excel("data/excel/0714提问数据汇总(已分析)_软件.xlsx")
|
||||
|
||||
msg_id_list = df["msg_id"].tolist()
|
||||
msg_debug_list = {}
|
||||
# 流式解析 JSON 数组
|
||||
with open("data/excel/msg_debug_list.json", "r", encoding="utf-8") as f:
|
||||
# 使用ijson.items直接获取顶层键值对
|
||||
for msg_id, data in ijson.kvitems(f, ''):
|
||||
if msg_id in msg_id_list:
|
||||
msg_debug_list[msg_id] = data
|
||||
|
||||
def get_rewrite_query(intent_node_execution_info)->str:
|
||||
outputs_result =json.loads(intent_node_execution_info['outputs'])
|
||||
@@ -28,7 +24,7 @@ def judge_error_node_and_reason(intent_node_execution_info, knowledge_filter_nod
|
||||
|
||||
outputs_result =json.loads(intent_node_execution_info['outputs'])
|
||||
result["问题改写结果"] = outputs_result['optimize_query']
|
||||
if outputs_result['is_complete'] == False:
|
||||
if outputs_result['is_complete'] == False and outputs_result["has_slot_filling"] == True:
|
||||
result["错误环节"] = "槽点填充"
|
||||
result["错误原因"] = f"槽点缺失"
|
||||
result["具体描述"] = f"缺失内容:{outputs_result['missing_slots']}"
|
||||
@@ -80,6 +76,8 @@ for index, row in df.iterrows():
|
||||
answer = row["回答"]
|
||||
query = row["提问"]
|
||||
rating = row["评价"]
|
||||
if rating != "dislike":
|
||||
continue
|
||||
class_type = row["问题分类"]
|
||||
dislike_reason = row["点踩原因"]
|
||||
if dislike_reason is None or pd.isna(dislike_reason):
|
||||
@@ -87,7 +85,8 @@ for index, row in df.iterrows():
|
||||
|
||||
answer_wiki_name = row["关联词条"]
|
||||
search_wiki = row["检索到的词条"]
|
||||
node_executions_info = msg_debug_list[msg_id]
|
||||
msg_debug_info = dify_tool.get_message_debug_info_by_id(msg_id)
|
||||
node_executions_info = msg_debug_info["workflow_node_executions_info"]
|
||||
intent_node_execution_info = [node_execution_info for node_execution_info in node_executions_info
|
||||
if node_execution_info["title"] == "意图识别结果解析"]
|
||||
|
||||
@@ -109,7 +108,7 @@ for index, row in df.iterrows():
|
||||
print(f"msg_id: {msg_id} 处理失败: {e}")
|
||||
continue
|
||||
|
||||
df.to_excel("data/excel/已分析数据汇总(第一轮)_分析.xlsx", index=False)
|
||||
df.to_excel("data/excel/0714提问数据汇总(已分析)_软件_分析.xlsx", index=False)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -84,15 +84,14 @@ async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/query_type", summary="异步检索API")
|
||||
async def query_type(query: str, query_type: str, workflow_run_id:str):
|
||||
async def query_type(query_type: str, workflow_run_id:str):
|
||||
try:
|
||||
# 记录请求
|
||||
logger.info(f"接收到请求: {query}, 类型: {query_type}, workflow_run_id: {workflow_run_id}")
|
||||
logger.info(f"接收到请求: 类型: {query_type}, workflow_run_id: {workflow_run_id}")
|
||||
|
||||
# 保存 提问、问题类型、当前时间戳到json
|
||||
timestamp = datetime.datetime.now().isoformat()
|
||||
query_data = {
|
||||
"query": query,
|
||||
"query_type": query_type,
|
||||
"timestamp": timestamp,
|
||||
"workflow_run_id": workflow_run_id
|
||||
@@ -127,7 +126,7 @@ async def query_type(query: str, query_type: str, workflow_run_id:str):
|
||||
logger.error(f"保存查询数据时出错: {str(e)}", exc_info=True)
|
||||
|
||||
# 返回响应
|
||||
content = f"<strong>当前提问</strong>: {query}<br><strong>问题类型</strong>: {query_type}<br><strong>操作是否成功</strong>: {'成功' if success else '失败'}"
|
||||
content = f"<strong>问题类型</strong>: {query_type}<br><strong>操作是否成功</strong>: {'成功' if success else '失败'}"
|
||||
return HTMLResponse(content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"处理请求时出错: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -84,7 +84,7 @@ class DifyComparisonTester:
|
||||
def get_llm(self, **kwargs):
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model = os.getenv("LLM_MODEL_NAME")
|
||||
model = os.getenv("MODEL_NAME")
|
||||
return OpenAiLLM(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
|
||||
def find_wiki_link(self, row) -> str | None:
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from rag2_0.dify.dify_client import DifyApi
|
||||
|
||||
soft_name_map = {
|
||||
"配网造价软件知识(new)": "配网计价通D3软件",
|
||||
"西藏造价软件知识(new)": "西藏计价通Z1软件",
|
||||
"储能C1计价通软件知识(new)": "储能计价通C1软件",
|
||||
"技改检修工程计价通T1软件知识(new)": "技改检修工程计价通T1软件",
|
||||
"技改检修清单计价通T1软件知识(new)": "技改检修清单计价通T1软件",
|
||||
"电力建设计价通(2018)软件知识(new)": "电力建设计价通软件",
|
||||
"下载安装注册(new)": "下载安装注册",
|
||||
}
|
||||
|
||||
soft_wiki_file_name = {
|
||||
"配网计价通D3软件": ["配网计价通D3软件.txt", []],
|
||||
"西藏计价通Z1软件": ["西藏计价通Z1软件.txt", []],
|
||||
"储能计价通C1软件": ["储能计价通C1软件.txt", []],
|
||||
"技改检修工程计价通T1软件": ["技改检修工程计价通T1软件.txt", []],
|
||||
"技改检修清单计价通T1软件": ["技改检修清单计价通T1软件.txt", []],
|
||||
"电力建设计价通软件": ["电力建设计价通软件.txt", []],
|
||||
"下载安装注册": ["下载安装注册.txt", []],
|
||||
}
|
||||
|
||||
def get_soft_wiki_titles(dify_api, soft_name_map, soft_wiki_file_name):
|
||||
"""获取每个软件的wiki标题列表"""
|
||||
dataset_list = dify_api.get_all_dataset_list()
|
||||
soft_name_map_keys = list(soft_name_map.keys())
|
||||
for dataset in dataset_list:
|
||||
if dataset["name"] not in soft_name_map_keys:
|
||||
continue
|
||||
dataset_name = dataset["name"]
|
||||
dataset_id = dataset["id"]
|
||||
documents = dify_api.get_documents(dataset_id=dataset_id)
|
||||
for document_id, doc_info in documents.items():
|
||||
document_name = doc_info["name"]
|
||||
wiki_name = document_name.split("/")[-1]
|
||||
wiki_title = re.sub(r'^(.*?)|^\(.*?\)', '', wiki_name)
|
||||
if wiki_title not in soft_wiki_file_name[soft_name_map[dataset_name]][1]:
|
||||
soft_wiki_file_name[soft_name_map[dataset_name]][1].append(wiki_title)
|
||||
return soft_wiki_file_name
|
||||
|
||||
def save_wiki_titles(soft_wiki_file_name, output_dir="data/wiki_data"):
|
||||
"""将wiki标题列表保存到对应txt文件"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
for soft_name, (txt_file_name, wiki_titles) in soft_wiki_file_name.items():
|
||||
output_path = os.path.join(output_dir, txt_file_name)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for title in wiki_titles:
|
||||
f.write(title + "\n")
|
||||
print(f"已保存 {soft_name} 的wiki标题列表到 {output_path},共 {len(wiki_titles)} 条")
|
||||
|
||||
def main():
|
||||
dify_api = DifyApi()
|
||||
wiki_titles = get_soft_wiki_titles(dify_api, soft_name_map, soft_wiki_file_name)
|
||||
save_wiki_titles(wiki_titles)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,151 +0,0 @@
|
||||
from rag2_0.dify.dify_tool import NewWorkflowChat
|
||||
import pandas as pd
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
class ChatDifyByWorkorder:
|
||||
|
||||
def __init__(self, api_key=None, base_url="https://api.dify.ai/v1") -> None:
|
||||
"""
|
||||
初始化ChatDifyByWorkorder类
|
||||
|
||||
Args:
|
||||
api_key: Dify API密钥,默认为None
|
||||
base_url: Dify API的基础URL,默认为"https://api.dify.ai/v1"
|
||||
"""
|
||||
baseurl = "http://172.20.0.145/v1"
|
||||
new_workflow_api_key = "app-qxsSybCs7ABiKlC1JabTYVn6"
|
||||
self.new_chat = NewWorkflowChat(api_key=new_workflow_api_key, base_url=baseurl)
|
||||
self.new_chat_answer = NewWorkflowChat(api_key=new_workflow_api_key, base_url=baseurl)
|
||||
|
||||
|
||||
def get_soft_name(self, row) -> str:
|
||||
if "博微配网计价通D3" in row["产品线"]:
|
||||
return "博微配网计价通D3"
|
||||
elif "博微电力建设计价通软件" in row["产品线"]:
|
||||
return "电力建设计价通软件"
|
||||
elif "新能源系列" in row["产品线"] and "博微新型储能电站建设计价通C1软件" in row["产品名称"]:
|
||||
return "储能C1软件"
|
||||
elif "博微西藏计价通Z1" in row["产品线"]:
|
||||
return "西藏计价通Z1"
|
||||
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-概预算" in row["产品名称"]:
|
||||
return "技改检修工程计价通T1"
|
||||
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-清单" in row["产品名称"]:
|
||||
return "检修清单计价通T1"
|
||||
return ""
|
||||
|
||||
def process_query(self, q:str) -> dict:
|
||||
"""
|
||||
发送问题并获取回答及相关工作流信息
|
||||
|
||||
Args:
|
||||
q: 用户问题
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和工作流信息的字典
|
||||
"""
|
||||
retry_count = 0
|
||||
max_retries = 2
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 发送问题获取回答和消息ID
|
||||
result = self.new_chat.process_question(q)
|
||||
return result
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count <= max_retries:
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
def process_answer(self, q:str) -> dict:
|
||||
"""
|
||||
发送问题并获取回答及相关工作流信息
|
||||
|
||||
Args:
|
||||
q: 用户问题
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和工作流信息的字典
|
||||
"""
|
||||
retry_count = 0
|
||||
max_retries = 2
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 发送问题获取回答和消息ID
|
||||
result = self.new_chat_answer.process_question(q)
|
||||
return result
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count <= max_retries:
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
def process_row(self, row):
|
||||
"""处理单行数据"""
|
||||
soft_name = self.get_soft_name(row=row)
|
||||
if soft_name == "":
|
||||
return None
|
||||
|
||||
# 使用线程池并发执行查询
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
# 提交两个任务并获取Future对象
|
||||
query_future = executor.submit(self.process_query, q=f"{soft_name},{row['客户问题']}")
|
||||
answer_future = executor.submit(self.process_answer, q=f"{soft_name},{row['解决方案']}")
|
||||
|
||||
# 获取结果
|
||||
query_result = query_future.result()
|
||||
answer_result = answer_future.result()
|
||||
except Exception as e:
|
||||
print(f"处理工单 {row.get('工单编号', '未知')} 时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
worker_id = str(row["工单编号"])
|
||||
if query_result is None or answer_result is None:
|
||||
print("处理对话出现错误")
|
||||
return None
|
||||
|
||||
worker_order_info = {
|
||||
"工单编号": worker_id,
|
||||
"用户问题": row['客户问题'],
|
||||
"解决方案": row['解决方案'],
|
||||
"AI回答": query_result["新流程答案"],
|
||||
"用户问题检索到的词条": query_result["新检索词条"],
|
||||
"解决方案检索到的词条": answer_result["新检索词条"],
|
||||
}
|
||||
return worker_order_info
|
||||
|
||||
def run(self, excel_path:str):
|
||||
df_data = pd.read_excel(excel_path)
|
||||
list_worker_order_info = []
|
||||
|
||||
# 创建进度条
|
||||
with tqdm(total=len(df_data), desc="处理工单") as pbar:
|
||||
# 创建线程池,最大并发数可以根据需要调整
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
# 提交所有任务
|
||||
future_to_row = {executor.submit(self.process_row, row): idx for idx, row in df_data.iterrows()}
|
||||
|
||||
# 处理完成的任务
|
||||
for future in concurrent.futures.as_completed(future_to_row):
|
||||
result = future.result()
|
||||
if result is not None:
|
||||
list_worker_order_info.append(result)
|
||||
pbar.update(1)
|
||||
|
||||
return list_worker_order_info
|
||||
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
worker_chat = ChatDifyByWorkorder()
|
||||
result = worker_chat.run(excel_path="data/excel/工单记录_均衡提取2000条.xlsx")
|
||||
# 可以选择保存结果到Excel
|
||||
if result:
|
||||
pd.DataFrame(result).to_excel("data/excel/工单处理结果.xlsx", index=False)
|
||||
@@ -1,4 +1,5 @@
|
||||
|
||||
__all__ = ["ChatClient", "CompletionClient", "DifyClient"]
|
||||
__all__ = ["ChatClient", "CompletionClient", "DifyClient", "DifyApi"]
|
||||
|
||||
from .client import ChatClient, CompletionClient, DifyClient
|
||||
from .dify_api import DifyApi
|
||||
|
||||
@@ -14,12 +14,12 @@ class DifyApi:
|
||||
用于与Dify API进行交互的类。
|
||||
"""
|
||||
|
||||
def __init__(self, dify_url: str="http://10.1.16.39/v1",
|
||||
dify_dataset_api_key: str="dataset-skLjmPVonjHo119OWNf3kAmY",
|
||||
dify_app_api_key: str="app-wUdkWJx5zeOvmvBUZizMoSw3"):
|
||||
self.dify_url = dify_url
|
||||
self.dify_dataset_api_key = dify_dataset_api_key
|
||||
self.dify_app_api_key = dify_app_api_key
|
||||
def __init__(self, dify_url: str=None,
|
||||
dify_dataset_api_key: str=None,
|
||||
dify_app_api_key: str=None):
|
||||
self.dify_url = dify_url if dify_url else os.environ.get('DIFY_BSAE_URL')
|
||||
self.dify_dataset_api_key = dify_dataset_api_key if dify_dataset_api_key else os.environ.get('DIFY_DATASET_KEY')
|
||||
self.dify_app_api_key = dify_app_api_key if dify_app_api_key else os.environ.get('DIFY_APP_KEY')
|
||||
|
||||
def get_document_indexing_status(self, datasets_id: str, batch: str) -> bool:
|
||||
"""
|
||||
|
||||
@@ -449,7 +449,7 @@ content: "{content}"
|
||||
"""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model = os.getenv("LLM_MODEL_NAME")
|
||||
model = os.getenv("MODEL_NAME")
|
||||
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model=model)
|
||||
response = llm.invoke(user_prompt=prompt, need_retry=True)
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
||||
# 定义请求模型
|
||||
class IntentRecognizeRequest(BaseModel):
|
||||
query: str
|
||||
conversation_context: str = ""
|
||||
conversation_context: Dict = None
|
||||
chat_history: Optional[List] = None
|
||||
previous_slots: str | Dict = None
|
||||
|
||||
@@ -89,13 +89,15 @@ _instance = None
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global _instance
|
||||
# 初始化AsyncIntentRecognizer实例
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
_instance = await AsyncIntentRecognizer.create(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
_instance = await AsyncIntentRecognizer.create()
|
||||
logger.info("AsyncIntentRecognizer初始化完成")
|
||||
|
||||
@app.post("/intent_recognize1")
|
||||
async def intent_recognize(request: Request):
|
||||
data = await request.json()
|
||||
print(data)
|
||||
return {"message": "success"}
|
||||
|
||||
@app.post("/intent_recognize", response_model=IntentRecognizeResponse, summary="意图识别", description="识别用户查询的意图并进行问题改写")
|
||||
async def intent_recognize(request: IntentRecognizeRequest):
|
||||
try:
|
||||
@@ -103,14 +105,15 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
||||
raise HTTPException(status_code=400, detail="缺少query参数")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
current_softname = request.conversation_context.get("current_softname", "")
|
||||
result = await _instance.process_query_async(
|
||||
query=request.query,
|
||||
conversation_context=request.conversation_context,
|
||||
chat_history=request.chat_history,
|
||||
previous_slots=request.previous_slots,
|
||||
use_jieba=True,
|
||||
enable_query_expansion=True
|
||||
enable_query_expansion=True,
|
||||
cur_soft_name=current_softname
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
import pandas as pd
|
||||
import random
|
||||
import math
|
||||
|
||||
work_order_excel="data/excel/6万工单记录.xlsx"
|
||||
|
||||
soft_row_data={
|
||||
"博微配网计价通D3":{"基本功能":[], "高级功能":[]},
|
||||
"储能C1软件":{"基本功能":[], "高级功能":[]},
|
||||
"西藏计价通Z1":{"基本功能":[], "高级功能":[]},
|
||||
"技改检修工程计价通T1":{"基本功能":[], "高级功能":[]},
|
||||
"检修清单计价通T1":{"基本功能":[], "高级功能":[]},
|
||||
"电力建设计价通软件":{"基本功能":[], "高级功能":[]},
|
||||
}
|
||||
|
||||
df = pd.read_excel(work_order_excel)
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
if pd.isna(row["产品线"]):
|
||||
continue
|
||||
|
||||
if "博微配网计价通D3" in row["产品线"]:
|
||||
soft_row_data["博微配网计价通D3"][row["问题类型"]].append((idx, row))
|
||||
elif "博微电力建设计价通软件" in row["产品线"]:
|
||||
soft_row_data["电力建设计价通软件"][row["问题类型"]].append((idx, row))
|
||||
elif "新能源系列" in row["产品线"] and "博微新型储能电站建设计价通C1软件" in row["产品名称"]:
|
||||
soft_row_data["储能C1软件"][row["问题类型"]].append((idx, row))
|
||||
elif "博微西藏计价通Z1" in row["产品线"]:
|
||||
soft_row_data["西藏计价通Z1"][row["问题类型"]].append((idx, row))
|
||||
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-概预算" in row["产品名称"]:
|
||||
soft_row_data["技改检修工程计价通T1"][row["问题类型"]].append((idx, row))
|
||||
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-清单" in row["产品名称"]:
|
||||
soft_row_data["检修清单计价通T1"][row["问题类型"]].append((idx, row))
|
||||
|
||||
# 计算每个软件和功能类型的数据量
|
||||
total_count = 0
|
||||
counts = {}
|
||||
for software, types in soft_row_data.items():
|
||||
counts[software] = {}
|
||||
for type_name, rows in types.items():
|
||||
counts[software][type_name] = len(rows)
|
||||
total_count += len(rows)
|
||||
|
||||
print(f"原始数据总量: {total_count}条")
|
||||
for software, types in counts.items():
|
||||
print(f"{software}: 基本功能 {types['基本功能']}条, 高级功能 {types['高级功能']}条")
|
||||
|
||||
# 计算均衡提取的数量
|
||||
total_target = 2000
|
||||
categories_count = sum(len(types) for types in soft_row_data.values())
|
||||
per_category_target = math.ceil(total_target / categories_count)
|
||||
|
||||
# 均衡提取数据
|
||||
balanced_data = []
|
||||
extracted_counts = {}
|
||||
extracted_indices = set() # 使用集合存储已提取数据的索引
|
||||
|
||||
for software, types in soft_row_data.items():
|
||||
extracted_counts[software] = {}
|
||||
|
||||
for type_name, rows in types.items():
|
||||
# 如果数据量不足,全部提取;否则随机抽取目标数量
|
||||
if len(rows) <= per_category_target:
|
||||
extracted = rows
|
||||
else:
|
||||
extracted = random.sample(rows, per_category_target)
|
||||
|
||||
extracted_counts[software][type_name] = len(extracted)
|
||||
for idx, row in extracted:
|
||||
extracted_indices.add(idx) # 记录已提取数据的索引
|
||||
balanced_data.append(row)
|
||||
|
||||
# 数据量不足2000时,从剩余数据中补充
|
||||
remaining_target = total_target - len(balanced_data)
|
||||
if remaining_target > 0:
|
||||
# 收集所有未被选中的数据
|
||||
remaining_data = []
|
||||
for software, types in soft_row_data.items():
|
||||
for type_name, rows in types.items():
|
||||
# 添加未被选中的数据
|
||||
for idx, row in rows:
|
||||
if idx not in extracted_indices:
|
||||
remaining_data.append(row)
|
||||
|
||||
# 如果剩余数据足够,随机抽取补充
|
||||
if len(remaining_data) >= remaining_target:
|
||||
additional_data = random.sample(remaining_data, remaining_target)
|
||||
else:
|
||||
additional_data = remaining_data
|
||||
|
||||
balanced_data.extend(additional_data)
|
||||
|
||||
# 输出结果
|
||||
print(f"\n均衡提取后数据总量: {len(balanced_data)}条")
|
||||
for software, types in extracted_counts.items():
|
||||
print(f"{software}: 基本功能 {types['基本功能']}条, 高级功能 {types['高级功能']}条")
|
||||
|
||||
# 将均衡提取的数据转换为DataFrame并保存
|
||||
balanced_df = pd.DataFrame(balanced_data)
|
||||
balanced_df.to_excel("data/excel/均衡提取2000条工单.xlsx", index=False)
|
||||
print(f"\n已将均衡提取的{len(balanced_data)}条数据保存至'data/excel/均衡提取2000条工单.xlsx'")
|
||||
Reference in New Issue
Block a user