自定义重排类,实现分数阈值过滤

This commit is contained in:
wanyaokun
2024-09-10 15:05:12 +08:00
parent 0bf2799acf
commit f4b1f40173
4 changed files with 82 additions and 7 deletions
+2 -2
View File
@@ -61,8 +61,8 @@ def get_chat_engine(filters=None, params:dict=None):
react_chat_formatter = ReActChatFormatter.from_defaults(ReActChatFormatter_messages)
agentrunner = AgentRunner.from_llm(
llm=Settings.llm,
tools=tools,
react_chat_formatter=react_chat_formatter,
tools=tools,
#react_chat_formatter=react_chat_formatter,
system_prompt=system_prompt,
verbose=True,
)
@@ -0,0 +1,75 @@
import requests
from llama_index.postprocessor.xinference_rerank import XinferenceRerank
from llama_index.core.bridge.pydantic import Field
from typing import List, Optional
from llama_index.core.bridge.pydantic import Field
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
ReRankEndEvent,
ReRankStartEvent,
)
from llama_index.core.schema import NodeWithScore, QueryBundle, MetadataMode
dispatcher = get_dispatcher(__name__)
class CustomXinFerenceRerank(XinferenceRerank):
score_threshold: float = Field(default=0.3,description="分数阈值")
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
dispatcher.event(
ReRankStartEvent(
query=query_bundle,
nodes=nodes,
top_n=self.top_n,
model_name=self.model,
)
)
if query_bundle is None:
raise ValueError("Missing query bundle.")
if len(nodes) == 0:
return []
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: self.get_query_str(query_bundle),
EventPayload.TOP_K: self.top_n,
},
) as event:
headers = {"Content-Type": "application/json"}
json_data = {
"model": self.model,
"query": self.get_query_str(query_bundle),
"documents": [
node.node.get_content(metadata_mode=MetadataMode.EMBED)
for node in nodes
],
}
response = requests.post(
url=f"{self.base_url}/v1/rerank", headers=headers, json=json_data
)
response.encoding = "utf-8"
if response.status_code != 200:
raise Exception(
f"Xinference call failed with status code {response.status_code}."
f"Details: {response.text}"
)
rerank_nodes = []
for result in response.json()["results"]:
node = NodeWithScore(
node=nodes[result["index"]].node, score=result["relevance_score"]
)
if node.score > self.score_threshold:
rerank_nodes.append(node)
if len(rerank_nodes) > self.top_n:
rerank_nodes = sorted(rerank_nodes,key=lambda x:x.score)[:self.top_n]
event.on_end(payload={EventPayload.NODES: rerank_nodes})
dispatcher.event(ReRankEndEvent(nodes=rerank_nodes))
return rerank_nodes
+5 -5
View File
@@ -5,10 +5,10 @@ from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.settings import Settings
from llama_index.embeddings.xinference import XinferenceEmbedding
#from llama_index.llms.xinference import Xinference
from app.engine.model.xinfeng import XinfengModel
#from llama_index.embeddings.xinference import XinferenceEmbedding
from app.engine.model.xinference import XinferenceModel
from app.engine.rerank.xinferenceRerank import CustomXinFerenceRerank
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from llama_index.postprocessor.xinference_rerank import XinferenceRerank
from app.engine.loaders import getProjectInfos
from app.api.routers.request.base import ProjectInfo
@@ -97,7 +97,7 @@ class XinferencePlatform(ModelPlatform):
model = os.getenv("MODEL")
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
return XinfengModel(model_uid = model,endpoint = base_url,temperature = temperature,max_tokens = max_tokens)
return XinferenceModel(model_uid = model,endpoint = base_url,temperature = temperature,max_tokens = max_tokens)
def embedding(self):
base_url = os.getenv("BASE_URL")
@@ -116,7 +116,7 @@ class XinferencePlatform(ModelPlatform):
rerank_threshold = os.getenv("RERANK_THRESHOLD")
postprocess = None
if rerank_model is not None:
postprocess = [XinferenceRerank(model = rerank_model, base_url = rerank_url, top_n=rerank_top_n)]
postprocess = [CustomXinFerenceRerank(model = rerank_model, base_url = rerank_url, top_n=rerank_top_n,score_threshold=rerank_threshold)]
return postprocess
@register(ModelPlateCategory,'openai')