diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py index 2de0a85..d02f7c1 100644 --- a/backend/app/engine/__init__.py +++ b/backend/app/engine/__init__.py @@ -61,8 +61,8 @@ def get_chat_engine(filters=None, params:dict=None): react_chat_formatter = ReActChatFormatter.from_defaults(ReActChatFormatter_messages) agentrunner = AgentRunner.from_llm( llm=Settings.llm, - tools=tools, - react_chat_formatter=react_chat_formatter, + tools=tools, + #react_chat_formatter=react_chat_formatter, system_prompt=system_prompt, verbose=True, ) diff --git a/backend/app/engine/model/xinfeng.py b/backend/app/engine/model/xinference.py similarity index 100% rename from backend/app/engine/model/xinfeng.py rename to backend/app/engine/model/xinference.py diff --git a/backend/app/engine/rerank/xinferenceRerank.py b/backend/app/engine/rerank/xinferenceRerank.py new file mode 100644 index 0000000..770195a --- /dev/null +++ b/backend/app/engine/rerank/xinferenceRerank.py @@ -0,0 +1,75 @@ +import requests +from llama_index.postprocessor.xinference_rerank import XinferenceRerank +from llama_index.core.bridge.pydantic import Field +from typing import List, Optional +from llama_index.core.bridge.pydantic import Field +from llama_index.core.callbacks import CBEventType, EventPayload +from llama_index.core.instrumentation import get_dispatcher +from llama_index.core.instrumentation.events.rerank import ( + ReRankEndEvent, + ReRankStartEvent, +) +from llama_index.core.schema import NodeWithScore, QueryBundle, MetadataMode +dispatcher = get_dispatcher(__name__) + +class CustomXinFerenceRerank(XinferenceRerank): + score_threshold: float = Field(default=0.3,description="分数阈值") + + def _postprocess_nodes( + self, + nodes: List[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + ) -> List[NodeWithScore]: + dispatcher.event( + ReRankStartEvent( + query=query_bundle, + nodes=nodes, + top_n=self.top_n, + model_name=self.model, + ) + ) + if query_bundle is None: + raise ValueError("Missing query bundle.") + if len(nodes) == 0: + return [] + with self.callback_manager.event( + CBEventType.RERANKING, + payload={ + EventPayload.NODES: nodes, + EventPayload.MODEL_NAME: self.model, + EventPayload.QUERY_STR: self.get_query_str(query_bundle), + EventPayload.TOP_K: self.top_n, + }, + ) as event: + headers = {"Content-Type": "application/json"} + json_data = { + "model": self.model, + "query": self.get_query_str(query_bundle), + "documents": [ + node.node.get_content(metadata_mode=MetadataMode.EMBED) + for node in nodes + ], + } + response = requests.post( + url=f"{self.base_url}/v1/rerank", headers=headers, json=json_data + ) + response.encoding = "utf-8" + if response.status_code != 200: + raise Exception( + f"Xinference call failed with status code {response.status_code}." + f"Details: {response.text}" + ) + + rerank_nodes = [] + for result in response.json()["results"]: + node = NodeWithScore( + node=nodes[result["index"]].node, score=result["relevance_score"] + ) + if node.score > self.score_threshold: + rerank_nodes.append(node) + + if len(rerank_nodes) > self.top_n: + rerank_nodes = sorted(rerank_nodes,key=lambda x:x.score)[:self.top_n] + event.on_end(payload={EventPayload.NODES: rerank_nodes}) + dispatcher.event(ReRankEndEvent(nodes=rerank_nodes)) + return rerank_nodes \ No newline at end of file diff --git a/backend/app/settings.py b/backend/app/settings.py index c35bc0f..e91a971 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -5,10 +5,10 @@ from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.core.settings import Settings from llama_index.embeddings.xinference import XinferenceEmbedding #from llama_index.llms.xinference import Xinference -from app.engine.model.xinfeng import XinfengModel -#from llama_index.embeddings.xinference import XinferenceEmbedding +from app.engine.model.xinference import XinferenceModel +from app.engine.rerank.xinferenceRerank import CustomXinFerenceRerank from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP -from llama_index.postprocessor.xinference_rerank import XinferenceRerank + from app.engine.loaders import getProjectInfos from app.api.routers.request.base import ProjectInfo @@ -97,7 +97,7 @@ class XinferencePlatform(ModelPlatform): model = os.getenv("MODEL") max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP)) - return XinfengModel(model_uid = model,endpoint = base_url,temperature = temperature,max_tokens = max_tokens) + return XinferenceModel(model_uid = model,endpoint = base_url,temperature = temperature,max_tokens = max_tokens) def embedding(self): base_url = os.getenv("BASE_URL") @@ -116,7 +116,7 @@ class XinferencePlatform(ModelPlatform): rerank_threshold = os.getenv("RERANK_THRESHOLD") postprocess = None if rerank_model is not None: - postprocess = [XinferenceRerank(model = rerank_model, base_url = rerank_url, top_n=rerank_top_n)] + postprocess = [CustomXinFerenceRerank(model = rerank_model, base_url = rerank_url, top_n=rerank_top_n,score_threshold=rerank_threshold)] return postprocess @register(ModelPlateCategory,'openai')