diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py index 1eaf1fe..6e2a97a 100644 --- a/backend/app/engine/__init__.py +++ b/backend/app/engine/__init__.py @@ -4,7 +4,7 @@ from llama_index.core.agent import AgentRunner, ReActChatFormatter from llama_index.core.settings import Settings from llama_index.core.tools.query_engine import QueryEngineTool -from app.engine.engine import create_query_engine, create_summary_query_engine, create_sql_query_engine +from app.engine.engine import create_query_engine, create_summary_query_engine from app.engine.index import get_index #from app.engine.loaders.db import makeDescriptionByEngine from app.engine.tools import ToolFactory @@ -17,11 +17,11 @@ def get_chat_engine(filters=None, params=None): tools = [] # 创建SQL查询工具 - sql_query_engine = create_sql_query_engine() - sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine, - name="zjdata_query_tool", - description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具" - ) +# sql_query_engine = create_summary_query_engine(index) + # sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine, + # name="zjdata_query_tool", + # description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具" + # ) #tools.append(sql_query_tool) # Add query tool if index exists diff --git a/backend/app/engine/engine.py b/backend/app/engine/engine.py index 6cb552f..379275e 100644 --- a/backend/app/engine/engine.py +++ b/backend/app/engine/engine.py @@ -52,8 +52,8 @@ def get_Retriever(index,**kwargs): sql_database = None sql_obj_index = None -# Create a sql query engine -def create_sql_query_engine(top_k=3, use_reranker=False, filters=None): +# Create a summary query engine +def create_summary_query_engine(top_k=3, use_reranker=False, filters=None): global sql_obj_index global sql_database if sql_obj_index is None or sql_database is None: diff --git a/backend/app/engine/loaders/__init__.py b/backend/app/engine/loaders/__init__.py index a220170..d311124 100644 --- a/backend/app/engine/loaders/__init__.py +++ b/backend/app/engine/loaders/__init__.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) def load_configs(): - with open("config/loaders.yaml") as f: + with open("config/loaders.yaml",'r', encoding='utf-8') as f: configs = yaml.safe_load(f) return configs diff --git a/backend/app/engine/loaders/db.py b/backend/app/engine/loaders/db.py index d6310e2..4be984d 100644 --- a/backend/app/engine/loaders/db.py +++ b/backend/app/engine/loaders/db.py @@ -2,17 +2,14 @@ import logging from typing import Any, List, Optional from llama_index.core import SQLDatabase, Document -from llama_index.core.objects import SQLTableSchema -from llama_index.core.readers.base import BaseReader from llama_index.readers.database import DatabaseReader from pydantic import BaseModel -from sqlalchemy import create_engine -from sqlalchemy import text +from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine logger = logging.getLogger(__name__) -class CustomDatabaseReader(BaseReader): +class CustomDatabaseReader(DatabaseReader): """Simple Database reader. Concatenates each row into Document used by LlamaIndex. @@ -76,28 +73,30 @@ class CustomDatabaseReader(BaseReader): "set of credentials." ) - def load_data(self, query: str) -> List[Document]: + def load_data(self, query: str, explanation: str) -> List[Document]: """Query and load data from the Database, returning a list of Documents. Args: query (str): Query parameter to filter tables and rows. + explanation (str): Explanation for the query to be included in the document. Returns: List[Document]: A list of Document objects. """ - dco_str = "" + dco_str = explanation + "\n" + with self.sql_database.engine.connect() as connection: if query is None: raise ValueError("A query parameter is necessary to filter the data") else: result = connection.execute(text(query)) - dco_str = ", ".join( + dco_str += ", ".join( [f"{entry}" for entry in result.keys()] - ) + ) + "\n" for item in result.fetchall(): - # fetch each item + # Fetch each item record_str = ", ".join( [f"{entry}" for col, entry in zip(result.keys(), item)] ) @@ -111,7 +110,7 @@ class CustomDatabaseReader(BaseReader): class DBLoaderConfig(BaseModel): uri: str - queries: List[str] + queries: List[dict] def get_db_documents(configs: list[DBLoaderConfig]): docs = [] @@ -123,33 +122,19 @@ def get_db_documents(configs: list[DBLoaderConfig]): return docs metadata = { - #'file_name':'', - 'file_type':'application/booway.document.zj', - #'file_path':'', - #'file_size':'', - #'creation_date':'', - #'last_modified_date':'', + 'file_type': 'application/booway.document.zj', } - #from llama_index.readers.database import DatabaseReader for entry in configs: engine = create_engine(entry.uri) sql_database = SQLDatabase(engine) - # table_schema_objs = makeDescriptionByEngine(sql_database) - # table_node_mapping = SQLTableNodeMapping(sql_database) - # - # nodes = table_node_mapping.to_nodes(table_schema_objs) - # for node in nodes: - # node.metadata.update(metadata) - # - # docs.extend(nodes) - - queries = entry.queries or [] loader = CustomDatabaseReader(sql_database) - for query in queries: + for query_dict in entry.queries: + query = query_dict.get("sql", "") + explanation = query_dict.get("explanation", "") logger.info(f"Loading data from database with query: {query}") - documents = loader.load_data(query=query) + documents = loader.load_data(query=query, explanation=explanation) docs.extend(documents) return docs diff --git a/backend/app/engine/prompt.py b/backend/app/engine/prompt.py index 101b6bf..511040e 100644 --- a/backend/app/engine/prompt.py +++ b/backend/app/engine/prompt.py @@ -39,15 +39,16 @@ refine_template_str = ( "这是原本的问题: {query_str}\n" "我们已经提供了回答: {existing_answer}\n" "现在我们有机会改进这个回答 " - "使用以下更多上下文(仅当需要用时)\n" + "使用以下更多上下文(仅当有助于改进回答时使用)\n" "------------\n" "{context_msg}\n" "------------\n" - "根据新的上下文, 请改进原来的回答。" - "如果新的上下文没有用, 直接返回原本的回答。\n" - "如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n" + "如果新的上下文对回答没有影响,或者原来的回答已经正确,直接返回原本的回答。\n" + "如果新的上下文有助于改进,请基于它更新回答,但不要引入与问题无关的信息。\n" + "如果是表结构或者是数据库的相关内容,仅用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n" "改进的回答: " ) + refine_template = PromptTemplate(refine_template_str) summary_template_str = (