自定义重排类,实现分数阈值过滤
This commit is contained in:
@@ -5,10 +5,10 @@ from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.embeddings.xinference import XinferenceEmbedding
|
||||
#from llama_index.llms.xinference import Xinference
|
||||
from app.engine.model.xinfeng import XinfengModel
|
||||
#from llama_index.embeddings.xinference import XinferenceEmbedding
|
||||
from app.engine.model.xinference import XinferenceModel
|
||||
from app.engine.rerank.xinferenceRerank import CustomXinFerenceRerank
|
||||
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
||||
from llama_index.postprocessor.xinference_rerank import XinferenceRerank
|
||||
|
||||
|
||||
from app.engine.loaders import getProjectInfos
|
||||
from app.api.routers.request.base import ProjectInfo
|
||||
@@ -97,7 +97,7 @@ class XinferencePlatform(ModelPlatform):
|
||||
model = os.getenv("MODEL")
|
||||
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
|
||||
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
|
||||
return XinfengModel(model_uid = model,endpoint = base_url,temperature = temperature,max_tokens = max_tokens)
|
||||
return XinferenceModel(model_uid = model,endpoint = base_url,temperature = temperature,max_tokens = max_tokens)
|
||||
|
||||
def embedding(self):
|
||||
base_url = os.getenv("BASE_URL")
|
||||
@@ -116,7 +116,7 @@ class XinferencePlatform(ModelPlatform):
|
||||
rerank_threshold = os.getenv("RERANK_THRESHOLD")
|
||||
postprocess = None
|
||||
if rerank_model is not None:
|
||||
postprocess = [XinferenceRerank(model = rerank_model, base_url = rerank_url, top_n=rerank_top_n)]
|
||||
postprocess = [CustomXinFerenceRerank(model = rerank_model, base_url = rerank_url, top_n=rerank_top_n,score_threshold=rerank_threshold)]
|
||||
return postprocess
|
||||
|
||||
@register(ModelPlateCategory,'openai')
|
||||
|
||||
Reference in New Issue
Block a user