改进rerank效果

This commit is contained in:
2024-08-19 10:03:46 +08:00
parent 806b694b37
commit 22c51218b3
3 changed files with 51 additions and 23 deletions
+11 -10
View File
@@ -3,6 +3,16 @@
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2 #SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
#--------------------------
# 是否启用检索重排功能
ENABLE_RERANK=true
# The number of similar embeddings to return when retrieving documents.
TOP_K=5
# Rerank model
RERANK_MODEL=bge-reranker-v2-m3
RERANK_BASE_URL=http://10.1.16.39:9995
RERANK_TOP_N=5
RERANK_THRESHOLD=0.3
#---------- Xinference ---------------- #---------- Xinference ----------------
# The provider for the AI models to use. # The provider for the AI models to use.
MODEL_PROVIDER=xinference MODEL_PROVIDER=xinference
@@ -19,9 +29,7 @@ EMBEDDING_MODEL=bge-m3
EMBEDDING_BASE_URL=http://10.1.16.39:9995 EMBEDDING_BASE_URL=http://10.1.16.39:9995
# Dimension of the embedding model to use. # Dimension of the embedding model to use.
EMBEDDING_DIM=1024 EMBEDDING_DIM=1024
# Rerank model
RERANK_MODEL=bge-reranker-v2-m3
RERANK_BASE_URL=http://10.1.16.39:9995
##---------- OpenAI ---------------- ##---------- OpenAI ----------------
## The provider for the AI models to use. ## The provider for the AI models to use.
#MODEL_PROVIDER=openai #MODEL_PROVIDER=openai
@@ -46,17 +54,10 @@ RERANK_BASE_URL=http://10.1.16.39:9995
## Name of the embedding model to use. ## Name of the embedding model to use.
#EMBEDDING_MODEL=text-embedding-v2 #EMBEDDING_MODEL=text-embedding-v2
#--------------------------
# 是否启用检索重排功能
ENABLE_RERANK=true
# The questions to help users get started (multi-line). # The questions to help users get started (multi-line).
CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容? CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容?
# The number of similar embeddings to return when retrieving documents.
TOP_K=5
# The time in milliseconds to wait for the stream to return a response. # The time in milliseconds to wait for the stream to return a response.
STREAM_TIMEOUT=60000 STREAM_TIMEOUT=60000
+8 -2
View File
@@ -10,11 +10,17 @@ from app.xinference.base import XinferenceEmbedding, XinferenceRerank
def get_node_postprocessors(): def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED")
if rerank_enabled is None or rerank_enabled is False:
return []
rerank_model = os.getenv("RERANK_MODEL") rerank_model = os.getenv("RERANK_MODEL")
rerank_url = os.getenv("RERANK_BASE_URL") rerank_url = os.getenv("RERANK_BASE_URL")
rerank_top_n = os.getenv("RERANK_TOP_N")
rerank_threshold = os.getenv("RERANK_THRESHOLD")
postprocess = None postprocess = None
if rerank_model is not None: if rerank_model is not None:
postprocess = [XinferenceRerank(rerank_model, rerank_url)] postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
return postprocess return postprocess
def init_settings(): def init_settings():
@@ -79,7 +85,7 @@ def init_xinference():
embed_model_name = os.getenv("EMBEDDING_MODEL") embed_model_name = os.getenv("EMBEDDING_MODEL")
dimensions = os.getenv("EMBEDDING_DIM") dimensions = os.getenv("EMBEDDING_DIM")
dimensions = int(dimensions) if dimensions is not None else None dimensions = int(dimensions) if dimensions is not None else None
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url) Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
def init_openai(): def init_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.core.constants import DEFAULT_TEMPERATURE
+32 -11
View File
@@ -5,10 +5,11 @@ from enum import Enum
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union, Tuple from typing import Any, Dict, List, Optional, Union, Tuple
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding, dispatcher
from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReRankEndEvent
from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
from pydantic import Field from pydantic import Field
@@ -51,7 +52,7 @@ class XinferenceEmbedding(BaseEmbedding):
# num_workers: Optional[int] = None, # num_workers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
generator, model_description = self.load_model( generator, model_description, embed_batch_size, dimensions = self.load_model(
model_uid, endpoint model_uid, endpoint
) )
self._generator = generator self._generator = generator
@@ -106,8 +107,11 @@ class XinferenceEmbedding(BaseEmbedding):
) )
model = model_description["model_name"] model = model_description["model_name"]
replica = model_description['replica']
dimensions = model_description['dimensions']
max_tokens = model_description['max_tokens']
return generator, model_description return generator, model_description, replica, dimensions
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
@@ -151,23 +155,27 @@ class XinferenceRerank(BaseNodePostprocessor):
description="The model description from Xinference." description="The model description from Xinference."
) )
_generator: Any = PrivateAttr() _generator: Any = PrivateAttr()
_model_uid: str = Field(description="The Xinference model to use.") _model_uid: str
_endpoint: str = Field(description="The Xinference endpoint URL to use.") _endpoint: str
#model: str = Field(description="Dashscope rerank model name.") model: str = Field(description="Dashscope rerank model name.")
top_n: int = Field(description="Top N nodes to return.") top_n: int = Field(description="Top N nodes to return.")
threshold: float = Field(description="threshold nodes to return.")
def __init__( def __init__(
self, self,
model_uid: str, model_uid: str,
endpoint: str, endpoint: str,
top_n: int = 3, top_n: int = 3,
threshold: float = 0.3,
return_documents: bool = False return_documents: bool = False
): ):
_model_uid = model_uid
_endpoint = endpoint
generator, model_description = self.load_model( generator, model_description = self.load_model(
model_uid, endpoint model_uid, endpoint
) )
self._generator = generator self._generator = generator
super().__init__(top_n=top_n, model=model_uid, return_documents=return_documents) super().__init__(top_n=top_n, model=model_uid, threshold = threshold, return_documents=return_documents)
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
@@ -183,6 +191,15 @@ class XinferenceRerank(BaseNodePostprocessor):
if len(nodes) == 0: if len(nodes) == 0:
return [] return []
dispatcher.event(
ReRankStartEvent(
nodes = nodes,
top_n = self.top_n,
query = query_bundle,
model_name = self.model
)
)
with self.callback_manager.event( with self.callback_manager.event(
CBEventType.RERANKING, CBEventType.RERANKING,
payload={ payload={
@@ -199,12 +216,16 @@ class XinferenceRerank(BaseNodePostprocessor):
new_node_with_score = NodeWithScore( new_node_with_score = NodeWithScore(
node=nodes[result['index']].node, score=result['relevance_score'] node=nodes[result['index']].node, score=result['relevance_score']
) )
print(new_node_with_score.node.get_content) if new_node_with_score.score >=self.threshold:
print('\n') new_nodes.append(new_node_with_score)
print(new_node_with_score.score)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes}) event.on_end(payload={EventPayload.NODES: new_nodes})
dispatcher.event(
ReRankEndEvent(
nodes= new_nodes
)
)
return new_nodes return new_nodes
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: