From 22c51218b3e4f64ead6c5f53356e2e4181c8ee3c Mon Sep 17 00:00:00 2001 From: paituo <330435863@qq.com> Date: Mon, 19 Aug 2024 10:03:46 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E8=BF=9Brerank=E6=95=88=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/.env.xinference | 21 +++++++++-------- backend/app/settings.py | 10 ++++++-- backend/app/xinference/base.py | 43 +++++++++++++++++++++++++--------- 3 files changed, 51 insertions(+), 23 deletions(-) diff --git a/backend/.env.xinference b/backend/.env.xinference index a8e9396..b0d2d0c 100644 --- a/backend/.env.xinference +++ b/backend/.env.xinference @@ -3,6 +3,16 @@ SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 #SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2 +#-------------------------- +# 是否启用检索重排功能 +ENABLE_RERANK=true +# The number of similar embeddings to return when retrieving documents. +TOP_K=5 +# Rerank model +RERANK_MODEL=bge-reranker-v2-m3 +RERANK_BASE_URL=http://10.1.16.39:9995 +RERANK_TOP_N=5 +RERANK_THRESHOLD=0.3 #---------- Xinference ---------------- # The provider for the AI models to use. MODEL_PROVIDER=xinference @@ -19,9 +29,7 @@ EMBEDDING_MODEL=bge-m3 EMBEDDING_BASE_URL=http://10.1.16.39:9995 # Dimension of the embedding model to use. EMBEDDING_DIM=1024 -# Rerank model -RERANK_MODEL=bge-reranker-v2-m3 -RERANK_BASE_URL=http://10.1.16.39:9995 + ##---------- OpenAI ---------------- ## The provider for the AI models to use. #MODEL_PROVIDER=openai @@ -46,17 +54,10 @@ RERANK_BASE_URL=http://10.1.16.39:9995 ## Name of the embedding model to use. #EMBEDDING_MODEL=text-embedding-v2 -#-------------------------- -# 是否启用检索重排功能 -ENABLE_RERANK=true - # The questions to help users get started (multi-line). CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容? -# The number of similar embeddings to return when retrieving documents. -TOP_K=5 - # The time in milliseconds to wait for the stream to return a response. STREAM_TIMEOUT=60000 diff --git a/backend/app/settings.py b/backend/app/settings.py index 55fe1e8..f5246cd 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -10,11 +10,17 @@ from app.xinference.base import XinferenceEmbedding, XinferenceRerank def get_node_postprocessors(): + rerank_enabled = os.getenv("RERANK_ENABLED") + if rerank_enabled is None or rerank_enabled is False: + return [] + rerank_model = os.getenv("RERANK_MODEL") rerank_url = os.getenv("RERANK_BASE_URL") + rerank_top_n = os.getenv("RERANK_TOP_N") + rerank_threshold = os.getenv("RERANK_THRESHOLD") postprocess = None if rerank_model is not None: - postprocess = [XinferenceRerank(rerank_model, rerank_url)] + postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)] return postprocess def init_settings(): @@ -79,7 +85,7 @@ def init_xinference(): embed_model_name = os.getenv("EMBEDDING_MODEL") dimensions = os.getenv("EMBEDDING_DIM") dimensions = int(dimensions) if dimensions is not None else None - Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url) + Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions) def init_openai(): from llama_index.core.constants import DEFAULT_TEMPERATURE diff --git a/backend/app/xinference/base.py b/backend/app/xinference/base.py index d6bca82..92f87e1 100644 --- a/backend/app/xinference/base.py +++ b/backend/app/xinference/base.py @@ -5,10 +5,11 @@ from enum import Enum from http import HTTPStatus from typing import Any, Dict, List, Optional, Union, Tuple -from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding +from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding, dispatcher from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.callbacks import CBEventType, EventPayload from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding +from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReRankEndEvent from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle from pydantic import Field @@ -51,7 +52,7 @@ class XinferenceEmbedding(BaseEmbedding): # num_workers: Optional[int] = None, **kwargs: Any, ) -> None: - generator, model_description = self.load_model( + generator, model_description, embed_batch_size, dimensions = self.load_model( model_uid, endpoint ) self._generator = generator @@ -106,8 +107,11 @@ class XinferenceEmbedding(BaseEmbedding): ) model = model_description["model_name"] + replica = model_description['replica'] + dimensions = model_description['dimensions'] + max_tokens = model_description['max_tokens'] - return generator, model_description + return generator, model_description, replica, dimensions @classmethod def class_name(cls) -> str: @@ -151,23 +155,27 @@ class XinferenceRerank(BaseNodePostprocessor): description="The model description from Xinference." ) _generator: Any = PrivateAttr() - _model_uid: str = Field(description="The Xinference model to use.") - _endpoint: str = Field(description="The Xinference endpoint URL to use.") - #model: str = Field(description="Dashscope rerank model name.") + _model_uid: str + _endpoint: str + model: str = Field(description="Dashscope rerank model name.") top_n: int = Field(description="Top N nodes to return.") + threshold: float = Field(description="threshold nodes to return.") def __init__( self, model_uid: str, endpoint: str, top_n: int = 3, + threshold: float = 0.3, return_documents: bool = False ): + _model_uid = model_uid + _endpoint = endpoint generator, model_description = self.load_model( model_uid, endpoint ) self._generator = generator - super().__init__(top_n=top_n, model=model_uid, return_documents=return_documents) + super().__init__(top_n=top_n, model=model_uid, threshold = threshold, return_documents=return_documents) @classmethod def class_name(cls) -> str: @@ -183,6 +191,15 @@ class XinferenceRerank(BaseNodePostprocessor): if len(nodes) == 0: return [] + dispatcher.event( + ReRankStartEvent( + nodes = nodes, + top_n = self.top_n, + query = query_bundle, + model_name = self.model + ) + ) + with self.callback_manager.event( CBEventType.RERANKING, payload={ @@ -199,12 +216,16 @@ class XinferenceRerank(BaseNodePostprocessor): new_node_with_score = NodeWithScore( node=nodes[result['index']].node, score=result['relevance_score'] ) - print(new_node_with_score.node.get_content) - print('\n') - print(new_node_with_score.score) - new_nodes.append(new_node_with_score) + if new_node_with_score.score >=self.threshold: + new_nodes.append(new_node_with_score) + event.on_end(payload={EventPayload.NODES: new_nodes}) + dispatcher.event( + ReRankEndEvent( + nodes= new_nodes + ) + ) return new_nodes def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: