调试后整体可以使用的版本。
This commit is contained in:
+102
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import json
|
||||
from src.dialog_manager import DialogManager
|
||||
from src.llm_client import LLMClient
|
||||
from src.code_executor import CodeExecutor
|
||||
from src.neo4j_raw_retriever import Neo4jRawRetriever
|
||||
from src.prompt_manager import PromptManager
|
||||
import yaml
|
||||
from src.config import Config
|
||||
from src.document_loader import load_file
|
||||
from src.embedding_client import EmbeddingClient
|
||||
|
||||
from project import ProjectBuilder, ProjectToolkit
|
||||
from project_implementation import ProjectToolkitNeo4j
|
||||
|
||||
def main():
|
||||
config = Config()
|
||||
|
||||
business_structure = load_file(config.business_object_structure_path)
|
||||
bowei_api_docs = load_file(config.bowei_api_docs_path)
|
||||
|
||||
llm_client = LLMClient(config.openai)
|
||||
|
||||
llm_client_coder = LLMClient(config.openai_coder)
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
neo4j_conf = config.neo4j_conf
|
||||
embedding_conf = config.embedding
|
||||
|
||||
embedding_client = EmbeddingClient(embedding_conf)
|
||||
|
||||
# 创建Neo4j检索器
|
||||
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
|
||||
|
||||
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
|
||||
|
||||
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder)
|
||||
|
||||
dialog_manager = DialogManager(
|
||||
llm_client,
|
||||
business_structure,
|
||||
bowei_api_docs,
|
||||
code_executor,
|
||||
knowledge_retriever,
|
||||
prompt_manager,
|
||||
)
|
||||
|
||||
# 加载 zhibiao.json
|
||||
with open('./tests/zhibiao.json', 'r', encoding='utf-8') as f:
|
||||
zhibiao_data = json.load(f)
|
||||
|
||||
# 提取指标映射关系并批量执行
|
||||
for item in zhibiao_data:
|
||||
query = item['指标描述']['指标映射']
|
||||
rewritten_results = dialog_manager.understand_user_question_stream(query)
|
||||
if rewritten_results is None or rewritten_results == []:
|
||||
print('问题: {} 没有找到符合要求的数据'.format(query))
|
||||
continue
|
||||
|
||||
selected_rewritten, selected_knowledge = rewritten_results[0]
|
||||
questions = []
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
error_list = []
|
||||
|
||||
# 检查 understand_user_question_stream 方法调用结果,假设存在该方法调用
|
||||
# 示例调用,实际使用时请替换为真实调用
|
||||
# stream_result = dialog_manager.understand_user_question_stream(query)
|
||||
# if stream_result is None or stream_result == []:
|
||||
# questions.append(query)
|
||||
# fail_count += 1
|
||||
# error_list.append(f"问题 {query} 调用 understand_user_question_stream 返回空结果")
|
||||
|
||||
result = dialog_manager.execute_generated_code(selected_rewritten, selected_knowledge)
|
||||
if result is None or result == []:
|
||||
questions.append(selected_rewritten)
|
||||
fail_count += 1
|
||||
error_list.append(f"问题 {selected_rewritten} 调用 execute_generated_code 返回空结果")
|
||||
else:
|
||||
success_count += 1
|
||||
|
||||
print(result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
total = success_count + fail_count
|
||||
if total > 0:
|
||||
success_rate = success_count / total
|
||||
fail_rate = fail_count / total
|
||||
else:
|
||||
success_rate = fail_rate = 0
|
||||
print(f"问题总数: {total}")
|
||||
print(f"成功比例: {success_rate * 100:.2f}%")
|
||||
print(f"失败比例: {fail_rate * 100:.2f}%")
|
||||
print("错误列表:")
|
||||
for error in error_list:
|
||||
print(error)
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import sys
|
||||
import os
|
||||
import io
|
||||
import logging
|
||||
import traceback
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from project import ProjectBuilder, ProjectToolkit
|
||||
from project_implementation import ProjectToolkitNeo4j
|
||||
from neo4j import GraphDatabase
|
||||
from src.config import Config
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, # 生产环境可改为 INFO 或 WARNING
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.FileHandler("test_runcode.log", encoding="utf-8"), logging.StreamHandler()],
|
||||
)
|
||||
|
||||
logger = logging.getLogger("test_runcode")
|
||||
|
||||
def main():
|
||||
|
||||
config = Config()
|
||||
neo4j_conf = config.neo4j_conf
|
||||
|
||||
code_str = '''
|
||||
def project_get_calculate_function():
|
||||
project = ProjectBuilder.build()
|
||||
status, data, error, helper_info = project.get_division_item_by_path("安装/架空输电线路本体工程")
|
||||
if status == 'success':
|
||||
return status, data.get('单位', ''), error, helper_info
|
||||
return status, None, error, helper_info
|
||||
'''
|
||||
|
||||
neo4j_driver = GraphDatabase.driver(neo4j_conf.get("uri"), auth=(neo4j_conf.get("username"), neo4j_conf.get("password")))
|
||||
|
||||
old_stdout = sys.stdout
|
||||
redirected_output = io.StringIO()
|
||||
|
||||
ProjectBuilder.register(ProjectToolkitNeo4j, neo4j_driver)
|
||||
|
||||
try:
|
||||
namespace = {
|
||||
"project": __import__("project"),
|
||||
"ProjectBuilder": ProjectBuilder,
|
||||
}
|
||||
|
||||
sys.stdout = redirected_output
|
||||
|
||||
exec(code_str, namespace)
|
||||
|
||||
# 确保neo4j_find_function存在
|
||||
if "project_get_calculate_function" not in namespace:
|
||||
raise ValueError("代码中未定义project_get_calculate_function函数")
|
||||
|
||||
result_tuple = namespace["project_get_calculate_function"]()
|
||||
|
||||
sys.stdout = old_stdout
|
||||
output = redirected_output.getvalue().strip()
|
||||
|
||||
if not isinstance(result_tuple, tuple) or len(result_tuple) != 4:
|
||||
raise ValueError("函数应返回包含4个元素的元组(status, data, error, helper_info)")
|
||||
|
||||
status, data, error, helper_info = result_tuple
|
||||
|
||||
logger.info(f"执行结果: status={status}, data={data}, error={error}")
|
||||
|
||||
except Exception as e:
|
||||
# 确保恢复stdout
|
||||
sys.stdout = old_stdout
|
||||
logger.error(f"执行代码时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+1146
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user