新增openai的向量,重排模型
This commit is contained in:
+21
-8
@@ -42,14 +42,27 @@ DASHSCOPE_API_KEY=sk-221d2d202e104618a56002ce2e7dc0d0
|
|||||||
MODEL=qwen2-math-72b-instruct
|
MODEL=qwen2-math-72b-instruct
|
||||||
|
|
||||||
# #---------- model - openai ----------------
|
# #---------- model - openai ----------------
|
||||||
# MODEL_PROVIDER=openai
|
MODEL_PROVIDER=openai
|
||||||
# OPENAI_API_KEY=sk-hhoqttvhibirwheyponjifsqwssgxotoqlcjufkidytwxngi
|
OPENAI_API_KEY=
|
||||||
# BASE_URL=https://api.siliconflow.cn/v1
|
BASE_URL=https://api.siliconflow.cn/v1
|
||||||
# MODEL=alibaba/Qwen1.5-110B-Chat
|
MODEL=alibaba/Qwen1.5-110B-Chat
|
||||||
# LLM_TEMPERATURE=0.1
|
LLM_TEMPERATURE=0.1
|
||||||
# CONTEXT_WINDOW = 8192
|
CONTEXT_WINDOW = 8192
|
||||||
# IS_CHAT_MODEL = true
|
IS_CHAT_MODEL = true
|
||||||
# IS_FUN_CALL_MODEL = false
|
IS_FUN_CALL_MODEL = false
|
||||||
|
|
||||||
|
#---------- embedding - openai ----------------
|
||||||
|
EMBEDDING_PROVIDER=openai
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
EMBEDDING_MODEL=BAAI/bge-m3
|
||||||
|
EMBEDDING_BASE_URL=https://api.siliconflow.cn/v1
|
||||||
|
EMBEDDING_DIM=1024
|
||||||
|
|
||||||
|
RERANK_PROVIDER=openai
|
||||||
|
OPENAI_API_KEY=sk-hhoqttvhibirwheyponjifsqwssgxotoqlcjufkidytwxngi
|
||||||
|
RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
||||||
|
RERANK_BASE_URL=https://api.siliconflow.cn/v1
|
||||||
|
RERANK_TOP_N=5
|
||||||
|
|
||||||
|
|
||||||
#---------- embedding - Xinference ----------------
|
#---------- embedding - Xinference ----------------
|
||||||
|
|||||||
@@ -0,0 +1,97 @@
|
|||||||
|
import requests
|
||||||
|
from typing import List, Optional
|
||||||
|
from llama_index.core.bridge.pydantic import Field
|
||||||
|
from llama_index.core.callbacks import CBEventType, EventPayload
|
||||||
|
from llama_index.core.instrumentation import get_dispatcher
|
||||||
|
from llama_index.core.instrumentation.events.rerank import (
|
||||||
|
ReRankEndEvent,
|
||||||
|
ReRankStartEvent,
|
||||||
|
)
|
||||||
|
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||||
|
from llama_index.core.schema import NodeWithScore, QueryBundle, MetadataMode
|
||||||
|
|
||||||
|
dispatcher = get_dispatcher(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SiliconCloudRerank(BaseNodePostprocessor):
|
||||||
|
top_n: int = Field(
|
||||||
|
default=5,
|
||||||
|
description="The number of nodes to return.",
|
||||||
|
)
|
||||||
|
model: str = Field(
|
||||||
|
default="bge-reranker-base",
|
||||||
|
description="The SiliconCloud model uid to use.",
|
||||||
|
)
|
||||||
|
base_url: str = Field(
|
||||||
|
default="https://api.siliconflow.cn/v1",
|
||||||
|
description="The SiliconCloud base url to use.",
|
||||||
|
)
|
||||||
|
api_key:str = Field(
|
||||||
|
default="",
|
||||||
|
description="The SiliconCloud Api key to use.",
|
||||||
|
)
|
||||||
|
score_threshold: float = Field(default=0.3,description="分数阈值")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
return "SiliconCloudRerank"
|
||||||
|
|
||||||
|
def get_query_str(self, query):
|
||||||
|
return query.query_str if isinstance(query, QueryBundle) else query
|
||||||
|
|
||||||
|
def _postprocess_nodes(
|
||||||
|
self,
|
||||||
|
nodes: List[NodeWithScore],
|
||||||
|
query_bundle: Optional[QueryBundle] = None,
|
||||||
|
) -> List[NodeWithScore]:
|
||||||
|
dispatcher.event(
|
||||||
|
ReRankStartEvent(
|
||||||
|
query=query_bundle,
|
||||||
|
nodes=nodes,
|
||||||
|
top_n=self.top_n,
|
||||||
|
model_name=self.model,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if query_bundle is None:
|
||||||
|
raise ValueError("Missing query bundle.")
|
||||||
|
if len(nodes) == 0:
|
||||||
|
return []
|
||||||
|
with self.callback_manager.event(
|
||||||
|
CBEventType.RERANKING,
|
||||||
|
payload={
|
||||||
|
EventPayload.NODES: nodes,
|
||||||
|
EventPayload.MODEL_NAME: self.model,
|
||||||
|
EventPayload.QUERY_STR: self.get_query_str(query_bundle),
|
||||||
|
EventPayload.TOP_K: self.top_n,
|
||||||
|
},
|
||||||
|
) as event:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
'Authorization': f'Bearer {self.api_key}'
|
||||||
|
}
|
||||||
|
json_data = {
|
||||||
|
"model": self.model,
|
||||||
|
"query": self.get_query_str(query_bundle),
|
||||||
|
"documents": [
|
||||||
|
node.node.get_content(metadata_mode=MetadataMode.EMBED)
|
||||||
|
for node in nodes
|
||||||
|
],
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
url=f"{self.base_url}/rerank", headers=headers, json=json_data
|
||||||
|
)
|
||||||
|
response.encoding = "utf-8"
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"SiliconCloud call failed with status code {response.status_code}."
|
||||||
|
f"Details: {response.text}"
|
||||||
|
)
|
||||||
|
rerank_nodes = [
|
||||||
|
NodeWithScore(
|
||||||
|
node=nodes[result["index"]].node, score=result["relevance_score"]
|
||||||
|
)
|
||||||
|
for result in response.json()["results"][: self.top_n]
|
||||||
|
]
|
||||||
|
event.on_end(payload={EventPayload.NODES: rerank_nodes})
|
||||||
|
dispatcher.event(ReRankEndEvent(nodes=rerank_nodes))
|
||||||
|
return rerank_nodes
|
||||||
+11
-8
@@ -16,7 +16,6 @@ from modelProvide.customDashScope import CustomDashScope
|
|||||||
from util.register import *
|
from util.register import *
|
||||||
from llama_index.core.callbacks import CallbackManager
|
from llama_index.core.callbacks import CallbackManager
|
||||||
|
|
||||||
|
|
||||||
ModelPlateCategory = '模型平台'
|
ModelPlateCategory = '模型平台'
|
||||||
|
|
||||||
def init_settings():
|
def init_settings():
|
||||||
@@ -130,15 +129,19 @@ class OpenAIPlatform(ModelPlatform):
|
|||||||
|
|
||||||
def embedding(self):
|
def embedding(self):
|
||||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||||
dimensions = os.getenv("EMBEDDING_DIM")
|
return OpenAIEmbedding(api_key=os.getenv('OPENAI_API_KEY'),
|
||||||
config = {
|
api_base= os.getenv('EMBEDDING_BASE_URL'),
|
||||||
"model": os.getenv("EMBEDDING_MODEL"),
|
model_name = os.getenv('EMBEDDING_MODEL'),
|
||||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
dimensions= int(os.getenv("EMBEDDING_DIM")))
|
||||||
}
|
|
||||||
return OpenAIEmbedding(**config)
|
|
||||||
|
|
||||||
def rerank(self):
|
def rerank(self):
|
||||||
pass
|
from app.engine.rerank.siliconCloudRerank import SiliconCloudRerank
|
||||||
|
postprocess = [SiliconCloudRerank(top_n = int(os.getenv('RERANK_TOP_N',5)),
|
||||||
|
model = os.getenv('RERANK_MODEL'),
|
||||||
|
base_url = os.getenv('RERANK_BASE_URL'),
|
||||||
|
api_key = os.getenv('OPENAI_API_KEY')
|
||||||
|
)]
|
||||||
|
return postprocess
|
||||||
|
|
||||||
@register(ModelPlateCategory,'dashscope')
|
@register(ModelPlateCategory,'dashscope')
|
||||||
class DashscopePlatform(ModelPlatform):
|
class DashscopePlatform(ModelPlatform):
|
||||||
|
|||||||
Reference in New Issue
Block a user