diff --git a/main.py b/main.py index b5ecd4a..591ce0b 100644 --- a/main.py +++ b/main.py @@ -56,6 +56,7 @@ from src.code_executor import CodeExecutor from src.dialog_manager import DialogManager from src.neo4j_raw_retriever import Neo4jRawRetriever from src.embedding_client import EmbeddingClient + from project import ProjectBuilder, ProjectToolkit from project_implementation import ProjectToolkitNeo4j @@ -69,14 +70,11 @@ def main(): os.environ["LANGSMITH_TRACING"] = "true" if tracing_enabled else "false" os.environ["LANGSMITH_API_KEY"] = config.langsmith.get("api_key") #os.environ["LANGSMITH_API_KEY"] = config.langsmith.get("api_url") - business_structure = load_file(config.business_object_structure_path) bowei_api_docs = load_file(config.bowei_api_docs_path) llm_client = LLMClient(config.openai) - llm_client_coder = LLMClient(config.openai_coder) - prompt_manager = PromptManager() neo4j_conf = config.neo4j_conf @@ -85,12 +83,10 @@ def main(): embedding_client = EmbeddingClient(embedding_conf) # 创建Neo4j检索器 + knowledge_retriever = Neo4jRawRetriever(neo4j_conf) - ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver) - - code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder) dialog_manager = DialogManager( @@ -102,10 +98,14 @@ def main(): prompt_manager, ) - pre_input_question = "查找名称中包含“工程”的项目划分项,并返回其人工费乘以1000的值。" - pre_input_question = "查找名称中包含“工程”的项目划分项,并返回单位。" + pre_input_question = '查找名称中包含"工程"的项目划分项,并返回其人工费乘以1000的值。' + pre_input_question = '查找名称中包含"工程"的项目划分项,并返回单位。' + try: asyncio.run(dialog_manager.run_async(pre_input=pre_input_question)) + finally: + + neo4j_driver.close() if __name__ == "__main__": diff --git a/project.py b/project.py index ee7c47e..40ba9ba 100644 --- a/project.py +++ b/project.py @@ -161,6 +161,7 @@ class ProjectToolkit(ABC): Args: parent_path (str): 父节点的路径,以'/'分隔的多级节点路径 + code (str): 目标节点编码 Returns: @@ -770,7 +771,7 @@ class ProjectBuilder: if not issubclass(toolkit_class, ProjectToolkit): raise TypeError(f"{toolkit_class.__name__} 必须继承自 ProjectToolkit") - cls._config = config + _registry = toolkit_class cls._config = config cls._registry = toolkit_class @classmethod @@ -786,4 +787,5 @@ class ProjectBuilder: if cls._registry is None: raise KeyError(f"未注册的类,请先注册类") + return cls._registry(cls._config) diff --git a/requirements.txt b/requirements.txt index 1bd838c..0d25c76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ langchain-core langchain-experimental pyyaml neo4j -langchain-neo4j \ No newline at end of file +langchain-neo4j +langgraph \ No newline at end of file