78 lines
3.6 KiB
Python
78 lines
3.6 KiB
Python
import os
|
|
|
|
from llama_index.core.agent import AgentRunner, ReActChatFormatter
|
|
from llama_index.core.settings import Settings
|
|
from llama_index.core.tools.query_engine import QueryEngineTool
|
|
|
|
from app.engine.engine import create_query_engine, create_summary_query_engine
|
|
from app.engine.index import get_index
|
|
from app.engine.prompt import ReActChatFormatter_messages, tree_summary_query_engine_tool_messages, \
|
|
query_engine_tool_messages, summary_query_tool_messages
|
|
#from app.engine.loaders.db import makeDescriptionByEngine
|
|
from app.engine.tools import ToolFactory
|
|
from app.api.routers.request.base import ProjectInfo
|
|
from llama_index.core.response_synthesizers import ResponseMode
|
|
|
|
def getPrjFalg(params:dict=None)->str:
|
|
prjFlag = ''
|
|
if params is not None:
|
|
prjFlag = ProjectInfo().prjFalg(params.get('projectname'))
|
|
return prjFlag
|
|
|
|
|
|
def get_chat_engine(filters=None, params:dict=None):
|
|
system_prompt = os.getenv("SYSTEM_PROMPT")
|
|
top_k = int(os.getenv("TOP_K", "3"))
|
|
use_reranker = os.getenv("RERANK_ENABLED")
|
|
tools = []
|
|
# 创建SQL查询工具
|
|
# sql_query_engine = create_summary_query_engine(index)
|
|
# sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine,
|
|
# name="zjdata_query_tool",
|
|
# description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具"
|
|
# )
|
|
#tools.append(sql_query_tool)
|
|
|
|
# Add query tool if index exists
|
|
index = get_index(getPrjFalg(params))
|
|
if index is not None:
|
|
|
|
|
|
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
|
|
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
|
description=summary_query_tool_messages,
|
|
)
|
|
|
|
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = ResponseMode.TREE_SUMMARIZE)
|
|
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
|
|
description=query_engine_tool_messages)
|
|
|
|
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = ResponseMode.TREE_SUMMARIZE)
|
|
query_engine_tool_1 = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool_1",
|
|
description=tree_summary_query_engine_tool_messages)
|
|
|
|
tools.append(query_engine_tool)
|
|
#tools.append(query_engine_tool_1)
|
|
#tools.append(summary_query_tool)
|
|
|
|
# Add additional tools
|
|
tools += ToolFactory.from_env()
|
|
|
|
react_chat_formatter = ReActChatFormatter.from_defaults(ReActChatFormatter_messages)
|
|
agentrunner = AgentRunner.from_llm(
|
|
llm=Settings.llm,
|
|
tools=tools,
|
|
#react_chat_formatter=react_chat_formatter,
|
|
system_prompt=system_prompt,
|
|
verbose=True,
|
|
)
|
|
return agentrunner
|
|
|
|
# create the function calling worker for reasoning
|
|
# worker = FunctionCallingAgentWorker.from_tools(
|
|
# tools, verbose=True
|
|
# )
|
|
#
|
|
# # wrap the worker in the top-level planner
|
|
# return StructuredPlannerAgent(worker, tools)
|