更新最新代码
This commit is contained in:
+42
-22
@@ -2,9 +2,25 @@ import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import logging
|
||||
def setup_logger(logger_name):
|
||||
"""
|
||||
设置指定名称的logger,将其级别设置为WARNING并禁用传播
|
||||
:param logger_name: logger的名称
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
|
||||
logger.propagate = False # 可选:禁用传播(防止被根logger处理)
|
||||
return logger
|
||||
|
||||
|
||||
logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"]
|
||||
for name in logger_names:
|
||||
setup_logger(name)
|
||||
|
||||
import json
|
||||
from src.dialog_manager import DialogManager
|
||||
from src.llm_client import LLMClient
|
||||
from src.multi_llm_client import MultiAPIKeyChatOpenAI
|
||||
from src.code_executor import CodeExecutor
|
||||
from src.neo4j_raw_retriever import Neo4jRawRetriever
|
||||
from src.prompt_manager import PromptManager
|
||||
@@ -13,18 +29,25 @@ from src.config import Config
|
||||
from src.document_loader import load_file
|
||||
from src.embedding_client import EmbeddingClient
|
||||
|
||||
from project import ProjectBuilder, ProjectToolkit
|
||||
from project_implementation import ProjectToolkitNeo4j
|
||||
from src.project import ProjectBuilder, ProjectToolkit
|
||||
from src.project_implementation import ProjectToolkitNeo4j
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
questions = []
|
||||
error_list = []
|
||||
|
||||
def main():
|
||||
global fail_count, success_count, questions, error_list
|
||||
|
||||
config = Config()
|
||||
|
||||
business_structure = load_file(config.business_object_structure_path)
|
||||
bowei_api_docs = load_file(config.bowei_api_docs_path)
|
||||
|
||||
llm_client = LLMClient(config.openai)
|
||||
llm_client = MultiAPIKeyChatOpenAI(config.openai)
|
||||
|
||||
llm_client_coder = LLMClient(config.openai_coder)
|
||||
llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder)
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
@@ -38,7 +61,7 @@ def main():
|
||||
|
||||
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
|
||||
|
||||
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder)
|
||||
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, config.max_retries)
|
||||
|
||||
dialog_manager = DialogManager(
|
||||
llm_client,
|
||||
@@ -55,37 +78,34 @@ def main():
|
||||
|
||||
# 提取指标映射关系并批量执行
|
||||
for item in zhibiao_data:
|
||||
datasource = item['数据来源']
|
||||
if datasource == '报表指标' or datasource == '指标库':
|
||||
continue
|
||||
|
||||
query = item['指标描述']['指标映射']
|
||||
rewritten_results = dialog_manager.understand_user_question_stream(query)
|
||||
#rewritten_results = dialog_manager.understand_user_question(query)
|
||||
rewritten_results = []
|
||||
if rewritten_results is None or rewritten_results == []:
|
||||
print('问题: {} 没有找到符合要求的数据'.format(query))
|
||||
fail_count += 1
|
||||
error_list.append(f"问题 {query} 调用 understand_user_question 返回空结果")
|
||||
continue
|
||||
|
||||
selected_rewritten, selected_knowledge = rewritten_results[0]
|
||||
questions = []
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
error_list = []
|
||||
|
||||
# 检查 understand_user_question_stream 方法调用结果,假设存在该方法调用
|
||||
# 示例调用,实际使用时请替换为真实调用
|
||||
# stream_result = dialog_manager.understand_user_question_stream(query)
|
||||
# if stream_result is None or stream_result == []:
|
||||
# questions.append(query)
|
||||
# fail_count += 1
|
||||
# error_list.append(f"问题 {query} 调用 understand_user_question_stream 返回空结果")
|
||||
|
||||
result = dialog_manager.execute_generated_code(selected_rewritten, selected_knowledge)
|
||||
if result is None or result == []:
|
||||
questions.append(selected_rewritten)
|
||||
fail_count += 1
|
||||
error_list.append(f"问题 {selected_rewritten} 调用 execute_generated_code 返回空结果")
|
||||
error_list.append(f"问题 {query} {selected_rewritten} 调用 execute_generated_code 返回空结果")
|
||||
else:
|
||||
success_count += 1
|
||||
|
||||
print(result)
|
||||
#print(result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
total = success_count + fail_count
|
||||
if total > 0:
|
||||
success_rate = success_count / total
|
||||
@@ -98,5 +118,5 @@ if __name__ == "__main__":
|
||||
print("错误列表:")
|
||||
for error in error_list:
|
||||
print(error)
|
||||
main()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user