改进rerank效果

This commit is contained in:
2024-08-19 10:03:46 +08:00
parent 806b694b37
commit 22c51218b3
3 changed files with 51 additions and 23 deletions
+8 -2
View File
@@ -10,11 +10,17 @@ from app.xinference.base import XinferenceEmbedding, XinferenceRerank
def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED")
if rerank_enabled is None or rerank_enabled is False:
return []
rerank_model = os.getenv("RERANK_MODEL")
rerank_url = os.getenv("RERANK_BASE_URL")
rerank_top_n = os.getenv("RERANK_TOP_N")
rerank_threshold = os.getenv("RERANK_THRESHOLD")
postprocess = None
if rerank_model is not None:
postprocess = [XinferenceRerank(rerank_model, rerank_url)]
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
return postprocess
def init_settings():
@@ -79,7 +85,7 @@ def init_xinference():
embed_model_name = os.getenv("EMBEDDING_MODEL")
dimensions = os.getenv("EMBEDDING_DIM")
dimensions = int(dimensions) if dimensions is not None else None
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url)
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
def init_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE