自定义重排类,实现分数阈值过滤

This commit is contained in:
wanyaokun
2024-09-10 15:05:12 +08:00
parent 0bf2799acf
commit f4b1f40173
4 changed files with 82 additions and 7 deletions
+5 -5
View File
@@ -5,10 +5,10 @@ from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.settings import Settings
from llama_index.embeddings.xinference import XinferenceEmbedding
#from llama_index.llms.xinference import Xinference
from app.engine.model.xinfeng import XinfengModel
#from llama_index.embeddings.xinference import XinferenceEmbedding
from app.engine.model.xinference import XinferenceModel
from app.engine.rerank.xinferenceRerank import CustomXinFerenceRerank
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from llama_index.postprocessor.xinference_rerank import XinferenceRerank
from app.engine.loaders import getProjectInfos
from app.api.routers.request.base import ProjectInfo
@@ -97,7 +97,7 @@ class XinferencePlatform(ModelPlatform):
model = os.getenv("MODEL")
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
return XinfengModel(model_uid = model,endpoint = base_url,temperature = temperature,max_tokens = max_tokens)
return XinferenceModel(model_uid = model,endpoint = base_url,temperature = temperature,max_tokens = max_tokens)
def embedding(self):
base_url = os.getenv("BASE_URL")
@@ -116,7 +116,7 @@ class XinferencePlatform(ModelPlatform):
rerank_threshold = os.getenv("RERANK_THRESHOLD")
postprocess = None
if rerank_model is not None:
postprocess = [XinferenceRerank(model = rerank_model, base_url = rerank_url, top_n=rerank_top_n)]
postprocess = [CustomXinFerenceRerank(model = rerank_model, base_url = rerank_url, top_n=rerank_top_n,score_threshold=rerank_threshold)]
return postprocess
@register(ModelPlateCategory,'openai')