import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import logging def setup_logger(logger_name): """ 设置指定名称的logger,将其级别设置为WARNING并禁用传播 :param logger_name: logger的名称 """ logger = logging.getLogger(logger_name) logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别 logger.propagate = False # 可选:禁用传播(防止被根logger处理) return logger logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"] for name in logger_names: setup_logger(name) import json from src.dialog_manager import DialogManager from src.multi_llm_client import MultiAPIKeyChatOpenAI from src.code_executor import CodeExecutor from src.neo4j_raw_retriever import Neo4jRawRetriever from src.prompt_manager import PromptManager import yaml from src.config import Config from src.document_loader import load_file from src.embedding_client import EmbeddingClient from src.project import ProjectBuilder, ProjectToolkit from src.project_implementation import ProjectToolkitNeo4j success_count = 0 fail_count = 0 questions = [] error_list = [] def main(): global fail_count, success_count, questions, error_list config = Config() business_structure = load_file(config.business_object_structure_path) bowei_api_docs = load_file(config.bowei_api_docs_path) llm_client = MultiAPIKeyChatOpenAI(config.openai) llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder) prompt_manager = PromptManager() neo4j_conf = config.neo4j_conf embedding_conf = config.embedding embedding_client = EmbeddingClient(embedding_conf) # 创建Neo4j检索器 knowledge_retriever = Neo4jRawRetriever(neo4j_conf) ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver) code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, config.max_retries) dialog_manager = DialogManager( llm_client, business_structure, bowei_api_docs, code_executor, knowledge_retriever, prompt_manager, ) # 加载 zhibiao.json with open('./tests/zhibiao.json', 'r', encoding='utf-8') as f: zhibiao_data = json.load(f) # 提取指标映射关系并批量执行 for item in zhibiao_data: datasource = item['数据来源'] if datasource == '报表指标' or datasource == '指标库': continue query = item['指标描述']['指标映射'] #rewritten_results = dialog_manager.understand_user_question(query) rewritten_results = [] if rewritten_results is None or rewritten_results == []: print('问题: {} 没有找到符合要求的数据'.format(query)) fail_count += 1 error_list.append(f"问题 {query} 调用 understand_user_question 返回空结果") continue selected_rewritten, selected_knowledge = rewritten_results[0] result = dialog_manager.execute_generated_code(selected_rewritten, selected_knowledge) if result is None or result == []: questions.append(selected_rewritten) fail_count += 1 error_list.append(f"问题 {query} {selected_rewritten} 调用 execute_generated_code 返回空结果") else: success_count += 1 #print(result) if __name__ == "__main__": main() total = success_count + fail_count if total > 0: success_rate = success_count / total fail_rate = fail_count / total else: success_rate = fail_rate = 0 print(f"问题总数: {total}") print(f"成功比例: {success_rate * 100:.2f}%") print(f"失败比例: {fail_rate * 100:.2f}%") print("错误列表:") for error in error_list: print(error)