From a165d55822caff4aec8e1528bccd7172580cae18 Mon Sep 17 00:00:00 2001 From: wanyaokun <12345678> Date: Tue, 10 Sep 2024 14:07:52 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0LlamaIndex=E7=89=88=E6=9C=AC?= =?UTF-8?q?=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/engine/__init__.py | 4 +- backend/app/engine/model/xinfeng.py | 72 +++++++++++++++++++ .../app/engine/response/treeSummResponse.py | 4 +- backend/app/settings.py | 7 +- 4 files changed, 80 insertions(+), 7 deletions(-) create mode 100644 backend/app/engine/model/xinfeng.py diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py index 56531f0..2de0a85 100644 --- a/backend/app/engine/__init__.py +++ b/backend/app/engine/__init__.py @@ -52,8 +52,8 @@ def get_chat_engine(filters=None, params:dict=None): description=tree_summary_query_engine_tool_messages) tools.append(query_engine_tool) - tools.append(query_engine_tool_1) - tools.append(summary_query_tool) + #tools.append(query_engine_tool_1) + #tools.append(summary_query_tool) # Add additional tools tools += ToolFactory.from_env() diff --git a/backend/app/engine/model/xinfeng.py b/backend/app/engine/model/xinfeng.py new file mode 100644 index 0000000..c2ec772 --- /dev/null +++ b/backend/app/engine/model/xinfeng.py @@ -0,0 +1,72 @@ + +from llama_index.llms.xinference import Xinference +from typing import Any, Callable, Dict, Optional, Sequence, Tuple +from llama_index.core.llms.callbacks import ( + llm_chat_callback, + llm_completion_callback, +) +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseGen, + CompletionResponse, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) +from llama_index.llms.xinference.utils import ( + xinference_message_to_history, + xinference_modelname_to_contextsize, +) + +class XinfengModel(Xinference): + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + assert self._generator is not None + response_text = self._generator.chat( + messages=messages, + generate_config={ + "stream": False, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + }, + )["choices"][0]["message"]["content"] + return ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content=response_text, + ), + delta=None, + ) + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + msgs = [] + for message in messages: + msgs.append(message.dict()) + assert self._generator is not None + response_iter = self._generator.chat( + messages=msgs, + generate_config={ + "stream": True, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + }, + ) + + def gen() -> ChatResponseGen: + text = "" + for c in response_iter: + delta = c["choices"][0]["delta"].get("content", "") + text += delta + yield ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content=text, + ), + delta=delta, + ) + + return gen() \ No newline at end of file diff --git a/backend/app/engine/response/treeSummResponse.py b/backend/app/engine/response/treeSummResponse.py index 7cf1868..b5127ef 100644 --- a/backend/app/engine/response/treeSummResponse.py +++ b/backend/app/engine/response/treeSummResponse.py @@ -5,7 +5,7 @@ from llama_index.core.callbacks.base import CallbackManager from llama_index.core.indices.prompt_helper import PromptHelper from llama_index.core.prompts import BasePromptTemplate from llama_index.core.service_context import ServiceContext -from llama_index.core.service_context_elements.llm_predictor import LLMPredictorType +from llama_index.core.llms import LLM from llama_index.core.types import BaseModel,RESPONSE_TEXT_TYPE from llama_index.core.async_utils import run_async_tasks from llama_index.core.utils import get_tokenizer @@ -14,7 +14,7 @@ from llama_index.core.prompts.prompt_utils import get_empty_prompt_txt class CustomTreeResponse(TreeSummarize): def __init__( self, - llm: Optional[LLMPredictorType] = None, + llm: Optional[LLM] = None, callback_manager: Optional[CallbackManager] = None, prompt_helper: Optional[PromptHelper] = None, summary_template: Optional[BasePromptTemplate] = None, diff --git a/backend/app/settings.py b/backend/app/settings.py index a4fd97e..c35bc0f 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -4,7 +4,8 @@ from abc import abstractmethod from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.core.settings import Settings from llama_index.embeddings.xinference import XinferenceEmbedding -from llama_index.llms.xinference import Xinference +#from llama_index.llms.xinference import Xinference +from app.engine.model.xinfeng import XinfengModel #from llama_index.embeddings.xinference import XinferenceEmbedding from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP from llama_index.postprocessor.xinference_rerank import XinferenceRerank @@ -96,7 +97,7 @@ class XinferencePlatform(ModelPlatform): model = os.getenv("MODEL") max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP)) - return Xinference(model, base_url, temperature, max_tokens) + return XinfengModel(model_uid = model,endpoint = base_url,temperature = temperature,max_tokens = max_tokens) def embedding(self): base_url = os.getenv("BASE_URL") @@ -115,7 +116,7 @@ class XinferencePlatform(ModelPlatform): rerank_threshold = os.getenv("RERANK_THRESHOLD") postprocess = None if rerank_model is not None: - postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)] + postprocess = [XinferenceRerank(model = rerank_model, base_url = rerank_url, top_n=rerank_top_n)] return postprocess @register(ModelPlateCategory,'openai')