更新
This commit is contained in:
@@ -0,0 +1,162 @@
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
from datetime import datetime
|
||||
|
||||
# 获取当前时间,格式化为字符串
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
|
||||
log_filename = f"test_code1{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
import logging
|
||||
def setup_logger(logger_name):
|
||||
"""
|
||||
设置指定名称的logger,将其级别设置为WARNING并禁用传播
|
||||
:param logger_name: logger的名称
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
|
||||
logger.propagate = False # 可选:禁用传播(防止被根logger处理)
|
||||
return logger
|
||||
|
||||
|
||||
logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"]
|
||||
for name in logger_names:
|
||||
setup_logger(name)
|
||||
|
||||
import json
|
||||
import os
|
||||
from src.dialog_manager import DialogManager
|
||||
from src.multi_llm_client import MultiAPIKeyChatOpenAI
|
||||
from src.code_executor import CodeExecutor
|
||||
from src.neo4j_raw_retriever import Neo4jRawRetriever
|
||||
from src.prompt_manager import PromptManager
|
||||
import yaml
|
||||
from src.config import Config
|
||||
from src.document_loader import load_file
|
||||
from src.embedding_client import EmbeddingClient
|
||||
|
||||
from src.project import ProjectBuilder, ProjectToolkit
|
||||
from src.project_implementation import ProjectToolkitNeo4j
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
questions = []
|
||||
error_list = []
|
||||
success_results = []
|
||||
fail_results = []
|
||||
|
||||
def main():
|
||||
global fail_count, success_count, questions, error_list
|
||||
|
||||
config = Config()
|
||||
|
||||
business_structure = load_file(config.business_object_structure_path)
|
||||
bowei_api_docs = load_file(config.bowei_api_docs_path)
|
||||
|
||||
llm_client = MultiAPIKeyChatOpenAI(config.openai)
|
||||
|
||||
llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder)
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
neo4j_conf = config.neo4j_conf
|
||||
embedding_conf = config.embedding
|
||||
|
||||
embedding_client = EmbeddingClient(embedding_conf)
|
||||
|
||||
# 创建Neo4j检索器
|
||||
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
|
||||
|
||||
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
|
||||
|
||||
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, config.max_retries)
|
||||
|
||||
dialog_manager = DialogManager(
|
||||
llm_client,
|
||||
business_structure,
|
||||
bowei_api_docs,
|
||||
code_executor,
|
||||
knowledge_retriever,
|
||||
prompt_manager,
|
||||
)
|
||||
|
||||
# 加载 zhibiao.jsonl
|
||||
zhibiao_data = []
|
||||
with open('./data/zhibiao.jsonl', 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
zhibiao_data.append(json.loads(line))
|
||||
|
||||
# 提取指标映射关系并批量执行
|
||||
for item in zhibiao_data:
|
||||
|
||||
query = item['query']
|
||||
name = item['name']
|
||||
selected_knowledge = item['result']
|
||||
logger.info(f"指标名称 {name} 问题: {query} 开始生成代码")
|
||||
|
||||
result = dialog_manager.generated_code(query, selected_knowledge)
|
||||
if isinstance(result, dict) and result.get('status', False):
|
||||
code = result['data']
|
||||
success_count += 1
|
||||
success_results.append({
|
||||
"name": name,
|
||||
"query": query,
|
||||
"code": code,
|
||||
})
|
||||
else:
|
||||
questions.append(query)
|
||||
fail_count += 1
|
||||
error_msg = result.get('message', '调用 generated_code 返回空结果') if isinstance(result, dict) else f"问题 {query} {selected_knowledge} 调用 execute_generated_code 返回空结果"
|
||||
error_list.append(error_msg)
|
||||
fail_results.append({
|
||||
"name": name,
|
||||
"query": query,
|
||||
"messages": error_msg,
|
||||
"selected_knowledge": selected_knowledge,
|
||||
})
|
||||
|
||||
#print(result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
total = success_count + fail_count
|
||||
if total > 0:
|
||||
success_rate = success_count / total
|
||||
fail_rate = fail_count / total
|
||||
else:
|
||||
success_rate = fail_rate = 0
|
||||
print(f"问题总数: {total}")
|
||||
print(f"成功比例: {success_rate * 100:.2f}%")
|
||||
print(f"失败比例: {fail_rate * 100:.2f}%")
|
||||
print("错误列表:")
|
||||
for error in error_list:
|
||||
print(error)
|
||||
|
||||
# 保存成功结果到 jsonl 文件
|
||||
success_filename = f'./data/code.jsonl'
|
||||
with open(success_filename, 'w', encoding='utf-8') as f:
|
||||
for item in success_results:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# 保存失败结果到 jsonl 文件
|
||||
fail_filename = f'./data/fail_{now_str}.jsonl'
|
||||
with open(fail_filename, 'w', encoding='utf-8') as f:
|
||||
for item in fail_results:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
# tests/test_userinteraction.py
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
||||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
log_filename = f"{current_file}_{now_str}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(current_file)
|
||||
|
||||
def setup_logger(logger_name):
|
||||
"""
|
||||
设置指定名称的logger,将其级别设置为WARNING并禁用传播
|
||||
:param logger_name: logger的名称
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
|
||||
logger.propagate = False # 可选:禁用传播(防止被根logger处理)
|
||||
return logger
|
||||
|
||||
|
||||
logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"]
|
||||
for name in logger_names:
|
||||
setup_logger(name)
|
||||
|
||||
from src.config import Config
|
||||
from src.document_loader import load_file
|
||||
from src.multi_llm_client import MultiAPIKeyChatOpenAI
|
||||
from src.user_interaction import UserInteraction
|
||||
import json
|
||||
from src.dialog_manager import DialogManager
|
||||
from src.multi_llm_client import MultiAPIKeyChatOpenAI
|
||||
from src.code_executor import CodeExecutor
|
||||
from src.neo4j_raw_retriever import Neo4jRawRetriever
|
||||
from src.prompt_manager import PromptManager
|
||||
import yaml
|
||||
from src.config import Config
|
||||
from src.document_loader import load_file
|
||||
from src.embedding_client import EmbeddingClient
|
||||
|
||||
from src.project import ProjectBuilder, ProjectToolkit
|
||||
from src.project_implementation import ProjectToolkitNeo4j
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
questions = []
|
||||
error_list = []
|
||||
success_list = []
|
||||
|
||||
def main():
|
||||
global success_count, fail_count, questions, error_list, success_list
|
||||
|
||||
config = Config()
|
||||
|
||||
business_structure = load_file(config.business_object_structure_path)
|
||||
bowei_api_docs = load_file(config.bowei_api_docs_path)
|
||||
|
||||
llm_client = MultiAPIKeyChatOpenAI(config.openai)
|
||||
user_interaction = UserInteraction(llm_client.llm, business_structure)
|
||||
|
||||
llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder)
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
neo4j_conf = config.neo4j_conf
|
||||
embedding_conf = config.embedding
|
||||
|
||||
embedding_client = EmbeddingClient(embedding_conf)
|
||||
|
||||
# 创建Neo4j检索器
|
||||
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
|
||||
|
||||
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
|
||||
|
||||
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, config.max_retries)
|
||||
|
||||
dialog_manager = DialogManager(
|
||||
llm_client,
|
||||
business_structure,
|
||||
bowei_api_docs,
|
||||
code_executor,
|
||||
knowledge_retriever,
|
||||
prompt_manager,
|
||||
)
|
||||
|
||||
# 读取 zhibiao.json
|
||||
zhibiao_path = os.path.join(os.path.dirname(__file__), "../data/zhibiao.json")
|
||||
with open(zhibiao_path, "r", encoding="utf-8") as f:
|
||||
zhibiao_data = json.load(f)
|
||||
|
||||
isTest = True
|
||||
isTest = False
|
||||
|
||||
if isTest:
|
||||
zhibiao_data = [
|
||||
{
|
||||
"指标名称": "杆塔总基数",
|
||||
"指标描述": {
|
||||
"指标映射": "从【架空输电线路本体工程/附件安装工程】项目划分中获取名称属于【'合计'】的费用",
|
||||
"映射规则": "YX2-1~7"
|
||||
},
|
||||
"code": "",
|
||||
"单位": "基",
|
||||
"单价类型": None,
|
||||
"序号": "1",
|
||||
"提取方式": None,
|
||||
"指标类型": "工程量指标",
|
||||
"数据来源": "定额数量"
|
||||
}
|
||||
]
|
||||
|
||||
for idx, item in enumerate(zhibiao_data):
|
||||
name = item.get("指标名称", "")
|
||||
datasource = item.get("数据来源", "")
|
||||
if datasource in ("报表指标", "指标库"):
|
||||
logger.info(f"跳过索引 {idx},数据来源为 {datasource}")
|
||||
continue
|
||||
|
||||
query = item.get("指标描述", {}).get("指标映射", "")
|
||||
if not query:
|
||||
logger.warning(f"索引 {idx} 缺少指标映射,跳过")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 调用用户交互理解接口(同步调用)
|
||||
result = user_interaction.understand(query)
|
||||
if not result:
|
||||
logger.error(f"问题: {query} 没有找到符合要求的数据")
|
||||
fail_count += 1
|
||||
error_list.append(f"指标名称 {name} 问题 {query} 调用 understand 返回空结果")
|
||||
continue
|
||||
|
||||
# 这里示例只打印理解结果,你可以根据业务逻辑替换为后续处理
|
||||
logger.info(
|
||||
f"指标名称 {name} 问题: {query} 理解结果: "
|
||||
f"{[{'name': r.get('name'), 'constraints': r.get('constraints')} for r in result]}"
|
||||
)
|
||||
|
||||
success_list.append({
|
||||
"name": name,
|
||||
"query": query,
|
||||
"result": [{'name': r.get('name'), 'constraints': r.get('constraints')} for r in result]
|
||||
})
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"指标名称 {name} 问题: {query} 处理异常: {e}")
|
||||
fail_count += 1
|
||||
error_list.append(f"指标名称 {name} 问题 {query} 异常: {e}")
|
||||
|
||||
total = success_count + fail_count
|
||||
success_rate = (success_count / total) * 100 if total > 0 else 0
|
||||
fail_rate = (fail_count / total) * 100 if total > 0 else 0
|
||||
|
||||
print(f"问题总数: {total}")
|
||||
print(f"成功比例: {success_rate:.2f}%)")
|
||||
print(f"失败比例: {fail_rate:.2f}%)")
|
||||
print("错误列表:")
|
||||
for error in error_list:
|
||||
print(error)
|
||||
|
||||
# 将成功内容保存为 jsonl 文件
|
||||
success_jsonl_path = os.path.join(os.path.dirname(__file__), f"../data/zhibiao.jsonl")
|
||||
with open(success_jsonl_path, "w", encoding="utf-8") as f:
|
||||
for item in success_list:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user