102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
# main.py
|
|
|
|
import os
|
|
import logging
|
|
from datetime import datetime
|
|
import logging
|
|
import os
|
|
import asyncio
|
|
import os
|
|
|
|
current_file = os.path.splitext(os.path.basename(__file__))[0]
|
|
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
log_filename = f"{current_file}_{now_str}.log"
|
|
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
handlers=[
|
|
logging.FileHandler(os.path.join("logs", log_filename), encoding="utf-8"),
|
|
logging.StreamHandler()
|
|
],
|
|
)
|
|
|
|
logger = logging.getLogger(current_file)
|
|
|
|
def setup_logger(logger_name):
|
|
"""
|
|
设置指定名称的logger,将其级别设置为WARNING并禁用传播
|
|
:param logger_name: logger的名称
|
|
"""
|
|
logger = logging.getLogger(logger_name)
|
|
logger.setLevel(logging.WARNING) # 设置httpcore及其子模块的级别
|
|
logger.propagate = False # 可选:禁用传播(防止被根logger处理)
|
|
return logger
|
|
|
|
logger_names = ["httpx", "openai", "langsmith.client", "neo4j", "urllib3", "httpcore"]
|
|
for name in logger_names:
|
|
setup_logger(name)
|
|
|
|
from src.config import Config
|
|
from src.document_loader import load_file
|
|
from src.multi_llm_client import MultiAPIKeyChatOpenAI
|
|
from src.prompt_manager import PromptManager
|
|
from src.code_executor import CodeExecutor
|
|
from src.dialog_manager import DialogManager
|
|
from src.neo4j_raw_retriever import Neo4jRawRetriever
|
|
from src.embedding_client import EmbeddingClient
|
|
|
|
from src.project import ProjectBuilder, ProjectToolkit
|
|
from src.project_implementation import ProjectToolkitNeo4j
|
|
|
|
|
|
|
|
def main():
|
|
config = Config()
|
|
# 根据配置设置环境变量
|
|
tracing_enabled = config.langsmith.get("tracing_enabled", False)
|
|
os.environ["LANGSMITH_PROJECT"] = config.langsmith.get("project")
|
|
os.environ["LANGSMITH_TRACING"] = "true" if tracing_enabled else "false"
|
|
os.environ["LANGSMITH_API_KEY"] = config.langsmith.get("api_key")
|
|
#os.environ["LANGSMITH_API_KEY"] = config.langsmith.get("api_url")
|
|
business_structure = load_file(config.business_object_structure_path)
|
|
bowei_api_docs = load_file(config.bowei_api_docs_path)
|
|
|
|
llm_client = MultiAPIKeyChatOpenAI(config.openai)
|
|
llm_client_coder = MultiAPIKeyChatOpenAI(config.openai_coder)
|
|
prompt_manager = PromptManager()
|
|
|
|
neo4j_conf = config.neo4j_conf
|
|
embedding_conf = config.embedding
|
|
|
|
embedding_client = EmbeddingClient(embedding_conf)
|
|
|
|
# 创建Neo4j检索器
|
|
|
|
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
|
|
|
|
|
|
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder, config.max_retries)
|
|
|
|
dialog_manager = DialogManager(
|
|
llm_client,
|
|
business_structure,
|
|
bowei_api_docs,
|
|
code_executor,
|
|
knowledge_retriever,
|
|
prompt_manager,
|
|
)
|
|
|
|
pre_input_question = '查找名称中包含"工程"的项目划分项,并返回其人工费乘以1000的值。'
|
|
pre_input_question = '查找名称中包含"工程"的项目划分项,并返回单位。'
|
|
|
|
try:
|
|
asyncio.run(dialog_manager.run_async(pre_input=pre_input_question))
|
|
finally:
|
|
|
|
neo4j_driver.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|