合并
This commit is contained in:
@@ -56,6 +56,7 @@ from src.code_executor import CodeExecutor
|
||||
from src.dialog_manager import DialogManager
|
||||
from src.neo4j_raw_retriever import Neo4jRawRetriever
|
||||
from src.embedding_client import EmbeddingClient
|
||||
|
||||
from project import ProjectBuilder, ProjectToolkit
|
||||
from project_implementation import ProjectToolkitNeo4j
|
||||
|
||||
@@ -69,14 +70,11 @@ def main():
|
||||
os.environ["LANGSMITH_TRACING"] = "true" if tracing_enabled else "false"
|
||||
os.environ["LANGSMITH_API_KEY"] = config.langsmith.get("api_key")
|
||||
#os.environ["LANGSMITH_API_KEY"] = config.langsmith.get("api_url")
|
||||
|
||||
business_structure = load_file(config.business_object_structure_path)
|
||||
bowei_api_docs = load_file(config.bowei_api_docs_path)
|
||||
|
||||
llm_client = LLMClient(config.openai)
|
||||
|
||||
llm_client_coder = LLMClient(config.openai_coder)
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
neo4j_conf = config.neo4j_conf
|
||||
@@ -85,12 +83,10 @@ def main():
|
||||
embedding_client = EmbeddingClient(embedding_conf)
|
||||
|
||||
# 创建Neo4j检索器
|
||||
|
||||
knowledge_retriever = Neo4jRawRetriever(neo4j_conf)
|
||||
|
||||
|
||||
ProjectBuilder.register(ProjectToolkitNeo4j, knowledge_retriever.driver)
|
||||
|
||||
|
||||
code_executor = CodeExecutor(prompt_manager.prompts, llm_client_coder)
|
||||
|
||||
dialog_manager = DialogManager(
|
||||
@@ -102,10 +98,14 @@ def main():
|
||||
prompt_manager,
|
||||
)
|
||||
|
||||
pre_input_question = "查找名称中包含“工程”的项目划分项,并返回其人工费乘以1000的值。"
|
||||
pre_input_question = "查找名称中包含“工程”的项目划分项,并返回单位。"
|
||||
pre_input_question = '查找名称中包含"工程"的项目划分项,并返回其人工费乘以1000的值。'
|
||||
pre_input_question = '查找名称中包含"工程"的项目划分项,并返回单位。'
|
||||
|
||||
try:
|
||||
asyncio.run(dialog_manager.run_async(pre_input=pre_input_question))
|
||||
finally:
|
||||
|
||||
neo4j_driver.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+3
-1
@@ -161,6 +161,7 @@ class ProjectToolkit(ABC):
|
||||
|
||||
Args:
|
||||
parent_path (str): 父节点的路径,以'/'分隔的多级节点路径
|
||||
|
||||
code (str): 目标节点编码
|
||||
|
||||
Returns:
|
||||
@@ -770,7 +771,7 @@ class ProjectBuilder:
|
||||
if not issubclass(toolkit_class, ProjectToolkit):
|
||||
raise TypeError(f"{toolkit_class.__name__} 必须继承自 ProjectToolkit")
|
||||
|
||||
cls._config = config
|
||||
_registry = toolkit_class cls._config = config
|
||||
cls._registry = toolkit_class
|
||||
|
||||
@classmethod
|
||||
@@ -786,4 +787,5 @@ class ProjectBuilder:
|
||||
if cls._registry is None:
|
||||
raise KeyError(f"未注册的类,请先注册类")
|
||||
|
||||
|
||||
return cls._registry(cls._config)
|
||||
|
||||
+2
-1
@@ -3,4 +3,5 @@ langchain-core
|
||||
langchain-experimental
|
||||
pyyaml
|
||||
neo4j
|
||||
langchain-neo4j
|
||||
langchain-neo4j
|
||||
langgraph
|
||||
Reference in New Issue
Block a user