增加对Rerank功能支持

This commit is contained in:
2024-08-19 08:59:08 +08:00
parent 8d4382376f
commit 2942730c9a
3 changed files with 18 additions and 3 deletions
+3
View File
@@ -19,6 +19,9 @@ EMBEDDING_MODEL=bge-m3
EMBEDDING_BASE_URL=http://10.1.16.39:9995 EMBEDDING_BASE_URL=http://10.1.16.39:9995
# Dimension of the embedding model to use. # Dimension of the embedding model to use.
EMBEDDING_DIM=1024 EMBEDDING_DIM=1024
# Rerank model
RERANK_MODEL=bge-reranker-v2-m3
RERANK_BASE_URL=http://10.1.16.39:9995
##---------- OpenAI ---------------- ##---------- OpenAI ----------------
## The provider for the AI models to use. ## The provider for the AI models to use.
#MODEL_PROVIDER=openai #MODEL_PROVIDER=openai
+6 -2
View File
@@ -11,6 +11,7 @@ from sqlalchemy import create_engine, Engine
from app.engine.loaders.db import makeDescriptionByEngine from app.engine.loaders.db import makeDescriptionByEngine
from app.engine.tools import ToolFactory from app.engine.tools import ToolFactory
from app.engine.index import get_index from app.engine.index import get_index
from app.settings import get_node_postprocessors
sql_database = None sql_database = None
sql_obj_index = None sql_obj_index = None
@@ -53,12 +54,15 @@ def get_chat_engine(filters=None, params=None):
) )
# 创建向量检索查询工具 # 创建向量检索查询工具
postprocess = get_node_postprocessors()
query_engine = index.as_query_engine( query_engine = index.as_query_engine(
similarity_top_k=top_k, filters=filters similarity_top_k=top_k, filters=filters,
node_postprocessors=postprocess,
) )
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool", query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。", description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。如果你不知道答案,就说你不知道,不要编造答案。",
) )
tools.append(summary_query_tool) tools.append(summary_query_tool)
tools.append(query_engine_tool) tools.append(query_engine_tool)
#tools.append(sql_query_tool) #tools.append(sql_query_tool)
+9 -1
View File
@@ -6,9 +6,17 @@ from llama_index.core.settings import Settings
from llama_index.llms.xinference import Xinference from llama_index.llms.xinference import Xinference
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from app.xinference.base import XinferenceEmbedding from app.xinference.base import XinferenceEmbedding, XinferenceRerank
def get_node_postprocessors():
rerank_model = os.getenv("RERANK_MODEL")
rerank_url = os.getenv("RERANK_BASE_URL")
postprocess = None
if rerank_model is None:
postprocess = [XinferenceRerank(rerank_model, rerank_url)]
return postprocess
def init_settings(): def init_settings():
model_provider = os.getenv("MODEL_PROVIDER") model_provider = os.getenv("MODEL_PROVIDER")
match model_provider: match model_provider: