新增ollama重排类
This commit is contained in:
@@ -26,6 +26,7 @@ RERANK_ENABLED=true
|
|||||||
RERANK_PROVIDER=ollama
|
RERANK_PROVIDER=ollama
|
||||||
RERANK_MODEL= /models/bge-reranker-base
|
RERANK_MODEL= /models/bge-reranker-base
|
||||||
RERANK_TOP_N=5
|
RERANK_TOP_N=5
|
||||||
|
RERANK_THRESHOLD=0.3
|
||||||
|
|
||||||
#---------- model - Xinference ----------------
|
#---------- model - Xinference ----------------
|
||||||
#MODEL_PROVIDER=xinference
|
#MODEL_PROVIDER=xinference
|
||||||
|
|||||||
@@ -0,0 +1,70 @@
|
|||||||
|
from typing import Any, List, Optional
|
||||||
|
from llama_index.core.postprocessor import SentenceTransformerRerank
|
||||||
|
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
|
||||||
|
from llama_index.core.callbacks import CBEventType, EventPayload
|
||||||
|
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||||
|
|
||||||
|
class OllamaRerank(SentenceTransformerRerank):
|
||||||
|
_score_threshold: float = PrivateAttr()
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
top_n: int = 2,
|
||||||
|
model: str = "cross-encoder/stsb-distilroberta-base",
|
||||||
|
device: Optional[str] = None,
|
||||||
|
keep_retrieval_score: Optional[bool] = False,
|
||||||
|
score_threshold:float = 0.3
|
||||||
|
):
|
||||||
|
self._score_threshold = score_threshold
|
||||||
|
super().__init__(top_n,model,device,keep_retrieval_score)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
return "OllamaRerank"
|
||||||
|
|
||||||
|
def _postprocess_nodes(
|
||||||
|
self,
|
||||||
|
nodes: List[NodeWithScore],
|
||||||
|
query_bundle: Optional[QueryBundle] = None,
|
||||||
|
) -> List[NodeWithScore]:
|
||||||
|
if query_bundle is None:
|
||||||
|
raise ValueError("Missing query bundle in extra info.")
|
||||||
|
if len(nodes) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_and_nodes = [
|
||||||
|
(
|
||||||
|
query_bundle.query_str,
|
||||||
|
node.node.get_content(metadata_mode=MetadataMode.EMBED),
|
||||||
|
)
|
||||||
|
for node in nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
with self.callback_manager.event(
|
||||||
|
CBEventType.RERANKING,
|
||||||
|
payload={
|
||||||
|
EventPayload.NODES: nodes,
|
||||||
|
EventPayload.MODEL_NAME: self.model,
|
||||||
|
EventPayload.QUERY_STR: query_bundle.query_str,
|
||||||
|
EventPayload.TOP_K: self.top_n,
|
||||||
|
},
|
||||||
|
) as event:
|
||||||
|
scores = self._model.predict(query_and_nodes)
|
||||||
|
|
||||||
|
assert len(scores) == len(nodes)
|
||||||
|
|
||||||
|
for node, score in zip(nodes, scores):
|
||||||
|
if self.keep_retrieval_score:
|
||||||
|
node.node.metadata["retrieval_score"] = node.score
|
||||||
|
node.score = score
|
||||||
|
|
||||||
|
for i in range(len(nodes)-1,-1,-1):
|
||||||
|
node = nodes[i]
|
||||||
|
if node.score < self._score_threshold:
|
||||||
|
nodes.remove(node)
|
||||||
|
|
||||||
|
new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
|
||||||
|
: self.top_n
|
||||||
|
]
|
||||||
|
event.on_end(payload={EventPayload.NODES: new_nodes})
|
||||||
|
|
||||||
|
return new_nodes
|
||||||
+12
-1
@@ -88,7 +88,18 @@ class OllamaPlatform(ModelPlatform):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def rerank(self):
|
def rerank(self):
|
||||||
pass
|
from app.engine.rerank.ollamRerank import OllamaRerank
|
||||||
|
modelpath = os.getcwd() + os.getenv('RERANK_MODEL')
|
||||||
|
top_n = os.getenv('RERANK_TOP_N',5)
|
||||||
|
threshold = float(os.getenv('RERANK_THRESHOLD',0.3))
|
||||||
|
rerank = OllamaRerank(
|
||||||
|
model=modelpath,
|
||||||
|
top_n=top_n,
|
||||||
|
device="cpu",
|
||||||
|
score_threshold= threshold
|
||||||
|
)
|
||||||
|
return [rerank]
|
||||||
|
|
||||||
|
|
||||||
@register(ModelPlateCategory,'xinference')
|
@register(ModelPlateCategory,'xinference')
|
||||||
class XinferencePlatform(ModelPlatform):
|
class XinferencePlatform(ModelPlatform):
|
||||||
|
|||||||
Reference in New Issue
Block a user