初始化提交
This commit is contained in:
@@ -20,9 +20,9 @@ class CallbackEvent(BaseModel):
|
||||
if self.payload:
|
||||
nodes = self.payload.get("nodes")
|
||||
if nodes:
|
||||
msg = f"Retrieved {len(nodes)} sources to use as context for the query"
|
||||
msg = f"根据查询检索到 {len(nodes)} 源文件"
|
||||
else:
|
||||
msg = f"Retrieving context for query: '{self.payload.get('query_str')}'"
|
||||
msg = f"查询检索中: '{self.payload.get('query_str')}'"
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {"title": msg},
|
||||
@@ -37,7 +37,7 @@ class CallbackEvent(BaseModel):
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {
|
||||
"title": f"Calling tool: {tool.name} with inputs: {func_call_args}",
|
||||
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ class CallbackEvent(BaseModel):
|
||||
case _:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in converting event to response: {e}")
|
||||
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -173,12 +173,12 @@ class SourceNodes(BaseModel):
|
||||
def from_source_node(cls, source_node: NodeWithScore):
|
||||
metadata = source_node.node.metadata
|
||||
url = cls.get_url_from_metadata(metadata)
|
||||
|
||||
text = 'filename' in metadata and metadata['filename'] or source_node.node.node_id
|
||||
return cls(
|
||||
id=source_node.node.node_id,
|
||||
metadata=metadata,
|
||||
score=source_node.score,
|
||||
text=source_node.node.text, # type: ignore
|
||||
text=text, # type: ignore
|
||||
url=url,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,24 +1,67 @@
|
||||
import os
|
||||
|
||||
from llama_index.core import SQLDatabase, SummaryIndex, VectorStoreIndex
|
||||
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.agent import AgentRunner
|
||||
from llama_index.core.agent import AgentRunner, StructuredPlannerAgent, FunctionCallingAgentWorker
|
||||
from llama_index.core.tools.query_engine import QueryEngineTool
|
||||
from sqlalchemy import create_engine, Engine
|
||||
|
||||
from app.engine.loaders.db import makeDescriptionByEngine
|
||||
from app.engine.tools import ToolFactory
|
||||
from app.engine.index import get_index
|
||||
|
||||
sql_database = None
|
||||
sql_obj_index = None
|
||||
|
||||
def get_chat_engine(filters=None, params=None):
|
||||
system_prompt = os.getenv("SYSTEM_PROMPT")
|
||||
top_k = os.getenv("TOP_K", "3")
|
||||
top_k = int(os.getenv("TOP_K", "3"))
|
||||
tools = []
|
||||
|
||||
global sql_obj_index
|
||||
global sql_database
|
||||
if sql_obj_index is None:
|
||||
sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||
sql_database = SQLDatabase(sqlengine)
|
||||
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
|
||||
sql_obj_index = ObjectIndex.from_objects(
|
||||
table_schema_objs,
|
||||
table_node_mapping,
|
||||
index_cls=VectorStoreIndex,
|
||||
)
|
||||
|
||||
# 创建SQL查询工具
|
||||
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
||||
sql_obj_index.as_retriever(similarity_top_k=top_k),
|
||||
verbose=True,)
|
||||
sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine,
|
||||
name="zjdata_query_tool",
|
||||
description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具")
|
||||
|
||||
# Add query tool if index exists
|
||||
index = get_index()
|
||||
if index is not None:
|
||||
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None))
|
||||
summary_query_engine = summary_index.as_query_engine()
|
||||
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
||||
description="适用于任何需要进行全面总结、概括的要求。",
|
||||
#description="适用于任何需要对所有内容进行全面总结的请求。有关电力造价领域更具体部分的问题,请使用zj_query_engine_tool",
|
||||
)
|
||||
|
||||
# 创建向量检索查询工具
|
||||
query_engine = index.as_query_engine(
|
||||
similarity_top_k=int(top_k), filters=filters
|
||||
similarity_top_k=top_k, filters=filters
|
||||
)
|
||||
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
|
||||
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
|
||||
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
|
||||
)
|
||||
tools.append(summary_query_tool)
|
||||
tools.append(query_engine_tool)
|
||||
#tools.append(sql_query_tool)
|
||||
|
||||
# Add additional tools
|
||||
tools += ToolFactory.from_env()
|
||||
@@ -29,3 +72,10 @@ def get_chat_engine(filters=None, params=None):
|
||||
system_prompt=system_prompt,
|
||||
verbose=True,
|
||||
)
|
||||
# create the function calling worker for reasoning
|
||||
# worker = FunctionCallingAgentWorker.from_tools(
|
||||
# tools, verbose=True
|
||||
# )
|
||||
#
|
||||
# # wrap the worker in the top-level planner
|
||||
# return StructuredPlannerAgent(worker, tools)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
STORAGE_DIR = "storage" # directory to cache the generated index
|
||||
@@ -2,50 +2,84 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
|
||||
import os
|
||||
|
||||
from app.engine.loaders import get_documents
|
||||
from app.engine.vectordb import get_vector_store
|
||||
from app.settings import init_settings
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
||||
|
||||
|
||||
def get_doc_store():
|
||||
|
||||
# If the storage directory is there, load the document store from it.
|
||||
# If not, set up an in-memory document store since we can't load from a directory that doesn't exist.
|
||||
if os.path.exists(STORAGE_DIR):
|
||||
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR)
|
||||
else:
|
||||
return SimpleDocumentStore()
|
||||
|
||||
|
||||
def run_pipeline(docstore, vector_store, documents):
|
||||
pipeline = IngestionPipeline(
|
||||
transformations=[
|
||||
SentenceSplitter(
|
||||
chunk_size=Settings.chunk_size,
|
||||
chunk_overlap=Settings.chunk_overlap,
|
||||
),
|
||||
Settings.embed_model,
|
||||
],
|
||||
docstore=docstore,
|
||||
docstore_strategy="upserts_and_delete",
|
||||
vector_store=vector_store,
|
||||
)
|
||||
|
||||
# Run the ingestion pipeline and store the results
|
||||
nodes = pipeline.run(show_progress=True, documents=documents)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def persist_storage(docstore, vector_store):
|
||||
storage_context = StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
vector_store=vector_store,
|
||||
)
|
||||
storage_context.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Generate index for the provided data")
|
||||
|
||||
name = os.getenv("LLAMA_CLOUD_INDEX_NAME")
|
||||
project_name = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
|
||||
api_key = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
base_url = os.getenv("LLAMA_CLOUD_BASE_URL")
|
||||
organization_id = os.getenv("LLAMA_CLOUD_ORGANIZATION_ID")
|
||||
|
||||
if name is None or project_name is None or api_key is None:
|
||||
raise ValueError(
|
||||
"Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
|
||||
" to your environment variables or config them in .env file"
|
||||
)
|
||||
|
||||
# Get the stores and documents or create new ones
|
||||
documents = get_documents()
|
||||
|
||||
# Set private=false to mark the document as public (required for filtering)
|
||||
for doc in documents:
|
||||
doc.metadata["private"] = "false"
|
||||
docstore = get_doc_store()
|
||||
vector_store = get_vector_store()
|
||||
|
||||
LlamaCloudIndex.from_documents(
|
||||
documents=documents,
|
||||
name=name,
|
||||
project_name=project_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
organization_id=organization_id
|
||||
)
|
||||
# Run the ingestion pipeline
|
||||
_ = run_pipeline(docstore, vector_store, documents)
|
||||
|
||||
# Build the index and persist storage
|
||||
persist_storage(docstore, vector_store)
|
||||
|
||||
logger.info("Finished generating the index")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
from phoenix.trace import using_project
|
||||
with using_project(os.getenv("PHOENIX_PROJECT_NAME") + "_generate") as obj:
|
||||
generate_datasource()
|
||||
|
||||
+13
-22
@@ -1,31 +1,22 @@
|
||||
import logging
|
||||
import os
|
||||
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from app.engine.vectordb import get_vector_store
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
index = None
|
||||
|
||||
def get_index(params=None):
|
||||
configParams = params or {}
|
||||
pipelineConfig = configParams.get("llamaCloudPipeline", {})
|
||||
name = pipelineConfig.get("pipeline", os.getenv("LLAMA_CLOUD_INDEX_NAME"))
|
||||
project_name = pipelineConfig.get("project", os.getenv("LLAMA_CLOUD_PROJECT_NAME"))
|
||||
api_key = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
base_url = os.getenv("LLAMA_CLOUD_BASE_URL")
|
||||
organization_id = os.getenv("LLAMA_CLOUD_ORGANIZATION_ID")
|
||||
global index
|
||||
if index is None:
|
||||
logger.info("Connecting vector store...")
|
||||
|
||||
if name is None or project_name is None or api_key is None:
|
||||
raise ValueError(
|
||||
"Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
|
||||
" to your environment variables or config them in .env file"
|
||||
)
|
||||
|
||||
index = LlamaCloudIndex(
|
||||
name=name,
|
||||
project_name=project_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
organization_id=organization_id
|
||||
)
|
||||
store = get_vector_store()
|
||||
# Load the index from the vector store
|
||||
# If you are using a vector store that doesn't store text,
|
||||
# you must load the index from both the vector store and the document store
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished load index from vector store.")
|
||||
|
||||
return index
|
||||
|
||||
@@ -17,19 +17,22 @@ def load_configs():
|
||||
def get_documents():
|
||||
documents = []
|
||||
config = load_configs()
|
||||
if config is None or len(config.items()) == 0:
|
||||
return documents
|
||||
|
||||
for loader_type, loader_config in config.items():
|
||||
logger.info(
|
||||
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
||||
)
|
||||
|
||||
loader_config = loader_config or []
|
||||
match loader_type:
|
||||
case "file":
|
||||
document = get_file_documents(FileLoaderConfig(**loader_config))
|
||||
case "web":
|
||||
document = get_web_documents(WebLoaderConfig(**loader_config))
|
||||
case "db":
|
||||
document = get_db_documents(
|
||||
configs=[DBLoaderConfig(**cfg) for cfg in loader_config]
|
||||
)
|
||||
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
|
||||
case _:
|
||||
raise ValueError(f"Invalid loader type: {loader_type}")
|
||||
documents.extend(document)
|
||||
|
||||
@@ -1,26 +1,187 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.schema import Document
|
||||
from llama_index.core.utilities.sql_wrapper import SQLDatabase
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Engine
|
||||
from llama_index.core import SQLDatabase, Document
|
||||
from llama_index.core.objects import SQLTableSchema, SQLTableNodeMapping
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from pydantic import BaseModel, validator
|
||||
from llama_index.core.indices.vector_store import VectorStoreIndex
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CustomDatabaseReader(BaseReader):
|
||||
"""Simple Database reader.
|
||||
|
||||
Concatenates each row into Document used by LlamaIndex.
|
||||
|
||||
Args:
|
||||
sql_database (Optional[SQLDatabase]): SQL database to use,
|
||||
including table names to specify.
|
||||
See :ref:`Ref-Struct-Store` for more details.
|
||||
|
||||
OR
|
||||
|
||||
engine (Optional[Engine]): SQLAlchemy Engine object of the database connection.
|
||||
|
||||
OR
|
||||
|
||||
uri (Optional[str]): uri of the database connection.
|
||||
|
||||
OR
|
||||
|
||||
scheme (Optional[str]): scheme of the database connection.
|
||||
host (Optional[str]): host of the database connection.
|
||||
port (Optional[int]): port of the database connection.
|
||||
user (Optional[str]): user of the database connection.
|
||||
password (Optional[str]): password of the database connection.
|
||||
dbname (Optional[str]): dbname of the database connection.
|
||||
|
||||
Returns:
|
||||
DatabaseReader: A DatabaseReader object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sql_database: Optional[SQLDatabase] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
uri: Optional[str] = None,
|
||||
scheme: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
dbname: Optional[str] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with parameters."""
|
||||
if sql_database:
|
||||
self.sql_database = sql_database
|
||||
elif engine:
|
||||
self.sql_database = SQLDatabase(engine, *args, **kwargs)
|
||||
elif uri:
|
||||
self.uri = uri
|
||||
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||||
elif scheme and host and port and user and password and dbname:
|
||||
uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
|
||||
self.uri = uri
|
||||
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You must provide either a SQLDatabase, "
|
||||
"a SQL Alchemy Engine, a valid connection URI, or a valid "
|
||||
"set of credentials."
|
||||
)
|
||||
|
||||
def load_data(self, query: str) -> List[Document]:
|
||||
"""Query and load data from the Database, returning a list of Documents.
|
||||
|
||||
Args:
|
||||
query (str): Query parameter to filter tables and rows.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
dco_str = ""
|
||||
with self.sql_database.engine.connect() as connection:
|
||||
if query is None:
|
||||
raise ValueError("A query parameter is necessary to filter the data")
|
||||
else:
|
||||
result = connection.execute(text(query))
|
||||
|
||||
dco_str = ", ".join(
|
||||
[f"{entry}" for entry in result.keys()]
|
||||
)
|
||||
|
||||
for item in result.fetchall():
|
||||
# fetch each item
|
||||
record_str = ", ".join(
|
||||
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
||||
)
|
||||
dco_str += record_str + "\n"
|
||||
|
||||
doc = Document(text=dco_str)
|
||||
doc.metadata["name"] = query
|
||||
doc.metadata["context"] = query
|
||||
doc.metadata["file_type"] = "application/vnd.ms-excel"
|
||||
return [doc]
|
||||
|
||||
class DBLoaderConfig(BaseModel):
|
||||
uri: str
|
||||
queries: List[str]
|
||||
|
||||
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
||||
reader = DatabaseReader(sql_database)
|
||||
|
||||
table_names = sql_database.get_usable_table_names()
|
||||
table_schema_objs = []
|
||||
for table_name in table_names:
|
||||
columns = sql_database.get_table_columns(table_name)
|
||||
if len(columns) > 150:
|
||||
continue
|
||||
stats_txt = ""
|
||||
|
||||
if table_name == 'gongchengshuxing':
|
||||
stats_txt = '该表中有以下属性:'
|
||||
documents = reader.load_data(query='select name from gongchengshuxing')
|
||||
for index in range(len(documents) if len(documents) < 30 else 30):
|
||||
if index == 0:
|
||||
continue
|
||||
elif index > 1:
|
||||
stats_txt += ','
|
||||
stats_txt += documents[index].text.split(':')[1]
|
||||
|
||||
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
|
||||
table_schema_objs.append(tbSchema)
|
||||
|
||||
return table_schema_objs
|
||||
|
||||
def get_db_documents(configs: list[DBLoaderConfig]):
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
|
||||
docs = []
|
||||
|
||||
if len(configs) == 0 or configs[0].uri == "":
|
||||
logger.warning(
|
||||
f"Failed to load database, error message: uri is empty. Return as empty document list."
|
||||
)
|
||||
return docs
|
||||
|
||||
metadata = {
|
||||
#'file_name':'',
|
||||
'file_type':'application/booway.document.zj',
|
||||
#'file_path':'',
|
||||
#'file_size':'',
|
||||
#'creation_date':'',
|
||||
#'last_modified_date':'',
|
||||
}
|
||||
|
||||
#from llama_index.readers.database import DatabaseReader
|
||||
for entry in configs:
|
||||
loader = DatabaseReader(uri=entry.uri)
|
||||
for query in entry.queries:
|
||||
engine = create_engine(entry.uri)
|
||||
sql_database = SQLDatabase(engine)
|
||||
|
||||
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
|
||||
nodes = table_node_mapping.to_nodes(table_schema_objs)
|
||||
for node in nodes:
|
||||
node.metadata.update(metadata)
|
||||
|
||||
docs.extend(nodes)
|
||||
|
||||
queries = entry.queries or []
|
||||
loader = CustomDatabaseReader(sql_database)
|
||||
for query in queries:
|
||||
logger.info(f"Loading data from database with query: {query}")
|
||||
documents = loader.load_data(query=query)
|
||||
docs.extend(documents)
|
||||
|
||||
return documents
|
||||
docs.extend(documents)
|
||||
return docs
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.readers.json import JSONReader
|
||||
from llama_parse import LlamaParse
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
@@ -39,6 +42,9 @@ def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
||||
parser = llama_parse_parser()
|
||||
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
||||
|
||||
def llama_local_extractor() -> Dict[str, BaseReader]:
|
||||
return {"json" : JSONReader}
|
||||
|
||||
|
||||
def get_file_documents(config: FileLoaderConfig):
|
||||
from llama_index.core.readers import SimpleDirectoryReader
|
||||
@@ -53,6 +59,9 @@ def get_file_documents(config: FileLoaderConfig):
|
||||
nest_asyncio.apply()
|
||||
|
||||
file_extractor = llama_parse_extractor()
|
||||
else:
|
||||
file_extractor = llama_local_extractor()
|
||||
|
||||
reader = SimpleDirectoryReader(
|
||||
config.data_dir,
|
||||
recursive=True,
|
||||
|
||||
@@ -11,7 +11,7 @@ class CrawlUrl(BaseModel):
|
||||
|
||||
class WebLoaderConfig(BaseModel):
|
||||
driver_arguments: list[str] = Field(default=None)
|
||||
urls: list[CrawlUrl]
|
||||
urls: list[CrawlUrl] = []
|
||||
|
||||
|
||||
def get_web_documents(config: WebLoaderConfig):
|
||||
@@ -25,6 +25,7 @@ def get_web_documents(config: WebLoaderConfig):
|
||||
options.add_argument(arg)
|
||||
|
||||
docs = []
|
||||
urls = config.urls or []
|
||||
for url in config.urls:
|
||||
scraper = WholeSiteReader(
|
||||
prefix=url.prefix,
|
||||
|
||||
@@ -48,9 +48,13 @@ class ToolFactory:
|
||||
if os.path.exists("config/tools.yaml"):
|
||||
with open("config/tools.yaml", "r") as f:
|
||||
tool_configs = yaml.safe_load(f)
|
||||
for tool_type, config_entries in tool_configs.items():
|
||||
for tool_name, config in config_entries.items():
|
||||
tools.extend(
|
||||
ToolFactory.load_tools(tool_type, tool_name, config)
|
||||
)
|
||||
if tool_configs != None and len(tool_configs.items()) != 0:
|
||||
for tool_type, config_entries in tool_configs.items():
|
||||
if config_entries == None or len(config_entries.items()) == 0:
|
||||
continue
|
||||
|
||||
for tool_name, config in config_entries.items():
|
||||
tools.extend(
|
||||
ToolFactory.load_tools(tool_type, tool_name, config)
|
||||
)
|
||||
return tools
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
from qdrant_client import qdrant_client
|
||||
|
||||
qclient = None
|
||||
|
||||
def get_qdrant_vector_store():
|
||||
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
||||
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
||||
host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("VECTOR_STORE_PORT", "6333")),
|
||||
|
||||
if not vector_store_path or not host:
|
||||
raise ValueError(
|
||||
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
|
||||
)
|
||||
# if VECTOR_STORE_PATH is set, use a local QdrantVectorStore from the path
|
||||
# otherwise, use a remote QdrantVectorStore
|
||||
global qclient
|
||||
if qclient == None:
|
||||
if vector_store_path:
|
||||
qclient = qdrant_client.QdrantClient(
|
||||
path=vector_store_path,
|
||||
)
|
||||
else:
|
||||
qclient = qdrant_client.QdrantClient(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name)
|
||||
return vector_store
|
||||
|
||||
def get_chroma_vector_store():
|
||||
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
||||
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
||||
# if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path
|
||||
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
|
||||
if vector_store_path:
|
||||
store = ChromaVectorStore.from_params(
|
||||
persist_dir=vector_store_path, collection_name=collection_name,
|
||||
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
|
||||
)
|
||||
else:
|
||||
if not os.getenv("VECTOR_STORE_HOST") or not os.getenv("VECTOR_STORE_PORT"):
|
||||
raise ValueError(
|
||||
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
|
||||
)
|
||||
store = ChromaVectorStore.from_params(
|
||||
host=os.getenv("VECTOR_STORE_HOST"),
|
||||
port=int(os.getenv("VECTOR_STORE_PORT")),
|
||||
collection_name=collection_name,
|
||||
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
|
||||
)
|
||||
return store
|
||||
|
||||
def get_vector_store():
|
||||
store_type=os.getenv("VECTOR_STORE_TYPE")
|
||||
|
||||
store = None
|
||||
|
||||
match store_type:
|
||||
case "chroma":
|
||||
store = get_chroma_vector_store()
|
||||
case "qdrant":
|
||||
store = get_qdrant_vector_store()
|
||||
case _:
|
||||
raise ValueError(f"Invalid vector store type: {store_type}")
|
||||
|
||||
return store
|
||||
@@ -1,2 +1,20 @@
|
||||
import os
|
||||
|
||||
import llama_index.core
|
||||
|
||||
def init_observability():
|
||||
pass
|
||||
|
||||
PHOENIX_API_KEY = os.getenv("PHOENIX_API_KEY")
|
||||
if not PHOENIX_API_KEY:
|
||||
raise ValueError("PHOENIX_API_KEY environment variable is not set")
|
||||
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={PHOENIX_API_KEY}"
|
||||
PHOENIX_URL = os.getenv("PHOENIX_URL")
|
||||
llama_index.core.set_global_handler(
|
||||
"arize_phoenix", endpoint=PHOENIX_URL, eval_params={}
|
||||
)
|
||||
|
||||
#debugHandle=[]
|
||||
# llama_debug = LlamaDebugHandler(print_trace_on_end=True)
|
||||
# debugHandle.append(llama_debug)
|
||||
# callback_manager = CallbackManager(debugHandle)
|
||||
# settings.Settings.callback_manager = callback_manager
|
||||
|
||||
+125
-96
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.core.settings import Settings
|
||||
|
||||
|
||||
@@ -9,6 +10,8 @@ def init_settings():
|
||||
match model_provider:
|
||||
case "openai":
|
||||
init_openai()
|
||||
case "dashscope":
|
||||
init_dashscope()
|
||||
case "groq":
|
||||
init_groq()
|
||||
case "ollama":
|
||||
@@ -33,20 +36,21 @@ def init_settings():
|
||||
|
||||
|
||||
def init_ollama():
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
||||
|
||||
base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||
request_timeout = float(
|
||||
os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
|
||||
)
|
||||
Settings.embed_model = OllamaEmbedding(
|
||||
base_url=base_url,
|
||||
model_name=os.getenv("EMBEDDING_MODEL"),
|
||||
)
|
||||
Settings.llm = Ollama(
|
||||
base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
|
||||
)
|
||||
# from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
# from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
||||
#
|
||||
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||
# request_timeout = float(
|
||||
# os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
|
||||
# )
|
||||
# Settings.embed_model = OllamaEmbedding(
|
||||
# base_url=base_url,
|
||||
# model_name=os.getenv("EMBEDDING_MODEL"),
|
||||
# )
|
||||
# Settings.llm = Ollama(
|
||||
# base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def init_openai():
|
||||
@@ -69,104 +73,129 @@ def init_openai():
|
||||
}
|
||||
Settings.embed_model = OpenAIEmbedding(**config)
|
||||
|
||||
def init_dashscope():
|
||||
from llama_index.llms.dashscope import DashScope,DashScopeGenerationModels
|
||||
from llama_index.embeddings.dashscope import DashScopeEmbedding,DashScopeBatchTextEmbeddingModels,DashScopeTextEmbeddingType,DashScopeTextEmbeddingModels
|
||||
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
config = {
|
||||
"model": os.getenv("MODEL"),
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
Settings.llm = llm = DashScope(model_name=DashScopeGenerationModels.QWEN_MAX)
|
||||
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
config = {
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
Settings.embed_model = DashScopeEmbedding(model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
|
||||
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY)
|
||||
|
||||
|
||||
def init_azure_openai():
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.llms.azure_openai import AzureOpenAI
|
||||
|
||||
llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
||||
embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
|
||||
azure_config = {
|
||||
"api_key": os.environ["AZURE_OPENAI_KEY"],
|
||||
"azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||
"api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
||||
or os.getenv("OPENAI_API_VERSION"),
|
||||
}
|
||||
|
||||
Settings.llm = AzureOpenAI(
|
||||
model=os.getenv("MODEL"),
|
||||
max_tokens=int(max_tokens) if max_tokens is not None else None,
|
||||
temperature=float(temperature),
|
||||
deployment_name=llm_deployment,
|
||||
**azure_config,
|
||||
)
|
||||
|
||||
Settings.embed_model = AzureOpenAIEmbedding(
|
||||
model=os.getenv("EMBEDDING_MODEL"),
|
||||
dimensions=int(dimensions) if dimensions is not None else None,
|
||||
deployment_name=embedding_deployment,
|
||||
**azure_config,
|
||||
)
|
||||
# from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
# from llama_index.llms.azure_openai import AzureOpenAI
|
||||
#
|
||||
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
||||
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
||||
# max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
# dimensions = os.getenv("EMBEDDING_DIM")
|
||||
#
|
||||
# azure_config = {
|
||||
# "api_key": os.environ["AZURE_OPENAI_KEY"],
|
||||
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
||||
# or os.getenv("OPENAI_API_VERSION"),
|
||||
# }
|
||||
#
|
||||
# Settings.llm = AzureOpenAI(
|
||||
# model=os.getenv("MODEL"),
|
||||
# max_tokens=int(max_tokens) if max_tokens is not None else None,
|
||||
# temperature=float(temperature),
|
||||
# deployment_name=llm_deployment,
|
||||
# **azure_config,
|
||||
# )
|
||||
#
|
||||
# Settings.embed_model = AzureOpenAIEmbedding(
|
||||
# model=os.getenv("EMBEDDING_MODEL"),
|
||||
# dimensions=int(dimensions) if dimensions is not None else None,
|
||||
# deployment_name=embedding_deployment,
|
||||
# **azure_config,
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def init_fastembed():
|
||||
"""
|
||||
Use Qdrant Fastembed as the local embedding provider.
|
||||
"""
|
||||
from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
||||
|
||||
embed_model_map: Dict[str, str] = {
|
||||
# Small and multilingual
|
||||
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
# Large and multilingual
|
||||
"paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
|
||||
}
|
||||
|
||||
# This will download the model automatically if it is not already downloaded
|
||||
Settings.embed_model = FastEmbedEmbedding(
|
||||
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
||||
)
|
||||
# from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
||||
#
|
||||
# embed_model_map: Dict[str, str] = {
|
||||
# # Small and multilingual
|
||||
# "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
# # Large and multilingual
|
||||
# "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
|
||||
# }
|
||||
#
|
||||
# # This will download the model automatically if it is not already downloaded
|
||||
# Settings.embed_model = FastEmbedEmbedding(
|
||||
# model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def init_groq():
|
||||
from llama_index.llms.groq import Groq
|
||||
|
||||
model_map: Dict[str, str] = {
|
||||
"llama3-8b": "llama3-8b-8192",
|
||||
"llama3-70b": "llama3-70b-8192",
|
||||
"mixtral-8x7b": "mixtral-8x7b-32768",
|
||||
}
|
||||
|
||||
Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
|
||||
# Groq does not provide embeddings, so we use FastEmbed instead
|
||||
init_fastembed()
|
||||
# from llama_index.llms.groq import Groq
|
||||
#
|
||||
# model_map: Dict[str, str] = {
|
||||
# "llama3-8b": "llama3-8b-8192",
|
||||
# "llama3-70b": "llama3-70b-8192",
|
||||
# "mixtral-8x7b": "mixtral-8x7b-32768",
|
||||
# }
|
||||
#
|
||||
# Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
|
||||
# # Groq does not provide embeddings, so we use FastEmbed instead
|
||||
# init_fastembed()
|
||||
pass
|
||||
|
||||
|
||||
def init_anthropic():
|
||||
from llama_index.llms.anthropic import Anthropic
|
||||
|
||||
model_map: Dict[str, str] = {
|
||||
"claude-3-opus": "claude-3-opus-20240229",
|
||||
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||
"claude-2.1": "claude-2.1",
|
||||
"claude-instant-1.2": "claude-instant-1.2",
|
||||
}
|
||||
|
||||
Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
|
||||
# Anthropic does not provide embeddings, so we use FastEmbed instead
|
||||
init_fastembed()
|
||||
# from llama_index.llms.anthropic import Anthropic
|
||||
#
|
||||
# model_map: Dict[str, str] = {
|
||||
# "claude-3-opus": "claude-3-opus-20240229",
|
||||
# "claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
# "claude-3-haiku": "claude-3-haiku-20240307",
|
||||
# "claude-2.1": "claude-2.1",
|
||||
# "claude-instant-1.2": "claude-instant-1.2",
|
||||
# }
|
||||
#
|
||||
# Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
|
||||
# # Anthropic does not provide embeddings, so we use FastEmbed instead
|
||||
# init_fastembed()
|
||||
pass
|
||||
|
||||
|
||||
def init_gemini():
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
from llama_index.llms.gemini import Gemini
|
||||
|
||||
model_name = f"models/{os.getenv('MODEL')}"
|
||||
embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
|
||||
|
||||
Settings.llm = Gemini(model=model_name)
|
||||
Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
|
||||
|
||||
# from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
# from llama_index.llms.gemini import Gemini
|
||||
#
|
||||
# model_name = f"models/{os.getenv('MODEL')}"
|
||||
# embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
|
||||
#
|
||||
# Settings.llm = Gemini(model=model_name)
|
||||
# Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
|
||||
pass
|
||||
|
||||
def init_mistral():
|
||||
from llama_index.embeddings.mistralai import MistralAIEmbedding
|
||||
from llama_index.llms.mistralai import MistralAI
|
||||
|
||||
Settings.llm = MistralAI(model=os.getenv("MODEL"))
|
||||
Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||
# from llama_index.embeddings.mistralai import MistralAIEmbedding
|
||||
# from llama_index.llms.mistralai import MistralAI
|
||||
#
|
||||
# Settings.llm = MistralAI(model=os.getenv("MODEL"))
|
||||
# Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||
pass
|
||||
Reference in New Issue
Block a user