改进rerank效果

This commit is contained in:
2024-08-19 10:03:46 +08:00
parent 806b694b37
commit 22c51218b3
3 changed files with 51 additions and 23 deletions
+8 -2
View File
@@ -10,11 +10,17 @@ from app.xinference.base import XinferenceEmbedding, XinferenceRerank
def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED")
if rerank_enabled is None or rerank_enabled is False:
return []
rerank_model = os.getenv("RERANK_MODEL")
rerank_url = os.getenv("RERANK_BASE_URL")
rerank_top_n = os.getenv("RERANK_TOP_N")
rerank_threshold = os.getenv("RERANK_THRESHOLD")
postprocess = None
if rerank_model is not None:
postprocess = [XinferenceRerank(rerank_model, rerank_url)]
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
return postprocess
def init_settings():
@@ -79,7 +85,7 @@ def init_xinference():
embed_model_name = os.getenv("EMBEDDING_MODEL")
dimensions = os.getenv("EMBEDDING_DIM")
dimensions = int(dimensions) if dimensions is not None else None
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url)
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
def init_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE
+32 -11
View File
@@ -5,10 +5,11 @@ from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union, Tuple
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding, dispatcher
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReRankEndEvent
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
from pydantic import Field
@@ -51,7 +52,7 @@ class XinferenceEmbedding(BaseEmbedding):
# num_workers: Optional[int] = None,
**kwargs: Any,
) -> None:
generator, model_description = self.load_model(
generator, model_description, embed_batch_size, dimensions = self.load_model(
model_uid, endpoint
)
self._generator = generator
@@ -106,8 +107,11 @@ class XinferenceEmbedding(BaseEmbedding):
)
model = model_description["model_name"]
replica = model_description['replica']
dimensions = model_description['dimensions']
max_tokens = model_description['max_tokens']
return generator, model_description
return generator, model_description, replica, dimensions
@classmethod
def class_name(cls) -> str:
@@ -151,23 +155,27 @@ class XinferenceRerank(BaseNodePostprocessor):
description="The model description from Xinference."
)
_generator: Any = PrivateAttr()
_model_uid: str = Field(description="The Xinference model to use.")
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
#model: str = Field(description="Dashscope rerank model name.")
_model_uid: str
_endpoint: str
model: str = Field(description="Dashscope rerank model name.")
top_n: int = Field(description="Top N nodes to return.")
threshold: float = Field(description="threshold nodes to return.")
def __init__(
self,
model_uid: str,
endpoint: str,
top_n: int = 3,
threshold: float = 0.3,
return_documents: bool = False
):
_model_uid = model_uid
_endpoint = endpoint
generator, model_description = self.load_model(
model_uid, endpoint
)
self._generator = generator
super().__init__(top_n=top_n, model=model_uid, return_documents=return_documents)
super().__init__(top_n=top_n, model=model_uid, threshold = threshold, return_documents=return_documents)
@classmethod
def class_name(cls) -> str:
@@ -183,6 +191,15 @@ class XinferenceRerank(BaseNodePostprocessor):
if len(nodes) == 0:
return []
dispatcher.event(
ReRankStartEvent(
nodes = nodes,
top_n = self.top_n,
query = query_bundle,
model_name = self.model
)
)
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
@@ -199,12 +216,16 @@ class XinferenceRerank(BaseNodePostprocessor):
new_node_with_score = NodeWithScore(
node=nodes[result['index']].node, score=result['relevance_score']
)
print(new_node_with_score.node.get_content)
print('\n')
print(new_node_with_score.score)
new_nodes.append(new_node_with_score)
if new_node_with_score.score >=self.threshold:
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})
dispatcher.event(
ReRankEndEvent(
nodes= new_nodes
)
)
return new_nodes
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: