Compare commits
7 Commits
a9b5dc94fe
...
26ecb256ce
| Author | SHA1 | Date | |
|---|---|---|---|
| 26ecb256ce | |||
| 3e2bdea196 | |||
| 176b49983a | |||
| 2942730c9a | |||
| 8d4382376f | |||
| 0f6d76ddbe | |||
| 01c815a17b |
@@ -19,6 +19,9 @@ EMBEDDING_MODEL=bge-m3
|
||||
EMBEDDING_BASE_URL=http://10.1.16.39:9995
|
||||
# Dimension of the embedding model to use.
|
||||
EMBEDDING_DIM=1024
|
||||
# Rerank model
|
||||
RERANK_MODEL=bge-reranker-v2-m3
|
||||
RERANK_BASE_URL=http://10.1.16.39:9995
|
||||
##---------- OpenAI ----------------
|
||||
## The provider for the AI models to use.
|
||||
#MODEL_PROVIDER=openai
|
||||
|
||||
@@ -59,6 +59,8 @@ async def chat(
|
||||
event_handler = EventCallbackHandler()
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
|
||||
# 由于基于历史消息的提示词没有调整好,所以暂时屏蔽历史消息
|
||||
messages = None
|
||||
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||
process_response_nodes(response.source_nodes, background_tasks)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import create_engine, Engine
|
||||
from app.engine.loaders.db import makeDescriptionByEngine
|
||||
from app.engine.tools import ToolFactory
|
||||
from app.engine.index import get_index
|
||||
from app.settings import get_node_postprocessors
|
||||
|
||||
sql_database = None
|
||||
sql_obj_index = None
|
||||
@@ -53,12 +54,15 @@ def get_chat_engine(filters=None, params=None):
|
||||
)
|
||||
|
||||
# 创建向量检索查询工具
|
||||
postprocess = get_node_postprocessors()
|
||||
query_engine = index.as_query_engine(
|
||||
similarity_top_k=top_k, filters=filters
|
||||
similarity_top_k=top_k, filters=filters,
|
||||
node_postprocessors=postprocess,
|
||||
)
|
||||
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
|
||||
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
|
||||
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。如果你不知道答案,就说你不知道,不要编造答案。",
|
||||
)
|
||||
|
||||
tools.append(summary_query_tool)
|
||||
tools.append(query_engine_tool)
|
||||
#tools.append(sql_query_tool)
|
||||
|
||||
@@ -43,7 +43,7 @@ def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
||||
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
||||
|
||||
def llama_local_extractor() -> Dict[str, BaseReader]:
|
||||
return {"json" : JSONReader}
|
||||
return {".json" : JSONReader(clean_json=False,levels_back=0)}
|
||||
|
||||
|
||||
def get_file_documents(config: FileLoaderConfig):
|
||||
|
||||
@@ -6,9 +6,17 @@ from llama_index.core.settings import Settings
|
||||
from llama_index.llms.xinference import Xinference
|
||||
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
||||
|
||||
from app.xinference.base import XinferenceEmbedding
|
||||
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_model = os.getenv("RERANK_MODEL")
|
||||
rerank_url = os.getenv("RERANK_BASE_URL")
|
||||
postprocess = None
|
||||
if rerank_model is None:
|
||||
postprocess = [XinferenceRerank(rerank_model, rerank_url)]
|
||||
return postprocess
|
||||
|
||||
def init_settings():
|
||||
model_provider = os.getenv("MODEL_PROVIDER")
|
||||
match model_provider:
|
||||
|
||||
+102
-131
@@ -7,147 +7,19 @@ from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
|
||||
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||
from llama_index.core.callbacks import CBEventType, EventPayload
|
||||
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
|
||||
from llama_index.core.schema import ImageType
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
|
||||
from pydantic import Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# class XinferenceTextEmbeddingType(str, Enum):
|
||||
# """DashScope TextEmbedding text_type."""
|
||||
#
|
||||
# TEXT_TYPE_QUERY = "query"
|
||||
# TEXT_TYPE_DOCUMENT = "document"
|
||||
#
|
||||
#
|
||||
# class DashScopeTextEmbeddingModels(str, Enum):
|
||||
# """DashScope TextEmbedding models."""
|
||||
#
|
||||
# TEXT_EMBEDDING_V1 = "text-embedding-v1"
|
||||
# TEXT_EMBEDDING_V2 = "text-embedding-v2"
|
||||
# TEXT_EMBEDDING_V3 = "text-embedding-v3"
|
||||
#
|
||||
#
|
||||
# class DashScopeBatchTextEmbeddingModels(str, Enum):
|
||||
# """DashScope TextEmbedding models."""
|
||||
#
|
||||
# TEXT_EMBEDDING_ASYNC_V1 = "text-embedding-async-v1"
|
||||
# TEXT_EMBEDDING_ASYNC_V2 = "text-embedding-async-v2"
|
||||
# TEXT_EMBEDDING_ASYNC_V3 = "text-embedding-async-v3"
|
||||
|
||||
|
||||
EMBED_MAX_INPUT_LENGTH = 2048
|
||||
EMBED_MAX_BATCH_SIZE = 1
|
||||
|
||||
|
||||
# class DashScopeMultiModalEmbeddingModels(str, Enum):
|
||||
# """DashScope MultiModalEmbedding models."""
|
||||
#
|
||||
# MULTIMODAL_EMBEDDING_ONE_PEACE_V1 = "multimodal-embedding-one-peace-v1"
|
||||
|
||||
|
||||
# def get_text_embedding(
|
||||
# model: str,
|
||||
# text: Union[str, List[str]],
|
||||
# api_key: Optional[str] = None,
|
||||
# **kwargs: Any,
|
||||
# ) -> List[List[float]]:
|
||||
# """Call DashScope text embedding.
|
||||
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details.
|
||||
#
|
||||
# Args:
|
||||
# model (str): The `DashScopeTextEmbeddingModels`
|
||||
# text (Union[str, List[str]]): text or list text to embedding.
|
||||
#
|
||||
# Raises:
|
||||
# ImportError: need import dashscope
|
||||
#
|
||||
# Returns:
|
||||
# List[List[float]]: The list of embedding result, if failed return empty list.
|
||||
# if some of test no output, the correspond index of output is None.
|
||||
# """
|
||||
# try:
|
||||
# import dashscope
|
||||
# except ImportError:
|
||||
# raise ImportError("DashScope requires `pip install dashscope")
|
||||
# if isinstance(text, str):
|
||||
# text = [text]
|
||||
# response = dashscope.TextEmbedding.call(
|
||||
# model=model, input=text, api_key=api_key, kwargs=kwargs
|
||||
# )
|
||||
# embedding_results = [None] * len(text)
|
||||
# if response.status_code == HTTPStatus.OK:
|
||||
# for emb in response.output["embeddings"]:
|
||||
# embedding_results[emb["text_index"]] = emb["embedding"]
|
||||
# else:
|
||||
# logger.error("Calling TextEmbedding failed, details: %s" % response)
|
||||
#
|
||||
# return embedding_results
|
||||
#
|
||||
#
|
||||
# def get_batch_text_embedding(
|
||||
# model: str, url: str, api_key: Optional[str] = None, **kwargs: Any
|
||||
# ) -> Optional[str]:
|
||||
# """Call DashScope batch text embedding.
|
||||
#
|
||||
# Args:
|
||||
# model (str): The `DashScopeMultiModalEmbeddingModels`
|
||||
# url (str): The url of the file to embedding which with lines of text to embedding.
|
||||
#
|
||||
# Raises:
|
||||
# ImportError: Need install dashscope package.
|
||||
#
|
||||
# Returns:
|
||||
# str: The url of the embedding result, format ref:
|
||||
# https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details
|
||||
# """
|
||||
# try:
|
||||
# import dashscope
|
||||
# except ImportError:
|
||||
# raise ImportError("DashScope requires `pip install dashscope")
|
||||
# response = dashscope.BatchTextEmbedding.call(
|
||||
# model=model, url=url, api_key=api_key, kwargs=kwargs
|
||||
# )
|
||||
# if response.status_code == HTTPStatus.OK:
|
||||
# return response.output["url"]
|
||||
# else:
|
||||
# logger.error("Calling BatchTextEmbedding failed, details: %s" % response)
|
||||
# return None
|
||||
|
||||
|
||||
# def get_multimodal_embedding(
|
||||
# model: str, input: list, api_key: Optional[str] = None, **kwargs: Any
|
||||
# ) -> List[float]:
|
||||
# """Call DashScope multimodal embedding.
|
||||
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/one-peace-multimodal-embedding-api-details.
|
||||
#
|
||||
# Args:
|
||||
# model (str): The `DashScopeBatchTextEmbeddingModels`
|
||||
# input (str): The input of the embedding, eg:
|
||||
# [{'factor': 1, 'text': '你好'},
|
||||
# {'factor': 2, 'audio': 'https://dashscope.oss-cn-beijing.aliyuncs.com/audios/cow.flac'},
|
||||
# {'factor': 3, 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png'}]
|
||||
#
|
||||
# Raises:
|
||||
# ImportError: Need install dashscope package.
|
||||
#
|
||||
# Returns:
|
||||
# List[float]: Embedding result, if failed return empty list.
|
||||
# """
|
||||
# try:
|
||||
# import dashscope
|
||||
# except ImportError:
|
||||
# raise ImportError("DashScope requires `pip install dashscope")
|
||||
# response = dashscope.MultiModalEmbedding.call(
|
||||
# model=model, input=input, api_key=api_key, kwargs=kwargs
|
||||
# )
|
||||
# if response.status_code == HTTPStatus.OK:
|
||||
# return response.output["embedding"]
|
||||
# else:
|
||||
# logger.error("Calling MultiModalEmbedding failed, details: %s" % response)
|
||||
# return []
|
||||
|
||||
class XinferenceEmbedding(BaseEmbedding):
|
||||
"""Xinference class for text embedding.
|
||||
|
||||
@@ -270,3 +142,102 @@ class XinferenceEmbedding(BaseEmbedding):
|
||||
docstring for more information.
|
||||
"""
|
||||
return self._get_query_embedding(query)
|
||||
|
||||
class XinferenceRerank(BaseNodePostprocessor):
|
||||
"""Xinference class for rerank.
|
||||
|
||||
"""
|
||||
model_description: Dict[str, Any] = Field(
|
||||
description="The model description from Xinference."
|
||||
)
|
||||
_generator: Any = PrivateAttr()
|
||||
_model_uid: str = Field(description="The Xinference model to use.")
|
||||
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||
#model: str = Field(description="Dashscope rerank model name.")
|
||||
top_n: int = Field(description="Top N nodes to return.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_uid: str,
|
||||
endpoint: str,
|
||||
top_n: int = 3,
|
||||
return_documents: bool = False
|
||||
):
|
||||
generator, model_description = self.load_model(
|
||||
model_uid, endpoint
|
||||
)
|
||||
self._generator = generator
|
||||
super().__init__(top_n=top_n, model=model_uid, return_documents=return_documents)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "XinferenceRerank"
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: List[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> List[NodeWithScore]:
|
||||
if query_bundle is None:
|
||||
raise ValueError("Missing query bundle in extra info.")
|
||||
if len(nodes) == 0:
|
||||
return []
|
||||
|
||||
with self.callback_manager.event(
|
||||
CBEventType.RERANKING,
|
||||
payload={
|
||||
EventPayload.NODES: nodes,
|
||||
EventPayload.MODEL_NAME: self._model_uid,
|
||||
EventPayload.QUERY_STR: query_bundle.query_str,
|
||||
EventPayload.TOP_K: self.top_n,
|
||||
},
|
||||
) as event:
|
||||
texts = [node.node.get_content() for node in nodes]
|
||||
response = self._generator.rerank(texts,query_bundle.query_str)
|
||||
new_nodes = []
|
||||
for result in response['results']:
|
||||
new_node_with_score = NodeWithScore(
|
||||
node=nodes[result['index']].node, score=result['relevance_score']
|
||||
)
|
||||
print(new_node_with_score.node.get_content)
|
||||
print('\n')
|
||||
print(new_node_with_score.score)
|
||||
new_nodes.append(new_node_with_score)
|
||||
event.on_end(payload={EventPayload.NODES: new_nodes})
|
||||
|
||||
return new_nodes
|
||||
|
||||
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Xinference library."
|
||||
'Please install Xinference with `pip install "xinference[all]"`'
|
||||
)
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
try:
|
||||
assert isinstance(client, RESTfulClient)
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not create RESTfulClient instance."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
generator = client.get_model(model_uid)
|
||||
model_description = client.list_models()[model_uid]
|
||||
|
||||
try:
|
||||
assert generator is not None
|
||||
assert model_description is not None
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not get model from endpoint."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
model = model_description["model_name"]
|
||||
|
||||
return generator, model_description
|
||||
+44
-43
@@ -1,4 +1,5 @@
|
||||
from dotenv import load_dotenv
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -13,55 +14,55 @@ from app.api.routers.upload import file_upload_router
|
||||
from app.settings import init_settings
|
||||
from app.observability import init_observability
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from phoenix.trace import using_project
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
app = None
|
||||
|
||||
def init_webserver():
|
||||
global app
|
||||
app = FastAPI()
|
||||
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
|
||||
if environment == "dev":
|
||||
logger.warning("Running in development mode - allowing CORS for all origins")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
usPrj = using_project(os.getenv("PHOENIX_PROJECT_NAME"))
|
||||
usPrj.__enter__()
|
||||
|
||||
def mount_static_files(directory, path):
|
||||
if os.path.exists(directory):
|
||||
for dir, _, _ in os.walk(directory):
|
||||
relative_path = os.path.relpath(dir, directory)
|
||||
mount_path = path if relative_path == "." else f"{path}/{relative_path}"
|
||||
logger.info(f"Mounting static files '{dir}' at {mount_path}")
|
||||
app.mount(mount_path, StaticFiles(directory=dir), name=f"{dir}-static")
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
# Mount the data files to serve the file viewer
|
||||
mount_static_files("data", "/api/files/data")
|
||||
# Mount the output files from tools
|
||||
mount_static_files("data_output", "/api/files/output")
|
||||
app.include_router(chat_router, prefix="/api/chat")
|
||||
app.include_router(file_upload_router, prefix="/api/chat/upload")
|
||||
app = FastAPI()
|
||||
|
||||
# Redirect to documentation page when accessing base URL
|
||||
@app.get("/")
|
||||
async def redirect_to_docs():
|
||||
return RedirectResponse(url="/docs")
|
||||
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
|
||||
if environment == "dev":
|
||||
logger.warning("Running in development mode - allowing CORS for all origins")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
def mount_static_files(directory, path):
|
||||
if os.path.exists(directory):
|
||||
for dir, _, _ in os.walk(directory):
|
||||
relative_path = os.path.relpath(dir, directory)
|
||||
mount_path = path if relative_path == "." else f"{path}/{relative_path}"
|
||||
logger.info(f"Mounting static files '{dir}' at {mount_path}")
|
||||
app.mount(mount_path, StaticFiles(directory=dir), name=f"{dir}-static")
|
||||
|
||||
# Mount the data files to serve the file viewer
|
||||
mount_static_files("data", "/api/files/data")
|
||||
# Mount the output files from tools
|
||||
mount_static_files("data_output", "/api/files/output")
|
||||
app.include_router(chat_router, prefix="/api/chat")
|
||||
app.include_router(file_upload_router, prefix="/api/chat/upload")
|
||||
|
||||
# Redirect to documentation page when accessing base URL
|
||||
@app.get("/")
|
||||
async def redirect_to_docs():
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
SentenceSplitter
|
||||
if __name__ == "__main__":
|
||||
from phoenix.trace import using_project
|
||||
with using_project(os.getenv("PHOENIX_PROJECT_NAME")) as obj:
|
||||
app_host = os.getenv("APP_HOST", "0.0.0.0")
|
||||
app_port = int(os.getenv("APP_PORT", "8000"))
|
||||
reload = True if environment == "dev" else False
|
||||
reload = False
|
||||
uvicorn.run(app="main:app", host=app_host, port=app_port, reload=reload)
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
init_webserver()
|
||||
|
||||
app_host = os.getenv("APP_HOST", "0.0.0.0")
|
||||
app_port = int(os.getenv("APP_PORT", "8000"))
|
||||
#reload = True if environment == "dev" else False
|
||||
reload = False
|
||||
uvicorn.run(app=app, host=app_host, port=app_port, reload=reload)
|
||||
#usPrj.__exit__()
|
||||
|
||||
@@ -24,7 +24,8 @@ def main():
|
||||
top_k = 5
|
||||
filters = generate_filters([])
|
||||
#question = "从工程属性表中查找工程名称"
|
||||
question = "总算表中名称等于架空输电线路本体工程的金额?"
|
||||
#question = "总算表中名称等于架空输电线路本体工程的金额?"
|
||||
question = "工程监理费的金额是多少?"
|
||||
# 创建向量检索查询工具
|
||||
query_engine = index.as_query_engine(
|
||||
similarity_top_k=top_k, filters=filters
|
||||
@@ -35,18 +36,20 @@ def main():
|
||||
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||
sql_database = SQLDatabase(engine)
|
||||
|
||||
loader = CustomDatabaseReader(sql_database)
|
||||
documents = loader.load_data(query="select * from ProjectProperties")
|
||||
|
||||
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
|
||||
vectorIndex = VectorStoreIndex()
|
||||
# 创建SQL查询工具
|
||||
sql_obj_index = ObjectIndex.from_objects(
|
||||
# sql_obj_index = ObjectIndex.from_objects(
|
||||
# table_schema_objs,
|
||||
# table_node_mapping,
|
||||
# index_cls=VectorStoreIndex,
|
||||
# )
|
||||
sql_obj_index = ObjectIndex.from_objects_and_index(
|
||||
table_schema_objs,
|
||||
vectorIndex,
|
||||
table_node_mapping,
|
||||
index_cls=VectorStoreIndex,
|
||||
)
|
||||
|
||||
query_result =vectorIndex.as_query_engine(
|
||||
|
||||
Reference in New Issue
Block a user