改进rerank效果
This commit is contained in:
@@ -10,11 +10,17 @@ from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_enabled = os.getenv("RERANK_ENABLED")
|
||||
if rerank_enabled is None or rerank_enabled is False:
|
||||
return []
|
||||
|
||||
rerank_model = os.getenv("RERANK_MODEL")
|
||||
rerank_url = os.getenv("RERANK_BASE_URL")
|
||||
rerank_top_n = os.getenv("RERANK_TOP_N")
|
||||
rerank_threshold = os.getenv("RERANK_THRESHOLD")
|
||||
postprocess = None
|
||||
if rerank_model is not None:
|
||||
postprocess = [XinferenceRerank(rerank_model, rerank_url)]
|
||||
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
|
||||
return postprocess
|
||||
|
||||
def init_settings():
|
||||
@@ -79,7 +85,7 @@ def init_xinference():
|
||||
embed_model_name = os.getenv("EMBEDDING_MODEL")
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
dimensions = int(dimensions) if dimensions is not None else None
|
||||
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url)
|
||||
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
|
||||
|
||||
def init_openai():
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
|
||||
@@ -5,10 +5,11 @@ from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding, dispatcher
|
||||
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||
from llama_index.core.callbacks import CBEventType, EventPayload
|
||||
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
|
||||
from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReRankEndEvent
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
|
||||
from pydantic import Field
|
||||
@@ -51,7 +52,7 @@ class XinferenceEmbedding(BaseEmbedding):
|
||||
# num_workers: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
generator, model_description = self.load_model(
|
||||
generator, model_description, embed_batch_size, dimensions = self.load_model(
|
||||
model_uid, endpoint
|
||||
)
|
||||
self._generator = generator
|
||||
@@ -106,8 +107,11 @@ class XinferenceEmbedding(BaseEmbedding):
|
||||
)
|
||||
|
||||
model = model_description["model_name"]
|
||||
replica = model_description['replica']
|
||||
dimensions = model_description['dimensions']
|
||||
max_tokens = model_description['max_tokens']
|
||||
|
||||
return generator, model_description
|
||||
return generator, model_description, replica, dimensions
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
@@ -151,23 +155,27 @@ class XinferenceRerank(BaseNodePostprocessor):
|
||||
description="The model description from Xinference."
|
||||
)
|
||||
_generator: Any = PrivateAttr()
|
||||
_model_uid: str = Field(description="The Xinference model to use.")
|
||||
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||
#model: str = Field(description="Dashscope rerank model name.")
|
||||
_model_uid: str
|
||||
_endpoint: str
|
||||
model: str = Field(description="Dashscope rerank model name.")
|
||||
top_n: int = Field(description="Top N nodes to return.")
|
||||
threshold: float = Field(description="threshold nodes to return.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_uid: str,
|
||||
endpoint: str,
|
||||
top_n: int = 3,
|
||||
threshold: float = 0.3,
|
||||
return_documents: bool = False
|
||||
):
|
||||
_model_uid = model_uid
|
||||
_endpoint = endpoint
|
||||
generator, model_description = self.load_model(
|
||||
model_uid, endpoint
|
||||
)
|
||||
self._generator = generator
|
||||
super().__init__(top_n=top_n, model=model_uid, return_documents=return_documents)
|
||||
super().__init__(top_n=top_n, model=model_uid, threshold = threshold, return_documents=return_documents)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
@@ -183,6 +191,15 @@ class XinferenceRerank(BaseNodePostprocessor):
|
||||
if len(nodes) == 0:
|
||||
return []
|
||||
|
||||
dispatcher.event(
|
||||
ReRankStartEvent(
|
||||
nodes = nodes,
|
||||
top_n = self.top_n,
|
||||
query = query_bundle,
|
||||
model_name = self.model
|
||||
)
|
||||
)
|
||||
|
||||
with self.callback_manager.event(
|
||||
CBEventType.RERANKING,
|
||||
payload={
|
||||
@@ -199,12 +216,16 @@ class XinferenceRerank(BaseNodePostprocessor):
|
||||
new_node_with_score = NodeWithScore(
|
||||
node=nodes[result['index']].node, score=result['relevance_score']
|
||||
)
|
||||
print(new_node_with_score.node.get_content)
|
||||
print('\n')
|
||||
print(new_node_with_score.score)
|
||||
new_nodes.append(new_node_with_score)
|
||||
if new_node_with_score.score >=self.threshold:
|
||||
new_nodes.append(new_node_with_score)
|
||||
|
||||
event.on_end(payload={EventPayload.NODES: new_nodes})
|
||||
|
||||
dispatcher.event(
|
||||
ReRankEndEvent(
|
||||
nodes= new_nodes
|
||||
)
|
||||
)
|
||||
return new_nodes
|
||||
|
||||
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||
|
||||
Reference in New Issue
Block a user