新增ollama重排类

This commit is contained in:
wanyaokun
2024-09-06 10:07:22 +08:00
parent 21fdc16259
commit 60b0f11ca2
3 changed files with 83 additions and 1 deletions
+1
View File
@@ -26,6 +26,7 @@ RERANK_ENABLED=true
RERANK_PROVIDER=ollama RERANK_PROVIDER=ollama
RERANK_MODEL= /models/bge-reranker-base RERANK_MODEL= /models/bge-reranker-base
RERANK_TOP_N=5 RERANK_TOP_N=5
RERANK_THRESHOLD=0.3
#---------- model - Xinference ---------------- #---------- model - Xinference ----------------
#MODEL_PROVIDER=xinference #MODEL_PROVIDER=xinference
+70
View File
@@ -0,0 +1,70 @@
from typing import Any, List, Optional
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.bridge.pydantic import PrivateAttr
class OllamaRerank(SentenceTransformerRerank):
_score_threshold: float = PrivateAttr()
def __init__(
self,
top_n: int = 2,
model: str = "cross-encoder/stsb-distilroberta-base",
device: Optional[str] = None,
keep_retrieval_score: Optional[bool] = False,
score_threshold:float = 0.3
):
self._score_threshold = score_threshold
super().__init__(top_n,model,device,keep_retrieval_score)
@classmethod
def class_name(cls) -> str:
return "OllamaRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
query_and_nodes = [
(
query_bundle.query_str,
node.node.get_content(metadata_mode=MetadataMode.EMBED),
)
for node in nodes
]
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
scores = self._model.predict(query_and_nodes)
assert len(scores) == len(nodes)
for node, score in zip(nodes, scores):
if self.keep_retrieval_score:
node.node.metadata["retrieval_score"] = node.score
node.score = score
for i in range(len(nodes)-1,-1,-1):
node = nodes[i]
if node.score < self._score_threshold:
nodes.remove(node)
new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
: self.top_n
]
event.on_end(payload={EventPayload.NODES: new_nodes})
return new_nodes
+12 -1
View File
@@ -88,7 +88,18 @@ class OllamaPlatform(ModelPlatform):
pass pass
def rerank(self): def rerank(self):
pass from app.engine.rerank.ollamRerank import OllamaRerank
modelpath = os.getcwd() + os.getenv('RERANK_MODEL')
top_n = os.getenv('RERANK_TOP_N',5)
threshold = float(os.getenv('RERANK_THRESHOLD',0.3))
rerank = OllamaRerank(
model=modelpath,
top_n=top_n,
device="cpu",
score_threshold= threshold
)
return [rerank]
@register(ModelPlateCategory,'xinference') @register(ModelPlateCategory,'xinference')
class XinferencePlatform(ModelPlatform): class XinferencePlatform(ModelPlatform):