Files
langchain_projectagent/tests/test.py
T

103 lines
3.3 KiB
Python

import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
from src.dialog_manager import DialogManager
from src.llm_client import LLMClient
from src.code_executor import CodeExecutor
from src.neo4j_raw_retriever import Neo4jRawRetriever
from src.prompt_manager import PromptManager
import yaml
from src.config import Config
from src.document_loader import load_file
from src.embedding_client import EmbeddingClient
from project import ProjectBuilder, ProjectToolkit
from project_implementation import ProjectToolkitNeo4j
def main():
config = Config()
business_structure = load_file(config.business_object_structure_path)
bowei_api_docs = load_file(config.bowei_api_docs_path)
llm_client = LLMClient(config.openai)
llm_client_coder = LLMClient(config.openai_coder)
prompt_manager = PromptManager()
neo4j_conf = config.neo4j_conf
embedding_conf = config.embedding
embedding_client = EmbeddingClient(embedding_conf)
# 创建Neo4j检索器
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder)
dialog_manager = DialogManager(
llm_client,
business_structure,
bowei_api_docs,
code_executor,
knowledge_retriever,
prompt_manager,
)
# 加载 zhibiao.json
with open('./tests/zhibiao.json', 'r', encoding='utf-8') as f:
zhibiao_data = json.load(f)
# 提取指标映射关系并批量执行
for item in zhibiao_data:
query = item['指标描述']['指标映射']
rewritten_results = dialog_manager.understand_user_question_stream(query)
if rewritten_results is None or rewritten_results == []:
print('问题: {} 没有找到符合要求的数据'.format(query))
continue
selected_rewritten, selected_knowledge = rewritten_results[0]
questions = []
success_count = 0
fail_count = 0
error_list = []
# 检查 understand_user_question_stream 方法调用结果,假设存在该方法调用
# 示例调用,实际使用时请替换为真实调用
# stream_result = dialog_manager.understand_user_question_stream(query)
# if stream_result is None or stream_result == []:
# questions.append(query)
# fail_count += 1
# error_list.append(f"问题 {query} 调用 understand_user_question_stream 返回空结果")
result = dialog_manager.execute_generated_code(selected_rewritten, selected_knowledge)
if result is None or result == []:
questions.append(selected_rewritten)
fail_count += 1
error_list.append(f"问题 {selected_rewritten} 调用 execute_generated_code 返回空结果")
else:
success_count += 1
print(result)
if __name__ == "__main__":
total = success_count + fail_count
if total > 0:
success_rate = success_count / total
fail_rate = fail_count / total
else:
success_rate = fail_rate = 0
print(f"问题总数: {total}")
print(f"成功比例: {success_rate * 100:.2f}%")
print(f"失败比例: {fail_rate * 100:.2f}%")
print("错误列表:")
for error in error_list:
print(error)
main()