Compare commits

18 Commits

Author SHA1 Message Date
ly 9ee24627c2 Merge pull request 'dev' (#2) from dev into main
Reviewed-on: #2
2024-08-23 09:37:06 +08:00
chentianrui 5fc8375a06 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-23 08:55:54 +08:00
chentianrui cf1ed4e71d 解决输出频繁出现'''的问题 2024-08-23 08:53:13 +08:00
ly 8050551a53 调整创建SQL引擎函数名称 2024-08-22 21:21:37 +08:00
ly 513ce73190 Merge branch 'dev' of https://git.97id.com/ly/zjdataai-app into dev 2024-08-22 21:18:37 +08:00
chentianrui 48d10fd1f3 修复了重排的参数问题 2024-08-22 19:40:56 +08:00
ly 9cbe414a0c 调整参数顺序 2024-08-22 17:10:32 +08:00
chentianrui 4c1c67aa50 增加开启了混合检索 2024-08-22 17:06:28 +08:00
chentianrui 59ef831a41 修改了提示词 2024-08-22 16:36:37 +08:00
ly 3ceb30c375 修复缺陷 2024-08-22 16:17:10 +08:00
ly e71da586e3 修复缺陷 2024-08-22 16:02:07 +08:00
ly b3a575d158 调整代码结构,同时修改重定义提示词的方式。 2024-08-22 15:39:49 +08:00
chentianrui db006985d7 修改了提示词,约束模型回答 2024-08-22 15:24:29 +08:00
wanyaokun 870af69189 新增包依赖 2024-08-22 12:09:15 +08:00
wanyaokun 3460b8410e 新增关键字缓存路径 2024-08-22 12:06:43 +08:00
wanyaokun 586bb76c9c 新增关键字检索缓存路径 2024-08-22 11:09:16 +08:00
wanyaokun 8d7190d0b6 新增关键字检索类 2024-08-22 11:07:23 +08:00
wanyaokun 043aea6cca 新增自定义关键词检索类 2024-08-22 11:06:22 +08:00
12 changed files with 485 additions and 164 deletions
+1
View File
@@ -49,6 +49,7 @@ VECTOR_STORE_COLLECTION=default
# Specify this if you are using a local vector database. # Specify this if you are using a local vector database.
# Otherwise, use VECTOR_STORE__HOST and VECTOR_STORE__PORT config above # Otherwise, use VECTOR_STORE__HOST and VECTOR_STORE__PORT config above
VECTOR_STORE_PATH=./storage_vector VECTOR_STORE_PATH=./storage_vector
BM_RETRIEVER_PATH =./storage_bm
+6 -5
View File
@@ -6,12 +6,13 @@ SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zj
# The number of similar embeddings to return when retrieving documents. # The number of similar embeddings to return when retrieving documents.
TOP_K=10 TOP_K=10
#-------------------------- #--------------------------
# 是否启用混合检索
HYBRID_ENABLED = true
# 混合检索阈值
HYBRID_ALPHA = 0.6
#--------------------------
# 是否启用检索重排功能 # 是否启用检索重排功能
RERANK_ENABLED=true RERANK_ENABLED=true
# 是否启用混合检索
HYBRID_ENABLED = false
# 混合检索阈值
HYBRID_ALPHA = 0.5
# Rerank model # Rerank model
RERANK_MODEL=bge-reranker-v2-m3 RERANK_MODEL=bge-reranker-v2-m3
RERANK_BASE_URL=http://10.1.16.39:9995 RERANK_BASE_URL=http://10.1.16.39:9995
@@ -80,7 +81,7 @@ VECTOR_STORE_COLLECTION=default
# Specify this if you are using a local vector database. # Specify this if you are using a local vector database.
# Otherwise, use VECTOR_STORE__HOST and VECTOR_STORE__PORT config above # Otherwise, use VECTOR_STORE__HOST and VECTOR_STORE__PORT config above
VECTOR_STORE_PATH=./storage_vector VECTOR_STORE_PATH=./storage_vector
BM_RETRIEVER_PATH =./storage_bm
PHOENIX_API_KEY=123456 PHOENIX_API_KEY=123456
+17 -120
View File
@@ -1,154 +1,57 @@
import os import os
from llama_index.core import SQLDatabase, SummaryIndex, VectorStoreIndex from llama_index.core.agent import AgentRunner, ReActChatFormatter
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
from llama_index.core.settings import Settings from llama_index.core.settings import Settings
from llama_index.core.agent import AgentRunner, StructuredPlannerAgent, FunctionCallingAgentWorker
from llama_index.core.tools.query_engine import QueryEngineTool from llama_index.core.tools.query_engine import QueryEngineTool
from sqlalchemy import create_engine, Engine
from app.engine.loaders.db import makeDescriptionByEngine from app.engine.engine import create_query_engine, create_summary_query_engine, create_sql_query_engine
from app.engine.tools import ToolFactory
from app.engine.index import get_index from app.engine.index import get_index
from app.settings import get_node_postprocessors #from app.engine.loaders.db import makeDescriptionByEngine
from app.engine.tools import ToolFactory
from llama_index.core.retrievers import BaseRetriever
from llama_index.core import QueryBundle
from llama_index.core.schema import NodeWithScore
from typing import List, Any, Optional,Dict
from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine
class HybridRetriever(BaseRetriever):
def __init__(
self,
vector_index,
similarity_top_k: int = 2,
out_top_k: Optional[int] = None,
alpha: float = 0.5,
filters = None,
**kwargs: Any,
) -> None:
from llama_index.retrievers.bm25 import BM25Retriever
from nltk.corpus import stopwords
super().__init__(**kwargs)
self._vector_index = vector_index
self._embed_model = vector_index._embed_model
self._out_top_k = out_top_k or similarity_top_k
self._vecRetriever = vector_index.as_retriever(
similarity_top_k=similarity_top_k,filters = filters
)
self._bm25Retriever = BM25Retriever.from_defaults(similarity_top_k=similarity_top_k,
nodes=self._vector_index.vector_store.get_nodes(None),
language=stopwords.words('chinese'))
self._alpha = alpha
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str)
bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str)
bmDic:Dict[str,NodeWithScore] = {}
for node in bmNodes:
bmDic[node.node_id] = node
result_tups = []
for i in range(len(vecNodes)):
node = vecNodes[i]
bmScore = 0.0
if node.node_id in bmDic:
bmScore = bmDic[node.node_id].score
bmDic.pop(node.node_id)
else:
bmScore = 0.0
full_similarity = (self._alpha * node.score) + (
(1 - self._alpha) * bmScore
)
result_tups.append((full_similarity, node))
for _,node in bmDic.items():
full_similarity = (1 - self._alpha) * node.score
result_tups.append((full_similarity, node))
result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True)
for full_score, node in result_tups:
node.score = full_score
return [n for _, n in result_tups][:self._out_top_k]
def get_Retriever(index,**kwargs):
bEnableHybrid = True if os.getenv("HYBRID_ENABLED",False).title() == 'True' else False
if bEnableHybrid:
alpha = float(os.getenv("HYBRID_ALPHA", "0.5"))
retriever = HybridRetriever(index,alpha = alpha,**kwargs)
else:
retriever = index.as_retriever(**kwargs)
return retriever
sql_database = None
sql_obj_index = None
def get_chat_engine(filters=None, params=None): def get_chat_engine(filters=None, params=None):
system_prompt = os.getenv("SYSTEM_PROMPT") system_prompt = os.getenv("SYSTEM_PROMPT")
top_k = int(os.getenv("TOP_K", "3")) top_k = int(os.getenv("TOP_K", "3"))
use_reranker = os.getenv("RERANK_ENABLED")
tools = [] tools = []
global sql_obj_index
global sql_database
if sql_obj_index is None:
sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
sql_database = SQLDatabase(sqlengine)
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
sql_obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
index_cls=VectorStoreIndex,
)
# 创建SQL查询工具 # 创建SQL查询工具
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database, sql_query_engine = create_sql_query_engine()
sql_obj_index.as_retriever(similarity_top_k=top_k),
verbose=True,)
sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine, sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine,
name="zjdata_query_tool", name="zjdata_query_tool",
description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具") description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具"
)
#tools.append(sql_query_tool)
# Add query tool if index exists # Add query tool if index exists
index = get_index() index = get_index()
if index is not None: if index is not None:
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None)) summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
summary_query_engine = summary_index.as_query_engine()
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool", summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
description="适用于任何需要进行全面总结、概括的要求。", description="适用于任何需要进行全面总结、概括的要求。",
#description="适用于任何需要对所有内容进行全面总结的请求。有关电力造价领域更具体部分的问题,请使用zj_query_engine_tool",
) )
query_engine = create_query_engine(index,top_k,use_reranker,filters)
# 创建向量检索查询工具
postprocess = get_node_postprocessors()
query_engine = RetrieverQueryEngine.from_args(
get_Retriever(index,similarity_top_k=top_k,
filters=filters),
node_postprocessors=postprocess,
)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool", query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。如果你不知道答案,就说你不知道,不要编造答案。", description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
) )
tools.append(summary_query_tool) tools.append(summary_query_tool)
tools.append(query_engine_tool) tools.append(query_engine_tool)
#tools.append(sql_query_tool)
# Add additional tools # Add additional tools
tools += ToolFactory.from_env() tools += ToolFactory.from_env()
return AgentRunner.from_llm( prefix_messages = ("""您的设计旨在帮助完成各种任务,从回答问题到提供其他类型分析的摘要。\n\n##工具\n\n你可以访问各种工具。你有责任按照你认为合适的顺序使用这些工具来完成当前的任务。\n这可能需要将任务分解为子任务,并使用不同的工具来完成每个子任务。\n\n你可以访问以下工具:\n{tool_desc}\n\n\n##输出格式\n\n请用与问题相同的语言回答,并使用以下格式:\n\n \nThought: 用户当前的语言是:(user's language)。我需要使用工具来帮助我回答问题。\nAction: 如果使用工具,则为工具名称(one of {tool_names})。\nAction Input: 输入给工具的内容,使用JSON格式表示kwargs(例如{{\"input\": \"hello world\", \"num_beams\": 5}}\n \n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n切勿用Markdown代码标记包围你的响应。如果需要,可以在响应中使用代码标记。\n\n请为Action Input使用有效的JSON格式。不要这样做{{\'input\': \'hello world\', \'num_beams\': 5}}。\n\n如果使用此格式,用户将以下面的格式进行回应:\n\n \nObservation: 工具响应\n \n\n你应该继续重复上述格式,直到你有足够的信息来回答问题而无需使用更多工具。此时,你必须使用以下两种格式之一进行回答:\n\n \nThought: 我可以不用任何工具来回答。我将使用用户的语言来回答。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n \nThought: 我无法使用提供的工具回答问题。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n##如果从工具中得到的回应是Empty Response,那么只需要回答“我不知道”,不需要额外回答别的内容。## 当前对话\n\n以下是当前对话,由人类和助手的消息交替组成。\n""")
react_chat_formatter = ReActChatFormatter.from_defaults(prefix_messages)
agentrunner = AgentRunner.from_llm(
llm=Settings.llm, llm=Settings.llm,
tools=tools, tools=tools,
react_chat_formatter=react_chat_formatter,
system_prompt=system_prompt, system_prompt=system_prompt,
verbose=True, verbose=True,
) )
return agentrunner
# create the function calling worker for reasoning # create the function calling worker for reasoning
# worker = FunctionCallingAgentWorker.from_tools( # worker = FunctionCallingAgentWorker.from_tools(
# tools, verbose=True # tools, verbose=True
@@ -156,9 +59,3 @@ def get_chat_engine(filters=None, params=None):
# #
# # wrap the worker in the top-level planner # # wrap the worker in the top-level planner
# return StructuredPlannerAgent(worker, tools) # return StructuredPlannerAgent(worker, tools)
+108
View File
@@ -0,0 +1,108 @@
import os
from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.readers.database import DatabaseReader
from sqlalchemy import create_engine
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
from app.engine.retriever.HybridRetriever import HybridRetriever
from app.settings import get_node_postprocessors
def makeDescriptionByEngine(sql_database:SQLDatabase):
reader = DatabaseReader(sql_database)
table_names = sql_database.get_usable_table_names()
table_schema_objs = []
for table_name in table_names:
columns = sql_database.get_table_columns(table_name)
if len(columns) > 150:
continue
stats_txt = ""
if table_name == 'gongchengshuxing':
stats_txt = '该表中有以下属性:'
documents = reader.load_data(query='select name from gongchengshuxing')
for index in range(len(documents) if len(documents) < 30 else 30):
if index == 0:
continue
elif index > 1:
stats_txt += ','
stats_txt += documents[index].text.split(':')[1]
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
table_schema_objs.append(tbSchema)
return table_schema_objs
def get_Retriever(index,**kwargs):
strEnableHybrid = os.getenv("HYBRID_ENABLED",'False')
bEnableHybrid = True if strEnableHybrid is not None and strEnableHybrid.title() == 'True' else False
if bEnableHybrid:
alpha = float(os.getenv("HYBRID_ALPHA", "0.5"))
retriever = HybridRetriever(index,alpha = alpha,**kwargs)
else:
retriever = index.as_retriever(**kwargs)
return retriever
sql_database = None
sql_obj_index = None
# Create a sql query engine
def create_sql_query_engine(top_k=3, use_reranker=False, filters=None):
global sql_obj_index
global sql_database
if sql_obj_index is None or sql_database is None:
sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
sql_database = SQLDatabase(sqlengine)
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
sql_obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
index_cls=VectorStoreIndex,
)
# 创建SQL查询工具
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
sql_obj_index.as_retriever(similarity_top_k=top_k),
verbose=True,
)
return sql_query_engine
# Create a summary query engine
def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None):
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None))
summary_query_engine = summary_index.as_query_engine(
response_mode=ResponseMode.TREE_SUMMARIZE,
use_async=True,
streaming=True,
)
return summary_query_engine
# Create a query engine
def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
# 创建向量检索查询工具
postprocess = None
if use_reranker:
postprocess = get_node_postprocessors()
query_engine = RetrieverQueryEngine.from_args(
get_Retriever(index,
similarity_top_k=top_k,
filters=filters),
text_qa_template=text_qa_template,
refine_template=refine_template,
summary_template = summary_template,
simple_template = simple_template,
node_postprocessors=postprocess,
use_async=True,
streaming=True,
)
return query_engine
+9
View File
@@ -8,6 +8,7 @@ import os
from app.engine.loaders import get_documents from app.engine.loaders import get_documents
from app.engine.vectordb import get_vector_store from app.engine.vectordb import get_vector_store
from app.settings import init_settings from app.settings import init_settings
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
from llama_index.core.ingestion import IngestionPipeline from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.settings import Settings from llama_index.core.settings import Settings
@@ -58,6 +59,13 @@ def persist_storage(docstore, vector_store):
storage_context.persist(STORAGE_DIR) storage_context.persist(STORAGE_DIR)
def persist_BMRetriever(vector_store):
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
top_k = int(os.getenv("TOP_K", "3"))
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=top_k,nodes=vector_store.get_nodes([]))
bmRetriver.persist(STORAGE_DIR)
def generate_datasource(): def generate_datasource():
init_settings() init_settings()
logger.info("Generate index for the provided data") logger.info("Generate index for the provided data")
@@ -75,6 +83,7 @@ def generate_datasource():
# Build the index and persist storage # Build the index and persist storage
persist_storage(docstore, vector_store) persist_storage(docstore, vector_store)
persist_BMRetriever(vector_store)
logger.info("Finished generating the index") logger.info("Finished generating the index")
+4 -36
View File
@@ -1,20 +1,14 @@
import os
import logging import logging
from typing import List
from typing import Any, List, Optional from typing import Any, List, Optional
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document
from llama_index.core.utilities.sql_wrapper import SQLDatabase
from sqlalchemy import text
from sqlalchemy.engine import Engine
from llama_index.core import SQLDatabase, Document from llama_index.core import SQLDatabase, Document
from llama_index.core.objects import SQLTableSchema, SQLTableNodeMapping from llama_index.core.objects import SQLTableSchema
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
from llama_index.readers.database import DatabaseReader from llama_index.readers.database import DatabaseReader
from pydantic import BaseModel, validator from pydantic import BaseModel
from llama_index.core.indices.vector_store import VectorStoreIndex
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy.engine import Engine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -119,32 +113,6 @@ class DBLoaderConfig(BaseModel):
uri: str uri: str
queries: List[str] queries: List[str]
def makeDescriptionByEngine(sql_database:SQLDatabase):
reader = DatabaseReader(sql_database)
table_names = sql_database.get_usable_table_names()
table_schema_objs = []
for table_name in table_names:
columns = sql_database.get_table_columns(table_name)
if len(columns) > 150:
continue
stats_txt = ""
if table_name == 'gongchengshuxing':
stats_txt = '该表中有以下属性:'
documents = reader.load_data(query='select name from gongchengshuxing')
for index in range(len(documents) if len(documents) < 30 else 30):
if index == 0:
continue
elif index > 1:
stats_txt += ','
stats_txt += documents[index].text.split(':')[1]
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
table_schema_objs.append(tbSchema)
return table_schema_objs
def get_db_documents(configs: list[DBLoaderConfig]): def get_db_documents(configs: list[DBLoaderConfig]):
docs = [] docs = []
+89
View File
@@ -0,0 +1,89 @@
from llama_index.core import PromptTemplate
text_qa_template_str = (
"# 角色\n"
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
"如同直接从文件中提取的内容。\n"
"## 技能\n"
"### 技能 1: 数据查询与提供\n"
"- 准确回答所有关于电力工程造价的相关问题。\n"
"- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n"
"- 确保提供的信息严格基于工程文档中的记录。\n"
"### 技能 2: 技术性解释\n"
"- 解释造价工程中的技术术语和概念。\n"
"- 为复杂的工程细节提供清晰易懂的说明。\n"
"## 约束\n"
"- 仅回答与电力工程造价文件相关的具体问题。\n"
"- 不进行任何超出文件内容的猜测或假设。\n"
"- 所有回答均基于文件内容,采用客观和技术性的语言。\n"
"- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n"
"以下为上下文信息\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"请根据上下文信息而非先前知识回答我的问题或回复我的指令。前面的上下文信息可能有用,也可能没用,你需要从我给出的上下文信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。我的问题或指令是什么语种,你就用什么语种回复。\n"
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
"问题:{query_str}\n"
"你的回复: "
)
text_qa_template = PromptTemplate(text_qa_template_str)
refine_template_str = (
"这是原本的问题: {query_str}\n"
"我们已经提供了回答: {existing_answer}\n"
"现在我们有机会改进这个回答 "
"使用以下更多上下文(仅当需要用时)\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"根据新的上下文, 请改进原来的回答。"
"如果新的上下文没有用, 直接返回原本的回答。\n"
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
"改进的回答: "
)
refine_template = PromptTemplate(refine_template_str)
summary_template_str = (
"# 角色\n"
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
"如同直接从文件中提取的内容。\n"
"## 技能\n"
"### 技能 1: 数据查询与提供\n"
"- 准确回答所有关于电力工程造价的相关问题。\n"
"- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n"
"- 确保提供的信息严格基于工程文档中的记录。\n"
"### 技能 2: 技术性解释\n"
"- 解释造价工程中的技术术语和概念。\n"
"- 为复杂的工程细节提供清晰易懂的说明。\n"
"## 约束\n"
"- 仅回答与电力工程造价文件相关的具体问题。\n"
"- 不进行任何超出文件内容的猜测或假设。\n"
"- 所有回答均基于文件内容,采用客观和技术性的语言。\n"
"- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n"
"来自多个来源的上下文信息如下。\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"鉴于来自多个来源的信息而非先验知识, "
"回答查询。\n"
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
"Query: {query_str}\n"
"Answer: "
)
summary_template = PromptTemplate(summary_template_str)
simple_template_str = (
"{query_str}"
)
simple_template = PromptTemplate(simple_template_str)
@@ -0,0 +1,133 @@
import json
import logging
import os
from typing import Any, Callable, Dict, List, Optional, cast
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.core.indices.vector_store.base import VectorStoreIndex
from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
from llama_index.core.storage.docstore.types import BaseDocumentStore
from llama_index.core.vector_stores.utils import (
node_to_metadata_dict,
metadata_dict_to_node,
)
import bm25s
from app.engine.retriever.CHTokener import chTokenize
CHDEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"}
CHDEFAULT_PERSIST_FILENAME = "retriever.json"
class CHBM25Retriever(BaseRetriever):
def __init__(
self,
nodes: Optional[List[BaseNode]] = None,
existing_bm25: Optional[bm25s.BM25] = None,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
callback_manager: Optional[CallbackManager] = None,
objects: Optional[List[IndexNode]] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
) -> None:
self.similarity_top_k = similarity_top_k
if existing_bm25 is not None:
self.bm25 = existing_bm25
self.corpus = existing_bm25.corpus
else:
from nltk.corpus import stopwords
if nodes is None:
raise ValueError("Please pass nodes or an existing BM25 object.")
self.corpus = [node_to_metadata_dict(node) for node in nodes]
corpus_tokens = chTokenize(
[node.get_content() for node in nodes],
show_progress=verbose,
)
self.bm25 = bm25s.BM25()
self.bm25.index(corpus_tokens, show_progress=verbose)
super().__init__(
callback_manager=callback_manager,
object_map=object_map,
objects=objects,
verbose=verbose,
)
@classmethod
def from_defaults(
cls,
index: Optional[VectorStoreIndex] = None,
nodes: Optional[List[BaseNode]] = None,
docstore: Optional[BaseDocumentStore] = None,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
verbose: bool = False,
) -> "CHBM25Retriever":
if sum(bool(val) for val in [index, nodes, docstore]) != 1:
raise ValueError("Please pass exactly one of index, nodes, or docstore.")
if index is not None:
docstore = index.docstore
if docstore is not None:
nodes = cast(List[BaseNode], list(docstore.docs.values()))
assert (
nodes is not None
), "Please pass exactly one of index, nodes, or docstore."
return cls(
nodes=nodes,
similarity_top_k=similarity_top_k,
verbose=verbose,
)
def get_persist_args(self) -> Dict[str, Any]:
"""Get Persist Args Dict to Save."""
return {
CHDEFAULT_PERSIST_ARGS[key]: getattr(self, key)
for key in CHDEFAULT_PERSIST_ARGS
if hasattr(self, key)
}
def persist(self, path: str, **kwargs: Any) -> None:
"""Persist the retriever to a directory."""
self.bm25.save(path, corpus=self.corpus, **kwargs)
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME), "w") as f:
json.dump(self.get_persist_args(), f, indent=2)
@classmethod
def from_persist_dir(cls, path: str, **kwargs: Any) -> "CHBM25Retriever":
"""Load the retriever from a directory."""
bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME)) as f:
retriever_data = json.load(f)
return cls(existing_bm25=bm25, **retriever_data)
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
query = query_bundle.query_str
tokenized_query = chTokenize(
query,show_progress=self._verbose
)
indexes, scores = self.bm25.retrieve(
tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
)
# batched, but only one query
indexes = indexes[0]
scores = scores[0]
nodes: List[NodeWithScore] = []
for idx, score in zip(indexes, scores):
# idx can be an int or a dict of the node
if isinstance(idx, dict):
node = metadata_dict_to_node(idx)
else:
node_dict = self.corpus[int(idx)]
node = metadata_dict_to_node(node_dict)
nodes.append(NodeWithScore(node=node, score=float(score)))
return nodes
+46
View File
@@ -0,0 +1,46 @@
from typing import Any, Dict, List, Union, Callable, NamedTuple
from bm25s.tokenization import *
try:
from tqdm.auto import tqdm
except ImportError:
def tqdm(iterable, *args, **kwargs):
return iterable
def chinese_tokenizer(text: str) -> List[str]:
import jieba
from nltk.corpus import stopwords
tokens = jieba.lcut(text)
return [token for token in tokens if token not in stopwords.words('chinese')]
def chTokenize(
texts,
show_progress: bool = True,
leave: bool = False,
) -> Union[List[List[str]], Tokenized]:
if isinstance(texts, str):
texts = [texts]
corpus_ids = []
token_to_index = {}
for text in tqdm(
texts, desc="Split strings", leave=leave, disable=not show_progress
):
splitted = chinese_tokenizer(text)
doc_ids = []
for token in splitted:
if token not in token_to_index:
token_to_index[token] = len(token_to_index)
token_id = token_to_index[token]
doc_ids.append(token_id)
corpus_ids.append(doc_ids)
return Tokenized(ids=corpus_ids, vocab=token_to_index)
@@ -0,0 +1,67 @@
import os
from typing import Optional, Any, Dict, List
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.schema import NodeWithScore, QueryBundle
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
class HybridRetriever(BaseRetriever):
def __init__(
self,
vector_index,
similarity_top_k: int = 2,
out_top_k: Optional[int] = None,
alpha: float = 0.5,
filters = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._vector_index = vector_index
self._embed_model = vector_index._embed_model
self._out_top_k = out_top_k or similarity_top_k
self._vecRetriever = vector_index.as_retriever(
similarity_top_k=similarity_top_k,filters = filters
)
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
else:
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None))
bmRetriver.persist(STORAGE_DIR)
self._alpha = alpha
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str)
bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str)
bmDic:Dict[str,NodeWithScore] = {}
for node in bmNodes:
bmDic[node.node_id] = node
result_tups = []
for i in range(len(vecNodes)):
node = vecNodes[i]
bmScore = 0.0
if node.node_id in bmDic:
bmScore = bmDic[node.node_id].score
bmDic.pop(node.node_id)
else:
bmScore = 0.0
full_similarity = (self._alpha * node.score) + (
(1 - self._alpha) * bmScore
)
result_tups.append((full_similarity, node))
for _,node in bmDic.items():
full_similarity = (1 - self._alpha) * node.score
result_tups.append((full_similarity, node))
result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True)
for full_score, node in result_tups:
node.score = full_score
return [n for _, n in result_tups][:self._out_top_k]
+4 -3
View File
@@ -1,3 +1,4 @@
from dotenv import load_dotenv from dotenv import load_dotenv
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
@@ -16,11 +17,14 @@ from app.observability import init_observability
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from phoenix.trace import using_project from phoenix.trace import using_project
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
usPrj = using_project(os.getenv("PHOENIX_PROJECT_NAME")) usPrj = using_project(os.getenv("PHOENIX_PROJECT_NAME"))
usPrj.__enter__() usPrj.__enter__()
init_settings() init_settings()
init_observability() init_observability()
@@ -52,12 +56,10 @@ mount_static_files("data_output", "/api/files/output")
app.include_router(chat_router, prefix="/api/chat") app.include_router(chat_router, prefix="/api/chat")
app.include_router(file_upload_router, prefix="/api/chat/upload") app.include_router(file_upload_router, prefix="/api/chat/upload")
# Redirect to documentation page when accessing base URL
@app.get("/") @app.get("/")
async def redirect_to_docs(): async def redirect_to_docs():
return RedirectResponse(url="/docs") return RedirectResponse(url="/docs")
SentenceSplitter
if __name__ == "__main__": if __name__ == "__main__":
app_host = os.getenv("APP_HOST", "0.0.0.0") app_host = os.getenv("APP_HOST", "0.0.0.0")
app_port = int(os.getenv("APP_PORT", "8000")) app_port = int(os.getenv("APP_PORT", "8000"))
@@ -65,4 +67,3 @@ if __name__ == "__main__":
reload = False reload = False
uvicorn.run(app="main:app", host=app_host, port=app_port, reload=reload) uvicorn.run(app="main:app", host=app_host, port=app_port, reload=reload)
#usPrj.__exit__()
+1
View File
@@ -18,6 +18,7 @@ llama-index = "0.10.63"
cachetools = "^5.3.3" cachetools = "^5.3.3"
protobuf = "4.25.4" protobuf = "4.25.4"
nltk = "^3.8.2" nltk = "^3.8.2"
jieba = "^0.42.1"
#arize-phoenix = "^4.12.0" #arize-phoenix = "^4.12.0"
openinference-instrumentation-llama-index="2.2.3" openinference-instrumentation-llama-index="2.2.3"