dev #5
+11
-10
@@ -3,6 +3,16 @@
|
||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||
|
||||
#--------------------------
|
||||
# 是否启用检索重排功能
|
||||
ENABLE_RERANK=true
|
||||
# The number of similar embeddings to return when retrieving documents.
|
||||
TOP_K=5
|
||||
# Rerank model
|
||||
RERANK_MODEL=bge-reranker-v2-m3
|
||||
RERANK_BASE_URL=http://10.1.16.39:9995
|
||||
RERANK_TOP_N=5
|
||||
RERANK_THRESHOLD=0.3
|
||||
#---------- Xinference ----------------
|
||||
# The provider for the AI models to use.
|
||||
MODEL_PROVIDER=xinference
|
||||
@@ -19,9 +29,7 @@ EMBEDDING_MODEL=bge-m3
|
||||
EMBEDDING_BASE_URL=http://10.1.16.39:9995
|
||||
# Dimension of the embedding model to use.
|
||||
EMBEDDING_DIM=1024
|
||||
# Rerank model
|
||||
RERANK_MODEL=bge-reranker-v2-m3
|
||||
RERANK_BASE_URL=http://10.1.16.39:9995
|
||||
|
||||
##---------- OpenAI ----------------
|
||||
## The provider for the AI models to use.
|
||||
#MODEL_PROVIDER=openai
|
||||
@@ -46,17 +54,10 @@ RERANK_BASE_URL=http://10.1.16.39:9995
|
||||
## Name of the embedding model to use.
|
||||
#EMBEDDING_MODEL=text-embedding-v2
|
||||
|
||||
#--------------------------
|
||||
# 是否启用检索重排功能
|
||||
ENABLE_RERANK=true
|
||||
|
||||
|
||||
# The questions to help users get started (multi-line).
|
||||
CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容?
|
||||
|
||||
# The number of similar embeddings to return when retrieving documents.
|
||||
TOP_K=5
|
||||
|
||||
# The time in milliseconds to wait for the stream to return a response.
|
||||
STREAM_TIMEOUT=60000
|
||||
|
||||
|
||||
@@ -10,11 +10,17 @@ from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_enabled = os.getenv("RERANK_ENABLED")
|
||||
if rerank_enabled is None or rerank_enabled is False:
|
||||
return []
|
||||
|
||||
rerank_model = os.getenv("RERANK_MODEL")
|
||||
rerank_url = os.getenv("RERANK_BASE_URL")
|
||||
rerank_top_n = os.getenv("RERANK_TOP_N")
|
||||
rerank_threshold = os.getenv("RERANK_THRESHOLD")
|
||||
postprocess = None
|
||||
if rerank_model is not None:
|
||||
postprocess = [XinferenceRerank(rerank_model, rerank_url)]
|
||||
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
|
||||
return postprocess
|
||||
|
||||
def init_settings():
|
||||
@@ -79,7 +85,7 @@ def init_xinference():
|
||||
embed_model_name = os.getenv("EMBEDDING_MODEL")
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
dimensions = int(dimensions) if dimensions is not None else None
|
||||
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url)
|
||||
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
|
||||
|
||||
def init_openai():
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
|
||||
@@ -5,10 +5,11 @@ from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding, dispatcher
|
||||
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||
from llama_index.core.callbacks import CBEventType, EventPayload
|
||||
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
|
||||
from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReRankEndEvent
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
|
||||
from pydantic import Field
|
||||
@@ -51,7 +52,7 @@ class XinferenceEmbedding(BaseEmbedding):
|
||||
# num_workers: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
generator, model_description = self.load_model(
|
||||
generator, model_description, embed_batch_size, dimensions = self.load_model(
|
||||
model_uid, endpoint
|
||||
)
|
||||
self._generator = generator
|
||||
@@ -106,8 +107,11 @@ class XinferenceEmbedding(BaseEmbedding):
|
||||
)
|
||||
|
||||
model = model_description["model_name"]
|
||||
replica = model_description['replica']
|
||||
dimensions = model_description['dimensions']
|
||||
max_tokens = model_description['max_tokens']
|
||||
|
||||
return generator, model_description
|
||||
return generator, model_description, replica, dimensions
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
@@ -151,23 +155,27 @@ class XinferenceRerank(BaseNodePostprocessor):
|
||||
description="The model description from Xinference."
|
||||
)
|
||||
_generator: Any = PrivateAttr()
|
||||
_model_uid: str = Field(description="The Xinference model to use.")
|
||||
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||
#model: str = Field(description="Dashscope rerank model name.")
|
||||
_model_uid: str
|
||||
_endpoint: str
|
||||
model: str = Field(description="Dashscope rerank model name.")
|
||||
top_n: int = Field(description="Top N nodes to return.")
|
||||
threshold: float = Field(description="threshold nodes to return.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_uid: str,
|
||||
endpoint: str,
|
||||
top_n: int = 3,
|
||||
threshold: float = 0.3,
|
||||
return_documents: bool = False
|
||||
):
|
||||
_model_uid = model_uid
|
||||
_endpoint = endpoint
|
||||
generator, model_description = self.load_model(
|
||||
model_uid, endpoint
|
||||
)
|
||||
self._generator = generator
|
||||
super().__init__(top_n=top_n, model=model_uid, return_documents=return_documents)
|
||||
super().__init__(top_n=top_n, model=model_uid, threshold = threshold, return_documents=return_documents)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
@@ -183,6 +191,15 @@ class XinferenceRerank(BaseNodePostprocessor):
|
||||
if len(nodes) == 0:
|
||||
return []
|
||||
|
||||
dispatcher.event(
|
||||
ReRankStartEvent(
|
||||
nodes = nodes,
|
||||
top_n = self.top_n,
|
||||
query = query_bundle,
|
||||
model_name = self.model
|
||||
)
|
||||
)
|
||||
|
||||
with self.callback_manager.event(
|
||||
CBEventType.RERANKING,
|
||||
payload={
|
||||
@@ -199,12 +216,16 @@ class XinferenceRerank(BaseNodePostprocessor):
|
||||
new_node_with_score = NodeWithScore(
|
||||
node=nodes[result['index']].node, score=result['relevance_score']
|
||||
)
|
||||
print(new_node_with_score.node.get_content)
|
||||
print('\n')
|
||||
print(new_node_with_score.score)
|
||||
if new_node_with_score.score >=self.threshold:
|
||||
new_nodes.append(new_node_with_score)
|
||||
|
||||
event.on_end(payload={EventPayload.NODES: new_nodes})
|
||||
|
||||
dispatcher.event(
|
||||
ReRankEndEvent(
|
||||
nodes= new_nodes
|
||||
)
|
||||
)
|
||||
return new_nodes
|
||||
|
||||
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||
|
||||
Reference in New Issue
Block a user