This commit is contained in:
2025-07-07 12:47:34 +08:00
parent b352571e17
commit fcb09c04f2
7 changed files with 464 additions and 13 deletions
-162
View File
@@ -1,162 +0,0 @@
import os
import sys
import logging
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from datetime import datetime
# 获取当前时间,格式化为字符串
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
log_filename = f"test_code1{now_str}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
logging.StreamHandler()
],
)
logger = logging.getLogger("test_code1")
import logging
def setup_logger(logger_name):
"""
设置指定名称的logger,将其级别设置为WARNING并禁用传播
:param logger_name: logger的名称
"""
logger = logging.getLogger(logger_name)
logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
logger.propagate = False # 可选:禁用传播(防止被根logger处理)
return logger
logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"]
for name in logger_names:
setup_logger(name)
import json
import os
from src.dialog_manager import DialogManager
from src.multi_llm_client import MultiAPIKeyChatOpenAI
from src.code_executor import CodeExecutor
from src.neo4j_raw_retriever import Neo4jRawRetriever
from src.prompt_manager import PromptManager
import yaml
from src.config import Config
from src.document_loader import load_file
from src.embedding_client import EmbeddingClient
from src.project import ProjectBuilder, ProjectToolkit
from src.project_implementation import ProjectToolkitNeo4j
success_count = 0
fail_count = 0
questions = []
error_list = []
success_results = []
fail_results = []
def main():
global fail_count, success_count, questions, error_list
config = Config()
business_structure = load_file(config.business_object_structure_path)
bowei_api_docs = load_file(config.bowei_api_docs_path)
llm_client = MultiAPIKeyChatOpenAI(config.openai)
llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder)
prompt_manager = PromptManager()
neo4j_conf = config.neo4j_conf
embedding_conf = config.embedding
embedding_client = EmbeddingClient(embedding_conf)
# 创建Neo4j检索器
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, config.max_retries)
dialog_manager = DialogManager(
llm_client,
business_structure,
bowei_api_docs,
code_executor,
knowledge_retriever,
prompt_manager,
)
# 加载 zhibiao.jsonl
zhibiao_data = []
with open('./tests/zhibiao.jsonl', 'r', encoding='utf-8') as f:
for line in f:
zhibiao_data.append(json.loads(line))
# 提取指标映射关系并批量执行
for item in zhibiao_data:
query = item['query']
name = item['name']
selected_knowledge = item['result']
logger.info(f"指标名称 {name} 问题: {query} 开始生成代码")
result = dialog_manager.generated_code(query, selected_knowledge)
if isinstance(result, dict) and result.get('status', False):
code = result['data']
success_count += 1
success_results.append({
"name": name,
"query": query,
"code": code,
})
else:
questions.append(query)
fail_count += 1
error_msg = result.get('message', '调用 generated_code 返回空结果') if isinstance(result, dict) else f"问题 {query} {selected_knowledge} 调用 execute_generated_code 返回空结果"
error_list.append(error_msg)
fail_results.append({
"name": name,
"query": query,
"messages": error_msg,
"selected_knowledge": selected_knowledge,
})
#print(result)
if __name__ == "__main__":
main()
total = success_count + fail_count
if total > 0:
success_rate = success_count / total
fail_rate = fail_count / total
else:
success_rate = fail_rate = 0
print(f"问题总数: {total}")
print(f"成功比例: {success_rate * 100:.2f}%")
print(f"失败比例: {fail_rate * 100:.2f}%")
print("错误列表:")
for error in error_list:
print(error)
# 保存成功结果到 jsonl 文件
success_filename = f'./tests/code_{now_str}.jsonl'
with open(success_filename, 'w', encoding='utf-8') as f:
for item in success_results:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# 保存失败结果到 jsonl 文件
fail_filename = f'./tests/fail_{now_str}.jsonl'
with open(fail_filename, 'w', encoding='utf-8') as f:
for item in fail_results:
f.write(json.dumps(item, ensure_ascii=False) + '\n')