From 7e7109e01b7765fa668fb8b46d642395e97fc62d Mon Sep 17 00:00:00 2001 From: chentianrui Date: Fri, 27 Jun 2025 22:36:39 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_node.py | 111 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/test_node.py diff --git a/tests/test_node.py b/tests/test_node.py new file mode 100644 index 0000000..a81e5db --- /dev/null +++ b/tests/test_node.py @@ -0,0 +1,111 @@ +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import logging + +logging.basicConfig( + level=logging.DEBUG, # 生产环境可改为 INFO 或 WARNING + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.FileHandler("test.log", encoding="utf-8"), logging.StreamHandler()], +) + +logger = logging.getLogger("test") + +import logging + + +def setup_logger(logger_name): + """ + 设置指定名称的logger,将其级别设置为WARNING并禁用传播 + :param logger_name: logger的名称 + """ + logger = logging.getLogger(logger_name) + logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别 + logger.propagate = False # 可选:禁用传播(防止被根logger处理) + return logger + + +logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"] +for name in logger_names: + setup_logger(name) + +import asyncio +import json +from src.dialog_manager import DialogManager +from src.llm_client import LLMClient +from src.code_executor import CodeExecutor +from src.neo4j_raw_retriever import Neo4jRawRetriever +from src.prompt_manager import PromptManager +import yaml +from src.config import Config +from src.document_loader import load_file +from src.embedding_client import EmbeddingClient + +from project import ProjectBuilder, ProjectToolkit +from project_implementation import ProjectToolkitNeo4j + + +def main(): + config = Config() + + business_structure = load_file(config.business_object_structure_path) + bowei_api_docs = load_file(config.bowei_api_docs_path) + + llm_client = LLMClient(config.openai) + + llm_client_coder = LLMClient(config.openai_coder) + + prompt_manager = PromptManager() + + neo4j_conf = config.neo4j_conf + embedding_conf = config.embedding + + embedding_client = EmbeddingClient(embedding_conf) + + # 创建Neo4j检索器 + knowledge_retriever = Neo4jRawRetriever(neo4j_conf) + + ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver) + + code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder) + + dialog_manager = DialogManager( + llm_client, + business_structure, + bowei_api_docs, + code_executor, + knowledge_retriever, + prompt_manager, + ) + + query = "查找一下项目划分节点【架空输电线路本体工程/基础工程/基础砌筑】下的费用预览中的【主材费_ZCF】" + knowledge = """ +: 4:3d187051-ffce-4fca-8060-a4d9949fdd38:13295 +: 13295 +cost: 996944.127814 +id: 主材费_ZCF +name: 主材费_ZCF +unique_id: {8C4CC636-3741-4F07-865A-28D38AB7F31D}_主材费_ZCF +: 4:3d187051-ffce-4fca-8060-a4d9949fdd38:14682 +: 14682 +GUID: {8C4CC636-3741-4F07-865A-28D38AB7F31D} +name: 基础砌筑 +path: 安装/架空输电线路本体工程/基础工程/基础砌筑 +type: 项目划分 +专业类型: 线路 +取费表: 线路取费表 +取费表id: 3_1 +序号: 1.3 +费率: 0 +""" + + result = dialog_manager.execute_generated_code(query, knowledge) + print(result) + + +if __name__ == "__main__": + main()