Compare commits

7 Commits

Author SHA1 Message Date
ly 26ecb256ce Merge remote-tracking branch 'origin/dev' into dev 2024-08-19 09:07:29 +08:00
ly 3e2bdea196 修复自定义JSON文件加载支持BUG 2024-08-19 09:06:12 +08:00
ly 176b49983a 调整测试代码 2024-08-19 08:59:45 +08:00
ly 2942730c9a 增加对Rerank功能支持 2024-08-19 08:59:08 +08:00
ly 8d4382376f 由于基于历史消息的提示词没有调整好,所以暂时屏蔽历史消息 2024-08-19 08:58:08 +08:00
ly 0f6d76ddbe 增加XinferenceRerank 2024-08-19 08:27:22 +08:00
ly 01c815a17b 调整入口代码结构 2024-08-19 08:26:13 +08:00
8 changed files with 176 additions and 184 deletions
+3
View File
@@ -19,6 +19,9 @@ EMBEDDING_MODEL=bge-m3
EMBEDDING_BASE_URL=http://10.1.16.39:9995
# Dimension of the embedding model to use.
EMBEDDING_DIM=1024
# Rerank model
RERANK_MODEL=bge-reranker-v2-m3
RERANK_BASE_URL=http://10.1.16.39:9995
##---------- OpenAI ----------------
## The provider for the AI models to use.
#MODEL_PROVIDER=openai
+2
View File
@@ -59,6 +59,8 @@ async def chat(
event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
# 由于基于历史消息的提示词没有调整好,所以暂时屏蔽历史消息
messages = None
response = await chat_engine.astream_chat(last_message_content, messages)
process_response_nodes(response.source_nodes, background_tasks)
+6 -2
View File
@@ -11,6 +11,7 @@ from sqlalchemy import create_engine, Engine
from app.engine.loaders.db import makeDescriptionByEngine
from app.engine.tools import ToolFactory
from app.engine.index import get_index
from app.settings import get_node_postprocessors
sql_database = None
sql_obj_index = None
@@ -53,12 +54,15 @@ def get_chat_engine(filters=None, params=None):
)
# 创建向量检索查询工具
postprocess = get_node_postprocessors()
query_engine = index.as_query_engine(
similarity_top_k=top_k, filters=filters
similarity_top_k=top_k, filters=filters,
node_postprocessors=postprocess,
)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。如果你不知道答案,就说你不知道,不要编造答案。",
)
tools.append(summary_query_tool)
tools.append(query_engine_tool)
#tools.append(sql_query_tool)
+1 -1
View File
@@ -43,7 +43,7 @@ def llama_parse_extractor() -> Dict[str, LlamaParse]:
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
def llama_local_extractor() -> Dict[str, BaseReader]:
return {"json" : JSONReader}
return {".json" : JSONReader(clean_json=False,levels_back=0)}
def get_file_documents(config: FileLoaderConfig):
+9 -1
View File
@@ -6,9 +6,17 @@ from llama_index.core.settings import Settings
from llama_index.llms.xinference import Xinference
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from app.xinference.base import XinferenceEmbedding
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
def get_node_postprocessors():
rerank_model = os.getenv("RERANK_MODEL")
rerank_url = os.getenv("RERANK_BASE_URL")
postprocess = None
if rerank_model is None:
postprocess = [XinferenceRerank(rerank_model, rerank_url)]
return postprocess
def init_settings():
model_provider = os.getenv("MODEL_PROVIDER")
match model_provider:
+102 -131
View File
@@ -7,147 +7,19 @@ from typing import Any, Dict, List, Optional, Union, Tuple
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
from llama_index.core.schema import ImageType
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
from pydantic import Field
logger = logging.getLogger(__name__)
# class XinferenceTextEmbeddingType(str, Enum):
# """DashScope TextEmbedding text_type."""
#
# TEXT_TYPE_QUERY = "query"
# TEXT_TYPE_DOCUMENT = "document"
#
#
# class DashScopeTextEmbeddingModels(str, Enum):
# """DashScope TextEmbedding models."""
#
# TEXT_EMBEDDING_V1 = "text-embedding-v1"
# TEXT_EMBEDDING_V2 = "text-embedding-v2"
# TEXT_EMBEDDING_V3 = "text-embedding-v3"
#
#
# class DashScopeBatchTextEmbeddingModels(str, Enum):
# """DashScope TextEmbedding models."""
#
# TEXT_EMBEDDING_ASYNC_V1 = "text-embedding-async-v1"
# TEXT_EMBEDDING_ASYNC_V2 = "text-embedding-async-v2"
# TEXT_EMBEDDING_ASYNC_V3 = "text-embedding-async-v3"
EMBED_MAX_INPUT_LENGTH = 2048
EMBED_MAX_BATCH_SIZE = 1
# class DashScopeMultiModalEmbeddingModels(str, Enum):
# """DashScope MultiModalEmbedding models."""
#
# MULTIMODAL_EMBEDDING_ONE_PEACE_V1 = "multimodal-embedding-one-peace-v1"
# def get_text_embedding(
# model: str,
# text: Union[str, List[str]],
# api_key: Optional[str] = None,
# **kwargs: Any,
# ) -> List[List[float]]:
# """Call DashScope text embedding.
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details.
#
# Args:
# model (str): The `DashScopeTextEmbeddingModels`
# text (Union[str, List[str]]): text or list text to embedding.
#
# Raises:
# ImportError: need import dashscope
#
# Returns:
# List[List[float]]: The list of embedding result, if failed return empty list.
# if some of test no output, the correspond index of output is None.
# """
# try:
# import dashscope
# except ImportError:
# raise ImportError("DashScope requires `pip install dashscope")
# if isinstance(text, str):
# text = [text]
# response = dashscope.TextEmbedding.call(
# model=model, input=text, api_key=api_key, kwargs=kwargs
# )
# embedding_results = [None] * len(text)
# if response.status_code == HTTPStatus.OK:
# for emb in response.output["embeddings"]:
# embedding_results[emb["text_index"]] = emb["embedding"]
# else:
# logger.error("Calling TextEmbedding failed, details: %s" % response)
#
# return embedding_results
#
#
# def get_batch_text_embedding(
# model: str, url: str, api_key: Optional[str] = None, **kwargs: Any
# ) -> Optional[str]:
# """Call DashScope batch text embedding.
#
# Args:
# model (str): The `DashScopeMultiModalEmbeddingModels`
# url (str): The url of the file to embedding which with lines of text to embedding.
#
# Raises:
# ImportError: Need install dashscope package.
#
# Returns:
# str: The url of the embedding result, format ref:
# https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details
# """
# try:
# import dashscope
# except ImportError:
# raise ImportError("DashScope requires `pip install dashscope")
# response = dashscope.BatchTextEmbedding.call(
# model=model, url=url, api_key=api_key, kwargs=kwargs
# )
# if response.status_code == HTTPStatus.OK:
# return response.output["url"]
# else:
# logger.error("Calling BatchTextEmbedding failed, details: %s" % response)
# return None
# def get_multimodal_embedding(
# model: str, input: list, api_key: Optional[str] = None, **kwargs: Any
# ) -> List[float]:
# """Call DashScope multimodal embedding.
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/one-peace-multimodal-embedding-api-details.
#
# Args:
# model (str): The `DashScopeBatchTextEmbeddingModels`
# input (str): The input of the embedding, eg:
# [{'factor': 1, 'text': '你好'},
# {'factor': 2, 'audio': 'https://dashscope.oss-cn-beijing.aliyuncs.com/audios/cow.flac'},
# {'factor': 3, 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png'}]
#
# Raises:
# ImportError: Need install dashscope package.
#
# Returns:
# List[float]: Embedding result, if failed return empty list.
# """
# try:
# import dashscope
# except ImportError:
# raise ImportError("DashScope requires `pip install dashscope")
# response = dashscope.MultiModalEmbedding.call(
# model=model, input=input, api_key=api_key, kwargs=kwargs
# )
# if response.status_code == HTTPStatus.OK:
# return response.output["embedding"]
# else:
# logger.error("Calling MultiModalEmbedding failed, details: %s" % response)
# return []
class XinferenceEmbedding(BaseEmbedding):
"""Xinference class for text embedding.
@@ -270,3 +142,102 @@ class XinferenceEmbedding(BaseEmbedding):
docstring for more information.
"""
return self._get_query_embedding(query)
class XinferenceRerank(BaseNodePostprocessor):
"""Xinference class for rerank.
"""
model_description: Dict[str, Any] = Field(
description="The model description from Xinference."
)
_generator: Any = PrivateAttr()
_model_uid: str = Field(description="The Xinference model to use.")
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
#model: str = Field(description="Dashscope rerank model name.")
top_n: int = Field(description="Top N nodes to return.")
def __init__(
self,
model_uid: str,
endpoint: str,
top_n: int = 3,
return_documents: bool = False
):
generator, model_description = self.load_model(
model_uid, endpoint
)
self._generator = generator
super().__init__(top_n=top_n, model=model_uid, return_documents=return_documents)
@classmethod
def class_name(cls) -> str:
return "XinferenceRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self._model_uid,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
texts = [node.node.get_content() for node in nodes]
response = self._generator.rerank(texts,query_bundle.query_str)
new_nodes = []
for result in response['results']:
new_node_with_score = NodeWithScore(
node=nodes[result['index']].node, score=result['relevance_score']
)
print(new_node_with_score.node.get_content)
print('\n')
print(new_node_with_score.score)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})
return new_nodes
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
try:
from xinference.client import RESTfulClient
except ImportError:
raise ImportError(
"Could not import Xinference library."
'Please install Xinference with `pip install "xinference[all]"`'
)
client = RESTfulClient(endpoint)
try:
assert isinstance(client, RESTfulClient)
except AssertionError:
raise RuntimeError(
"Could not create RESTfulClient instance."
"Please make sure Xinference endpoint is running at the correct port."
)
generator = client.get_model(model_uid)
model_description = client.list_models()[model_uid]
try:
assert generator is not None
assert model_description is not None
except AssertionError:
raise RuntimeError(
"Could not get model from endpoint."
"Please make sure Xinference endpoint is running at the correct port."
)
model = model_description["model_name"]
return generator, model_description
+14 -13
View File
@@ -1,4 +1,5 @@
from dotenv import load_dotenv
from llama_index.core.node_parser import SentenceSplitter
load_dotenv()
@@ -13,14 +14,18 @@ from app.api.routers.upload import file_upload_router
from app.settings import init_settings
from app.observability import init_observability
from fastapi.staticfiles import StaticFiles
from phoenix.trace import using_project
logger = logging.getLogger("uvicorn")
app = None
def init_webserver():
global app
usPrj = using_project(os.getenv("PHOENIX_PROJECT_NAME"))
usPrj.__enter__()
init_settings()
init_observability()
app = FastAPI()
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
if environment == "dev":
logger.warning("Running in development mode - allowing CORS for all origins")
@@ -52,16 +57,12 @@ def init_webserver():
async def redirect_to_docs():
return RedirectResponse(url="/docs")
SentenceSplitter
if __name__ == "__main__":
from phoenix.trace import using_project
with using_project(os.getenv("PHOENIX_PROJECT_NAME")) as obj:
init_settings()
init_observability()
init_webserver()
app_host = os.getenv("APP_HOST", "0.0.0.0")
app_port = int(os.getenv("APP_PORT", "8000"))
#reload = True if environment == "dev" else False
reload = True if environment == "dev" else False
reload = False
uvicorn.run(app=app, host=app_host, port=app_port, reload=reload)
uvicorn.run(app="main:app", host=app_host, port=app_port, reload=reload)
#usPrj.__exit__()
+9 -6
View File
@@ -24,7 +24,8 @@ def main():
top_k = 5
filters = generate_filters([])
#question = "从工程属性表中查找工程名称"
question = "总算表中名称等于架空输电线路本体工程的金额?"
#question = "总算表中名称等于架空输电线路本体工程的金额?"
question = "工程监理费的金额是多少?"
# 创建向量检索查询工具
query_engine = index.as_query_engine(
similarity_top_k=top_k, filters=filters
@@ -35,18 +36,20 @@ def main():
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
sql_database = SQLDatabase(engine)
loader = CustomDatabaseReader(sql_database)
documents = loader.load_data(query="select * from ProjectProperties")
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
vectorIndex = VectorStoreIndex()
# 创建SQL查询工具
sql_obj_index = ObjectIndex.from_objects(
# sql_obj_index = ObjectIndex.from_objects(
# table_schema_objs,
# table_node_mapping,
# index_cls=VectorStoreIndex,
# )
sql_obj_index = ObjectIndex.from_objects_and_index(
table_schema_objs,
vectorIndex,
table_node_mapping,
index_cls=VectorStoreIndex,
)
query_result =vectorIndex.as_query_engine(