Compare commits
45 Commits
a200e8adfc
..
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| e634746a52 | |||
| d12800e14e | |||
| c1df0d1bba | |||
| 0664952ecd | |||
| 7023b54246 | |||
| aee6aa3c04 | |||
| 680e24c516 | |||
| 6663ee8976 | |||
| 0a5f335981 | |||
| 2901bd9eaf | |||
| 453b3ca55c | |||
| 03c4eb1af1 | |||
| 480a1f7fdc | |||
| cdc9d84a1e | |||
| 50f35bb0c9 | |||
| 4a8c79e83d | |||
| f0afd1a4bb | |||
| de34c3938c | |||
| eb572eff27 | |||
| 2706cf9d5a | |||
| 5fa4752d6e | |||
| aff1793c4e | |||
| 0db159ac89 | |||
| 131d6ef1d1 | |||
| 3ee1ba529f | |||
| 576a2ae737 | |||
| 9b47e1a6e1 | |||
| 20510a937b | |||
| a7c79df339 | |||
| 327bba75d5 | |||
| d1242d2080 | |||
| 0f09551f5d | |||
| 8a5facb5b6 | |||
| 0f7c900c1e | |||
| b008ad9766 | |||
| 56459c164e | |||
| 07a3b2a147 | |||
| b4c571cddb | |||
| 7068b058e8 | |||
| 33b2281b7b | |||
| 1704b61609 | |||
| afccaf6eb5 | |||
| b052d373f1 | |||
| 7462244f01 | |||
| 2b64aca26b |
@@ -0,0 +1,3 @@
|
||||
[submodule "webapp"]
|
||||
path = webapp
|
||||
url = https://git.97id.com/ly/webapp.git
|
||||
@@ -1,7 +1,13 @@
|
||||
JIEBA_DATA=./nltk_data
|
||||
NLTK_DATA=./nltk_data
|
||||
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||
DATA_SOURCE_CACHE=./restapi
|
||||
|
||||
# The Llama Cloud API key.
|
||||
# LLAMA_CLOUD_API_KEY=
|
||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||
|
||||
DASHSCOPE_API_KEY=sk-02c8540e86d84b7ca0e6f4f51bac6e60
|
||||
# The provider for the AI models to use.
|
||||
@@ -79,3 +85,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
|
||||
- You can install any pip package (if it exists) by running a cell with pip install.
|
||||
"
|
||||
|
||||
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
|
||||
@@ -1,7 +1,13 @@
|
||||
JIEBA_DATA=./nltk_data
|
||||
NLTK_DATA=./nltk_data
|
||||
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||
DATA_SOURCE_CACHE=./restapi
|
||||
|
||||
# The Llama Cloud API key.
|
||||
# LLAMA_CLOUD_API_KEY=
|
||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||
|
||||
# The number of similar embeddings to return when retrieving documents.
|
||||
TOP_K=10
|
||||
@@ -110,3 +116,4 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
|
||||
- You can install any pip package (if it exists) by running a cell with pip install.
|
||||
"
|
||||
|
||||
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
|
||||
@@ -1,61 +0,0 @@
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.core.settings import Settings
|
||||
from typing import Dict
|
||||
import os
|
||||
|
||||
DEFAULT_MODEL = "gpt-3.5-turbo"
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large"
|
||||
|
||||
class TSIEmbedding(OpenAIEmbedding):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._query_engine = self._text_engine = self.model_name
|
||||
|
||||
def llm_config_from_env() -> Dict:
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
|
||||
model = os.getenv("MODEL", DEFAULT_MODEL)
|
||||
temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
api_key = os.getenv("T_SYSTEMS_LLMHUB_API_KEY")
|
||||
api_base = os.getenv("T_SYSTEMS_LLMHUB_BASE_URL")
|
||||
|
||||
config = {
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"temperature": float(temperature),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def embedding_config_from_env() -> Dict:
|
||||
from llama_index.core.constants import DEFAULT_EMBEDDING_DIM
|
||||
|
||||
model = os.getenv("EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
|
||||
dimension = os.getenv("EMBEDDING_DIM", DEFAULT_EMBEDDING_DIM)
|
||||
api_key = os.getenv("T_SYSTEMS_LLMHUB_API_KEY")
|
||||
api_base = os.getenv("T_SYSTEMS_LLMHUB_BASE_URL")
|
||||
|
||||
config = {
|
||||
"model_name": model,
|
||||
"dimension": int(dimension) if dimension is not None else None,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
}
|
||||
return config
|
||||
|
||||
def init_llmhub():
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
|
||||
llm_configs = llm_config_from_env()
|
||||
embedding_configs = embedding_config_from_env()
|
||||
|
||||
Settings.embed_model = TSIEmbedding(**embedding_configs)
|
||||
Settings.llm = OpenAILike(
|
||||
**llm_configs,
|
||||
is_chat_model=True,
|
||||
is_function_calling_model=False,
|
||||
context_window=4096,
|
||||
)
|
||||
@@ -1,20 +0,0 @@
|
||||
import os
|
||||
|
||||
import llama_index.core
|
||||
|
||||
def init_observability():
|
||||
|
||||
PHOENIX_API_KEY = os.getenv("PHOENIX_API_KEY")
|
||||
if not PHOENIX_API_KEY:
|
||||
raise ValueError("PHOENIX_API_KEY environment variable is not set")
|
||||
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={PHOENIX_API_KEY}"
|
||||
PHOENIX_URL = os.getenv("PHOENIX_URL")
|
||||
llama_index.core.set_global_handler(
|
||||
"arize_phoenix", endpoint=PHOENIX_URL, eval_params={}
|
||||
)
|
||||
|
||||
#debugHandle=[]
|
||||
# llama_debug = LlamaDebugHandler(print_trace_on_end=True)
|
||||
# debugHandle.append(llama_debug)
|
||||
# callback_manager = CallbackManager(debugHandle)
|
||||
# settings.Settings.callback_manager = callback_manager
|
||||
@@ -1,235 +0,0 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.llms.xinference import Xinference
|
||||
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
||||
|
||||
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||
if rerank_enabled is None or rerank_enabled == 'False':
|
||||
return []
|
||||
|
||||
rerank_model = os.getenv("RERANK_MODEL")
|
||||
rerank_url = os.getenv("RERANK_BASE_URL")
|
||||
rerank_top_n = os.getenv("RERANK_TOP_N")
|
||||
rerank_threshold = os.getenv("RERANK_THRESHOLD")
|
||||
postprocess = None
|
||||
if rerank_model is not None:
|
||||
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
|
||||
return postprocess
|
||||
|
||||
def init_settings():
|
||||
model_provider = os.getenv("MODEL_PROVIDER")
|
||||
match model_provider:
|
||||
case "openai":
|
||||
init_openai()
|
||||
case "dashscope":
|
||||
init_dashscope()
|
||||
case "groq":
|
||||
init_groq()
|
||||
case "ollama":
|
||||
init_ollama()
|
||||
case "anthropic":
|
||||
init_anthropic()
|
||||
case "gemini":
|
||||
init_gemini()
|
||||
case "mistral":
|
||||
init_mistral()
|
||||
case "azure-openai":
|
||||
init_azure_openai()
|
||||
case "t-systems":
|
||||
from .llmhub import init_llmhub
|
||||
init_llmhub()
|
||||
case "xinference":
|
||||
init_xinference()
|
||||
case _:
|
||||
raise ValueError(f"Invalid model provider: {model_provider}")
|
||||
|
||||
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
|
||||
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
|
||||
|
||||
|
||||
def init_ollama():
|
||||
# from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
# from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
||||
#
|
||||
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||
# request_timeout = float(
|
||||
# os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
|
||||
# )
|
||||
# Settings.embed_model = OllamaEmbedding(
|
||||
# base_url=base_url,
|
||||
# model_name=os.getenv("EMBEDDING_MODEL"),
|
||||
# )
|
||||
# Settings.llm = Ollama(
|
||||
# base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
|
||||
# )
|
||||
pass
|
||||
|
||||
def init_xinference():
|
||||
base_url = os.getenv("BASE_URL")
|
||||
model = os.getenv("MODEL")
|
||||
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
|
||||
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
|
||||
|
||||
Settings.llm = Xinference(model, base_url, temperature, max_tokens)
|
||||
|
||||
embedding_base_url = os.getenv("EMBEDDING_BASE_URL")
|
||||
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url
|
||||
|
||||
embed_model_name = os.getenv("EMBEDDING_MODEL")
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
dimensions = int(dimensions) if dimensions is not None else None
|
||||
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
|
||||
|
||||
def init_openai():
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
config = {
|
||||
"model": os.getenv("MODEL"),
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
Settings.llm = OpenAI(**config)
|
||||
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
config = {
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
Settings.embed_model = OpenAIEmbedding(**config)
|
||||
|
||||
def init_dashscope():
|
||||
from llama_index.llms.dashscope import DashScope,DashScopeGenerationModels
|
||||
from llama_index.embeddings.dashscope import DashScopeEmbedding,DashScopeBatchTextEmbeddingModels,DashScopeTextEmbeddingType,DashScopeTextEmbeddingModels
|
||||
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
config = {
|
||||
"model": os.getenv("MODEL"),
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
Settings.llm = llm = DashScope(model_name=DashScopeGenerationModels.QWEN_MAX)
|
||||
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
config = {
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
Settings.embed_model = DashScopeEmbedding(model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
|
||||
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY)
|
||||
|
||||
|
||||
def init_azure_openai():
|
||||
# from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
# from llama_index.llms.azure_openai import AzureOpenAI
|
||||
#
|
||||
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
||||
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
||||
# max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
# dimensions = os.getenv("EMBEDDING_DIM")
|
||||
#
|
||||
# azure_config = {
|
||||
# "api_key": os.environ["AZURE_OPENAI_KEY"],
|
||||
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
||||
# or os.getenv("OPENAI_API_VERSION"),
|
||||
# }
|
||||
#
|
||||
# Settings.llm = AzureOpenAI(
|
||||
# model=os.getenv("MODEL"),
|
||||
# max_tokens=int(max_tokens) if max_tokens is not None else None,
|
||||
# temperature=float(temperature),
|
||||
# deployment_name=llm_deployment,
|
||||
# **azure_config,
|
||||
# )
|
||||
#
|
||||
# Settings.embed_model = AzureOpenAIEmbedding(
|
||||
# model=os.getenv("EMBEDDING_MODEL"),
|
||||
# dimensions=int(dimensions) if dimensions is not None else None,
|
||||
# deployment_name=embedding_deployment,
|
||||
# **azure_config,
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def init_fastembed():
|
||||
"""
|
||||
Use Qdrant Fastembed as the local embedding provider.
|
||||
"""
|
||||
# from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
||||
#
|
||||
# embed_model_map: Dict[str, str] = {
|
||||
# # Small and multilingual
|
||||
# "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
# # Large and multilingual
|
||||
# "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
|
||||
# }
|
||||
#
|
||||
# # This will download the model automatically if it is not already downloaded
|
||||
# Settings.embed_model = FastEmbedEmbedding(
|
||||
# model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def init_groq():
|
||||
# from llama_index.llms.groq import Groq
|
||||
#
|
||||
# model_map: Dict[str, str] = {
|
||||
# "llama3-8b": "llama3-8b-8192",
|
||||
# "llama3-70b": "llama3-70b-8192",
|
||||
# "mixtral-8x7b": "mixtral-8x7b-32768",
|
||||
# }
|
||||
#
|
||||
# Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
|
||||
# # Groq does not provide embeddings, so we use FastEmbed instead
|
||||
# init_fastembed()
|
||||
pass
|
||||
|
||||
|
||||
def init_anthropic():
|
||||
# from llama_index.llms.anthropic import Anthropic
|
||||
#
|
||||
# model_map: Dict[str, str] = {
|
||||
# "claude-3-opus": "claude-3-opus-20240229",
|
||||
# "claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
# "claude-3-haiku": "claude-3-haiku-20240307",
|
||||
# "claude-2.1": "claude-2.1",
|
||||
# "claude-instant-1.2": "claude-instant-1.2",
|
||||
# }
|
||||
#
|
||||
# Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
|
||||
# # Anthropic does not provide embeddings, so we use FastEmbed instead
|
||||
# init_fastembed()
|
||||
pass
|
||||
|
||||
|
||||
def init_gemini():
|
||||
# from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
# from llama_index.llms.gemini import Gemini
|
||||
#
|
||||
# model_name = f"models/{os.getenv('MODEL')}"
|
||||
# embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
|
||||
#
|
||||
# Settings.llm = Gemini(model=model_name)
|
||||
# Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
|
||||
pass
|
||||
|
||||
def init_mistral():
|
||||
# from llama_index.embeddings.mistralai import MistralAIEmbedding
|
||||
# from llama_index.llms.mistralai import MistralAI
|
||||
#
|
||||
# Settings.llm = MistralAI(model=os.getenv("MODEL"))
|
||||
# Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||
pass
|
||||
@@ -1,150 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
|
||||
from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore
|
||||
from llama_index.core.llms import MessageRole
|
||||
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
|
||||
|
||||
from app.api.routers.events import EventCallbackHandler
|
||||
from app.api.routers.models import (
|
||||
ChatConfig,
|
||||
ChatData,
|
||||
Message,
|
||||
Result,
|
||||
SourceNodes,
|
||||
)
|
||||
from app.api.routers.vercel_response import VercelStreamResponse
|
||||
from app.api.services.llama_cloud import LLamaCloudFileService
|
||||
from app.engine import get_chat_engine
|
||||
|
||||
chat_router = r = APIRouter()
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def process_response_nodes(
|
||||
nodes: List[NodeWithScore],
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
"""
|
||||
Start background tasks on the source nodes if needed.
|
||||
"""
|
||||
files_to_download = SourceNodes.get_download_files(nodes)
|
||||
for file in files_to_download:
|
||||
background_tasks.add_task(
|
||||
LLamaCloudFileService.download_llamacloud_pipeline_file, file
|
||||
)
|
||||
|
||||
|
||||
# streaming endpoint - delete if not needed
|
||||
@r.post("")
|
||||
async def chat(
|
||||
request: Request,
|
||||
data: ChatData,
|
||||
background_tasks: BackgroundTasks,
|
||||
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||
):
|
||||
try:
|
||||
last_message_content = data.get_last_message_content()
|
||||
# 由于基于历史消息的提示词没有调整好,所以暂时屏蔽历史消息
|
||||
data.messages.clear()
|
||||
messages = data.get_history_messages()
|
||||
|
||||
doc_ids = data.get_chat_document_ids()
|
||||
filters = generate_filters(doc_ids)
|
||||
params = data.data or {}
|
||||
logger.info("Creating chat engine with filters", filters.dict())
|
||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||
|
||||
event_handler = EventCallbackHandler()
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
|
||||
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||
process_response_nodes(response.source_nodes, background_tasks)
|
||||
|
||||
return VercelStreamResponse(request, event_handler, response, data)
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat engine", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error in chat engine: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
def generate_filters(doc_ids):
|
||||
if len(doc_ids) > 0:
|
||||
filters = MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="private",
|
||||
value=["true"],
|
||||
operator="nin", # type: ignore
|
||||
),
|
||||
MetadataFilter(
|
||||
key="doc_id",
|
||||
value=doc_ids,
|
||||
operator="in", # type: ignore
|
||||
),
|
||||
],
|
||||
condition="or", # type: ignore
|
||||
)
|
||||
else:
|
||||
filters = MetadataFilters(
|
||||
# Use the "NIN" - "not in" operator to include all public documents (don't have the private key set)
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="private",
|
||||
value=["true"],
|
||||
operator="nin", # type: ignore
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
# non-streaming endpoint - delete if not needed
|
||||
@r.post("/request")
|
||||
async def chat_request(
|
||||
data: ChatData,
|
||||
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||
) -> Result:
|
||||
last_message_content = data.get_last_message_content()
|
||||
messages = data.get_history_messages()
|
||||
|
||||
response = await chat_engine.achat(last_message_content, messages)
|
||||
return Result(
|
||||
result=Message(role=MessageRole.ASSISTANT, content=response.response),
|
||||
nodes=SourceNodes.from_source_nodes(response.source_nodes),
|
||||
)
|
||||
|
||||
|
||||
@r.get("/config")
|
||||
async def chat_config() -> ChatConfig:
|
||||
starter_questions = None
|
||||
conversation_starters = os.getenv("CONVERSATION_STARTERS")
|
||||
if conversation_starters and conversation_starters.strip():
|
||||
starter_questions = conversation_starters.strip().split("\\n")
|
||||
return ChatConfig(starter_questions=starter_questions)
|
||||
|
||||
|
||||
@r.get("/config/llamacloud")
|
||||
async def chat_llama_cloud_config():
|
||||
projects = LLamaCloudFileService.get_all_projects_with_pipelines()
|
||||
pipeline = os.getenv("LLAMA_CLOUD_INDEX_NAME")
|
||||
project = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
|
||||
pipeline_config = (
|
||||
pipeline
|
||||
and project
|
||||
and {
|
||||
"pipeline": pipeline,
|
||||
"project": project,
|
||||
}
|
||||
or None
|
||||
)
|
||||
return {
|
||||
"projects": projects,
|
||||
"pipeline": pipeline_config,
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||||
from llama_index.core.callbacks.base import BaseCallbackHandler
|
||||
from llama_index.core.callbacks.schema import CBEventType
|
||||
from llama_index.core.tools.types import ToolOutput
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CallbackEvent(BaseModel):
|
||||
event_type: CBEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
event_id: str = ""
|
||||
|
||||
def get_retrieval_message(self) -> dict | None:
|
||||
if self.payload:
|
||||
nodes = self.payload.get("nodes")
|
||||
if nodes:
|
||||
msg = f"根据查询检索到 {len(nodes)} 源文件"
|
||||
else:
|
||||
msg = f"查询检索中: '{self.payload.get('query_str')}'"
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {"title": msg},
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_tool_message(self) -> dict | None:
|
||||
func_call_args = self.payload.get("function_call")
|
||||
if func_call_args is not None and "tool" in self.payload:
|
||||
tool = self.payload.get("tool")
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {
|
||||
"title": f"调用工具 {tool.name} ,参数: {func_call_args}",
|
||||
},
|
||||
}
|
||||
|
||||
def _is_output_serializable(self, output: Any) -> bool:
|
||||
try:
|
||||
json.dumps(output)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def get_agent_tool_response(self) -> dict | None:
|
||||
response = self.payload.get("response")
|
||||
if response is not None:
|
||||
sources = response.sources
|
||||
for source in sources:
|
||||
# Return the tool response here to include the toolCall information
|
||||
if isinstance(source, ToolOutput):
|
||||
if self._is_output_serializable(source.raw_output):
|
||||
output = source.raw_output
|
||||
else:
|
||||
output = source.content
|
||||
|
||||
return {
|
||||
"type": "tools",
|
||||
"data": {
|
||||
"toolOutput": {
|
||||
"output": output,
|
||||
"isError": source.is_error,
|
||||
},
|
||||
"toolCall": {
|
||||
"id": None, # There is no tool id in the ToolOutput
|
||||
"name": source.tool_name,
|
||||
"input": source.raw_input,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def to_response(self):
|
||||
try:
|
||||
match self.event_type:
|
||||
case "retrieve":
|
||||
return self.get_retrieval_message()
|
||||
case "function_call":
|
||||
return self.get_tool_message()
|
||||
case "agent_step":
|
||||
return self.get_agent_tool_response()
|
||||
case _:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class EventCallbackHandler(BaseCallbackHandler):
|
||||
_aqueue: asyncio.Queue
|
||||
is_done: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
"""Initialize the base callback handler."""
|
||||
ignored_events = [
|
||||
CBEventType.CHUNKING,
|
||||
CBEventType.NODE_PARSING,
|
||||
CBEventType.EMBEDDING,
|
||||
CBEventType.LLM,
|
||||
CBEventType.TEMPLATING,
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
try:
|
||||
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
@@ -1,253 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from pydantic import BaseModel, Field, validator, field_validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class FileContent(BaseModel):
|
||||
type: Literal["text", "ref"]
|
||||
# If the file is pure text then the value is be a string
|
||||
# otherwise, it's a list of document IDs
|
||||
value: str | List[str]
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
id: str
|
||||
content: FileContent
|
||||
filename: str
|
||||
filesize: int
|
||||
filetype: str
|
||||
|
||||
|
||||
class AnnotationFileData(BaseModel):
|
||||
files: List[File] = Field(
|
||||
default=[],
|
||||
description="List of files",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"csvFiles": [
|
||||
{
|
||||
"content": "Name, Age\nAlice, 25\nBob, 30",
|
||||
"filename": "example.csv",
|
||||
"filesize": 123,
|
||||
"id": "123",
|
||||
"type": "text/csv",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
alias_generator = to_camel
|
||||
|
||||
|
||||
class Annotation(BaseModel):
|
||||
type: str
|
||||
data: AnnotationFileData | List[str]
|
||||
|
||||
def to_content(self) -> str | None:
|
||||
if self.type == "document_file":
|
||||
# We only support generating context content for CSV files for now
|
||||
csv_files = [file for file in self.data.files if file.filetype == "csv"]
|
||||
if len(csv_files) > 0:
|
||||
return "Use data from following CSV raw content\n" + "\n".join(
|
||||
[f"```csv\n{csv_file.content.value}\n```" for csv_file in csv_files]
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"The annotation {self.type} is not supported for generating context content"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: MessageRole
|
||||
content: str
|
||||
annotations: List[Annotation] | None = None
|
||||
|
||||
|
||||
class ChatData(BaseModel):
|
||||
messages: List[Message]
|
||||
data: Any = None
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What standards for letters exist?",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@field_validator("messages")
|
||||
def messages_must_not_be_empty(cls, v):
|
||||
if len(v) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
return v
|
||||
|
||||
def get_last_message_content(self) -> str:
|
||||
"""
|
||||
Get the content of the last message along with the data content if available.
|
||||
Fallback to use data content from previous messages
|
||||
"""
|
||||
if len(self.messages) == 0:
|
||||
raise ValueError("There is not any message in the chat")
|
||||
last_message = self.messages[-1]
|
||||
message_content = last_message.content
|
||||
for message in reversed(self.messages):
|
||||
if message.role == MessageRole.USER and message.annotations is not None:
|
||||
annotation_contents = filter(
|
||||
None,
|
||||
[annotation.to_content() for annotation in message.annotations],
|
||||
)
|
||||
if not annotation_contents:
|
||||
continue
|
||||
annotation_text = "\n".join(annotation_contents)
|
||||
message_content = f"{message_content}\n{annotation_text}"
|
||||
break
|
||||
return message_content
|
||||
|
||||
def get_history_messages(self) -> List[ChatMessage]:
|
||||
"""
|
||||
Get the history messages
|
||||
"""
|
||||
return [
|
||||
ChatMessage(role=message.role, content=message.content)
|
||||
for message in self.messages[:-1]
|
||||
]
|
||||
|
||||
def is_last_message_from_user(self) -> bool:
|
||||
return self.messages[-1].role == MessageRole.USER
|
||||
|
||||
def get_chat_document_ids(self) -> List[str]:
|
||||
"""
|
||||
Get the document IDs from the chat messages
|
||||
"""
|
||||
document_ids: List[str] = []
|
||||
for message in self.messages:
|
||||
if message.role == MessageRole.USER and message.annotations is not None:
|
||||
for annotation in message.annotations:
|
||||
if (
|
||||
annotation.type == "document_file"
|
||||
and annotation.data.files is not None
|
||||
):
|
||||
for fi in annotation.data.files:
|
||||
if fi.content.type == "ref":
|
||||
document_ids += fi.content.value
|
||||
return list(set(document_ids))
|
||||
|
||||
|
||||
class LlamaCloudFile(BaseModel):
|
||||
file_name: str
|
||||
pipeline_id: str
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, LlamaCloudFile):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.file_name == other.file_name and self.pipeline_id == other.pipeline_id
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.file_name, self.pipeline_id))
|
||||
|
||||
|
||||
class SourceNodes(BaseModel):
|
||||
id: str
|
||||
metadata: Dict[str, Any]
|
||||
score: Optional[float]
|
||||
text: str
|
||||
url: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def from_source_node(cls, source_node: NodeWithScore):
|
||||
metadata = source_node.node.metadata
|
||||
url = cls.get_url_from_metadata(metadata)
|
||||
#text = 'filename' in metadata and metadata['filename'] or source_node.node.node_id
|
||||
text = source_node.node.text
|
||||
return cls(
|
||||
id=source_node.node.node_id,
|
||||
metadata=metadata,
|
||||
score=source_node.score,
|
||||
text=text, # type: ignore
|
||||
url=url,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_url_from_metadata(cls, metadata: Dict[str, Any]) -> str:
|
||||
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not url_prefix:
|
||||
logger.warning(
|
||||
"Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server"
|
||||
)
|
||||
file_name = metadata.get("file_name")
|
||||
if file_name and url_prefix:
|
||||
# file_name exists and file server is configured
|
||||
pipeline_id = metadata.get("pipeline_id")
|
||||
if pipeline_id and metadata.get("private") is None:
|
||||
# file is from LlamaCloud and was not ingested locally
|
||||
file_name = f"{pipeline_id}${file_name}"
|
||||
return f"{url_prefix}/output/llamacloud/{file_name}"
|
||||
is_private = metadata.get("private", "false") == "true"
|
||||
if is_private:
|
||||
return f"{url_prefix}/output/uploaded/{file_name}"
|
||||
return f"{url_prefix}/data/{file_name}"
|
||||
else:
|
||||
# fallback to URL in metadata (e.g. for websites)
|
||||
return metadata.get("URL")
|
||||
|
||||
@classmethod
|
||||
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
|
||||
return [cls.from_source_node(node) for node in source_nodes]
|
||||
|
||||
@staticmethod
|
||||
def get_download_files(nodes: List[NodeWithScore]) -> Set[LlamaCloudFile]:
|
||||
source_nodes = SourceNodes.from_source_nodes(nodes)
|
||||
llama_cloud_files = [
|
||||
LlamaCloudFile(
|
||||
file_name=node.metadata.get("file_name"),
|
||||
pipeline_id=node.metadata.get("pipeline_id"),
|
||||
)
|
||||
for node in source_nodes
|
||||
if (
|
||||
node.metadata.get("private")
|
||||
is None # Only download files are from LlamaCloud and were not ingested locally
|
||||
and node.metadata.get("pipeline_id") is not None
|
||||
and node.metadata.get("file_name") is not None
|
||||
)
|
||||
]
|
||||
# Remove duplicates and return
|
||||
return set(llama_cloud_files)
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
result: Message
|
||||
nodes: List[SourceNodes]
|
||||
|
||||
|
||||
class ChatConfig(BaseModel):
|
||||
starter_questions: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of starter questions",
|
||||
serialization_alias="starterQuestions",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"starterQuestions": [
|
||||
"What standards for letters exist?",
|
||||
"What are the requirements for a letter to be considered a letter?",
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.services.file import PrivateFileService
|
||||
|
||||
file_upload_router = r = APIRouter()
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class FileUploadRequest(BaseModel):
|
||||
base64: str
|
||||
|
||||
|
||||
@r.post("")
|
||||
def upload_file(request: FileUploadRequest) -> List[str]:
|
||||
try:
|
||||
logger.info("Processing file")
|
||||
return PrivateFileService.process_file(request.base64)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Error processing file")
|
||||
@@ -1,109 +0,0 @@
|
||||
import json
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||
|
||||
from app.api.routers.events import EventCallbackHandler
|
||||
from app.api.routers.models import ChatData, Message, SourceNodes
|
||||
from app.api.services.suggestion import NextQuestionSuggestion
|
||||
|
||||
|
||||
class VercelStreamResponse(StreamingResponse):
|
||||
"""
|
||||
Class to convert the response from the chat engine to the streaming format expected by Vercel
|
||||
"""
|
||||
|
||||
TEXT_PREFIX = "0:"
|
||||
DATA_PREFIX = "8:"
|
||||
|
||||
@classmethod
|
||||
def convert_text(cls, token: str):
|
||||
# Escape newlines and double quotes to avoid breaking the stream
|
||||
token = json.dumps(token)
|
||||
return f"{cls.TEXT_PREFIX}{token}\n"
|
||||
|
||||
@classmethod
|
||||
def convert_data(cls, data: dict):
|
||||
data_str = json.dumps(data)
|
||||
return f"{cls.DATA_PREFIX}[{data_str}]\n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request: Request,
|
||||
event_handler: EventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
chat_data: ChatData,
|
||||
):
|
||||
content = VercelStreamResponse.content_generator(
|
||||
request, event_handler, response, chat_data
|
||||
)
|
||||
super().__init__(content=content)
|
||||
|
||||
@classmethod
|
||||
async def content_generator(
|
||||
cls,
|
||||
request: Request,
|
||||
event_handler: EventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
chat_data: ChatData,
|
||||
):
|
||||
# Yield the text response
|
||||
async def _chat_response_generator():
|
||||
final_response = ""
|
||||
async for token in response.async_response_gen():
|
||||
final_response += token
|
||||
yield VercelStreamResponse.convert_text(token)
|
||||
|
||||
# Generate questions that user might interested to
|
||||
conversation = chat_data.messages + [
|
||||
Message(role="assistant", content=final_response)
|
||||
]
|
||||
questions = await NextQuestionSuggestion.suggest_next_questions(
|
||||
conversation
|
||||
)
|
||||
if len(questions) > 0:
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"type": "suggested_questions",
|
||||
"data": questions,
|
||||
}
|
||||
)
|
||||
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
|
||||
# Yield the source nodes
|
||||
yield cls.convert_data(
|
||||
{
|
||||
"type": "sources",
|
||||
"data": {
|
||||
"nodes": [
|
||||
SourceNodes.from_source_node(node).dict()
|
||||
for node in response.source_nodes
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
event_response = event.to_response()
|
||||
if event_response is not None:
|
||||
yield VercelStreamResponse.convert_data(event_response)
|
||||
|
||||
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||
is_stream_started = False
|
||||
async with combine.stream() as streamer:
|
||||
async for output in streamer:
|
||||
if not is_stream_started:
|
||||
is_stream_started = True
|
||||
# Stream a blank message to start the stream
|
||||
yield VercelStreamResponse.convert_text("")
|
||||
|
||||
yield output
|
||||
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
@@ -0,0 +1,490 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||
from collections import deque
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core import BaseCallbackHandler
|
||||
from llama_index.core.base.llms.types import ChatMessage
|
||||
from llama_index.core.callbacks import CBEventType
|
||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||
from llama_index.core.tools import ToolOutput
|
||||
from pydantic import BaseModel
|
||||
from app.api.routers.request.base import userMng, conversations,message,parameter,feedback
|
||||
from app.api.routers.request.baseConfig import *
|
||||
from app.api.routers.request.models import ChatRequestData,ChatFileUploadRequest
|
||||
from app.engine import get_chat_engine
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
api_router = r = APIRouter()
|
||||
v1_router = v = APIRouter()
|
||||
|
||||
class ChatCallbackEvent(BaseModel):
|
||||
event_type: ChatEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
|
||||
def get_common_param(self)-> dict:
|
||||
return {
|
||||
'event': self.event_type.name,
|
||||
'conversation_id':self.payload.get("conversation_id"),
|
||||
'message_id': self.payload.get("message_id"),
|
||||
'created_at': int(time.time()),
|
||||
'task_id': self.payload.get("task_id")
|
||||
}
|
||||
|
||||
def get_WorkflowStart_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('workflow_run_id'),
|
||||
"workflow_id": self.payload.get('workflow_id'),
|
||||
"sequence_number": 1709,
|
||||
"inputs": {
|
||||
"sys.query": self.payload.get('query'),
|
||||
"sys.files": [],
|
||||
"sys.conversation_id": self.payload.get('conversation_id'),
|
||||
"sys.user_id": self.payload.get('use_id')
|
||||
},
|
||||
"created_at": int(time.time())
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def get_WorkflowFinished_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('workflow_run_id'),
|
||||
"workflow_id": self.payload.get('workflow_id'),
|
||||
"sequence_number": 1709,
|
||||
"status": "succeeded",
|
||||
"outputs": {
|
||||
"answer": self.payload.get('response')
|
||||
},
|
||||
"error": '',
|
||||
"elapsed_time": 36.03764106379822,
|
||||
"total_tokens": 11707,
|
||||
"total_steps": 10,
|
||||
"created_by": {
|
||||
"id": str(uuid.uuid4()),
|
||||
"user": self.payload.get('use_id')
|
||||
},
|
||||
"created_at": int(time.time()),
|
||||
"finished_at": int(time.time()),
|
||||
"files": []
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def get_NodeStart_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('nodeid'),
|
||||
"node_id": self.payload.get('nodeid'),
|
||||
"node_type": "http-request",
|
||||
"title": self.payload.get('title'),
|
||||
"index": self.payload.get('index'),
|
||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
||||
"inputs": '',
|
||||
"created_at": 1724398751,
|
||||
"extras": {}
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def get_NodeFinished_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'workflow_run_id':self.payload.get('workflow_run_id'),
|
||||
'data':{
|
||||
"id": self.payload.get('nodeid'),
|
||||
"node_id": self.payload.get('nodeid'),
|
||||
"node_type": "http-request",
|
||||
"title": self.payload.get('title'),
|
||||
"index": self.payload.get('index'),
|
||||
"predecessor_node_id": self.payload.get('predecessor_node_id'),
|
||||
"inputs": '',
|
||||
"process_data": '',
|
||||
"outputs": '',
|
||||
"status": "succeeded",
|
||||
"error": '',
|
||||
"elapsed_time": 0.10402441816404462,
|
||||
"execution_metadata": '',
|
||||
"created_at": 1724398751,
|
||||
"finished_at": 1724398751,
|
||||
"files": []
|
||||
}
|
||||
})
|
||||
return params
|
||||
|
||||
def get_Message_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'id':self.payload.get('message_id'),
|
||||
'answer':self.payload.get('answer')
|
||||
})
|
||||
return params
|
||||
|
||||
def get_MessageEnd_param(self) -> dict:
|
||||
params = self.get_common_param()
|
||||
params.update({
|
||||
'id':self.payload.get('message_id'),
|
||||
'metadata':self.payload.get('metadata')
|
||||
})
|
||||
return params
|
||||
|
||||
def to_response(self)-> dict|None:
|
||||
try:
|
||||
match self.event_type:
|
||||
case "workflow_started":
|
||||
return self.get_WorkflowStart_param()
|
||||
case "workflow_finished":
|
||||
return self.get_WorkflowFinished_param()
|
||||
case "node_started":
|
||||
return self.get_NodeStart_param()
|
||||
case 'node_finished':
|
||||
return self.get_NodeFinished_param()
|
||||
case 'message':
|
||||
return self.get_Message_param()
|
||||
case 'message_end':
|
||||
return self.get_MessageEnd_param()
|
||||
case _:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"转换回应时间时发生错误,原因: {e}")
|
||||
return None
|
||||
|
||||
class ChatEventCallbackHandler(BaseCallbackHandler):
|
||||
_aqueue: asyncio.Queue
|
||||
is_done: bool = False
|
||||
|
||||
def __init__(self,**params):
|
||||
"""Initialize the base callback handler."""
|
||||
ignored_events = [
|
||||
# CBEventType.CHUNKING,
|
||||
# CBEventType.NODE_PARSING,
|
||||
# CBEventType.EMBEDDING,
|
||||
# CBEventType.LLM,
|
||||
# CBEventType.TEMPLATING,
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
self._response:str = ''
|
||||
self._params:Dict[str,Any] = params
|
||||
self._nodeStack:deque = deque()
|
||||
|
||||
#添加工作流开始事件
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'use_id': data.user,
|
||||
'query': data.query,
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_START,payload = args)
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
logger.info("event_start:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||
|
||||
self._nodeStack.append(event_id)
|
||||
nindex = self._nodeStack.count() - 1
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'nodeid':event_id,
|
||||
'title':event_type.name,
|
||||
'index':nindex + 1,
|
||||
'predecessor_node_id': self._nodeStack[nindex - 1] if nindex > 0 else ''
|
||||
}
|
||||
)
|
||||
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_START,payload = args)
|
||||
if nd_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(nd_event)
|
||||
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
logger.info("event_end:{} type:{} payload:{}\n".format(event_id, event_type, payload))
|
||||
|
||||
#self.response = payload.get("response","")
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
nodeID = self._nodeStack[-1]
|
||||
if nodeID == event_id:
|
||||
nindex = self._nodeStack.count() - 1
|
||||
args.update(
|
||||
{
|
||||
'nodeid':event_id,
|
||||
'title':event_type.name,
|
||||
'index':nindex + 1,
|
||||
'predecessor_node_id':self._nodeStack[nindex - 1] if nindex > 0 else ''
|
||||
}
|
||||
)
|
||||
nd_event = ChatCallbackEvent(event_type = ChatEventType.NODE_FINISHED,payload = args)
|
||||
if nd_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(nd_event)
|
||||
self._nodeStack.pop()
|
||||
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_start:{}\n".format(trace_id))
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
logger.info("trace_end:{} trace_map:{}\n".format(trace_id, trace_map))
|
||||
data:ChatRequestData = self._params['data']
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
args.update(
|
||||
{
|
||||
'response':self._response,
|
||||
'conversation_id': data.conversation_id
|
||||
}
|
||||
)
|
||||
wf_event = ChatCallbackEvent(event_type = ChatEventType.WORKFLOW_FINISHED,payload = args)
|
||||
if wf_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(wf_event)
|
||||
|
||||
|
||||
args:Dict[str,Any] = self._params['ids']
|
||||
msgEnt_event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE_END,payload = args)
|
||||
if msgEnt_event.to_response() is not None:
|
||||
self._aqueue.put_nowait(msgEnt_event)
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[ChatCallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
try:
|
||||
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
class IDManager:
|
||||
def createID(self):
|
||||
return {
|
||||
"message_id" : str(uuid.uuid4()),
|
||||
'task_id':str(uuid.uuid4()),
|
||||
'workflow_run_id': str(uuid.uuid4()),
|
||||
"workflow_id": str(uuid.uuid4())
|
||||
}
|
||||
|
||||
class ChatStreamResponse(StreamingResponse):
|
||||
TEXT_PREFIX = "data: "
|
||||
DATA_PREFIX = "data: "
|
||||
ids:Dict[str,Any] = {}
|
||||
data:ChatRequestData = None
|
||||
|
||||
@classmethod
|
||||
def convert_Message(cls, token: str):
|
||||
params = cls.ids
|
||||
params.update({
|
||||
'answer':token,
|
||||
'conversation_id':cls.data.conversation_id
|
||||
})
|
||||
event = ChatCallbackEvent(event_type = ChatEventType.MESSAGE,payload = params)
|
||||
data_str = json.dumps(event.to_response())
|
||||
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||
|
||||
@classmethod
|
||||
def convert_Event(cls, data: dict):
|
||||
data_str = json.dumps(data)
|
||||
return f"{cls.DATA_PREFIX}{data_str}\n\n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request: Request,
|
||||
event_handler: ChatEventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
data: ChatRequestData,
|
||||
ids:Dict[str,Any]
|
||||
):
|
||||
ChatStreamResponse.ids = ids
|
||||
ChatStreamResponse.data = data
|
||||
content = ChatStreamResponse.content_generator(
|
||||
request, event_handler, response, data
|
||||
)
|
||||
super().__init__(content=content)
|
||||
|
||||
@classmethod
|
||||
async def content_generator(
|
||||
cls,
|
||||
request: Request,
|
||||
event_handler: ChatEventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
data: ChatRequestData
|
||||
):
|
||||
|
||||
# Yield the text response
|
||||
async def _chat_response_generator():
|
||||
final_response = ""
|
||||
async for token in response.async_response_gen():
|
||||
final_response += token
|
||||
yield ChatStreamResponse.convert_Message(token)
|
||||
|
||||
# 存储消息历史
|
||||
message().add(user_id=data.user,conversation_id=data.conversation_id,query=data.query,answer=final_response)
|
||||
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
event_response = event.to_response()
|
||||
if event_response is not None:
|
||||
yield ChatStreamResponse.convert_Event(event_response)
|
||||
|
||||
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||
is_stream_started = False
|
||||
async with combine.stream() as streamer:
|
||||
async for output in streamer:
|
||||
if not is_stream_started:
|
||||
is_stream_started = True
|
||||
|
||||
yield output
|
||||
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
@v.post("/chat-messages")
|
||||
async def post_conversations(request: Request, data: ChatRequestData):
|
||||
userMng.findNoExistCreate(data.user)
|
||||
data.conversation_id = data.conversation_id if data.conversation_id else str(uuid.uuid4())
|
||||
|
||||
conversaObj = conversations()
|
||||
conversationinfo = conversaObj.get(data.conversation_id)
|
||||
if conversationinfo is None:
|
||||
conversationinfo = conversaObj.add(data.conversation_id, data.user, "新建会话")
|
||||
|
||||
# 生成聊天参数
|
||||
last_message_content = ChatMessage.from_str(data.query)
|
||||
filters = None
|
||||
params = data.inputs or {}
|
||||
|
||||
# 获取聊天引擎对象
|
||||
chat_engine = get_chat_engine(filters=filters, params=params)
|
||||
|
||||
# 启动聊天事件监听
|
||||
ids = IDManager().createID()
|
||||
event_handler = ChatEventCallbackHandler(ids = ids,data = data)
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
|
||||
# 执行异步聊天
|
||||
response = await chat_engine.astream_chat(data.query)
|
||||
|
||||
# 返回异步消息回应
|
||||
return ChatStreamResponse(request, event_handler, response, data,ids)
|
||||
|
||||
@v.get("/messages")
|
||||
async def query_messages(user:str, conversation_id:str):
|
||||
#conversation_id = default_conversation_id if conversation_id is None else conversation_id
|
||||
datas = []
|
||||
records = message().gets(user,conversation_id)
|
||||
if records is None:
|
||||
return {
|
||||
"limit": 20,
|
||||
"has_more": False,
|
||||
"data": []
|
||||
}
|
||||
|
||||
for record in records:
|
||||
res = record.dict()
|
||||
feeds = feedback().query(res['id'])
|
||||
res["message_files"] = []
|
||||
res["feedback"] = {'rating':feeds['rating'] } if feeds != None else ''
|
||||
res["retriever_resources"] = []
|
||||
res["created_at"] = 1723444905
|
||||
res["agent_thoughts"] = []
|
||||
res["status"] = "normal"
|
||||
res["error"] = ''
|
||||
datas.append(res)
|
||||
|
||||
return {
|
||||
"limit": 20,
|
||||
"has_more": False,
|
||||
"data": datas
|
||||
}
|
||||
|
||||
@v.post("/conversations/{itemid}/name")
|
||||
async def post_conversations(request: Request,itemid:str,params:Dict[str,Any]):
|
||||
consaObj = conversations()
|
||||
consaObj.rename(itemid,'知识问答')
|
||||
cond = {
|
||||
'id':itemid,
|
||||
'user_id':params['user']
|
||||
}
|
||||
results = consaObj.query(**cond)
|
||||
if len(results) > 0:
|
||||
res = results[0]
|
||||
return {
|
||||
"id": res['id'],
|
||||
"name": res['name'],
|
||||
"inputs": res['inputs'],
|
||||
"status": res['status'],
|
||||
"introduction": res['introduction'],
|
||||
"created_at": res['created_at'],
|
||||
#"工程位置"
|
||||
}
|
||||
return 'null'
|
||||
|
||||
@v.get("/conversations")
|
||||
async def query_conversations(user:str, first_id:str = None, limit:str = None, pinned:str = None):
|
||||
user_id = '' if user is None else user
|
||||
userMng.findNoExistCreate(user_id)
|
||||
|
||||
return {
|
||||
"limit": 20,
|
||||
"has_more": False,
|
||||
"data": conversations().gets(user_id)
|
||||
}
|
||||
|
||||
@v.get("/parameters")
|
||||
async def query_parameters(user:str):
|
||||
params = parameter().get(user)
|
||||
if len(params) == 0:
|
||||
params = BaseConfig().ParamterCfg()
|
||||
return params
|
||||
|
||||
@v.post("/messages/{message_id}/feedbacks")
|
||||
async def post_feedbacks(request: Request,message_id:str,params:Dict[str,Any]):
|
||||
if params['rating'] =='null':
|
||||
feedback().delete(message_id)
|
||||
else:
|
||||
condition = {'id':message_id}
|
||||
results = message().query(**condition)
|
||||
if len(results) > 0:
|
||||
result = results[0]
|
||||
feedback().add(message_id=message_id,query=result['query'],
|
||||
answer=result['answer'],rating=params['rating'])
|
||||
|
||||
@r.post("")
|
||||
def upload_file(request: ChatFileUploadRequest) -> List[str]:
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from app.api.routers.request.baseConfig import BaseConfig
|
||||
from app.api.routers.request.dbOrm import DBManager
|
||||
|
||||
dbManage = DBManager()
|
||||
|
||||
class conversations:
|
||||
def __init__(self) -> None:
|
||||
self._tableName = 'conversations'
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def gets(self,user_id:str):
|
||||
records = dbManage.query(self._tableName,user_id = user_id)
|
||||
datas = []
|
||||
for record in records:
|
||||
datas.append(record)
|
||||
|
||||
return datas
|
||||
|
||||
def get(self, id:str):
|
||||
records = dbManage.query(self._tableName, id=id)
|
||||
if len(records) >0:
|
||||
return records[0]
|
||||
return None
|
||||
|
||||
def add(self,id:str, user_id:str, name:str):
|
||||
template = BaseConfig().ConversationCfg()
|
||||
template['id'] = id
|
||||
template['user_id'] = user_id
|
||||
template['name'] = name
|
||||
template['created_at'] = 1724399038
|
||||
dbManage.addRecord(self._tableName,template)
|
||||
|
||||
def delete(self,id:str):
|
||||
dbManage.delete(self._tableName,id=id)
|
||||
|
||||
def rename(self,id:str,name:str):
|
||||
data = {'name':name}
|
||||
dbManage.update(self._tableName,data,id=id)
|
||||
|
||||
def query(self,**condition):
|
||||
results = []
|
||||
records = dbManage.query(self._tableName,**condition)
|
||||
for record in records:
|
||||
results.append(record.dict())
|
||||
return results
|
||||
|
||||
class user:
|
||||
def __init__(self) -> None:
|
||||
self._tableName = 'user'
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def gets(self):
|
||||
return dbManage.query(self._tableName)
|
||||
|
||||
def get(self,id:str):
|
||||
return dbManage.query(self._tableName,id = id)
|
||||
|
||||
def add(self,id:str):
|
||||
info = {
|
||||
'id':id,
|
||||
'createtime': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
dbManage.addRecord(self._tableName,info)
|
||||
|
||||
def delete(self,id:str):
|
||||
dbManage.delete(self._tableName,id = id)
|
||||
|
||||
class userMng:
|
||||
userObj = user()
|
||||
@classmethod
|
||||
def findNoExistCreate(cls,user_id:str):
|
||||
userInfo = cls.userObj.get(user_id)
|
||||
if len(userInfo) == 0:
|
||||
cls.userObj.add(user_id)
|
||||
|
||||
def remove(cls,user_id:str):
|
||||
cls.userObj.delete(user_id)
|
||||
|
||||
class parameter:
|
||||
def __init__(self) -> None:
|
||||
self._tableName = 'parameters'
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def get(self,user_id:str):
|
||||
records = dbManage.query(self._tableName,user_id = user_id)
|
||||
data = {}
|
||||
for record in records:
|
||||
key = record['name']
|
||||
value = record['value']
|
||||
data[key] = value
|
||||
return data
|
||||
|
||||
def set(self,user_id:str):
|
||||
dbManage.addRecord(self._tableName,{})
|
||||
|
||||
def delete(self,user_id:str):
|
||||
dbManage.delete(self._tableName,user_id = user_id)
|
||||
|
||||
class message:
|
||||
def __init__(self) -> None:
|
||||
self._tableName = 'messages'
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def gets(self,user_id:str,conversation_id:str):
|
||||
records = dbManage.query(self._tableName,user_id = user_id,conversation_id = conversation_id)
|
||||
datas = []
|
||||
for record in records:
|
||||
datas.append(record)
|
||||
return datas
|
||||
|
||||
def add(self,user_id:str,conversation_id:str,query:str,answer:str):
|
||||
template = BaseConfig.MessageCfg()
|
||||
template['id'] = str(uuid.uuid4())
|
||||
template['user_id'] = user_id
|
||||
template['conversation_id'] = conversation_id
|
||||
template['query'] = query
|
||||
template['answer'] = answer
|
||||
dbManage.addRecord(self._tableName,template)
|
||||
|
||||
def delete(self,user_id:str):
|
||||
dbManage.delete(self._tableName,user_id = user_id)
|
||||
|
||||
def query(self,**condition):
|
||||
results = []
|
||||
records = dbManage.query(self._tableName,**condition)
|
||||
for record in records:
|
||||
results.append(record.dict())
|
||||
return results
|
||||
|
||||
class feedback:
|
||||
def __init__(self) -> None:
|
||||
self._tableName = 'feedbacks'
|
||||
dbManage.createTable(self._tableName)
|
||||
|
||||
def add(self,message_id:str,query:str,answer:str,rating:str):
|
||||
record = {
|
||||
'message_id': message_id,
|
||||
'query': query,
|
||||
'answer': answer,
|
||||
'rating': rating,
|
||||
}
|
||||
dbManage.addRecord(self._tableName,record)
|
||||
|
||||
def delete(self,message_id:str):
|
||||
cond = {'message_id':message_id}
|
||||
dbManage.delete(self._tableName,**cond)
|
||||
|
||||
def query(self,message_id:str):
|
||||
cond = {'message_id':message_id}
|
||||
records = dbManage.query(self._tableName,**cond)
|
||||
if len(records) > 0:
|
||||
return records[0].dict()
|
||||
return None
|
||||
@@ -0,0 +1,80 @@
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
projectInfo:str = os.getenv("PROJECT_TITLE","您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!")
|
||||
|
||||
def ParamterCfg(self):
|
||||
questions = os.getenv("CONVERSATION_STARTERS", "dev")
|
||||
return{
|
||||
"opening_statement": self.projectInfo,
|
||||
"suggested_questions": questions.split('\n'),
|
||||
"suggested_questions_after_answer": {
|
||||
"enabled": False
|
||||
},
|
||||
"speech_to_text": {
|
||||
"enabled": False
|
||||
},
|
||||
"text_to_speech": {
|
||||
"enabled": False,
|
||||
"language": "",
|
||||
"voice": ""
|
||||
},
|
||||
"retriever_resource": {
|
||||
"enabled": True
|
||||
},
|
||||
"annotation_reply": {
|
||||
"enabled": False
|
||||
},
|
||||
"more_like_this": {
|
||||
"enabled": False
|
||||
},
|
||||
"user_input_form": [],
|
||||
"sensitive_word_avoidance": {
|
||||
"enabled": False
|
||||
},
|
||||
"file_upload": {
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"transfer_methods": [
|
||||
"remote_url"
|
||||
]
|
||||
}
|
||||
},
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": "10"
|
||||
}
|
||||
}
|
||||
|
||||
def ConversationCfg(self):
|
||||
return{
|
||||
"id": "",
|
||||
'user_id':'',
|
||||
"name": "",
|
||||
"inputs": {},
|
||||
"status": "normal",
|
||||
"introduction": self.projectInfo,
|
||||
"created_at":''
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def MessageCfg(cls):
|
||||
return {
|
||||
"id": "",
|
||||
'user_id':'',
|
||||
"conversation_id": "",
|
||||
"inputs": {},
|
||||
"query": "",
|
||||
"answer": ""
|
||||
}
|
||||
|
||||
|
||||
class ChatEventType(str, Enum):
|
||||
WORKFLOW_START = "workflow_started"
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_START = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
MESSAGE = "message"
|
||||
MESSAGE_END = "message_end"
|
||||
@@ -0,0 +1,220 @@
|
||||
import os
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import create_engine, Column, String, Integer, JSON,Float
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
#orm类
|
||||
class ConversationOrm(Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
name = Column(String)
|
||||
inputs = Column(JSON)
|
||||
status = Column(String)
|
||||
introduction = Column(String)
|
||||
created_at = Column(Integer)
|
||||
|
||||
def update(self,data:Dict[str,Any]):
|
||||
if 'name' in data:
|
||||
self.name = data['name']
|
||||
|
||||
class UserOrm(Base):
|
||||
__tablename__ = "user"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
createtime = Column(String)
|
||||
|
||||
class ParametersOrm(Base):
|
||||
__tablename__ = "parameters"
|
||||
|
||||
user_id = Column(String,primary_key=True)
|
||||
name = Column(String)
|
||||
value = Column(JSON)
|
||||
|
||||
class MessagesOrm(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id = Column(String,primary_key=True)
|
||||
user_id = Column(String)
|
||||
conversation_id = Column(String)
|
||||
inputs = Column(JSON)
|
||||
query = Column(String)
|
||||
answer = Column(String)
|
||||
|
||||
class FeedBackOrm(Base):
|
||||
__tablename__ = "feedbacks"
|
||||
|
||||
message_id = Column(String,primary_key=True)
|
||||
query = Column(String)
|
||||
answer = Column(String)
|
||||
rating = Column(String)
|
||||
|
||||
#数据结构
|
||||
class ConversationModel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
inputs: Dict[str, Any]
|
||||
status: str
|
||||
introduction: str
|
||||
created_at: int
|
||||
|
||||
class Config:
|
||||
from_attributes=True
|
||||
|
||||
@classmethod
|
||||
def orm(cls):
|
||||
return ConversationOrm
|
||||
|
||||
class UserModel(BaseModel):
|
||||
id: str
|
||||
createtime: str
|
||||
|
||||
class Config:
|
||||
from_attributes=True
|
||||
|
||||
@classmethod
|
||||
def orm(cls):
|
||||
return UserOrm
|
||||
|
||||
class ParametersModel(BaseModel):
|
||||
user_id : str
|
||||
name : str
|
||||
value : Dict[str, Any]
|
||||
|
||||
class Config:
|
||||
from_attributes=True
|
||||
|
||||
@classmethod
|
||||
def orm(cls):
|
||||
return ParametersOrm
|
||||
|
||||
class MessagesModel(BaseModel):
|
||||
id :str
|
||||
conversation_id :str
|
||||
inputs : Dict[str, Any]
|
||||
query : str
|
||||
answer : str
|
||||
|
||||
class Config:
|
||||
from_attributes=True
|
||||
|
||||
@classmethod
|
||||
def orm(cls):
|
||||
return MessagesOrm
|
||||
|
||||
class FeedBackModel(BaseModel):
|
||||
message_id :str
|
||||
query :str
|
||||
answer :str
|
||||
rating :str
|
||||
|
||||
class Config:
|
||||
from_attributes=True
|
||||
|
||||
@classmethod
|
||||
def orm(cls):
|
||||
return FeedBackOrm
|
||||
|
||||
class DBManager:
|
||||
def __init__(self) -> None:
|
||||
DATABASE_URL = os.getenv("SQLITE_DATABASE_URL")
|
||||
self._engine = create_engine(DATABASE_URL)
|
||||
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
|
||||
|
||||
def createTable(self,tableName:str):
|
||||
if self._engine is None:
|
||||
return
|
||||
if not self.exist(tableName):
|
||||
Base.metadata.tables[tableName].create(self._engine)
|
||||
|
||||
def addRecord(self,tableName:str,record:Dict[str,Any]):
|
||||
ormCls = self._get_orm(tableName)
|
||||
if ormCls is None:
|
||||
return
|
||||
session = self.SessionLocal()
|
||||
data = ormCls(**record)
|
||||
session.add(data)
|
||||
session.commit()
|
||||
|
||||
def addRecords(self,tableName:str,records:List[Dict[str,Any]]):
|
||||
ormCls = self._get_orm(tableName)
|
||||
if ormCls is None:
|
||||
return
|
||||
datas = []
|
||||
session = self.SessionLocal()
|
||||
for record in records:
|
||||
datas.append(ormCls(**record))
|
||||
session.add(datas)
|
||||
session.commit()
|
||||
|
||||
def delete(self,tableName:str,**filter):
|
||||
session = self.SessionLocal()
|
||||
ormCls = self._get_orm(tableName)
|
||||
if ormCls is None:
|
||||
return
|
||||
records = session.query(ormCls).filter_by(**filter).all()
|
||||
if records is not None:
|
||||
session.delete(records)
|
||||
session.commit()
|
||||
|
||||
def update(self,tableName:str,data:Dict[str,Any],**filter):
|
||||
if not self.exist(tableName):
|
||||
return
|
||||
session = self.SessionLocal()
|
||||
ormCls = self._get_orm(tableName)
|
||||
if ormCls is None:
|
||||
return
|
||||
if len(filter) > 0:
|
||||
records = session.query(ormCls).filter_by(**filter).all()
|
||||
else:
|
||||
records = session.query(ormCls).all()
|
||||
for record in records:
|
||||
if record is not None:
|
||||
record.update(data)
|
||||
session.commit()
|
||||
|
||||
def query(self,tableName:str,**filter):
|
||||
session = self.SessionLocal()
|
||||
ormCls = self._get_orm(tableName)
|
||||
if ormCls is None:
|
||||
return
|
||||
modelCls = self._get_model(ormCls)
|
||||
if modelCls is None:
|
||||
return
|
||||
|
||||
if filter is not None:
|
||||
records = session.query(ormCls).filter_by(**filter).all()
|
||||
else:
|
||||
records = session.query(ormCls).all()
|
||||
|
||||
datas = []
|
||||
for record in records:
|
||||
datas.append(modelCls.from_orm(record))
|
||||
return datas
|
||||
|
||||
def exist(self,tableName:str)->bool:
|
||||
if self._engine is None:
|
||||
return
|
||||
inspector = Inspector.from_engine(self._engine)
|
||||
return inspector.has_table(tableName)
|
||||
|
||||
def _get_orm(self,tableName:str):
|
||||
subClss = Base.__subclasses__()
|
||||
for sunCls in subClss:
|
||||
if sunCls.__tablename__ == tableName:
|
||||
return sunCls
|
||||
return None
|
||||
|
||||
def _get_model(self,orm:Any):
|
||||
subClss = BaseModel.__subclasses__()
|
||||
for sunCls in subClss:
|
||||
if 'orm' in sunCls.__dict__ and sunCls.orm() == orm:
|
||||
return sunCls
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class ChatRequestData(BaseModel):
|
||||
inputs: Dict[str,Any]
|
||||
query: str
|
||||
user: str
|
||||
response_mode: str
|
||||
files: Any
|
||||
conversation_id: str = None
|
||||
|
||||
class ChatFileUploadRequest(BaseModel):
|
||||
base64: str
|
||||
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from uuid import uuid4
|
||||
|
||||
from app.engine.index import get_index
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.readers.file.base import (
|
||||
_try_loading_included_file_formats as get_file_loaders_map,
|
||||
)
|
||||
from llama_index.core.readers.file.base import (
|
||||
default_file_metadata_func,
|
||||
)
|
||||
from llama_index.core.schema import Document
|
||||
from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex
|
||||
from llama_index.readers.file import FlatReader
|
||||
|
||||
|
||||
def get_llamaparse_parser():
|
||||
from app.engine.loaders import load_configs
|
||||
from app.engine.loaders.file import FileLoaderConfig, llama_parse_parser
|
||||
|
||||
config = load_configs()
|
||||
file_loader_config = FileLoaderConfig(**config["file"])
|
||||
if file_loader_config.use_llama_parse:
|
||||
return llama_parse_parser()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def default_file_loaders_map():
|
||||
default_loaders = get_file_loaders_map()
|
||||
default_loaders[".txt"] = FlatReader
|
||||
return default_loaders
|
||||
|
||||
|
||||
class PrivateFileService:
|
||||
PRIVATE_STORE_PATH = "output/uploaded"
|
||||
|
||||
@staticmethod
|
||||
def preprocess_base64_file(base64_content: str) -> tuple:
|
||||
header, data = base64_content.split(",", 1)
|
||||
mime_type = header.split(";")[0].split(":", 1)[1]
|
||||
extension = mimetypes.guess_extension(mime_type)
|
||||
# File data as bytes
|
||||
return base64.b64decode(data), extension
|
||||
|
||||
@staticmethod
|
||||
def store_and_parse_file(file_data, extension) -> List[Document]:
|
||||
# Store file to the private directory
|
||||
os.makedirs(PrivateFileService.PRIVATE_STORE_PATH, exist_ok=True)
|
||||
|
||||
# random file name
|
||||
file_name = f"{uuid4().hex}{extension}"
|
||||
file_path = Path(os.path.join(PrivateFileService.PRIVATE_STORE_PATH, file_name))
|
||||
|
||||
# write file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_data)
|
||||
|
||||
# Load file to documents
|
||||
# If LlamaParse is enabled, use it to parse the file
|
||||
# Otherwise, use the default file loaders
|
||||
reader = get_llamaparse_parser()
|
||||
if reader is None:
|
||||
reader_cls = default_file_loaders_map().get(extension)
|
||||
if reader_cls is None:
|
||||
raise ValueError(f"File extension {extension} is not supported")
|
||||
reader = reader_cls()
|
||||
documents = reader.load_data(file_path)
|
||||
# Add custom metadata
|
||||
for doc in documents:
|
||||
doc.metadata["file_name"] = file_name
|
||||
doc.metadata["private"] = "true"
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def process_file(base64_content: str) -> List[str]:
|
||||
file_data, extension = PrivateFileService.preprocess_base64_file(base64_content)
|
||||
documents = PrivateFileService.store_and_parse_file(file_data, extension)
|
||||
|
||||
# Only process nodes, no store the index
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
|
||||
# Add the nodes to the index and persist it
|
||||
current_index = get_index()
|
||||
|
||||
# Insert the documents into the index
|
||||
if isinstance(current_index, LlamaCloudIndex):
|
||||
# LlamaCloudIndex is a managed index so we don't need to process the nodes
|
||||
# just insert the documents
|
||||
for doc in documents:
|
||||
current_index.insert(doc)
|
||||
else:
|
||||
# Only process nodes, no store the index
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
|
||||
# Add the nodes to the index and persist it
|
||||
if current_index is None:
|
||||
current_index = VectorStoreIndex(nodes=nodes)
|
||||
else:
|
||||
current_index.insert_nodes(nodes=nodes)
|
||||
current_index.storage_context.persist(
|
||||
persist_dir=os.environ.get("STORAGE_DIR", "storage")
|
||||
)
|
||||
|
||||
# Return the document ids
|
||||
return [doc.doc_id for doc in documents]
|
||||
@@ -1,114 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from app.api.routers.models import LlamaCloudFile
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class LLamaCloudFileService:
|
||||
LLAMA_CLOUD_URL = "https://cloud.llamaindex.ai/api/v1"
|
||||
LOCAL_STORE_PATH = "output/llamacloud"
|
||||
|
||||
DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}"
|
||||
|
||||
@classmethod
|
||||
def get_all_projects(cls) -> List[Dict[str, Any]]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/projects"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def get_all_pipelines(cls) -> List[Dict[str, Any]]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/pipelines"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def get_all_projects_with_pipelines(cls) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
projects = cls.get_all_projects()
|
||||
pipelines = cls.get_all_pipelines()
|
||||
return [
|
||||
{
|
||||
**project,
|
||||
"pipelines": [p for p in pipelines if p["project_id"] == project["id"]],
|
||||
}
|
||||
for project in projects
|
||||
]
|
||||
except Exception as error:
|
||||
logger.error(f"Error listing projects and pipelines: {error}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _get_files(cls, pipeline_id: str) -> List[Dict[str, Any]]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/pipelines/{pipeline_id}/files"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def _get_file_detail(cls, project_id: str, file_id: str) -> Dict[str, Any]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/files/{file_id}/content?project_id={project_id}"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def _download_file(cls, url: str, local_file_path: str):
|
||||
logger.info(f"Downloading file to {local_file_path}")
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(cls.LOCAL_STORE_PATH, exist_ok=True)
|
||||
# Download the file
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(local_file_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
logger.info("File downloaded successfully")
|
||||
|
||||
@classmethod
|
||||
def download_llamacloud_pipeline_file(
|
||||
cls,
|
||||
file: LlamaCloudFile,
|
||||
force_download: bool = False,
|
||||
):
|
||||
file_name = file.file_name
|
||||
pipeline_id = file.pipeline_id
|
||||
|
||||
# Check is the file already exists
|
||||
downloaded_file_path = cls.get_file_path(file_name, pipeline_id)
|
||||
if os.path.exists(downloaded_file_path) and not force_download:
|
||||
logger.debug(f"File {file_name} already exists in local storage")
|
||||
return
|
||||
try:
|
||||
logger.info(f"Downloading file {file_name} for pipeline {pipeline_id}")
|
||||
files = cls._get_files(pipeline_id)
|
||||
if not files or not isinstance(files, list):
|
||||
raise Exception("No files found in LlamaCloud")
|
||||
for file_entry in files:
|
||||
if file_entry["name"] == file_name:
|
||||
file_id = file_entry["file_id"]
|
||||
project_id = file_entry["project_id"]
|
||||
file_detail = cls._get_file_detail(project_id, file_id)
|
||||
cls._download_file(file_detail["url"], downloaded_file_path)
|
||||
break
|
||||
except Exception as error:
|
||||
logger.info(f"Error fetching file from LlamaCloud: {error}")
|
||||
|
||||
@classmethod
|
||||
def get_file_name(cls, name: str, pipeline_id: str) -> str:
|
||||
return cls.DOWNLOAD_FILE_NAME_TPL.format(pipeline_id=pipeline_id, filename=name)
|
||||
|
||||
@classmethod
|
||||
def get_file_path(cls, name: str, pipeline_id: str) -> str:
|
||||
return os.path.join(cls.LOCAL_STORE_PATH, cls.get_file_name(name, pipeline_id))
|
||||
|
||||
@staticmethod
|
||||
def _make_request(
|
||||
url: str, data=None, headers: Optional[Dict] = None, method: str = "get"
|
||||
):
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Authorization": f'Bearer {os.getenv("LLAMA_CLOUD_API_KEY")}',
|
||||
}
|
||||
response = requests.request(method, url, headers=headers, data=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from app.api.routers.models import Message
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.settings import Settings
|
||||
from pydantic import BaseModel
|
||||
|
||||
NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
|
||||
"你是一个乐于助人的助手!你的任务是对用户可能会问的下一个问题给出建议。 "
|
||||
"\n这是对话历史记录"
|
||||
"\n---------------------\n{conversation}\n---------------------"
|
||||
"考虑到对话历史记录,仅限于现在知识库已有内容, 请给我 $number_of_questions 个你接下来可能会问题的问题!"
|
||||
)
|
||||
N_QUESTION_TO_GENERATE = 3
|
||||
|
||||
|
||||
class NextQuestions(BaseModel):
|
||||
"""A list of questions that user might ask next"""
|
||||
|
||||
questions: List[str]
|
||||
|
||||
|
||||
class NextQuestionSuggestion:
|
||||
@staticmethod
|
||||
async def suggest_next_questions(
|
||||
messages: List[Message],
|
||||
number_of_questions: int = N_QUESTION_TO_GENERATE,
|
||||
) -> List[str]:
|
||||
# Reduce the cost by only using the last two messages
|
||||
last_user_message = None
|
||||
last_assistant_message = None
|
||||
for message in reversed(messages):
|
||||
if message.role == "user":
|
||||
last_user_message = f"User: {message.content}"
|
||||
elif message.role == "assistant":
|
||||
last_assistant_message = f"Assistant: {message.content}"
|
||||
if last_user_message and last_assistant_message:
|
||||
break
|
||||
conversation: str = f"{last_user_message}\n{last_assistant_message}"
|
||||
|
||||
output: NextQuestions = await Settings.llm.astructured_predict(
|
||||
NextQuestions,
|
||||
prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
|
||||
conversation=conversation,
|
||||
nun_questions=number_of_questions,
|
||||
)
|
||||
|
||||
return output.questions
|
||||
@@ -1,22 +0,0 @@
|
||||
import logging
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from app.engine.vectordb import get_vector_store
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
index = None
|
||||
|
||||
def get_index(params=None):
|
||||
global index
|
||||
if index is None:
|
||||
logger.info("Connecting vector store...")
|
||||
|
||||
store = get_vector_store()
|
||||
# Load the index from the vector store
|
||||
# If you are using a vector store that doesn't store text,
|
||||
# you must load the index from both the vector store and the document store
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished load index from vector store.")
|
||||
|
||||
return index
|
||||
@@ -1,61 +0,0 @@
|
||||
import os
|
||||
|
||||
from llama_index.core.agent import AgentRunner, ReActChatFormatter
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.tools.query_engine import QueryEngineTool
|
||||
|
||||
from app.engine.engine import create_query_engine, create_summary_query_engine
|
||||
from app.engine.index import get_index
|
||||
#from app.engine.loaders.db import makeDescriptionByEngine
|
||||
from app.engine.tools import ToolFactory
|
||||
|
||||
|
||||
def get_chat_engine(filters=None, params=None):
|
||||
system_prompt = os.getenv("SYSTEM_PROMPT")
|
||||
top_k = int(os.getenv("TOP_K", "3"))
|
||||
use_reranker = os.getenv("RERANK_ENABLED")
|
||||
tools = []
|
||||
|
||||
# 创建SQL查询工具
|
||||
# sql_query_engine = create_summary_query_engine(index)
|
||||
# sql_query_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine,
|
||||
# name="zjdata_query_tool",
|
||||
# description="来源于一个由博微公司电力造价软件编制的造价工程文件。该文件以多张表格的形式存储存储了整个工程的全部数据内容。适用于以详细的自然语言查询表格数据方式查询造价工程各项具体属性、费用的数值。请先使用“zj_query_tool”无法解决才使用本工具"
|
||||
# )
|
||||
#tools.append(sql_query_tool)
|
||||
|
||||
# Add query tool if index exists
|
||||
index = get_index()
|
||||
if index is not None:
|
||||
summary_query_engine = create_summary_query_engine(index,top_k,use_reranker,filters)
|
||||
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
||||
description="适用于任何需要进行全面总结、概括的要求。",
|
||||
)
|
||||
query_engine = create_query_engine(index,top_k,use_reranker,filters)
|
||||
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
|
||||
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
|
||||
)
|
||||
|
||||
tools.append(summary_query_tool)
|
||||
tools.append(query_engine_tool)
|
||||
|
||||
# Add additional tools
|
||||
tools += ToolFactory.from_env()
|
||||
|
||||
prefix_messages = ("""您的设计旨在帮助完成各种任务,从回答问题到提供其他类型分析的摘要。\n\n##工具\n\n你可以访问各种工具。你有责任按照你认为合适的顺序使用这些工具来完成当前的任务。\n这可能需要将任务分解为子任务,并使用不同的工具来完成每个子任务。\n\n你可以访问以下工具:\n{tool_desc}\n\n\n##输出格式\n\n请用与问题相同的语言回答,并使用以下格式:\n\n \nThought: 用户当前的语言是:(user's language)。我需要使用工具来帮助我回答问题。\nAction: 如果使用工具,则为工具名称(one of {tool_names})。\nAction Input: 输入给工具的内容,使用JSON格式表示kwargs(例如{{\"input\": \"hello world\", \"num_beams\": 5}})\n \n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n请始终以Thought开始。\n\n切勿用Markdown代码标记包围你的响应。如果需要,可以在响应中使用代码标记。\n\n请为Action Input使用有效的JSON格式。不要这样做{{\'input\': \'hello world\', \'num_beams\': 5}}。\n\n如果使用此格式,用户将以下面的格式进行回应:\n\n \nObservation: 工具响应\n \n\n你应该继续重复上述格式,直到你有足够的信息来回答问题而无需使用更多工具。此时,你必须使用以下两种格式之一进行回答:\n\n \nThought: 我可以不用任何工具来回答。我将使用用户的语言来回答。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n \nThought: 我无法使用提供的工具回答问题。\nAnswer: [你的答案(与用户问题相同的语言)]\n \n\n##如果从工具中得到的回应是Empty Response,那么只需要回答“我不知道”,不需要额外回答别的内容。## 当前对话\n\n以下是当前对话,由人类和助手的消息交替组成。\n""")
|
||||
react_chat_formatter = ReActChatFormatter.from_defaults(prefix_messages)
|
||||
agentrunner = AgentRunner.from_llm(
|
||||
llm=Settings.llm,
|
||||
tools=tools,
|
||||
react_chat_formatter=react_chat_formatter,
|
||||
system_prompt=system_prompt,
|
||||
verbose=True,
|
||||
)
|
||||
return agentrunner
|
||||
# create the function calling worker for reasoning
|
||||
# worker = FunctionCallingAgentWorker.from_tools(
|
||||
# tools, verbose=True
|
||||
# )
|
||||
#
|
||||
# # wrap the worker in the top-level planner
|
||||
# return StructuredPlannerAgent(worker, tools)
|
||||
@@ -1 +0,0 @@
|
||||
STORAGE_DIR = "storage" # directory to cache the generated index
|
||||
@@ -1,108 +0,0 @@
|
||||
import os
|
||||
|
||||
from llama_index.core import SummaryIndex, SQLDatabase, VectorStoreIndex
|
||||
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
||||
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import ResponseMode
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||
from app.engine.retriever.HybridRetriever import HybridRetriever
|
||||
from app.settings import get_node_postprocessors
|
||||
|
||||
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
||||
reader = DatabaseReader(sql_database)
|
||||
|
||||
table_names = sql_database.get_usable_table_names()
|
||||
table_schema_objs = []
|
||||
for table_name in table_names:
|
||||
columns = sql_database.get_table_columns(table_name)
|
||||
if len(columns) > 150:
|
||||
continue
|
||||
stats_txt = ""
|
||||
|
||||
if table_name == 'gongchengshuxing':
|
||||
stats_txt = '该表中有以下属性:'
|
||||
documents = reader.load_data(query='select name from gongchengshuxing')
|
||||
for index in range(len(documents) if len(documents) < 30 else 30):
|
||||
if index == 0:
|
||||
continue
|
||||
elif index > 1:
|
||||
stats_txt += ','
|
||||
stats_txt += documents[index].text.split(':')[1]
|
||||
|
||||
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
|
||||
table_schema_objs.append(tbSchema)
|
||||
|
||||
return table_schema_objs
|
||||
|
||||
def get_Retriever(index,**kwargs):
|
||||
strEnableHybrid = os.getenv("HYBRID_ENABLED",'False')
|
||||
bEnableHybrid = True if strEnableHybrid is not None and strEnableHybrid.title() == 'True' else False
|
||||
if bEnableHybrid:
|
||||
alpha = float(os.getenv("HYBRID_ALPHA", "0.5"))
|
||||
retriever = HybridRetriever(index,alpha = alpha,**kwargs)
|
||||
else:
|
||||
retriever = index.as_retriever(**kwargs)
|
||||
return retriever
|
||||
|
||||
|
||||
sql_database = None
|
||||
sql_obj_index = None
|
||||
|
||||
# Create a summary query engine
|
||||
def create_summary_query_engine(top_k=3, use_reranker=False, filters=None):
|
||||
global sql_obj_index
|
||||
global sql_database
|
||||
if sql_obj_index is None or sql_database is None:
|
||||
sqlengine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
||||
sql_database = SQLDatabase(sqlengine)
|
||||
table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
|
||||
sql_obj_index = ObjectIndex.from_objects(
|
||||
table_schema_objs,
|
||||
table_node_mapping,
|
||||
index_cls=VectorStoreIndex,
|
||||
)
|
||||
|
||||
# 创建SQL查询工具
|
||||
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
||||
sql_obj_index.as_retriever(similarity_top_k=top_k),
|
||||
verbose=True,
|
||||
)
|
||||
return sql_query_engine
|
||||
|
||||
# Create a summary query engine
|
||||
def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
||||
summary_index = SummaryIndex(index.vector_store.get_nodes(node_ids=None))
|
||||
summary_query_engine = summary_index.as_query_engine(
|
||||
response_mode=ResponseMode.TREE_SUMMARIZE,
|
||||
use_async=True,
|
||||
streaming=True,
|
||||
)
|
||||
return summary_query_engine
|
||||
|
||||
# Create a query engine
|
||||
def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
||||
# 创建向量检索查询工具
|
||||
postprocess = None
|
||||
if use_reranker:
|
||||
postprocess = get_node_postprocessors()
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
get_Retriever(index,
|
||||
similarity_top_k=top_k,
|
||||
filters=filters),
|
||||
text_qa_template=text_qa_template,
|
||||
refine_template=refine_template,
|
||||
summary_template = summary_template,
|
||||
simple_template = simple_template,
|
||||
node_postprocessors=postprocess,
|
||||
use_async=True,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return query_engine
|
||||
@@ -1,94 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from app.engine.loaders import get_documents
|
||||
from app.engine.vectordb import get_vector_store
|
||||
from app.settings import init_settings
|
||||
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
||||
|
||||
|
||||
def get_doc_store():
|
||||
|
||||
# If the storage directory is there, load the document store from it.
|
||||
# If not, set up an in-memory document store since we can't load from a directory that doesn't exist.
|
||||
if os.path.exists(STORAGE_DIR):
|
||||
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR)
|
||||
else:
|
||||
return SimpleDocumentStore()
|
||||
|
||||
|
||||
def run_pipeline(docstore, vector_store, documents):
|
||||
pipeline = IngestionPipeline(
|
||||
transformations=[
|
||||
SentenceSplitter(
|
||||
chunk_size=Settings.chunk_size,
|
||||
chunk_overlap=Settings.chunk_overlap,
|
||||
),
|
||||
Settings.embed_model,
|
||||
],
|
||||
docstore=docstore,
|
||||
docstore_strategy="upserts_and_delete",
|
||||
vector_store=vector_store,
|
||||
)
|
||||
|
||||
# Run the ingestion pipeline and store the results
|
||||
nodes = pipeline.run(show_progress=True, documents=documents)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def persist_storage(docstore, vector_store):
|
||||
storage_context = StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
vector_store=vector_store,
|
||||
)
|
||||
storage_context.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def persist_BMRetriever(vector_store):
|
||||
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||
top_k = int(os.getenv("TOP_K", "3"))
|
||||
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=top_k,nodes=vector_store.get_nodes([]))
|
||||
bmRetriver.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Generate index for the provided data")
|
||||
|
||||
# Get the stores and documents or create new ones
|
||||
documents = get_documents()
|
||||
# Set private=false to mark the document as public (required for filtering)
|
||||
for doc in documents:
|
||||
doc.metadata["private"] = "false"
|
||||
docstore = get_doc_store()
|
||||
vector_store = get_vector_store()
|
||||
|
||||
# Run the ingestion pipeline
|
||||
_ = run_pipeline(docstore, vector_store, documents)
|
||||
|
||||
# Build the index and persist storage
|
||||
persist_storage(docstore, vector_store)
|
||||
persist_BMRetriever(vector_store)
|
||||
|
||||
logger.info("Finished generating the index")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from phoenix.trace import using_project
|
||||
with using_project(os.getenv("PHOENIX_PROJECT_NAME") + "_generate") as obj:
|
||||
generate_datasource()
|
||||
@@ -1,93 +0,0 @@
|
||||
from llama_index.core import PromptTemplate
|
||||
|
||||
text_qa_template_str = (
|
||||
"# 角色\n"
|
||||
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
|
||||
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
|
||||
"如同直接从文件中提取的内容。\n"
|
||||
"知识库中已经导入一个工程的全部数据,请你站在当前工程的角度回答用户关于工程文件的问题。\n"
|
||||
"例如:询问“此工程”指当前导入的工程。询问“此工程名称”指当前导入的工程的工程名称。\n"
|
||||
|
||||
"## 技能\n"
|
||||
"### 技能 1: 数据查询与提供\n"
|
||||
"- 准确回答所有关于电力工程造价的相关问题。\n"
|
||||
"- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n"
|
||||
"- 确保提供的信息严格基于工程文档中的记录。\n"
|
||||
|
||||
"### 技能 2: 技术性解释\n"
|
||||
"- 解释造价工程中的技术术语和概念。\n"
|
||||
"- 为复杂的工程细节提供清晰易懂的说明。\n"
|
||||
|
||||
"## 约束\n"
|
||||
"- 仅回答与电力工程造价文件相关的具体问题。\n"
|
||||
"- 不进行任何超出文件内容的猜测或假设。\n"
|
||||
"- 所有回答均基于文件内容,采用客观和技术性的语言。\n"
|
||||
"- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n"
|
||||
"以下为上下文信息\n"
|
||||
"---------------------\n"
|
||||
"{context_str}\n"
|
||||
"---------------------\n"
|
||||
"请根据上下文信息而非先前知识回答我的问题或回复我的指令。前面的上下文信息可能有用,也可能没用,你需要从我给出的上下文信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。我的问题或指令是什么语种,你就用什么语种回复。\n"
|
||||
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||
|
||||
"问题:{query_str}\n"
|
||||
"你的回复: "
|
||||
)
|
||||
|
||||
|
||||
text_qa_template = PromptTemplate(text_qa_template_str)
|
||||
|
||||
refine_template_str = (
|
||||
"这是原本的问题: {query_str}\n"
|
||||
"我们已经提供了回答: {existing_answer}\n"
|
||||
"现在我们有机会改进这个回答 "
|
||||
"使用以下更多上下文(仅当有助于改进回答时使用)\n"
|
||||
"如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n"
|
||||
"如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n"
|
||||
"------------\n"
|
||||
"{context_msg}\n"
|
||||
"------------\n"
|
||||
"如果回答中已经包含有正确答案,不要返回多余的解释等信息,只返回正确答案\n"
|
||||
"如果是表结构或者是数据库的相关内容,仅用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||
"改进的回答: "
|
||||
)
|
||||
|
||||
refine_template = PromptTemplate(refine_template_str)
|
||||
|
||||
summary_template_str = (
|
||||
"# 角色\n"
|
||||
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
|
||||
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
|
||||
"如同直接从文件中提取的内容。\n"
|
||||
|
||||
"## 技能\n"
|
||||
"### 技能 1: 数据查询与提供\n"
|
||||
"- 准确回答所有关于电力工程造价的相关问题。\n"
|
||||
"- 提供具体数据,如成本估算、材料清单、劳动力需求等。\n"
|
||||
"- 确保提供的信息严格基于工程文档中的记录。\n"
|
||||
|
||||
"### 技能 2: 技术性解释\n"
|
||||
"- 解释造价工程中的技术术语和概念。\n"
|
||||
"- 为复杂的工程细节提供清晰易懂的说明。\n"
|
||||
|
||||
"## 约束\n"
|
||||
"- 仅回答与电力工程造价文件相关的具体问题。\n"
|
||||
"- 不进行任何超出文件内容的猜测或假设。\n"
|
||||
"- 所有回答均基于文件内容,采用客观和技术性的语言。\n"
|
||||
"- 请基于这些信息回答问题。如果无法找到相关信息,请不要额外发散回答,不要回答多余的信息,只需要回答“我不知道这个问题的答案”。\n"
|
||||
"来自多个来源的上下文信息如下。\n"
|
||||
"---------------------\n"
|
||||
"{context_str}\n"
|
||||
"---------------------\n"
|
||||
"鉴于来自多个来源的信息而非先验知识, "
|
||||
"回答查询。\n"
|
||||
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||
"Query: {query_str}\n"
|
||||
"Answer: "
|
||||
)
|
||||
summary_template = PromptTemplate(summary_template_str)
|
||||
|
||||
simple_template_str = (
|
||||
"{query_str}"
|
||||
)
|
||||
simple_template = PromptTemplate(simple_template_str)
|
||||
@@ -1,71 +0,0 @@
|
||||
import os
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
from qdrant_client import qdrant_client
|
||||
|
||||
qclient = None
|
||||
|
||||
def get_qdrant_vector_store():
|
||||
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
||||
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
||||
host=os.getenv("VECTOR_STORE_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("VECTOR_STORE_PORT", "6333")),
|
||||
|
||||
if not vector_store_path or not host:
|
||||
raise ValueError(
|
||||
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
|
||||
)
|
||||
# if VECTOR_STORE_PATH is set, use a local QdrantVectorStore from the path
|
||||
# otherwise, use a remote QdrantVectorStore
|
||||
global qclient
|
||||
if qclient == None:
|
||||
if vector_store_path:
|
||||
qclient = qdrant_client.QdrantClient(
|
||||
path=vector_store_path,
|
||||
)
|
||||
else:
|
||||
qclient = qdrant_client.QdrantClient(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
vector_store = QdrantVectorStore(client=qclient, collection_name=collection_name)
|
||||
return vector_store
|
||||
|
||||
def get_chroma_vector_store():
|
||||
collection_name = os.getenv("VECTOR_STORE_COLLECTION", "default")
|
||||
vector_store_path = os.getenv("VECTOR_STORE_PATH")
|
||||
# if VECTOR_STORE_PATH is set, use a local ChromaVectorStore from the path
|
||||
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
|
||||
if vector_store_path:
|
||||
store = ChromaVectorStore.from_params(
|
||||
persist_dir=vector_store_path, collection_name=collection_name,
|
||||
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
|
||||
)
|
||||
else:
|
||||
if not os.getenv("VECTOR_STORE_HOST") or not os.getenv("VECTOR_STORE_PORT"):
|
||||
raise ValueError(
|
||||
"Please provide either VECTOR_STORE_PATH or VECTOR_STORE_HOST and VECTOR_STORE_PORT"
|
||||
)
|
||||
store = ChromaVectorStore.from_params(
|
||||
host=os.getenv("VECTOR_STORE_HOST"),
|
||||
port=int(os.getenv("VECTOR_STORE_PORT")),
|
||||
collection_name=collection_name,
|
||||
collection_kwargs={"metadata":{"hnsw:space":"cosine"}},
|
||||
)
|
||||
return store
|
||||
|
||||
def get_vector_store():
|
||||
store_type=os.getenv("VECTOR_STORE_TYPE")
|
||||
|
||||
store = None
|
||||
|
||||
match store_type:
|
||||
case "chroma":
|
||||
store = get_chroma_vector_store()
|
||||
case "qdrant":
|
||||
store = get_qdrant_vector_store()
|
||||
case _:
|
||||
raise ValueError(f"Invalid vector store type: {store_type}")
|
||||
|
||||
return store
|
||||
@@ -31,13 +31,19 @@ def get_chat_engine(filters=None, params=None):
|
||||
summary_query_tool = QueryEngineTool.from_defaults( query_engine=summary_query_engine, name="summary_query_tool",
|
||||
description="适用于任何需要进行全面总结、概括的要求。",
|
||||
)
|
||||
query_engine = create_query_engine(index,top_k,use_reranker,filters)
|
||||
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "COMPACT")
|
||||
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool",
|
||||
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后建议使用“zjdata_query_tool”工具。",
|
||||
)
|
||||
|
||||
query_engine = create_query_engine(index,top_k,use_reranker,filters,response_mode = "TREE_SUMMARIZE")
|
||||
query_engine_tool_1 = QueryEngineTool.from_defaults(query_engine=query_engine, name="zj_query_tool_1",
|
||||
description="由博微公司编制的关于电力造价知识、电力造价编制软件知识和造价工程文件结构的知识库。适用于查询电力领域、电力造价领域、博微、博微电力、博微造价等业务等内容。如果本知识库没有直接答案但有解决思路的可以返回解决办法后,且在询问工程中单位的具体数值,例如用量,费率,合计,金额等的时候建议使用“zj_query_tool_1”工具。",
|
||||
)
|
||||
|
||||
tools.append(summary_query_tool)
|
||||
tools.append(query_engine_tool)
|
||||
tools.append(query_engine_tool_1)
|
||||
|
||||
# Add additional tools
|
||||
tools += ToolFactory.from_env()
|
||||
|
||||
@@ -86,7 +86,7 @@ def create_summary_query_engine(index, top_k=3, use_reranker=False, filters=None
|
||||
return summary_query_engine
|
||||
|
||||
# Create a query engine
|
||||
def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
||||
def create_query_engine(index, top_k=3, use_reranker=False, filters=None, response_mode=None):
|
||||
# 创建向量检索查询工具
|
||||
postprocess = None
|
||||
if use_reranker:
|
||||
@@ -103,6 +103,7 @@ def create_query_engine(index, top_k=3, use_reranker=False, filters=None):
|
||||
node_postprocessors=postprocess,
|
||||
use_async=True,
|
||||
streaming=True,
|
||||
ResponseMode = response_mode
|
||||
)
|
||||
|
||||
return query_engine
|
||||
@@ -1,40 +0,0 @@
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_configs():
|
||||
with open("config/loaders.yaml") as f:
|
||||
configs = yaml.safe_load(f)
|
||||
return configs
|
||||
|
||||
|
||||
def get_documents():
|
||||
documents = []
|
||||
config = load_configs()
|
||||
if config is None or len(config.items()) == 0:
|
||||
return documents
|
||||
|
||||
for loader_type, loader_config in config.items():
|
||||
logger.info(
|
||||
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
||||
)
|
||||
|
||||
loader_config = loader_config or []
|
||||
match loader_type:
|
||||
case "file":
|
||||
document = get_file_documents(FileLoaderConfig(**loader_config))
|
||||
case "web":
|
||||
document = get_web_documents(WebLoaderConfig(**loader_config))
|
||||
case "db":
|
||||
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
|
||||
case _:
|
||||
raise ValueError(f"Invalid loader type: {loader_type}")
|
||||
documents.extend(document)
|
||||
|
||||
return documents
|
||||
@@ -1,140 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_index.core import SQLDatabase, Document
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CustomDatabaseReader(DatabaseReader):
|
||||
"""Simple Database reader.
|
||||
|
||||
Concatenates each row into Document used by LlamaIndex.
|
||||
|
||||
Args:
|
||||
sql_database (Optional[SQLDatabase]): SQL database to use,
|
||||
including table names to specify.
|
||||
See :ref:`Ref-Struct-Store` for more details.
|
||||
|
||||
OR
|
||||
|
||||
engine (Optional[Engine]): SQLAlchemy Engine object of the database connection.
|
||||
|
||||
OR
|
||||
|
||||
uri (Optional[str]): uri of the database connection.
|
||||
|
||||
OR
|
||||
|
||||
scheme (Optional[str]): scheme of the database connection.
|
||||
host (Optional[str]): host of the database connection.
|
||||
port (Optional[int]): port of the database connection.
|
||||
user (Optional[str]): user of the database connection.
|
||||
password (Optional[str]): password of the database connection.
|
||||
dbname (Optional[str]): dbname of the database connection.
|
||||
|
||||
Returns:
|
||||
DatabaseReader: A DatabaseReader object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sql_database: Optional[SQLDatabase] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
uri: Optional[str] = None,
|
||||
scheme: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
dbname: Optional[str] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with parameters."""
|
||||
if sql_database:
|
||||
self.sql_database = sql_database
|
||||
elif engine:
|
||||
self.sql_database = SQLDatabase(engine, *args, **kwargs)
|
||||
elif uri:
|
||||
self.uri = uri
|
||||
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||||
elif scheme and host and port and user and password and dbname:
|
||||
uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
|
||||
self.uri = uri
|
||||
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You must provide either a SQLDatabase, "
|
||||
"a SQL Alchemy Engine, a valid connection URI, or a valid "
|
||||
"set of credentials."
|
||||
)
|
||||
|
||||
def load_data(self, query: str, explanation: str) -> List[Document]:
|
||||
"""Query and load data from the Database, returning a list of Documents.
|
||||
|
||||
Args:
|
||||
query (str): Query parameter to filter tables and rows.
|
||||
explanation (str): Explanation for the query to be included in the document.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
dco_str = explanation + "\n"
|
||||
|
||||
with self.sql_database.engine.connect() as connection:
|
||||
if query is None:
|
||||
raise ValueError("A query parameter is necessary to filter the data")
|
||||
else:
|
||||
result = connection.execute(text(query))
|
||||
|
||||
dco_str += ", ".join(
|
||||
[f"{entry}" for entry in result.keys()]
|
||||
) + "\n"
|
||||
|
||||
for item in result.fetchall():
|
||||
# Fetch each item
|
||||
record_str = ", ".join(
|
||||
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
||||
)
|
||||
dco_str += record_str + "\n"
|
||||
|
||||
doc = Document(text=dco_str)
|
||||
doc.metadata["name"] = query
|
||||
doc.metadata["context"] = query
|
||||
doc.metadata["file_type"] = "application/vnd.ms-excel"
|
||||
return [doc]
|
||||
|
||||
class DBLoaderConfig(BaseModel):
|
||||
uri: str
|
||||
queries: List[dict]
|
||||
|
||||
def get_db_documents(configs: list[DBLoaderConfig]):
|
||||
docs = []
|
||||
|
||||
if len(configs) == 0 or configs[0].uri == "":
|
||||
logger.warning(
|
||||
f"Failed to load database, error message: uri is empty. Return as empty document list."
|
||||
)
|
||||
return docs
|
||||
|
||||
metadata = {
|
||||
'file_type': 'application/booway.document.zj',
|
||||
}
|
||||
|
||||
for entry in configs:
|
||||
engine = create_engine(entry.uri)
|
||||
sql_database = SQLDatabase(engine)
|
||||
|
||||
loader = CustomDatabaseReader(sql_database)
|
||||
for query_dict in entry.queries:
|
||||
query = query_dict.get("sql", "")
|
||||
explanation = query_dict.get("explanation", "")
|
||||
logger.info(f"Loading data from database with query: {query}")
|
||||
documents = loader.load_data(query=query, explanation=explanation)
|
||||
|
||||
docs.extend(documents)
|
||||
return docs
|
||||
@@ -1,88 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.readers.json import JSONReader
|
||||
from llama_parse import LlamaParse
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileLoaderConfig(BaseModel):
|
||||
data_dir: str = "data"
|
||||
use_llama_parse: bool = False
|
||||
|
||||
@validator("data_dir")
|
||||
def data_dir_must_exist(cls, v):
|
||||
if not os.path.isdir(v):
|
||||
raise ValueError(f"Directory '{v}' does not exist")
|
||||
return v
|
||||
|
||||
|
||||
def llama_parse_parser():
|
||||
if os.getenv("LLAMA_CLOUD_API_KEY") is None:
|
||||
raise ValueError(
|
||||
"LLAMA_CLOUD_API_KEY environment variable is not set. "
|
||||
"Please set it in .env file or in your shell environment then run again!"
|
||||
)
|
||||
parser = LlamaParse(
|
||||
result_type="markdown",
|
||||
verbose=True,
|
||||
language="en",
|
||||
ignore_errors=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
||||
from llama_parse.utils import SUPPORTED_FILE_TYPES
|
||||
|
||||
parser = llama_parse_parser()
|
||||
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
||||
|
||||
def llama_local_extractor() -> Dict[str, BaseReader]:
|
||||
return {".json" : JSONReader(clean_json=False,levels_back=0)}
|
||||
|
||||
|
||||
def get_file_documents(config: FileLoaderConfig):
|
||||
from llama_index.core.readers import SimpleDirectoryReader
|
||||
|
||||
try:
|
||||
file_extractor = None
|
||||
if config.use_llama_parse:
|
||||
# LlamaParse is async first,
|
||||
# so we need to use nest_asyncio to run it in sync mode
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
file_extractor = llama_parse_extractor()
|
||||
else:
|
||||
file_extractor = llama_local_extractor()
|
||||
|
||||
reader = SimpleDirectoryReader(
|
||||
config.data_dir,
|
||||
recursive=True,
|
||||
filename_as_id=True,
|
||||
raise_on_error=True,
|
||||
file_extractor=file_extractor,
|
||||
)
|
||||
return reader.load_data()
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Catch the error if the data dir is empty
|
||||
# and return as empty document list
|
||||
_, _, exc_traceback = sys.exc_info()
|
||||
function_name = traceback.extract_tb(exc_traceback)[-1].name
|
||||
if function_name == "_add_files":
|
||||
logger.warning(
|
||||
f"Failed to load file documents, error message: {e} . Return as empty document list."
|
||||
)
|
||||
return []
|
||||
else:
|
||||
# Raise the error if it is not the case of empty data dir
|
||||
raise e
|
||||
@@ -1,37 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CrawlUrl(BaseModel):
|
||||
base_url: str
|
||||
prefix: str
|
||||
max_depth: int = Field(default=1, ge=0)
|
||||
|
||||
|
||||
class WebLoaderConfig(BaseModel):
|
||||
driver_arguments: list[str] = Field(default=None)
|
||||
urls: list[CrawlUrl] = []
|
||||
|
||||
|
||||
def get_web_documents(config: WebLoaderConfig):
|
||||
from llama_index.readers.web import WholeSiteReader
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
|
||||
options = Options()
|
||||
driver_arguments = config.driver_arguments or []
|
||||
for arg in driver_arguments:
|
||||
options.add_argument(arg)
|
||||
|
||||
docs = []
|
||||
urls = config.urls or []
|
||||
for url in config.urls:
|
||||
scraper = WholeSiteReader(
|
||||
prefix=url.prefix,
|
||||
max_depth=url.max_depth,
|
||||
driver=webdriver.Chrome(options=options),
|
||||
)
|
||||
docs.extend(scraper.load_data(url.base_url))
|
||||
|
||||
return docs
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||
@@ -9,7 +8,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_configs():
|
||||
with open("config/loaders.yaml") as f:
|
||||
with open("config/loaders.yaml",encoding='UTF-8') as f:
|
||||
configs = yaml.safe_load(f)
|
||||
return configs
|
||||
|
||||
@@ -17,10 +16,12 @@ def load_configs():
|
||||
def get_documents():
|
||||
documents = []
|
||||
config = load_configs()
|
||||
|
||||
if config is None or len(config.items()) == 0:
|
||||
return documents
|
||||
|
||||
for loader_type, loader_config in config.items():
|
||||
if loader_config.get('enable', True): # 检查 enable 字段
|
||||
logger.info(
|
||||
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
||||
)
|
||||
|
||||
@@ -2,17 +2,14 @@ import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_index.core import SQLDatabase, Document
|
||||
from llama_index.core.objects import SQLTableSchema
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CustomDatabaseReader(BaseReader):
|
||||
class CustomDatabaseReader(DatabaseReader):
|
||||
"""Simple Database reader.
|
||||
|
||||
Concatenates each row into Document used by LlamaIndex.
|
||||
@@ -86,18 +83,19 @@ class CustomDatabaseReader(BaseReader):
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
dco_str = ""
|
||||
|
||||
with self.sql_database.engine.connect() as connection:
|
||||
if query is None:
|
||||
raise ValueError("A query parameter is necessary to filter the data")
|
||||
else:
|
||||
result = connection.execute(text(query))
|
||||
|
||||
dco_str = ", ".join(
|
||||
dco_str += ", ".join(
|
||||
[f"{entry}" for entry in result.keys()]
|
||||
)
|
||||
) + "\n"
|
||||
|
||||
for item in result.fetchall():
|
||||
# fetch each item
|
||||
# Fetch each item
|
||||
record_str = ", ".join(
|
||||
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
||||
)
|
||||
@@ -111,45 +109,36 @@ class CustomDatabaseReader(BaseReader):
|
||||
|
||||
class DBLoaderConfig(BaseModel):
|
||||
uri: str
|
||||
queries: List[str]
|
||||
queries: List[dict]
|
||||
|
||||
def get_db_documents(configs: list[DBLoaderConfig]):
|
||||
def get_db_documents(configs: List[DBLoaderConfig]) -> List[Document]:
|
||||
docs = []
|
||||
|
||||
if len(configs) == 0 or configs[0].uri == "":
|
||||
if not configs or not configs[0].uri:
|
||||
logger.warning(
|
||||
f"Failed to load database, error message: uri is empty. Return as empty document list."
|
||||
)
|
||||
return docs
|
||||
|
||||
metadata = {
|
||||
#'file_name':'',
|
||||
'file_type': 'application/booway.document.zj',
|
||||
#'file_path':'',
|
||||
#'file_size':'',
|
||||
#'creation_date':'',
|
||||
#'last_modified_date':'',
|
||||
}
|
||||
|
||||
#from llama_index.readers.database import DatabaseReader
|
||||
for entry in configs:
|
||||
engine = create_engine(entry.uri)
|
||||
sql_database = SQLDatabase(engine)
|
||||
|
||||
# table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
# table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
#
|
||||
# nodes = table_node_mapping.to_nodes(table_schema_objs)
|
||||
# for node in nodes:
|
||||
# node.metadata.update(metadata)
|
||||
#
|
||||
# docs.extend(nodes)
|
||||
|
||||
queries = entry.queries or []
|
||||
loader = CustomDatabaseReader(sql_database)
|
||||
for query in queries:
|
||||
for query_dict in entry.queries:
|
||||
query = query_dict.get("sql", "")
|
||||
explanation = query_dict.get("explanation", "")
|
||||
logger.info(f"Loading data from database with query: {query}")
|
||||
documents = loader.load_data(query=query)
|
||||
|
||||
docs.extend(documents)
|
||||
# 添加解释到元数据中
|
||||
for doc in documents:
|
||||
doc.metadata["explanation"] = explanation
|
||||
doc.metadata.update(metadata) # 更新或添加额外的元数据
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@@ -5,6 +5,8 @@ text_qa_template_str = (
|
||||
"你是一名博微造价工程数据查询助手,专精于电力工程文件中的信息。"
|
||||
"你的职责是提供有关电力造价、造价编制软件、文件结构及相关数据的精准、客观的回答,"
|
||||
"如同直接从文件中提取的内容。\n"
|
||||
"知识库中已经导入一个工程的全部数据,请你站在当前工程的角度回答用户关于工程文件的问题。\n"
|
||||
"例如:询问“此工程”指当前导入的工程。询问“此工程名称”指当前导入的工程的工程名称。\n"
|
||||
|
||||
"## 技能\n"
|
||||
"### 技能 1: 数据查询与提供\n"
|
||||
@@ -39,15 +41,19 @@ refine_template_str = (
|
||||
"这是原本的问题: {query_str}\n"
|
||||
"我们已经提供了回答: {existing_answer}\n"
|
||||
"现在我们有机会改进这个回答 "
|
||||
"使用以下更多上下文(仅当需要用时)\n"
|
||||
"使用以下更多上下文(仅当有助于改进回答时使用)\n"
|
||||
"你需要仔细的判断新的上下文的信息与原本问题必须一个字都不差,如果有一点差别,那就不能改变我现有的回答。\n"
|
||||
"在判断回答是否正确的时候,你应该仔细对比新的上下文中包含的信息是否与原本的问题一字不差,如果一字不差,才能当作新的正确回答。\n"
|
||||
"如果新的上下文对回答没有影响,或者原来的回答已经正确,不要在上次回答的后边再加上多余的补充信息,直接返回原本的回答。\n"
|
||||
"判断一下如果原回答正确,且在新的上下文仍然包含正确的回答,请将新的回答与原回答一起返回。\n"
|
||||
"------------\n"
|
||||
"{context_msg}\n"
|
||||
"------------\n"
|
||||
"根据新的上下文, 请改进原来的回答。"
|
||||
"如果新的上下文没有用, 直接返回原本的回答。\n"
|
||||
"如果是表结构或者是数据库的相关内容,只用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||
"如果回答中已经包含有正确答案,不要返回多余的解释等信息,只返回正确答案\n"
|
||||
"如果是表结构或者是数据库的相关内容,仅用于推导问题,不需要告诉用户数据库或表结构等物理信息。\n"
|
||||
"改进的回答: "
|
||||
)
|
||||
|
||||
refine_template = PromptTemplate(refine_template_str)
|
||||
|
||||
summary_template_str = (
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
|
||||
from llama_index.core.indices.vector_store.base import VectorStoreIndex
|
||||
from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
|
||||
from llama_index.core.storage.docstore.types import BaseDocumentStore
|
||||
from llama_index.core.vector_stores.utils import (
|
||||
node_to_metadata_dict,
|
||||
metadata_dict_to_node,
|
||||
)
|
||||
|
||||
import bm25s
|
||||
from app.engine.retriever.CHTokener import chTokenize
|
||||
|
||||
CHDEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"}
|
||||
|
||||
CHDEFAULT_PERSIST_FILENAME = "retriever.json"
|
||||
|
||||
class CHBM25Retriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
nodes: Optional[List[BaseNode]] = None,
|
||||
existing_bm25: Optional[bm25s.BM25] = None,
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
objects: Optional[List[IndexNode]] = None,
|
||||
object_map: Optional[dict] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.similarity_top_k = similarity_top_k
|
||||
if existing_bm25 is not None:
|
||||
self.bm25 = existing_bm25
|
||||
self.corpus = existing_bm25.corpus
|
||||
else:
|
||||
from nltk.corpus import stopwords
|
||||
if nodes is None:
|
||||
raise ValueError("Please pass nodes or an existing BM25 object.")
|
||||
|
||||
self.corpus = [node_to_metadata_dict(node) for node in nodes]
|
||||
|
||||
corpus_tokens = chTokenize(
|
||||
[node.get_content() for node in nodes],
|
||||
show_progress=verbose,
|
||||
)
|
||||
self.bm25 = bm25s.BM25()
|
||||
self.bm25.index(corpus_tokens, show_progress=verbose)
|
||||
super().__init__(
|
||||
callback_manager=callback_manager,
|
||||
object_map=object_map,
|
||||
objects=objects,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_defaults(
|
||||
cls,
|
||||
index: Optional[VectorStoreIndex] = None,
|
||||
nodes: Optional[List[BaseNode]] = None,
|
||||
docstore: Optional[BaseDocumentStore] = None,
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
verbose: bool = False,
|
||||
) -> "CHBM25Retriever":
|
||||
if sum(bool(val) for val in [index, nodes, docstore]) != 1:
|
||||
raise ValueError("Please pass exactly one of index, nodes, or docstore.")
|
||||
|
||||
if index is not None:
|
||||
docstore = index.docstore
|
||||
|
||||
if docstore is not None:
|
||||
nodes = cast(List[BaseNode], list(docstore.docs.values()))
|
||||
|
||||
assert (
|
||||
nodes is not None
|
||||
), "Please pass exactly one of index, nodes, or docstore."
|
||||
|
||||
return cls(
|
||||
nodes=nodes,
|
||||
similarity_top_k=similarity_top_k,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
def get_persist_args(self) -> Dict[str, Any]:
|
||||
"""Get Persist Args Dict to Save."""
|
||||
return {
|
||||
CHDEFAULT_PERSIST_ARGS[key]: getattr(self, key)
|
||||
for key in CHDEFAULT_PERSIST_ARGS
|
||||
if hasattr(self, key)
|
||||
}
|
||||
|
||||
def persist(self, path: str, **kwargs: Any) -> None:
|
||||
"""Persist the retriever to a directory."""
|
||||
self.bm25.save(path, corpus=self.corpus, **kwargs)
|
||||
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME), "w") as f:
|
||||
json.dump(self.get_persist_args(), f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_persist_dir(cls, path: str, **kwargs: Any) -> "CHBM25Retriever":
|
||||
"""Load the retriever from a directory."""
|
||||
bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
|
||||
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME)) as f:
|
||||
retriever_data = json.load(f)
|
||||
return cls(existing_bm25=bm25, **retriever_data)
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
query = query_bundle.query_str
|
||||
tokenized_query = chTokenize(
|
||||
query,show_progress=self._verbose
|
||||
)
|
||||
indexes, scores = self.bm25.retrieve(
|
||||
tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
|
||||
)
|
||||
|
||||
# batched, but only one query
|
||||
indexes = indexes[0]
|
||||
scores = scores[0]
|
||||
|
||||
nodes: List[NodeWithScore] = []
|
||||
for idx, score in zip(indexes, scores):
|
||||
# idx can be an int or a dict of the node
|
||||
if isinstance(idx, dict):
|
||||
node = metadata_dict_to_node(idx)
|
||||
else:
|
||||
node_dict = self.corpus[int(idx)]
|
||||
node = metadata_dict_to_node(node_dict)
|
||||
nodes.append(NodeWithScore(node=node, score=float(score)))
|
||||
|
||||
return nodes
|
||||
@@ -1,46 +0,0 @@
|
||||
from typing import Any, Dict, List, Union, Callable, NamedTuple
|
||||
from bm25s.tokenization import *
|
||||
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
except ImportError:
|
||||
|
||||
def tqdm(iterable, *args, **kwargs):
|
||||
return iterable
|
||||
|
||||
|
||||
def chinese_tokenizer(text: str) -> List[str]:
|
||||
import jieba
|
||||
from nltk.corpus import stopwords
|
||||
tokens = jieba.lcut(text)
|
||||
return [token for token in tokens if token not in stopwords.words('chinese')]
|
||||
|
||||
def chTokenize(
|
||||
texts,
|
||||
show_progress: bool = True,
|
||||
leave: bool = False,
|
||||
) -> Union[List[List[str]], Tokenized]:
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
corpus_ids = []
|
||||
token_to_index = {}
|
||||
|
||||
for text in tqdm(
|
||||
texts, desc="Split strings", leave=leave, disable=not show_progress
|
||||
):
|
||||
|
||||
splitted = chinese_tokenizer(text)
|
||||
doc_ids = []
|
||||
|
||||
for token in splitted:
|
||||
if token not in token_to_index:
|
||||
token_to_index[token] = len(token_to_index)
|
||||
|
||||
token_id = token_to_index[token]
|
||||
doc_ids.append(token_id)
|
||||
|
||||
corpus_ids.append(doc_ids)
|
||||
|
||||
return Tokenized(ids=corpus_ids, vocab=token_to_index)
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
import os
|
||||
from typing import Optional, Any, Dict, List
|
||||
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
|
||||
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||
|
||||
|
||||
class HybridRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
vector_index,
|
||||
similarity_top_k: int = 2,
|
||||
out_top_k: Optional[int] = None,
|
||||
alpha: float = 0.5,
|
||||
filters = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._vector_index = vector_index
|
||||
self._embed_model = vector_index._embed_model
|
||||
self._out_top_k = out_top_k or similarity_top_k
|
||||
self._vecRetriever = vector_index.as_retriever(
|
||||
similarity_top_k=similarity_top_k,filters = filters
|
||||
)
|
||||
|
||||
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
|
||||
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
|
||||
else:
|
||||
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None))
|
||||
bmRetriver.persist(STORAGE_DIR)
|
||||
self._alpha = alpha
|
||||
|
||||
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
vecNodes:List[NodeWithScore] = self._vecRetriever.retrieve(query_bundle.query_str)
|
||||
bmNodes:List[NodeWithScore] = self._bm25Retriever.retrieve(query_bundle.query_str)
|
||||
|
||||
bmDic:Dict[str,NodeWithScore] = {}
|
||||
for node in bmNodes:
|
||||
bmDic[node.node_id] = node
|
||||
|
||||
result_tups = []
|
||||
for i in range(len(vecNodes)):
|
||||
node = vecNodes[i]
|
||||
bmScore = 0.0
|
||||
if node.node_id in bmDic:
|
||||
bmScore = bmDic[node.node_id].score
|
||||
bmDic.pop(node.node_id)
|
||||
else:
|
||||
bmScore = 0.0
|
||||
full_similarity = (self._alpha * node.score) + (
|
||||
(1 - self._alpha) * bmScore
|
||||
)
|
||||
result_tups.append((full_similarity, node))
|
||||
|
||||
for _,node in bmDic.items():
|
||||
full_similarity = (1 - self._alpha) * node.score
|
||||
result_tups.append((full_similarity, node))
|
||||
|
||||
result_tups = sorted(result_tups, key=lambda x: x[0], reverse=True)
|
||||
for full_score, node in result_tups:
|
||||
node.score = full_score
|
||||
return [n for _, n in result_tups][:self._out_top_k]
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Union, Callable, NamedTuple
|
||||
from bm25s.tokenization import *
|
||||
|
||||
@@ -8,9 +9,12 @@ except ImportError:
|
||||
def tqdm(iterable, *args, **kwargs):
|
||||
return iterable
|
||||
|
||||
import jieba
|
||||
jiebapath = os.environ.get("JIEBA_DATA", "")
|
||||
jieba.set_dictionary(os.path.join(jiebapath, 'dict.txt')) #设置字典
|
||||
jieba.initialize() #初始化jeiba
|
||||
|
||||
def chinese_tokenizer(text: str) -> List[str]:
|
||||
import jieba
|
||||
from nltk.corpus import stopwords
|
||||
tokens = jieba.lcut(text)
|
||||
return [token for token in tokens if token not in stopwords.words('chinese')]
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
def duckduckgo_search(
|
||||
query: str,
|
||||
region: str = "wt-wt",
|
||||
max_results: int = 10,
|
||||
):
|
||||
"""
|
||||
Use this function to search for any query in DuckDuckGo.
|
||||
Args:
|
||||
query (str): The query to search in DuckDuckGo.
|
||||
region Optional(str): The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...
|
||||
max_results Optional(int): The maximum number of results to be returned. Default is 10.
|
||||
"""
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"duckduckgo_search package is required to use this function."
|
||||
"Please install it by running: `poetry add duckduckgo_search` or `pip install duckduckgo_search`"
|
||||
)
|
||||
|
||||
params = {
|
||||
"keywords": query,
|
||||
"region": region,
|
||||
"max_results": max_results,
|
||||
}
|
||||
results = []
|
||||
with DDGS() as ddg:
|
||||
results = list(ddg.text(**params))
|
||||
return results
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(duckduckgo_search)]
|
||||
@@ -1,60 +0,0 @@
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import importlib
|
||||
from cachetools import cached, LRUCache
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
class ToolType:
|
||||
LLAMAHUB = "llamahub"
|
||||
LOCAL = "local"
|
||||
|
||||
|
||||
class ToolFactory:
|
||||
|
||||
TOOL_SOURCE_PACKAGE_MAP = {
|
||||
ToolType.LLAMAHUB: "llama_index.tools",
|
||||
ToolType.LOCAL: "app.engine.tools",
|
||||
}
|
||||
|
||||
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
|
||||
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
|
||||
try:
|
||||
if "ToolSpec" in tool_name:
|
||||
tool_package, tool_cls_name = tool_name.split(".")
|
||||
module_name = f"{source_package}.{tool_package}"
|
||||
module = importlib.import_module(module_name)
|
||||
tool_class = getattr(module, tool_cls_name)
|
||||
tool_spec: BaseToolSpec = tool_class(**config)
|
||||
return tool_spec.to_tool_list()
|
||||
else:
|
||||
module = importlib.import_module(f"{source_package}.{tool_name}")
|
||||
tools = module.get_tools(**config)
|
||||
if not all(isinstance(tool, FunctionTool) for tool in tools):
|
||||
raise ValueError(
|
||||
f"The module {module} does not contain valid tools"
|
||||
)
|
||||
return tools
|
||||
except ImportError as e:
|
||||
raise ValueError(f"Failed to import tool {tool_name}: {e}")
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Failed to load tool {tool_name}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def from_env() -> list[FunctionTool]:
|
||||
tools = []
|
||||
if os.path.exists("config/tools.yaml"):
|
||||
with open("config/tools.yaml", "r") as f:
|
||||
tool_configs = yaml.safe_load(f)
|
||||
if tool_configs != None and len(tool_configs.items()) != 0:
|
||||
for tool_type, config_entries in tool_configs.items():
|
||||
if config_entries == None or len(config_entries.items()) == 0:
|
||||
continue
|
||||
|
||||
for tool_name, config in config_entries.items():
|
||||
tools.extend(
|
||||
ToolFactory.load_tools(tool_type, tool_name, config)
|
||||
)
|
||||
return tools
|
||||
@@ -1,108 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import requests
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from llama_index.core.tools import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageGeneratorToolOutput(BaseModel):
|
||||
is_success: bool = Field(
|
||||
...,
|
||||
description="Whether the image generation was successful.",
|
||||
)
|
||||
image_url: Optional[str] = Field(
|
||||
None,
|
||||
description="The URL of the generated image.",
|
||||
)
|
||||
error_message: Optional[str] = Field(
|
||||
None,
|
||||
description="The error message if the image generation failed.",
|
||||
)
|
||||
|
||||
|
||||
class ImageGeneratorTool:
|
||||
_IMG_OUTPUT_FORMAT = "webp"
|
||||
_IMG_OUTPUT_DIR = "output/tool"
|
||||
_IMG_GEN_API = "https://api.stability.ai/v2beta/stable-image/generate/core"
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
if not api_key:
|
||||
api_key = os.getenv("STABILITY_API_KEY")
|
||||
self._api_key = api_key
|
||||
self.fileserver_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if self._api_key is None:
|
||||
raise ValueError(
|
||||
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys"
|
||||
)
|
||||
if self.fileserver_url_prefix is None:
|
||||
raise ValueError("FILESERVER_URL_PREFIX is required.")
|
||||
|
||||
def _prepare_output_dir(self):
|
||||
"""
|
||||
Create the output directory if it doesn't exist
|
||||
"""
|
||||
if not os.path.exists(self._IMG_OUTPUT_DIR):
|
||||
os.makedirs(self._IMG_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
def _save_image(self, image_data: bytes):
|
||||
self._prepare_output_dir()
|
||||
filename = f"{uuid.uuid4()}.{self._IMG_OUTPUT_FORMAT}"
|
||||
output_path = os.path.join(self._IMG_OUTPUT_DIR, filename)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
url = f"{os.getenv('FILESERVER_URL_PREFIX')}/{self._IMG_OUTPUT_DIR}/{filename}"
|
||||
logger.info(f"Saved image to {output_path}.\nURL: {url}")
|
||||
return url
|
||||
|
||||
def _call_stability_api(self, prompt: str):
|
||||
headers = {
|
||||
"authorization": f"Bearer {self._api_key}",
|
||||
"accept": "image/*",
|
||||
}
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"output_format": self._IMG_OUTPUT_FORMAT,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self._IMG_GEN_API,
|
||||
headers=headers,
|
||||
files={"none": ""},
|
||||
data=data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
|
||||
def generate_image(self, prompt: str) -> ImageGeneratorToolOutput:
|
||||
"""
|
||||
Use this tool to generate an image based on the prompt.
|
||||
Args:
|
||||
prompt (str): The prompt to generate the image from.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Call the Stability API
|
||||
response = self._call_stability_api(prompt)
|
||||
|
||||
# Save the image and get the URL
|
||||
image_url = self._save_image(response.content)
|
||||
|
||||
return ImageGeneratorToolOutput(
|
||||
is_success=True,
|
||||
image_url=image_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e, exc_info=True)
|
||||
return ImageGeneratorToolOutput(
|
||||
is_success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(ImageGeneratorTool(**kwargs).generate_image)]
|
||||
@@ -1,143 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import base64
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
from llama_index.core.tools import FunctionTool
|
||||
from e2b_code_interpreter import CodeInterpreter
|
||||
from e2b_code_interpreter.models import Logs
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InterpreterExtraResult(BaseModel):
|
||||
type: str
|
||||
content: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class E2BToolOutput(BaseModel):
|
||||
is_error: bool
|
||||
logs: Logs
|
||||
results: List[InterpreterExtraResult] = []
|
||||
|
||||
|
||||
class E2BCodeInterpreter:
|
||||
|
||||
output_dir = "output/tool"
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
if api_key is None:
|
||||
api_key = os.getenv("E2B_API_KEY")
|
||||
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
|
||||
)
|
||||
if not filesever_url_prefix:
|
||||
raise ValueError(
|
||||
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
|
||||
)
|
||||
|
||||
self.filesever_url_prefix = filesever_url_prefix
|
||||
self.interpreter = CodeInterpreter(api_key=api_key)
|
||||
|
||||
def __del__(self):
|
||||
self.interpreter.close()
|
||||
|
||||
def get_output_path(self, filename: str) -> str:
|
||||
# if output directory doesn't exist, create it
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
return os.path.join(self.output_dir, filename)
|
||||
|
||||
def save_to_disk(self, base64_data: str, ext: str) -> Dict:
|
||||
filename = f"{uuid.uuid4()}.{ext}" # generate a unique filename
|
||||
buffer = base64.b64decode(base64_data)
|
||||
output_path = self.get_output_path(filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "wb") as file:
|
||||
file.write(buffer)
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to write to file {output_path}: {str(e)}")
|
||||
raise e
|
||||
|
||||
logger.info(f"Saved file to {output_path}")
|
||||
|
||||
return {
|
||||
"outputPath": output_path,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
def get_file_url(self, filename: str) -> str:
|
||||
return f"{self.filesever_url_prefix}/{self.output_dir}/{filename}"
|
||||
|
||||
def parse_result(self, result) -> List[InterpreterExtraResult]:
|
||||
"""
|
||||
The result could include multiple formats (e.g. png, svg, etc.) but encoded in base64
|
||||
We save each result to disk and return saved file metadata (extension, filename, url)
|
||||
"""
|
||||
if not result:
|
||||
return []
|
||||
|
||||
output = []
|
||||
|
||||
try:
|
||||
formats = result.formats()
|
||||
results = [result[format] for format in formats]
|
||||
|
||||
for ext, data in zip(formats, results):
|
||||
match ext:
|
||||
case "png" | "svg" | "jpeg" | "pdf":
|
||||
result = self.save_to_disk(data, ext)
|
||||
filename = result["filename"]
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext,
|
||||
filename=filename,
|
||||
url=self.get_file_url(filename),
|
||||
)
|
||||
)
|
||||
case _:
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext,
|
||||
content=data,
|
||||
)
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(error, exc_info=True)
|
||||
logger.error("Error when parsing output from E2b interpreter tool", error)
|
||||
|
||||
return output
|
||||
|
||||
def interpret(self, code: str) -> E2BToolOutput:
|
||||
"""
|
||||
Execute python code in a Jupyter notebook cell, the toll will return result, stdout, stderr, display_data, and error.
|
||||
|
||||
Parameters:
|
||||
code (str): The python code to be executed in a single cell.
|
||||
"""
|
||||
logger.info(
|
||||
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
|
||||
)
|
||||
exec = self.interpreter.notebook.exec_cell(code)
|
||||
|
||||
if exec.error:
|
||||
logger.error("Error when executing code", exec.error)
|
||||
output = E2BToolOutput(is_error=True, logs=exec.logs, results=[])
|
||||
else:
|
||||
if len(exec.results) == 0:
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
|
||||
else:
|
||||
results = self.parse_result(exec.results[0])
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=results)
|
||||
return output
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(E2BCodeInterpreter(**kwargs).interpret)]
|
||||
@@ -1,78 +0,0 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from llama_index.tools.openapi import OpenAPIToolSpec
|
||||
from llama_index.tools.requests import RequestsToolSpec
|
||||
|
||||
|
||||
class OpenAPIActionToolSpec(OpenAPIToolSpec, RequestsToolSpec):
|
||||
"""
|
||||
A combination of OpenAPI and Requests tool specs that can parse OpenAPI specs and make requests.
|
||||
|
||||
openapi_uri: str: The file path or URL to the OpenAPI spec.
|
||||
domain_headers: dict: Whitelist domains and the headers to use.
|
||||
"""
|
||||
|
||||
spec_functions = OpenAPIToolSpec.spec_functions + RequestsToolSpec.spec_functions
|
||||
# Cached parsed specs by URI
|
||||
_specs: Dict[str, Tuple[Dict, List[str]]] = {}
|
||||
|
||||
def __init__(self, openapi_uri: str, domain_headers: dict = None, **kwargs):
|
||||
if domain_headers is None:
|
||||
domain_headers = {}
|
||||
if openapi_uri not in self._specs:
|
||||
openapi_spec, servers = self._load_openapi_spec(openapi_uri)
|
||||
self._specs[openapi_uri] = (openapi_spec, servers)
|
||||
else:
|
||||
openapi_spec, servers = self._specs[openapi_uri]
|
||||
|
||||
# Add the servers to the domain headers if they are not already present
|
||||
for server in servers:
|
||||
if server not in domain_headers:
|
||||
domain_headers[server] = {}
|
||||
|
||||
OpenAPIToolSpec.__init__(self, spec=openapi_spec)
|
||||
RequestsToolSpec.__init__(self, domain_headers)
|
||||
|
||||
@staticmethod
|
||||
def _load_openapi_spec(uri: str) -> Tuple[Dict, List[str]]:
|
||||
"""
|
||||
Load an OpenAPI spec from a URI.
|
||||
|
||||
Args:
|
||||
uri (str): A file path or URL to the OpenAPI spec.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
import yaml
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if uri.startswith("http"):
|
||||
import requests
|
||||
|
||||
response = requests.get(uri)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: "
|
||||
f"Failed to load OpenAPI spec from {uri}, status code: {response.status_code}"
|
||||
)
|
||||
spec = yaml.safe_load(response.text)
|
||||
elif uri.startswith("file"):
|
||||
filepath = urlparse(uri).path
|
||||
with open(filepath, "r") as file:
|
||||
spec = yaml.safe_load(file)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI URI provided. "
|
||||
"Only HTTP and file path are supported."
|
||||
)
|
||||
# Add the servers to the whitelist
|
||||
try:
|
||||
servers = [
|
||||
urlparse(server["url"]).netloc for server in spec.get("servers", [])
|
||||
]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI spec provided. "
|
||||
"Could not get `servers` from the spec."
|
||||
) from e
|
||||
return spec, servers
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Open Meteo weather map tool spec."""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import pytz
|
||||
from llama_index.core.tools import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenMeteoWeather:
|
||||
geo_api = "https://geocoding-api.open-meteo.com/v1"
|
||||
weather_api = "https://api.open-meteo.com/v1"
|
||||
|
||||
@classmethod
|
||||
def _get_geo_location(cls, location: str) -> dict:
|
||||
"""Get geo location from location name."""
|
||||
params = {"name": location, "count": 10, "language": "en", "format": "json"}
|
||||
response = requests.get(f"{cls.geo_api}/search", params=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to fetch geo location: {response.status_code}")
|
||||
else:
|
||||
data = response.json()
|
||||
result = data["results"][0]
|
||||
geo_location = {
|
||||
"id": result["id"],
|
||||
"name": result["name"],
|
||||
"latitude": result["latitude"],
|
||||
"longitude": result["longitude"],
|
||||
}
|
||||
return geo_location
|
||||
|
||||
@classmethod
|
||||
def get_weather_information(cls, location: str) -> dict:
|
||||
"""Use this function to get the weather of any given location.
|
||||
Note that the weather code should follow WMO Weather interpretation codes (WW):
|
||||
0: Clear sky
|
||||
1, 2, 3: Mainly clear, partly cloudy, and overcast
|
||||
45, 48: Fog and depositing rime fog
|
||||
51, 53, 55: Drizzle: Light, moderate, and dense intensity
|
||||
56, 57: Freezing Drizzle: Light and dense intensity
|
||||
61, 63, 65: Rain: Slight, moderate and heavy intensity
|
||||
66, 67: Freezing Rain: Light and heavy intensity
|
||||
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
|
||||
77: Snow grains
|
||||
80, 81, 82: Rain showers: Slight, moderate, and violent
|
||||
85, 86: Snow showers slight and heavy
|
||||
95: Thunderstorm: Slight or moderate
|
||||
96, 99: Thunderstorm with slight and heavy hail
|
||||
"""
|
||||
logger.info(
|
||||
f"Calling open-meteo api to get weather information of location: {location}"
|
||||
)
|
||||
geo_location = cls._get_geo_location(location)
|
||||
timezone = pytz.timezone("UTC").zone
|
||||
params = {
|
||||
"latitude": geo_location["latitude"],
|
||||
"longitude": geo_location["longitude"],
|
||||
"current": "temperature_2m,weather_code",
|
||||
"hourly": "temperature_2m,weather_code",
|
||||
"daily": "weather_code",
|
||||
"timezone": timezone,
|
||||
}
|
||||
response = requests.get(f"{cls.weather_api}/forecast", params=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch weather information: {response.status_code}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import importlib
|
||||
from cachetools import cached, LRUCache
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
import os
|
||||
|
||||
import yaml
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
|
||||
|
||||
class ToolType:
|
||||
@@ -46,7 +45,7 @@ class ToolFactory:
|
||||
def from_env() -> list[FunctionTool]:
|
||||
tools = []
|
||||
if os.path.exists("config/tools.yaml"):
|
||||
with open("config/tools.yaml", "r") as f:
|
||||
with open("config/tools.yaml", "r", encoding='UTF-8') as f:
|
||||
tool_configs = yaml.safe_load(f)
|
||||
if tool_configs != None and len(tool_configs.items()) != 0:
|
||||
for tool_type, config_entries in tool_configs.items():
|
||||
|
||||
@@ -3,11 +3,10 @@ from typing import Dict
|
||||
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.core.settings import Settings
|
||||
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||
from llama_index.llms.xinference import Xinference
|
||||
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
||||
|
||||
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
"""Xinference embeddings file."""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding, dispatcher
|
||||
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||
from llama_index.core.callbacks import CBEventType, EventPayload
|
||||
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
|
||||
from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReRankEndEvent
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
|
||||
from pydantic import Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
EMBED_MAX_INPUT_LENGTH = 2048
|
||||
EMBED_MAX_BATCH_SIZE = 1
|
||||
|
||||
|
||||
class XinferenceEmbedding(BaseEmbedding):
|
||||
"""Xinference class for text embedding.
|
||||
|
||||
"""
|
||||
model_description: Dict[str, Any] = Field(
|
||||
description="The model description from Xinference."
|
||||
)
|
||||
_generator: Any = PrivateAttr()
|
||||
_model_uid: str = Field(description="The Xinference model to use.")
|
||||
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_uid: str,
|
||||
endpoint: str,
|
||||
embed_batch_size: int = EMBED_MAX_BATCH_SIZE,
|
||||
dimensions: Optional[int] = None,
|
||||
additional_kwargs: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
max_retries: int = 10,
|
||||
# timeout: float = 60.0,
|
||||
# reuse_client: bool = True,
|
||||
# callback_manager: Optional[CallbackManager] = None,
|
||||
# default_headers: Optional[Dict[str, str]] = None,
|
||||
# http_client: Optional[httpx.Client] = None,
|
||||
# async_http_client: Optional[httpx.AsyncClient] = None,
|
||||
# num_workers: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
generator, model_description, embed_batch_size, dimensions = self.load_model(
|
||||
model_uid, endpoint
|
||||
)
|
||||
self._generator = generator
|
||||
#self._model_uid = model_uid
|
||||
#self._endpoint = endpoint
|
||||
super().__init__(
|
||||
embed_batch_size=embed_batch_size,
|
||||
dimensions=dimensions,
|
||||
#callback_manager=callback_manager,
|
||||
model_name=model_uid,
|
||||
additional_kwargs=additional_kwargs,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
max_retries=max_retries,
|
||||
# reuse_client=reuse_client,
|
||||
# timeout=timeout,
|
||||
# default_headers=default_headers,
|
||||
# num_workers=num_workers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Xinference library."
|
||||
'Please install Xinference with `pip install "xinference[all]"`'
|
||||
)
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
try:
|
||||
assert isinstance(client, RESTfulClient)
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not create RESTfulClient instance."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
generator = client.get_model(model_uid)
|
||||
model_description = client.list_models()[model_uid]
|
||||
|
||||
try:
|
||||
assert generator is not None
|
||||
assert model_description is not None
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not get model from endpoint."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
model = model_description["model_name"]
|
||||
replica = model_description['replica']
|
||||
dimensions = model_description['dimensions']
|
||||
max_tokens = model_description['max_tokens']
|
||||
|
||||
return generator, model_description, replica, dimensions
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "XinferenceEmbedding"
|
||||
|
||||
def _get_text_embedding(self, text: str) -> Embedding:
|
||||
"""
|
||||
Embed the input text synchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_text_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
assert self._generator is not None
|
||||
|
||||
response = self._generator.create_embedding(input=text)
|
||||
return response['data'][0]['embedding']
|
||||
|
||||
def _get_query_embedding(self, query: str) -> Embedding:
|
||||
"""
|
||||
Embed the input query synchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_query_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
return self._get_text_embedding(query)
|
||||
|
||||
async def _aget_query_embedding(self, query: str) -> Embedding:
|
||||
"""
|
||||
Embed the input query asynchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_query_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
return self._get_query_embedding(query)
|
||||
|
||||
class XinferenceRerank(BaseNodePostprocessor):
|
||||
"""Xinference class for rerank.
|
||||
|
||||
"""
|
||||
model_description: Dict[str, Any] = Field(
|
||||
description="The model description from Xinference."
|
||||
)
|
||||
_generator: Any = PrivateAttr()
|
||||
_model_uid: str = Field(description="The Xinference model to use.")
|
||||
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||
model: str = Field(description="Dashscope rerank model name.")
|
||||
top_n: int = Field(description="Top N nodes to return.")
|
||||
threshold: float = Field(description="threshold nodes to return.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_uid: str,
|
||||
endpoint: str,
|
||||
top_n: int = None,
|
||||
threshold: float = None,
|
||||
return_documents: bool = False
|
||||
):
|
||||
_model_uid = model_uid
|
||||
_endpoint = endpoint
|
||||
_op_n = top_n
|
||||
threshold = threshold
|
||||
generator, model_description = self.load_model(
|
||||
model_uid, endpoint
|
||||
)
|
||||
self._generator = generator
|
||||
super().__init__(top_n=top_n, model=model_uid, model_uid=model_uid, threshold = threshold, return_documents=return_documents)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "XinferenceRerank"
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: List[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> List[NodeWithScore]:
|
||||
if query_bundle is None:
|
||||
raise ValueError("Missing query bundle in extra info.")
|
||||
if len(nodes) == 0:
|
||||
return []
|
||||
|
||||
dispatcher.event(
|
||||
ReRankStartEvent(
|
||||
nodes = nodes,
|
||||
top_n = self.top_n,
|
||||
query = query_bundle,
|
||||
model_name = self.model
|
||||
)
|
||||
)
|
||||
|
||||
with self.callback_manager.event(
|
||||
CBEventType.RERANKING,
|
||||
payload={
|
||||
EventPayload.NODES: nodes,
|
||||
EventPayload.MODEL_NAME: self._model_uid,
|
||||
EventPayload.QUERY_STR: query_bundle.query_str,
|
||||
EventPayload.TOP_K: self.top_n,
|
||||
},
|
||||
) as event:
|
||||
texts = [node.node.get_content() for node in nodes]
|
||||
response = self._generator.rerank(texts,query_bundle.query_str)
|
||||
new_nodes = []
|
||||
for result in response['results']:
|
||||
new_node_with_score = NodeWithScore(
|
||||
node=nodes[result['index']].node, score=result['relevance_score']
|
||||
)
|
||||
if self.threshold is not None:
|
||||
if new_node_with_score.score >=self.threshold:
|
||||
new_nodes.append(new_node_with_score)
|
||||
|
||||
if self.top_n is not None:
|
||||
if len(new_nodes) > self.top_n:
|
||||
for index in new_nodes[self.top_n:-1]:
|
||||
new_nodes.remove(index)
|
||||
|
||||
event.on_end(payload={EventPayload.NODES: new_nodes})
|
||||
|
||||
dispatcher.event(
|
||||
ReRankEndEvent(
|
||||
nodes= new_nodes
|
||||
)
|
||||
)
|
||||
return new_nodes
|
||||
|
||||
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Xinference library."
|
||||
'Please install Xinference with `pip install "xinference[all]"`'
|
||||
)
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
try:
|
||||
assert isinstance(client, RESTfulClient)
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not create RESTfulClient instance."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
generator = client.get_model(model_uid)
|
||||
model_description = client.list_models()[model_uid]
|
||||
|
||||
try:
|
||||
assert generator is not None
|
||||
assert model_description is not None
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not get model from endpoint."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
model = model_description["model_name"]
|
||||
|
||||
return generator, model_description
|
||||
@@ -1,4 +1,5 @@
|
||||
file:
|
||||
enable: true # 添加 enable 字段
|
||||
# use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable
|
||||
use_llama_parse: false
|
||||
|
||||
@@ -7,27 +8,41 @@ db:
|
||||
# uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
|
||||
# query: The query to fetch data from the database. E.g.: SELECT * FROM table
|
||||
- uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
#- uri: mysql+pymysql://zjinfo:Y6EAjEEdSYmskA8B@110.42.234.166:3306/zjinfo
|
||||
# - uri: mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||
enable: true # 添加 enable 字段
|
||||
queries:
|
||||
- sql: select * from ProjectProperties limit 30;
|
||||
- sql: select * from ProjectProperties;
|
||||
explanation: "工程属性表数据,层级关系包含在博微电力造价工程文件格式_ProjectProperties.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, Name, Code, Amount, Amount_Total from TotalCalculateTable;
|
||||
explanation: "总算表数据,层级关系包含在博微电力造价工程文件格式_TotalCalculateTable.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '线路' limit 50;
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '线路';
|
||||
explanation: "专业类型为线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '余物清理' limit 50;
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '余物清理';
|
||||
explanation: "专业类型为余物清理的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 and ProfessionalType = '拆除线路' limit 50;
|
||||
- sql: select Id, ParentId, Level, SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where ProfessionalType = '拆除线路';
|
||||
explanation: "专业类型为拆除线路的项目划分表数据,层级关系包含在博微电力造价工程文件格式_ProjectDivision.json文件中。"
|
||||
|
||||
- sql: select Id, ParentId, Level, Name, Code, Rate, Amount from OtherFee;
|
||||
explanation: "其他费用表数据,层级关系包含在博微电力造价工程文件格式_OtherFee.json文件中"
|
||||
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(调试工程)aa'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '大型土石方取费表'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(余物清理)(1)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from FeeCollectionTable where FeeCollection_Table_Name = '线路取费表(拆除)'
|
||||
explanation: "取费表名称为线路取费表的取费表数据,层级关系包含在博微电力造价工程文件格式_FeeCollectionTable.json文件中"
|
||||
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '线路'
|
||||
explanation: "专业类型为线路的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
- sql: select Name, Code, Calculation_Formula, Rate, from ProjectQuantities where Professional_Type = '余物清理'
|
||||
explanation: "专业类型为余物清理的工程量表数据,层级关系包含在博微电力造价工程文件格式_ProjectQuantities.json文件中"
|
||||
#web:
|
||||
# driver_arguments:
|
||||
# # The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
+3
-2
@@ -1,7 +1,5 @@
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
@@ -12,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from app.api.routers.chat import chat_router
|
||||
from app.api.routers.upload import file_upload_router
|
||||
from app.api.routers.app import v1_router
|
||||
from app.settings import init_settings
|
||||
from app.observability import init_observability
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -56,6 +55,8 @@ mount_static_files("data_output", "/api/files/output")
|
||||
app.include_router(chat_router, prefix="/api/chat")
|
||||
app.include_router(file_upload_router, prefix="/api/chat/upload")
|
||||
|
||||
app.include_router(v1_router, prefix="/v1")
|
||||
|
||||
@app.get("/")
|
||||
async def redirect_to_docs():
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
Binary file not shown.
Binary file not shown.
+349046
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
@@ -17,7 +17,7 @@ aiostream = "^0.6.2"
|
||||
llama-index = "0.10.63"
|
||||
cachetools = "^5.3.3"
|
||||
protobuf = "4.25.4"
|
||||
nltk = "^3.8.2"
|
||||
nltk = "^3.9.1"
|
||||
jieba = "^0.42.1"
|
||||
|
||||
#arize-phoenix = "^4.12.0"
|
||||
@@ -35,6 +35,7 @@ chroma="^0.2.0"
|
||||
llama-index-vector-stores-chroma = "^0.1.10"
|
||||
llama-index-readers-json = "^0.1.5"
|
||||
llama-index-retrievers-bm25 = "^0.2.2"
|
||||
llama-index-experimental = "^0.2.0"
|
||||
|
||||
duckduckgo_search = "^6.2.6"
|
||||
|
||||
@@ -62,6 +63,12 @@ version = "^0.8"
|
||||
version = "0.0.7"
|
||||
|
||||
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "mirrors"
|
||||
url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
|
||||
priority = "default"
|
||||
|
||||
[build-system]
|
||||
requires = [ "poetry-core" ]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,138 @@
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.evaluation import (
|
||||
FaithfulnessEvaluator,
|
||||
DatasetGenerator,
|
||||
CorrectnessEvaluator,
|
||||
SemanticSimilarityEvaluator,
|
||||
)
|
||||
from llama_index.experimental.param_tuner import ParamTuner
|
||||
from llama_index.experimental.param_tuner.base import RunResult
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
import asyncio
|
||||
|
||||
# 初始化环境
|
||||
from app.observability import init_observability
|
||||
from app.settings import init_settings
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
# 读取文档
|
||||
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
||||
|
||||
# 参数字典
|
||||
param_dict = {
|
||||
"chunk_size": [512, 1024],
|
||||
"top_k": [1, 5],
|
||||
"temperature": [0.1, 1.0]
|
||||
}
|
||||
|
||||
# 辅助函数
|
||||
def _build_index(chunk_size, documents):
|
||||
# 构建索引
|
||||
splitter = SentenceSplitter(chunk_size=chunk_size)
|
||||
vector_index = VectorStoreIndex.from_documents(
|
||||
documents, transformations=[splitter],
|
||||
)
|
||||
return vector_index
|
||||
|
||||
# 评估函数
|
||||
def evaluate_query_engine(query_engine, questions):
|
||||
loop = asyncio.get_event_loop()
|
||||
correct, total = loop.run_until_complete(_evaluate_query_engine_async(query_engine, questions))
|
||||
return correct, total
|
||||
|
||||
async def _evaluate_query_engine_async(query_engine, questions):
|
||||
c = [query_engine.aquery(q) for q in questions]
|
||||
gathering_future = asyncio.gather(*c)
|
||||
results = await gathering_future
|
||||
|
||||
total_correct = 0
|
||||
for r in results:
|
||||
eval_result = (
|
||||
1 if FaithfulnessEvaluator().evaluate_response(response=r).passing else 0
|
||||
)
|
||||
total_correct += eval_result
|
||||
|
||||
return total_correct, len(results)
|
||||
|
||||
|
||||
|
||||
# 生成问题
|
||||
question_generator = DatasetGenerator.from_documents(documents)
|
||||
eval_questions = question_generator.generate_questions_from_nodes(1) # 假设生成10个问题
|
||||
|
||||
# 打印生成的问题
|
||||
for i, q in enumerate(eval_questions, start=1):
|
||||
print(f"问题 {i}: {q}")
|
||||
|
||||
# 目标函数
|
||||
def objective_function(params_dict, documents, questions):
|
||||
chunk_size = params_dict["chunk_size"]
|
||||
top_k = params_dict["top_k"]
|
||||
temperature = params_dict["temperature"]
|
||||
|
||||
# 构建索引
|
||||
vector_index = _build_index(chunk_size, documents)
|
||||
|
||||
# 查询引擎
|
||||
query_engine = vector_index.as_query_engine(
|
||||
similarity_top_k=top_k, temperature=temperature
|
||||
)
|
||||
|
||||
# 评估查询引擎
|
||||
correct, total = 0, len(questions)
|
||||
question_answers = [] # 添加列表来收集问题和答案
|
||||
|
||||
for question in questions:
|
||||
response = query_engine.query(question)
|
||||
if response is not None:
|
||||
question_answers.append((question, response.response))
|
||||
eval_result = FaithfulnessEvaluator().evaluate_response(response=response, query_str=question)
|
||||
if eval_result.passing:
|
||||
correct += 1
|
||||
|
||||
# 计算分数
|
||||
score = correct / total if total > 0 else 0
|
||||
return RunResult(score=score, params=params_dict, question_answers=question_answers)
|
||||
|
||||
# 创建 ParamTuner 实例
|
||||
param_tuner = ParamTuner(
|
||||
param_fn=lambda params_dict: objective_function(params_dict, documents, eval_questions),
|
||||
param_dict=param_dict,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
# 调用 tune 方法
|
||||
results = param_tuner.tune()
|
||||
best_result = results.best_run_result
|
||||
best_top_k = best_result.params["top_k"]
|
||||
best_chunk_size = best_result.params["chunk_size"]
|
||||
best_temperature = best_result.params["temperature"]
|
||||
print(f"得分: {best_result.score}")
|
||||
print(f"Top-k: {best_top_k}")
|
||||
print(f"文本块大小: {best_chunk_size}")
|
||||
print(f"温度: {best_temperature}")
|
||||
|
||||
# 使用最佳参数再次运行查询引擎,并打印问题与答案
|
||||
best_vector_index = _build_index(best_chunk_size, documents)
|
||||
best_query_engine = best_vector_index.as_query_engine(
|
||||
similarity_top_k=best_top_k, temperature=best_temperature
|
||||
)
|
||||
|
||||
best_question_answers = []
|
||||
for question in eval_questions:
|
||||
response = best_query_engine.query(question)
|
||||
if response is not None:
|
||||
best_question_answers.append((question, response.response))
|
||||
|
||||
# 打印最佳参数下的问题与答案
|
||||
for i, (question, answer) in enumerate(best_question_answers, start=1):
|
||||
print(f"最佳参数 - 问题 {i}: {question}\n答案: {answer}\n")
|
||||
@@ -0,0 +1,81 @@
|
||||
from app.observability import init_observability
|
||||
from app.settings import init_settings
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core import (
|
||||
VectorStoreIndex,
|
||||
SimpleDirectoryReader,
|
||||
Response,
|
||||
)
|
||||
from llama_index.core.evaluation import (
|
||||
FaithfulnessEvaluator,
|
||||
DatasetGenerator,
|
||||
CorrectnessEvaluator,
|
||||
SemanticSimilarityEvaluator,)
|
||||
|
||||
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
faith_evaluator_qwen = FaithfulnessEvaluator() #诚实度评测
|
||||
corr_evaluator_qwen = CorrectnessEvaluator() #准确率评测
|
||||
Seman_evaluator_qwen = SemanticSimilarityEvaluator()#嵌入相似度评估
|
||||
|
||||
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
||||
|
||||
splitter = SentenceSplitter(chunk_size=512)
|
||||
|
||||
|
||||
vector_index = VectorStoreIndex.from_documents(
|
||||
documents, transformations=[splitter],
|
||||
)
|
||||
|
||||
|
||||
# # 运行评估
|
||||
# query_engine = vector_index.as_query_engine()
|
||||
# response_vector = query_engine.query("工程监理费的金额是多少?")
|
||||
# eval_result = evaluator_qwen.evaluate_response(response=response_vector)
|
||||
|
||||
# print(response_vector)
|
||||
# print(eval_result)
|
||||
|
||||
|
||||
question_generator = DatasetGenerator.from_documents(documents)
|
||||
eval_questions = question_generator.generate_questions_from_nodes(5)
|
||||
print(eval_questions)
|
||||
|
||||
import asyncio
|
||||
|
||||
async def evaluate_query_engine_async(query_engine, questions):
|
||||
c = [query_engine.aquery(q) for q in questions]
|
||||
gathering_future = asyncio.gather(*c)
|
||||
results = await gathering_future
|
||||
#print(results)
|
||||
|
||||
total_correct = 0
|
||||
for r in results:
|
||||
eval_result = (
|
||||
1 if faith_evaluator_qwen.evaluate_response(response=r).passing else 0
|
||||
)
|
||||
total_correct += eval_result
|
||||
|
||||
return total_correct, len(results)
|
||||
|
||||
def evaluate_query_engine(query_engine, questions):
|
||||
loop = asyncio.get_event_loop()
|
||||
correct, total = loop.run_until_complete(evaluate_query_engine_async(query_engine, questions))
|
||||
return correct, total
|
||||
|
||||
# 使用 evaluate_query_engine 函数
|
||||
vector_query_engine = vector_index.as_query_engine()
|
||||
correct, total = evaluate_query_engine(vector_query_engine, eval_questions[:5])
|
||||
|
||||
print(f"score: {correct}/{total}")
|
||||
@@ -0,0 +1,121 @@
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from llama_index.core.evaluation import CorrectnessEvaluator
|
||||
from app.engine import get_chat_engine
|
||||
from app.engine.index import get_index
|
||||
from app.observability import init_observability
|
||||
from app.settings import init_settings
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
index = get_index()
|
||||
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
from llama_index.core.prompts import (
|
||||
ChatMessage,
|
||||
ChatPromptTemplate,
|
||||
MessageRole
|
||||
)
|
||||
|
||||
DEFAULT_SYSTEM_TEMPLATE = """
|
||||
您是一个问答聊天机器人的专业评估系统。
|
||||
|
||||
您将获得以下信息:
|
||||
|
||||
- 用户查询,
|
||||
- 生成的回答,
|
||||
|
||||
也可能提供一个参考答案作为评估的依据。
|
||||
|
||||
您的任务是判断生成回答的相关性和正确性。
|
||||
输出一个代表全面评估的单一分数。
|
||||
您必须在一行中仅返回该分数。
|
||||
不要以其他任何格式返回答案。
|
||||
在单独的一行提供给定分数的理由。
|
||||
|
||||
请遵循以下评分指南:
|
||||
|
||||
- 您的分数必须在1到5之间,其中1是最差,5是最好的。
|
||||
-如果生成的回答与用户查询不相关,您应该给出1分。
|
||||
-如果生成的回答相关但包含错误,您应该给出2到3分之间的分数。
|
||||
-如果生成的回答相关且完全正确,您应该给出4到5分之间的分数。
|
||||
示例响应:
|
||||
4.0
|
||||
生成的回答与参考答案的指标完全相同,但不够精炼。
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_USER_TEMPLATE = """
|
||||
## User Query
|
||||
{query}
|
||||
|
||||
## Reference Answer
|
||||
{reference_answer}
|
||||
|
||||
## Generated Answer
|
||||
{generated_answer}
|
||||
"""
|
||||
|
||||
DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
|
||||
message_templates=[
|
||||
ChatMessage(role=MessageRole.SYSTEM, content=DEFAULT_SYSTEM_TEMPLATE),
|
||||
ChatMessage(role=MessageRole.USER, content=DEFAULT_USER_TEMPLATE),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# 初始化聊天引擎和评估器
|
||||
chat_engine = get_chat_engine()
|
||||
corr_evaluator_qwen = CorrectnessEvaluator()
|
||||
|
||||
# 加载本地问题回答文件
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(script_dir, 'questions_and_answers.json')
|
||||
output_file_path = file_path.replace('.json', '_test.json')
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 异步函数用于评估查询
|
||||
async def evaluate_query(question, answer, index, output_file):
|
||||
response = await chat_engine.astream_chat(question)
|
||||
|
||||
# 检查sources是否为空
|
||||
if response.sources:
|
||||
content_str = str(response.sources[0])
|
||||
else:
|
||||
content_str = "<无回答>"
|
||||
|
||||
result = corr_evaluator_qwen.evaluate(
|
||||
query=question,
|
||||
response=content_str,
|
||||
reference=answer,
|
||||
)
|
||||
|
||||
result_dict = {
|
||||
"编号": index,
|
||||
"问题": question,
|
||||
"答案": answer,
|
||||
"回答": result.response,
|
||||
"得分(1~5)": result.score,
|
||||
"评价": result.feedback
|
||||
}
|
||||
|
||||
with open(output_file, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(result_dict, ensure_ascii=False, indent=4))
|
||||
f.write(',\n')
|
||||
|
||||
# 主异步函数
|
||||
async def main():
|
||||
for index, item in enumerate(data, start=1):
|
||||
await evaluate_query(item['question'], item['answer'], index, output_file_path)
|
||||
|
||||
# 运行主协程
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,55 @@
|
||||
Attribute_Prompt = (
|
||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||
"现在需要你针对数据中属性一列进行提问和回答。"
|
||||
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元),工程总投资(万元)是77469835.590045万元','尖峰及施工基面土石方量,尖峰及施工基面土石方量是8377.6','截止阀的编码,截止阀的编码是F01010203',"
|
||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||
)
|
||||
|
||||
|
||||
Amount_Prompt = (
|
||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||
"现在需要你针对上下文信息中的金额或者合价进行提问和回答。"
|
||||
"问题和回答的示例应该是这种类型的,示例:'项目建设技术服务费的金额,项目建设技术服务费的金额是16855957065.4302','项目后评价费的费率,项目后评价费的费率是0.5','架空输电线路本体工程的金额,架空输电线路本体工程的金额是55105688268.5176','工程静态投资的金额,工程静态投资的金额是715035853336.391'"
|
||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||
)
|
||||
|
||||
|
||||
|
||||
Units_Prompt = (
|
||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||
"现在需要你针对上下文信息来进行单位转化问题提问和回答。"
|
||||
"问题和回答的示例应该是这种类型的,示例:'工程总投资(万元)结果用元表示,工程总投资(万元)是774698355900.45元','本体工程(元)结果用万元表示,本体工程(元)是5490494.261046万元'"
|
||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||
)
|
||||
|
||||
Name_Prompt = (
|
||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||
"现在需要你针对上下文信息中的重名问题进行提问和回答。"
|
||||
"问题和回答的示例应该是这种类型的,示例:'专业类型为线路的杆塔工程项目划分的合价,专业类型为线路的杆塔工程项目划分的合价是220969744.905856','专业类型为线路清理的杆塔工程项目划分的合价,电缆工程的合价是0'"
|
||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||
)
|
||||
|
||||
|
||||
All_Amount_Prompt = (
|
||||
"你是一个电力造价工程相关的项目经理,现在给你一些上下文信息,"
|
||||
"你需要根据现有的上下文信息,来生成{num_questions_per_chunk}个电力造价工程相关的问题和对应的回答,"
|
||||
"现在需要你针对上下文信息中的总体金额进行提问和回答。"
|
||||
"问题和回答的示例应该是这种类型的,示例:'架空输电线路本体工程的总体金额,架空输电线路本体工程的总体金额是7.706703','工程静态投资的总体金额,工程静态投资的总体金额是100'"
|
||||
"你生成的回答必须严格按照示例中的格式('问题, 回答'),不允许有丝毫的变动。问题和回答应该在一个单引号内。"
|
||||
"这种类似的问题和答案,生成的问题和答案必须一一对应,要符合文件里的内容,不要生成一些无关的问题,不要生成一些重复的问题,"
|
||||
"不要生成一些过于简单的问题,不要生成一些过于复杂的问题。"
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
from app.observability import init_observability
|
||||
from app.settings import init_settings
|
||||
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.evaluation import DatasetGenerator
|
||||
|
||||
import prompts
|
||||
|
||||
init_settings()
|
||||
init_observability()
|
||||
|
||||
# 读取所有文档(即所有表格)
|
||||
documents = SimpleDirectoryReader("D:/LLM_model/text2sql/zjdataai-app-test/backend/data-test").load_data()
|
||||
|
||||
# 定义表格名称和索引的对应关系
|
||||
table_names = {
|
||||
"工程信息表": 0,
|
||||
"其他费用表": 1,
|
||||
"取费表": 2,
|
||||
"项目划分表": 3,
|
||||
"项目划分_费用预览表": 4,
|
||||
"总算表": 5,
|
||||
"工程量表": 6
|
||||
}
|
||||
|
||||
# 定义中文提示词和Python代码中提示词名称的映射
|
||||
prompt_mapping = {
|
||||
"普通属性": "Attribute_Prompt",
|
||||
"金额查询": "Amount_Prompt",
|
||||
"单位换算": "Units_Prompt",
|
||||
"重名项目划分": "Name_Prompt",
|
||||
"总体金额查询": "All_Amount_Prompt"
|
||||
}
|
||||
|
||||
# 定义表格与其对应的查询类别
|
||||
table_prompt_mapping = {
|
||||
"工程信息表": ["普通属性", "单位换算"],
|
||||
"其他费用表": ["金额查询", "单位换算"],
|
||||
"取费表": ["金额查询"],
|
||||
"总算表": ["金额查询", "总体金额查询"],
|
||||
"工程量表": ["普通属性", "重名项目划分"]
|
||||
}
|
||||
|
||||
# 根据表格名称选择特定的表格
|
||||
def select_document(documents, table_name):
|
||||
if table_name not in table_names:
|
||||
raise ValueError(f"未找到名为 '{table_name}' 的表格")
|
||||
index = table_names[table_name]
|
||||
return [documents[index]] # 返回一个包含所选表格的列表
|
||||
|
||||
# 选择提示词
|
||||
def select_prompt(prompt_category):
|
||||
prompt_name = prompt_mapping.get(prompt_category)
|
||||
if not prompt_name:
|
||||
raise ValueError(f"未找到名为 '{prompt_category}' 的提示词")
|
||||
try:
|
||||
return getattr(prompts, prompt_name)
|
||||
except AttributeError:
|
||||
raise ValueError(f"未找到提示词 '{prompt_name}' 对应的函数")
|
||||
|
||||
# 生成问题和答案
|
||||
def generate_questions_from_document(document, quest_prompt, num_questions):
|
||||
question_generator = DatasetGenerator.from_documents(
|
||||
documents=document,
|
||||
question_gen_query=quest_prompt,
|
||||
num_questions_per_chunk=num_questions
|
||||
)
|
||||
|
||||
eval_questions = question_generator.generate_questions_from_nodes(num_questions)
|
||||
print(eval_questions)
|
||||
|
||||
qa_pairs = []
|
||||
for qa in eval_questions:
|
||||
if ',' in qa:
|
||||
question, answer = qa.split(",", 1)
|
||||
qa_pairs.append({
|
||||
"question": question.strip(),
|
||||
"answer": answer.strip()
|
||||
})
|
||||
else:
|
||||
print(f"无法处理的问题和答案: {qa}")
|
||||
|
||||
return qa_pairs
|
||||
|
||||
# 主函数,控制生成多个表格的问题和使用多个提示词,并将结果合并到一个文件中
|
||||
def main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt):
|
||||
if table_names_input == "all":
|
||||
selected_tables = list(table_prompt_mapping.keys())
|
||||
else:
|
||||
selected_tables = table_names_input.strip('[]').split(',')
|
||||
|
||||
all_results = {}
|
||||
|
||||
for table_name in selected_tables:
|
||||
table_name = table_name.strip() # 去掉前后空格
|
||||
document = select_document(documents, table_name)
|
||||
|
||||
if prompt_categories_input == "all":
|
||||
selected_prompts = table_prompt_mapping[table_name]
|
||||
else:
|
||||
selected_prompts = prompt_categories_input.strip('[]').split(',')
|
||||
selected_prompts = [p.strip() for p in selected_prompts] # 去掉前后空格
|
||||
|
||||
for prompt_category in selected_prompts:
|
||||
if prompt_category not in table_prompt_mapping[table_name]:
|
||||
print(f"跳过表格 '{table_name}' 的提示词 '{prompt_category}',因为该表中不包含该类别的信息")
|
||||
continue
|
||||
|
||||
quest_prompt = select_prompt(prompt_category).format(num_questions_per_chunk=num_questions_per_prompt)
|
||||
qa_pairs = generate_questions_from_document(document, quest_prompt, num_questions_per_prompt)
|
||||
|
||||
label = f"test:{table_name}_{prompt_category}"
|
||||
all_results[label] = qa_pairs
|
||||
|
||||
# 自动生成输出文件名
|
||||
output_file = "combined_test.json"
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(all_results, f, ensure_ascii=False, indent=4)
|
||||
|
||||
print(f"All questions and answers have been saved to '{output_file}'")
|
||||
|
||||
# 获取命令行参数
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 4:
|
||||
print("Usage: python script.py <table_names_input> <prompt_categories_input> <num_questions_per_prompt>")
|
||||
else:
|
||||
table_names_input = sys.argv[1]
|
||||
prompt_categories_input = sys.argv[2]
|
||||
num_questions_per_prompt = int(sys.argv[3])
|
||||
|
||||
main(documents, table_names_input, prompt_categories_input, num_questions_per_prompt)
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import phoenix as px
|
||||
|
||||
|
||||
os.environ['PHOENIX_HOST'] = "0.0.0.0"
|
||||
|
||||
session = px.launch_app(use_temp_dir=False)
|
||||
|
||||
import msvcrt
|
||||
|
||||
Submodule
+1
Submodule webapp added at 77dbc14a64
Reference in New Issue
Block a user