改进rerank效果
This commit is contained in:
+11
-10
@@ -3,6 +3,16 @@
|
|||||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||||
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||||
|
|
||||||
|
#--------------------------
|
||||||
|
# 是否启用检索重排功能
|
||||||
|
ENABLE_RERANK=true
|
||||||
|
# The number of similar embeddings to return when retrieving documents.
|
||||||
|
TOP_K=5
|
||||||
|
# Rerank model
|
||||||
|
RERANK_MODEL=bge-reranker-v2-m3
|
||||||
|
RERANK_BASE_URL=http://10.1.16.39:9995
|
||||||
|
RERANK_TOP_N=5
|
||||||
|
RERANK_THRESHOLD=0.3
|
||||||
#---------- Xinference ----------------
|
#---------- Xinference ----------------
|
||||||
# The provider for the AI models to use.
|
# The provider for the AI models to use.
|
||||||
MODEL_PROVIDER=xinference
|
MODEL_PROVIDER=xinference
|
||||||
@@ -19,9 +29,7 @@ EMBEDDING_MODEL=bge-m3
|
|||||||
EMBEDDING_BASE_URL=http://10.1.16.39:9995
|
EMBEDDING_BASE_URL=http://10.1.16.39:9995
|
||||||
# Dimension of the embedding model to use.
|
# Dimension of the embedding model to use.
|
||||||
EMBEDDING_DIM=1024
|
EMBEDDING_DIM=1024
|
||||||
# Rerank model
|
|
||||||
RERANK_MODEL=bge-reranker-v2-m3
|
|
||||||
RERANK_BASE_URL=http://10.1.16.39:9995
|
|
||||||
##---------- OpenAI ----------------
|
##---------- OpenAI ----------------
|
||||||
## The provider for the AI models to use.
|
## The provider for the AI models to use.
|
||||||
#MODEL_PROVIDER=openai
|
#MODEL_PROVIDER=openai
|
||||||
@@ -46,17 +54,10 @@ RERANK_BASE_URL=http://10.1.16.39:9995
|
|||||||
## Name of the embedding model to use.
|
## Name of the embedding model to use.
|
||||||
#EMBEDDING_MODEL=text-embedding-v2
|
#EMBEDDING_MODEL=text-embedding-v2
|
||||||
|
|
||||||
#--------------------------
|
|
||||||
# 是否启用检索重排功能
|
|
||||||
ENABLE_RERANK=true
|
|
||||||
|
|
||||||
|
|
||||||
# The questions to help users get started (multi-line).
|
# The questions to help users get started (multi-line).
|
||||||
CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容?
|
CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容?
|
||||||
|
|
||||||
# The number of similar embeddings to return when retrieving documents.
|
|
||||||
TOP_K=5
|
|
||||||
|
|
||||||
# The time in milliseconds to wait for the stream to return a response.
|
# The time in milliseconds to wait for the stream to return a response.
|
||||||
STREAM_TIMEOUT=60000
|
STREAM_TIMEOUT=60000
|
||||||
|
|
||||||
|
|||||||
@@ -10,11 +10,17 @@ from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
|||||||
|
|
||||||
|
|
||||||
def get_node_postprocessors():
|
def get_node_postprocessors():
|
||||||
|
rerank_enabled = os.getenv("RERANK_ENABLED")
|
||||||
|
if rerank_enabled is None or rerank_enabled is False:
|
||||||
|
return []
|
||||||
|
|
||||||
rerank_model = os.getenv("RERANK_MODEL")
|
rerank_model = os.getenv("RERANK_MODEL")
|
||||||
rerank_url = os.getenv("RERANK_BASE_URL")
|
rerank_url = os.getenv("RERANK_BASE_URL")
|
||||||
|
rerank_top_n = os.getenv("RERANK_TOP_N")
|
||||||
|
rerank_threshold = os.getenv("RERANK_THRESHOLD")
|
||||||
postprocess = None
|
postprocess = None
|
||||||
if rerank_model is not None:
|
if rerank_model is not None:
|
||||||
postprocess = [XinferenceRerank(rerank_model, rerank_url)]
|
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
|
||||||
return postprocess
|
return postprocess
|
||||||
|
|
||||||
def init_settings():
|
def init_settings():
|
||||||
@@ -79,7 +85,7 @@ def init_xinference():
|
|||||||
embed_model_name = os.getenv("EMBEDDING_MODEL")
|
embed_model_name = os.getenv("EMBEDDING_MODEL")
|
||||||
dimensions = os.getenv("EMBEDDING_DIM")
|
dimensions = os.getenv("EMBEDDING_DIM")
|
||||||
dimensions = int(dimensions) if dimensions is not None else None
|
dimensions = int(dimensions) if dimensions is not None else None
|
||||||
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url)
|
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
|
||||||
|
|
||||||
def init_openai():
|
def init_openai():
|
||||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
|
|||||||
@@ -5,10 +5,11 @@ from enum import Enum
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||||
|
|
||||||
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
|
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding, dispatcher
|
||||||
from llama_index.core.bridge.pydantic import PrivateAttr
|
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||||
from llama_index.core.callbacks import CBEventType, EventPayload
|
from llama_index.core.callbacks import CBEventType, EventPayload
|
||||||
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
|
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
|
||||||
|
from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReRankEndEvent
|
||||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||||
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
|
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@@ -51,7 +52,7 @@ class XinferenceEmbedding(BaseEmbedding):
|
|||||||
# num_workers: Optional[int] = None,
|
# num_workers: Optional[int] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
generator, model_description = self.load_model(
|
generator, model_description, embed_batch_size, dimensions = self.load_model(
|
||||||
model_uid, endpoint
|
model_uid, endpoint
|
||||||
)
|
)
|
||||||
self._generator = generator
|
self._generator = generator
|
||||||
@@ -106,8 +107,11 @@ class XinferenceEmbedding(BaseEmbedding):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = model_description["model_name"]
|
model = model_description["model_name"]
|
||||||
|
replica = model_description['replica']
|
||||||
|
dimensions = model_description['dimensions']
|
||||||
|
max_tokens = model_description['max_tokens']
|
||||||
|
|
||||||
return generator, model_description
|
return generator, model_description, replica, dimensions
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def class_name(cls) -> str:
|
def class_name(cls) -> str:
|
||||||
@@ -151,23 +155,27 @@ class XinferenceRerank(BaseNodePostprocessor):
|
|||||||
description="The model description from Xinference."
|
description="The model description from Xinference."
|
||||||
)
|
)
|
||||||
_generator: Any = PrivateAttr()
|
_generator: Any = PrivateAttr()
|
||||||
_model_uid: str = Field(description="The Xinference model to use.")
|
_model_uid: str
|
||||||
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
_endpoint: str
|
||||||
#model: str = Field(description="Dashscope rerank model name.")
|
model: str = Field(description="Dashscope rerank model name.")
|
||||||
top_n: int = Field(description="Top N nodes to return.")
|
top_n: int = Field(description="Top N nodes to return.")
|
||||||
|
threshold: float = Field(description="threshold nodes to return.")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_uid: str,
|
model_uid: str,
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
top_n: int = 3,
|
top_n: int = 3,
|
||||||
|
threshold: float = 0.3,
|
||||||
return_documents: bool = False
|
return_documents: bool = False
|
||||||
):
|
):
|
||||||
|
_model_uid = model_uid
|
||||||
|
_endpoint = endpoint
|
||||||
generator, model_description = self.load_model(
|
generator, model_description = self.load_model(
|
||||||
model_uid, endpoint
|
model_uid, endpoint
|
||||||
)
|
)
|
||||||
self._generator = generator
|
self._generator = generator
|
||||||
super().__init__(top_n=top_n, model=model_uid, return_documents=return_documents)
|
super().__init__(top_n=top_n, model=model_uid, threshold = threshold, return_documents=return_documents)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def class_name(cls) -> str:
|
def class_name(cls) -> str:
|
||||||
@@ -183,6 +191,15 @@ class XinferenceRerank(BaseNodePostprocessor):
|
|||||||
if len(nodes) == 0:
|
if len(nodes) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
dispatcher.event(
|
||||||
|
ReRankStartEvent(
|
||||||
|
nodes = nodes,
|
||||||
|
top_n = self.top_n,
|
||||||
|
query = query_bundle,
|
||||||
|
model_name = self.model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with self.callback_manager.event(
|
with self.callback_manager.event(
|
||||||
CBEventType.RERANKING,
|
CBEventType.RERANKING,
|
||||||
payload={
|
payload={
|
||||||
@@ -199,12 +216,16 @@ class XinferenceRerank(BaseNodePostprocessor):
|
|||||||
new_node_with_score = NodeWithScore(
|
new_node_with_score = NodeWithScore(
|
||||||
node=nodes[result['index']].node, score=result['relevance_score']
|
node=nodes[result['index']].node, score=result['relevance_score']
|
||||||
)
|
)
|
||||||
print(new_node_with_score.node.get_content)
|
if new_node_with_score.score >=self.threshold:
|
||||||
print('\n')
|
new_nodes.append(new_node_with_score)
|
||||||
print(new_node_with_score.score)
|
|
||||||
new_nodes.append(new_node_with_score)
|
|
||||||
event.on_end(payload={EventPayload.NODES: new_nodes})
|
event.on_end(payload={EventPayload.NODES: new_nodes})
|
||||||
|
|
||||||
|
dispatcher.event(
|
||||||
|
ReRankEndEvent(
|
||||||
|
nodes= new_nodes
|
||||||
|
)
|
||||||
|
)
|
||||||
return new_nodes
|
return new_nodes
|
||||||
|
|
||||||
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||||
|
|||||||
Reference in New Issue
Block a user