diff --git a/backend/.env.example b/backend/.env.example index 2b1db19..37ba235 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -49,6 +49,7 @@ VECTOR_STORE_COLLECTION=default # Specify this if you are using a local vector database. # Otherwise, use VECTOR_STORE__HOST and VECTOR_STORE__PORT config above VECTOR_STORE_PATH=./storage_vector +BM_RETRIEVER_PATH =./storage_bm diff --git a/backend/.env.xinference b/backend/.env.xinference index ae37317..6dd566f 100644 --- a/backend/.env.xinference +++ b/backend/.env.xinference @@ -6,12 +6,13 @@ SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zj # The number of similar embeddings to return when retrieving documents. TOP_K=10 #-------------------------- +# 是否启用混合检索 +HYBRID_ENABLED = true +# 混合检索阈值 +HYBRID_ALPHA = 0.6 +#-------------------------- # 是否启用检索重排功能 RERANK_ENABLED=true -# 是否启用混合检索 -HYBRID_ENABLED = false -# 混合检索阈值 -HYBRID_ALPHA = 0.5 # Rerank model RERANK_MODEL=bge-reranker-v2-m3 RERANK_BASE_URL=http://10.1.16.39:9995 @@ -80,7 +81,7 @@ VECTOR_STORE_COLLECTION=default # Specify this if you are using a local vector database. # Otherwise, use VECTOR_STORE__HOST and VECTOR_STORE__PORT config above VECTOR_STORE_PATH=./storage_vector - +BM_RETRIEVER_PATH =./storage_bm PHOENIX_API_KEY=123456 diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py index fc36c14..1eaf1fe 100644 --- a/backend/app/engine/__init__.py +++ b/backend/app/engine/__init__.py @@ -1,154 +1,57 @@ import os -from llama_index.core import SQLDatabase, SummaryIndex, VectorStoreIndex -from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine -from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex +from llama_index.core.agent import AgentRunner, ReActChatFormatter from llama_index.core.settings import Settings -from llama_index.core.agent import AgentRunner, StructuredPlannerAgent, FunctionCallingAgentWorker from llama_index.core.tools.query_engine import QueryEngineTool -from sqlalchemy import create_engine, Engine -from app.engine.loaders.db import makeDescriptionByEngine -from app.engine.tools import ToolFactory +from app.engine.engine import create_query_engine, create_summary_query_engine, create_sql_query_engine from app.engine.index import get_index -from app.settings import get_node_postprocessors +#from app.engine.loaders.db import makeDescriptionByEngine +from app.engine.tools import ToolFactory -from llama_index.core.retrievers import BaseRetriever -from llama_index.core import QueryBundle -from llama_index.core.schema import NodeWithScore -from typing import List, Any, Optional,Dict -from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine - -class HybridRetriever(BaseRetriever): - def __init__( - self, - vector_index, - similarity_top_k: int = 2, - out_top_k: Optional[int] = None, - alpha: float = 0.5, - filters = None, - **kwargs: Any, - ) -> None: - from llama_index.retrievers.bm25 import BM25Retriever - from nltk.corpus import stopwords - - super().__init__(**kwargs) - self._vector_index = vector_index - self._embed_model = vector_index._embed_model - self._out_top_k = out_top_k or similarity_top_k - self._vecRetriever = vector_index.as_retriever( - similarity_top_k=similarity_top_k,filters = filters - ) - self._bm25Retriever = BM25Retriever.from_defaults(similarity_top_k=similarity_top_k, - nodes=self._vector_index.vector_store.get_nodes(None), - language=stopwords.words('chinese')) - self._alpha = alpha - - def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: - vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str) - bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str) - - bmDic:Dict[str,NodeWithScore] = {} - for node in bmNodes: - bmDic[node.node_id] = node - - result_tups = [] - for i in range(len(vecNodes)): - node = vecNodes[i] - bmScore = 0.0 - if node.node_id in bmDic: - bmScore = bmDic[node.node_id].score - bmDic.pop(node.node_id) - else: - bmScore = 0.0 - full_similarity = (self._alpha * node.score) + ( - (1 - self._alpha) * bmScore - ) - result_tups.append((full_similarity, node)) - - for _,node in bmDic.items(): - full_similarity = (1 - self._alpha) * node.score - result_tups.append((full_similarity, node)) - - result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True) - for full_score, node in result_tups: - node.score = full_score - return [n for _, n in result_tups][:self._out_top_k] - -def get_Retriever(index,**kwargs): - bEnableHybrid = True if os.getenv("HYBRID_ENABLED",False).title() == 'True' else False - if bEnableHybrid: - alpha = float(os.getenv("HYBRID_ALPHA", "0.5")) - retriever = HybridRetriever(index,alpha = alpha,**kwargs) - else: - retriever = index.as_retriever(**kwargs) - return retriever - -sql_database = None -sql_obj_index = None def get_chat_engine(filters=None, params=None): system_prompt = os.getenv("SYSTEM_PROMPT") top_k = int(os.getenv("TOP_K", "3")) + use_reranker = os.getenv("RERANK_ENABLED") tools = [] - global sql_obj_index - global sql_database - if sql_obj_index is None: - sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", "")) - sql_database = SQLDatabase(sqlengine) - table_schema_objs = makeDescriptionByEngine(sql_database) - table_node_mapping = SQLTableNodeMapping(sql_database) - - sql_obj_index = ObjectIndex.from_objects( - table_schema_objs, - table_node_mapping, - index_cls=VectorStoreIndex, - ) - # 创建SQL查询工具 - sql_query_engine = SQLTableRetrieverQueryEngine(sql_database, - sql_obj_index.as_retriever(similarity_top_k=top_k), - verbose=True,) + sql_query_engine = create_sql_query_engine() sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine, name="zjdata_query_tool", - description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具") + description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具" + ) + #tools.append(sql_query_tool) # Add query tool if index exists index = get_index() if index is not None: - summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None)) - summary_query_engine = summary_index.as_query_engine() + summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters) summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool", description="适用于任何需要进行全面总结、概括的要求。", - #description="适用于任何需要对所有内容进行全面总结的请求。有关电力造价领域更具体部分的问题,请使用zj_query_engine_tool", ) - - # 创建向量检索查询工具 - postprocess = get_node_postprocessors() - query_engine = RetrieverQueryEngine.from_args( - get_Retriever(index,similarity_top_k=top_k, - filters=filters), - node_postprocessors=postprocess, - ) - + query_engine = create_query_engine(index,top_k,use_reranker,filters) query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool", - description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。如果你不知道答案,就说你不知道,不要编造答案。", + description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。", ) tools.append(summary_query_tool) tools.append(query_engine_tool) - #tools.append(sql_query_tool) # Add additional tools tools += ToolFactory.from_env() - return AgentRunner.from_llm( + prefix_messages = ("""您的设计旨在帮助完成各种任务,从回答问题到提供其他类型分析的摘要。\n\n##工具\n\n你可以访问各种工具。你有责任按照你认为合适的顺序使用这些工具来完成当前的任务。\n这可能需要将任务分解为子任务,并使用不同的工具来完成每个子任务。\n\n你可以访问以下工具:\n{tool_desc}\n\n\n##输出格式\n\n请用与问题相同的语言回答,并使用以下格式:\n\n \nThought: 用户当前的语言是:(user's language)。我需要使用工具来帮助我回答问题。\nAction: 如果使用工具,则为工具名称(one of {tool_names})。\nAction Input: 输入给工具的内容,使用JSON格式表示kwargs(例如{{\"input\": \"hello world\", \"num_beams\": 5}})\n \n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n切勿用Markdown代码标记包围你的响应。如果需要,可以在响应中使用代码标记。\n\n请为Action Input使用有效的JSON格式。不要这样做{{\'input\': \'hello world\', \'num_beams\': 5}}。\n\n如果使用此格式,用户将以下面的格式进行回应:\n\n \nObservation: 工具响应\n \n\n你应该继续重复上述格式,直到你有足够的信息来回答问题而无需使用更多工具。此时,你必须使用以下两种格式之一进行回答:\n\n \nThought: 我可以不用任何工具来回答。我将使用用户的语言来回答。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n \nThought: 我无法使用提供的工具回答问题。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n##如果从工具中得到的回应是Empty Response,那么只需要回答“我不知道”,不需要额外回答别的内容。## 当前对话\n\n以下是当前对话,由人类和助手的消息交替组成。\n""") + react_chat_formatter = ReActChatFormatter.from_defaults(prefix_messages) + agentrunner = AgentRunner.from_llm( llm=Settings.llm, tools=tools, + react_chat_formatter=react_chat_formatter, system_prompt=system_prompt, verbose=True, ) + return agentrunner # create the function calling worker for reasoning # worker = FunctionCallingAgentWorker.from_tools( # tools, verbose=True @@ -156,9 +59,3 @@ def get_chat_engine(filters=None, params=None): # # # wrap the worker in the top-level planner # return StructuredPlannerAgent(worker, tools) - - - - - - diff --git a/backend/app/engine/engine.py b/backend/app/engine/engine.py new file mode 100644 index 0000000..6cb552f --- /dev/null +++ b/backend/app/engine/engine.py @@ -0,0 +1,108 @@ +import os + +from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex +from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine +from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema +from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.response_synthesizers import ResponseMode +from llama_index.readers.database import DatabaseReader +from sqlalchemy import create_engine + +from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template +from app.engine.retriever.HybridRetriever import HybridRetriever +from app.settings import get_node_postprocessors + +def makeDescriptionByEngine(sql_database:SQLDatabase): + reader = DatabaseReader(sql_database) + + table_names = sql_database.get_usable_table_names() + table_schema_objs = [] + for table_name in table_names: + columns = sql_database.get_table_columns(table_name) + if len(columns) > 150: + continue + stats_txt = "" + + if table_name == 'gongchengshuxing': + stats_txt = '该表中有以下属性:' + documents = reader.load_data(query='select name from gongchengshuxing') + for index in range(len(documents) if len(documents) < 30 else 30): + if index == 0: + continue + elif index > 1: + stats_txt += ',' + stats_txt += documents[index].text.split(':')[1] + + tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt)) + table_schema_objs.append(tbSchema) + + return table_schema_objs + +def get_Retriever(index,**kwargs): + strEnableHybrid = os.getenv("HYBRID_ENABLED",'False') + bEnableHybrid = True if strEnableHybrid is not None and strEnableHybrid.title() == 'True' else False + if bEnableHybrid: + alpha = float(os.getenv("HYBRID_ALPHA", "0.5")) + retriever = HybridRetriever(index,alpha = alpha,**kwargs) + else: + retriever = index.as_retriever(**kwargs) + return retriever + + +sql_database = None +sql_obj_index = None + +# Create a sql query engine +def create_sql_query_engine(top_k=3, use_reranker=False, filters=None): + global sql_obj_index + global sql_database + if sql_obj_index is None or sql_database is None: + sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", "")) + sql_database = SQLDatabase(sqlengine) + table_schema_objs = makeDescriptionByEngine(sql_database) + table_node_mapping = SQLTableNodeMapping(sql_database) + + sql_obj_index = ObjectIndex.from_objects( + table_schema_objs, + table_node_mapping, + index_cls=VectorStoreIndex, + ) + + # 创建SQL查询工具 + sql_query_engine = SQLTableRetrieverQueryEngine(sql_database, + sql_obj_index.as_retriever(similarity_top_k=top_k), + verbose=True, + ) + return sql_query_engine + +# Create a summary query engine +def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None): + summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None)) + summary_query_engine = summary_index.as_query_engine( + response_mode=ResponseMode.TREE_SUMMARIZE, + use_async=True, + streaming=True, + ) + return summary_query_engine + +# Create a query engine +def create_query_engine(index, top_k=3, use_reranker=False, filters=None): + # 创建向量检索查询工具 + postprocess = None + if use_reranker: + postprocess = get_node_postprocessors() + + query_engine = RetrieverQueryEngine.from_args( + get_Retriever(index, + similarity_top_k=top_k, + filters=filters), + text_qa_template=text_qa_template, + refine_template=refine_template, + summary_template = summary_template, + simple_template = simple_template, + node_postprocessors=postprocess, + use_async=True, + streaming=True, + ) + + return query_engine \ No newline at end of file diff --git a/backend/app/engine/generate.py b/backend/app/engine/generate.py index 115c175..87ecfa1 100644 --- a/backend/app/engine/generate.py +++ b/backend/app/engine/generate.py @@ -8,6 +8,7 @@ import os from app.engine.loaders import get_documents from app.engine.vectordb import get_vector_store from app.settings import init_settings +from app.engine.retriever.CHBM25Retriever import CHBM25Retriever from llama_index.core.ingestion import IngestionPipeline from llama_index.core.node_parser import SentenceSplitter from llama_index.core.settings import Settings @@ -58,6 +59,13 @@ def persist_storage(docstore, vector_store): storage_context.persist(STORAGE_DIR) +def persist_BMRetriever(vector_store): + STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm") + top_k = int(os.getenv("TOP_K", "3")) + bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=top_k,nodes=vector_store.get_nodes([])) + bmRetriver.persist(STORAGE_DIR) + + def generate_datasource(): init_settings() logger.info("Generate index for the provided data") @@ -75,6 +83,7 @@ def generate_datasource(): # Build the index and persist storage persist_storage(docstore, vector_store) + persist_BMRetriever(vector_store) logger.info("Finished generating the index") diff --git a/backend/app/engine/loaders/db.py b/backend/app/engine/loaders/db.py index 63a7c02..d6310e2 100644 --- a/backend/app/engine/loaders/db.py +++ b/backend/app/engine/loaders/db.py @@ -1,20 +1,14 @@ -import os import logging -from typing import List from typing import Any, List, Optional -from llama_index.core.readers.base import BaseReader -from llama_index.core.schema import Document -from llama_index.core.utilities.sql_wrapper import SQLDatabase -from sqlalchemy import text -from sqlalchemy.engine import Engine from llama_index.core import SQLDatabase, Document -from llama_index.core.objects import SQLTableSchema, SQLTableNodeMapping +from llama_index.core.objects import SQLTableSchema from llama_index.core.readers.base import BaseReader from llama_index.readers.database import DatabaseReader -from pydantic import BaseModel, validator -from llama_index.core.indices.vector_store import VectorStoreIndex +from pydantic import BaseModel from sqlalchemy import create_engine +from sqlalchemy import text +from sqlalchemy.engine import Engine logger = logging.getLogger(__name__) @@ -119,32 +113,6 @@ class DBLoaderConfig(BaseModel): uri: str queries: List[str] -def makeDescriptionByEngine(sql_database:SQLDatabase): - reader = DatabaseReader(sql_database) - - table_names = sql_database.get_usable_table_names() - table_schema_objs = [] - for table_name in table_names: - columns = sql_database.get_table_columns(table_name) - if len(columns) > 150: - continue - stats_txt = "" - - if table_name == 'gongchengshuxing': - stats_txt = '该表中有以下属性:' - documents = reader.load_data(query='select name from gongchengshuxing') - for index in range(len(documents) if len(documents) < 30 else 30): - if index == 0: - continue - elif index > 1: - stats_txt += ',' - stats_txt += documents[index].text.split(':')[1] - - tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt)) - table_schema_objs.append(tbSchema) - - return table_schema_objs - def get_db_documents(configs: list[DBLoaderConfig]): docs = [] diff --git a/backend/app/engine/prompt.py b/backend/app/engine/prompt.py new file mode 100644 index 0000000..101b6bf --- /dev/null +++ b/backend/app/engine/prompt.py @@ -0,0 +1,89 @@ +from llama_index.core import PromptTemplate + +text_qa_template_str = ( + "# 角色\n" + "你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。" + "你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答," + "如同直接从文件中提取的内容。\n" + + "## 技能\n" + "### 技能 1: 数据查询与提供\n" + "- 准确回答所有关于电力工程造价的相关问题。\n" + "- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n" + "- 确保提供的信息严格基于工程文档中的记录。\n" + + "### 技能 2: 技术性解释\n" + "- 解释造价工程中的技术术语和概念。\n" + "- 为复杂的工程细节提供清晰易懂的说明。\n" + + "## 约束\n" + "- 仅回答与电力工程造价文件相关的具体问题。\n" + "- 不进行任何超出文件内容的猜测或假设。\n" + "- 所有回答均基于文件内容,采用客观和技术性的语言。\n" + "- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n" + "以下为上下文信息\n" + "---------------------\n" + "{context_str}\n" + "---------------------\n" + "请根据上下文信息而非先前知识回答我的问题或回复我的指令。前面的上下文信息可能有用,也可能没用,你需要从我给出的上下文信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。我的问题或指令是什么语种,你就用什么语种回复。\n" + "如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n" + + "问题:{query_str}\n" + "你的回复: " +) + + +text_qa_template = PromptTemplate(text_qa_template_str) + +refine_template_str = ( + "这是原本的问题: {query_str}\n" + "我们已经提供了回答: {existing_answer}\n" + "现在我们有机会改进这个回答 " + "使用以下更多上下文(仅当需要用时)\n" + "------------\n" + "{context_msg}\n" + "------------\n" + "根据新的上下文, 请改进原来的回答。" + "如果新的上下文没有用, 直接返回原本的回答。\n" + "如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n" + "改进的回答: " +) +refine_template = PromptTemplate(refine_template_str) + +summary_template_str = ( + "# 角色\n" + "你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。" + "你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答," + "如同直接从文件中提取的内容。\n" + + "## 技能\n" + "### 技能 1: 数据查询与提供\n" + "- 准确回答所有关于电力工程造价的相关问题。\n" + "- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n" + "- 确保提供的信息严格基于工程文档中的记录。\n" + + "### 技能 2: 技术性解释\n" + "- 解释造价工程中的技术术语和概念。\n" + "- 为复杂的工程细节提供清晰易懂的说明。\n" + + "## 约束\n" + "- 仅回答与电力工程造价文件相关的具体问题。\n" + "- 不进行任何超出文件内容的猜测或假设。\n" + "- 所有回答均基于文件内容,采用客观和技术性的语言。\n" + "- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n" + "来自多个来源的上下文信息如下。\n" + "---------------------\n" + "{context_str}\n" + "---------------------\n" + "鉴于来自多个来源的信息而非先验知识, " + "回答查询。\n" + "如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n" + "Query: {query_str}\n" + "Answer: " +) +summary_template = PromptTemplate(summary_template_str) + +simple_template_str = ( + "{query_str}" +) +simple_template = PromptTemplate(simple_template_str) diff --git a/backend/app/engine/retriever/CHBM25Retriever.py b/backend/app/engine/retriever/CHBM25Retriever.py new file mode 100644 index 0000000..fa5d5ec --- /dev/null +++ b/backend/app/engine/retriever/CHBM25Retriever.py @@ -0,0 +1,133 @@ +import json +import logging +import os + +from typing import Any, Callable, Dict, List, Optional, cast + +from llama_index.core.base.base_retriever import BaseRetriever +from llama_index.core.callbacks.base import CallbackManager +from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K +from llama_index.core.indices.vector_store.base import VectorStoreIndex +from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle +from llama_index.core.storage.docstore.types import BaseDocumentStore +from llama_index.core.vector_stores.utils import ( + node_to_metadata_dict, + metadata_dict_to_node, +) + +import bm25s +from app.engine.retriever.CHTokener import chTokenize + +CHDEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"} + +CHDEFAULT_PERSIST_FILENAME = "retriever.json" + +class CHBM25Retriever(BaseRetriever): + def __init__( + self, + nodes: Optional[List[BaseNode]] = None, + existing_bm25: Optional[bm25s.BM25] = None, + similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + callback_manager: Optional[CallbackManager] = None, + objects: Optional[List[IndexNode]] = None, + object_map: Optional[dict] = None, + verbose: bool = False, + ) -> None: + self.similarity_top_k = similarity_top_k + if existing_bm25 is not None: + self.bm25 = existing_bm25 + self.corpus = existing_bm25.corpus + else: + from nltk.corpus import stopwords + if nodes is None: + raise ValueError("Please pass nodes or an existing BM25 object.") + + self.corpus = [node_to_metadata_dict(node) for node in nodes] + + corpus_tokens = chTokenize( + [node.get_content() for node in nodes], + show_progress=verbose, + ) + self.bm25 = bm25s.BM25() + self.bm25.index(corpus_tokens, show_progress=verbose) + super().__init__( + callback_manager=callback_manager, + object_map=object_map, + objects=objects, + verbose=verbose, + ) + + @classmethod + def from_defaults( + cls, + index: Optional[VectorStoreIndex] = None, + nodes: Optional[List[BaseNode]] = None, + docstore: Optional[BaseDocumentStore] = None, + similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + verbose: bool = False, + ) -> "CHBM25Retriever": + if sum(bool(val) for val in [index, nodes, docstore]) != 1: + raise ValueError("Please pass exactly one of index, nodes, or docstore.") + + if index is not None: + docstore = index.docstore + + if docstore is not None: + nodes = cast(List[BaseNode], list(docstore.docs.values())) + + assert ( + nodes is not None + ), "Please pass exactly one of index, nodes, or docstore." + + return cls( + nodes=nodes, + similarity_top_k=similarity_top_k, + verbose=verbose, + ) + + def get_persist_args(self) -> Dict[str, Any]: + """Get Persist Args Dict to Save.""" + return { + CHDEFAULT_PERSIST_ARGS[key]: getattr(self, key) + for key in CHDEFAULT_PERSIST_ARGS + if hasattr(self, key) + } + + def persist(self, path: str, **kwargs: Any) -> None: + """Persist the retriever to a directory.""" + self.bm25.save(path, corpus=self.corpus, **kwargs) + with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME), "w") as f: + json.dump(self.get_persist_args(), f, indent=2) + + @classmethod + def from_persist_dir(cls, path: str, **kwargs: Any) -> "CHBM25Retriever": + """Load the retriever from a directory.""" + bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs) + with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME)) as f: + retriever_data = json.load(f) + return cls(existing_bm25=bm25, **retriever_data) + + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + query = query_bundle.query_str + tokenized_query = chTokenize( + query,show_progress=self._verbose + ) + indexes, scores = self.bm25.retrieve( + tokenized_query, k=self.similarity_top_k, show_progress=self._verbose + ) + + # batched, but only one query + indexes = indexes[0] + scores = scores[0] + + nodes: List[NodeWithScore] = [] + for idx, score in zip(indexes, scores): + # idx can be an int or a dict of the node + if isinstance(idx, dict): + node = metadata_dict_to_node(idx) + else: + node_dict = self.corpus[int(idx)] + node = metadata_dict_to_node(node_dict) + nodes.append(NodeWithScore(node=node, score=float(score))) + + return nodes \ No newline at end of file diff --git a/backend/app/engine/retriever/CHTokener.py b/backend/app/engine/retriever/CHTokener.py new file mode 100644 index 0000000..9c5a071 --- /dev/null +++ b/backend/app/engine/retriever/CHTokener.py @@ -0,0 +1,46 @@ +from typing import Any, Dict, List, Union, Callable, NamedTuple +from bm25s.tokenization import * + +try: + from tqdm.auto import tqdm +except ImportError: + + def tqdm(iterable, *args, **kwargs): + return iterable + + +def chinese_tokenizer(text: str) -> List[str]: + import jieba + from nltk.corpus import stopwords + tokens = jieba.lcut(text) + return [token for token in tokens if token not in stopwords.words('chinese')] + +def chTokenize( + texts, + show_progress: bool = True, + leave: bool = False, +) -> Union[List[List[str]], Tokenized]: + if isinstance(texts, str): + texts = [texts] + + corpus_ids = [] + token_to_index = {} + + for text in tqdm( + texts, desc="Split strings", leave=leave, disable=not show_progress + ): + + splitted = chinese_tokenizer(text) + doc_ids = [] + + for token in splitted: + if token not in token_to_index: + token_to_index[token] = len(token_to_index) + + token_id = token_to_index[token] + doc_ids.append(token_id) + + corpus_ids.append(doc_ids) + + return Tokenized(ids=corpus_ids, vocab=token_to_index) + diff --git a/backend/app/engine/retriever/HybridRetriever.py b/backend/app/engine/retriever/HybridRetriever.py new file mode 100644 index 0000000..4bf0b8d --- /dev/null +++ b/backend/app/engine/retriever/HybridRetriever.py @@ -0,0 +1,67 @@ +import os +from typing import Optional, Any, Dict, List + +from llama_index.core.base.base_retriever import BaseRetriever +from llama_index.core.schema import NodeWithScore, QueryBundle + +from app.engine.retriever.CHBM25Retriever import CHBM25Retriever + + +class HybridRetriever(BaseRetriever): + def __init__( + self, + vector_index, + similarity_top_k: int = 2, + out_top_k: Optional[int] = None, + alpha: float = 0.5, + filters = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self._vector_index = vector_index + self._embed_model = vector_index._embed_model + self._out_top_k = out_top_k or similarity_top_k + self._vecRetriever = vector_index.as_retriever( + similarity_top_k=similarity_top_k,filters = filters + ) + + STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm") + if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0: + self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR) + else: + bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None)) + bmRetriver.persist(STORAGE_DIR) + self._alpha = alpha + + + + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str) + bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str) + + bmDic:Dict[str,NodeWithScore] = {} + for node in bmNodes: + bmDic[node.node_id] = node + + result_tups = [] + for i in range(len(vecNodes)): + node = vecNodes[i] + bmScore = 0.0 + if node.node_id in bmDic: + bmScore = bmDic[node.node_id].score + bmDic.pop(node.node_id) + else: + bmScore = 0.0 + full_similarity = (self._alpha * node.score) + ( + (1 - self._alpha) * bmScore + ) + result_tups.append((full_similarity, node)) + + for _,node in bmDic.items(): + full_similarity = (1 - self._alpha) * node.score + result_tups.append((full_similarity, node)) + + result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True) + for full_score, node in result_tups: + node.score = full_score + return [n for _, n in result_tups][:self._out_top_k] \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index 5f84c63..0f5e9ad 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,3 +1,4 @@ + from dotenv import load_dotenv from llama_index.core.node_parser import SentenceSplitter @@ -16,11 +17,14 @@ from app.observability import init_observability from fastapi.staticfiles import StaticFiles from phoenix.trace import using_project + logger = logging.getLogger("uvicorn") + usPrj = using_project(os.getenv("PHOENIX_PROJECT_NAME")) usPrj.__enter__() + init_settings() init_observability() @@ -52,12 +56,10 @@ mount_static_files("data_output", "/api/files/output") app.include_router(chat_router, prefix="/api/chat") app.include_router(file_upload_router, prefix="/api/chat/upload") -# Redirect to documentation page when accessing base URL @app.get("/") async def redirect_to_docs(): return RedirectResponse(url="/docs") -SentenceSplitter if __name__ == "__main__": app_host = os.getenv("APP_HOST", "0.0.0.0") app_port = int(os.getenv("APP_PORT", "8000")) @@ -65,4 +67,3 @@ if __name__ == "__main__": reload = False uvicorn.run(app="main:app", host=app_host, port=app_port, reload=reload) - #usPrj.__exit__() diff --git a/backend/pyproject.toml b/backend/pyproject.toml index bb8fcf2..de1fbbb 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -18,6 +18,7 @@ llama-index = "0.10.63" cachetools = "^5.3.3" protobuf = "4.25.4" nltk = "^3.8.2" +jieba = "^0.42.1" #arize-phoenix = "^4.12.0" openinference-instrumentation-llama-index="2.2.3"