commit 49233370382286cffec22b2996023a32523cc590 Author: paituo <330435863@qq.com> Date: Thu Aug 8 18:33:08 2024 +0800 Initial commit from Create Llama diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..888481b --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,47 @@ +{ + "image": "mcr.microsoft.com/vscode/devcontainers/typescript-node:dev-20-bullseye", + "features": { + "ghcr.io/devcontainers-contrib/features/turborepo-npm:1": {}, + "ghcr.io/devcontainers-contrib/features/typescript:2": {}, + "ghcr.io/devcontainers/features/python:1": { + "version": "3.11", + "toolsToInstall": [ + "flake8", + "black", + "mypy", + "poetry" + ] + } + }, + "customizations": { + "codespaces": { + "openFiles": [ + "README.md" + ] + }, + "vscode": { + "extensions": [ + "ms-vscode.typescript-language-features", + "esbenp.prettier-vscode", + "ms-python.python", + "ms-python.black-formatter", + "ms-python.vscode-flake8", + "ms-python.vscode-pylance" + ], + "settings": { + "python.formatting.provider": "black", + "python.languageServer": "Pylance", + "python.analysis.typeCheckingMode": "basic" + } + } + }, + "containerEnv": { + "POETRY_VIRTUALENVS_CREATE": "false", + "PYTHONPATH": "${PYTHONPATH}:${workspaceFolder}/backend" + }, + "forwardPorts": [ + 3000, + 8000 + ], + "postCreateCommand": "cd backend && poetry install && cd ../frontend && npm install" +} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..5a41b8c --- /dev/null +++ b/README.md @@ -0,0 +1,18 @@ +This is a [LlamaIndex](https://www.llamaindex.ai/) project bootstrapped with [`create-llama`](https://github.com/run-llama/LlamaIndexTS/tree/main/packages/create-llama). + +## Getting Started + +First, startup the backend as described in the [backend README](./backend/README.md). + +Second, run the development server of the frontend as described in the [frontend README](./frontend/README.md). + +Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. + +## Learn More + +To learn more about LlamaIndex, take a look at the following resources: + +- [LlamaIndex Documentation](https://docs.llamaindex.ai) - learn about LlamaIndex (Python features). +- [LlamaIndexTS Documentation](https://ts.llamaindex.ai) - learn about LlamaIndex (Typescript features). + +You can check out [the LlamaIndexTS GitHub repository](https://github.com/run-llama/LlamaIndexTS) - your feedback and contributions are welcome! diff --git a/backend/.gitignore b/backend/.gitignore new file mode 100644 index 0000000..ae22d34 --- /dev/null +++ b/backend/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +storage +.env +output diff --git a/backend/Dockerfile b/backend/Dockerfile new file mode 100644 index 0000000..624364b --- /dev/null +++ b/backend/Dockerfile @@ -0,0 +1,26 @@ +FROM python:3.11 as build + +WORKDIR /app + +ENV PYTHONPATH=/app + +# Install Poetry +RUN curl -sSL https://install.python-poetry.org | POETRY_HOME=/opt/poetry python && \ + cd /usr/local/bin && \ + ln -s /opt/poetry/bin/poetry && \ + poetry config virtualenvs.create false + +# Install Chromium for web loader +# Can disable this if you don't use the web loader to reduce the image size +RUN apt update && apt install -y chromium chromium-driver + +# Install dependencies +COPY ./pyproject.toml ./poetry.lock* /app/ +RUN poetry install --no-root --no-cache --only main + +# ==================================== +FROM build as release + +COPY . . + +CMD ["python", "main.py"] \ No newline at end of file diff --git a/backend/README.md b/backend/README.md new file mode 100644 index 0000000..7969ff0 --- /dev/null +++ b/backend/README.md @@ -0,0 +1,101 @@ +This is a [LlamaIndex](https://www.llamaindex.ai/) project using [FastAPI](https://fastapi.tiangolo.com/) bootstrapped with [`create-llama`](https://github.com/run-llama/LlamaIndexTS/tree/main/packages/create-llama). + +## Getting Started + +First, setup the environment with poetry: + +> **_Note:_** This step is not needed if you are using the dev-container. + +``` +poetry install +poetry shell +``` + +Then check the parameters that have been pre-configured in the `.env` file in this directory. (E.g. you might need to configure an `OPENAI_API_KEY` if you're using OpenAI as model provider). + +If you are using any tools or data sources, you can update their config files in the `config` folder. + +Second, generate the embeddings of the documents in the `./data` directory (if this folder exists - otherwise, skip this step): + +``` +poetry run generate +``` + +Third, run the development server: + +``` +python main.py +``` + +The example provides two different API endpoints: + +1. `/api/chat` - a streaming chat endpoint +2. `/api/chat/request` - a non-streaming chat endpoint + +You can test the streaming endpoint with the following curl request: + +``` +curl --location 'localhost:8000/api/chat' \ +--header 'Content-Type: application/json' \ +--data '{ "messages": [{ "role": "user", "content": "Hello" }] }' +``` + +And for the non-streaming endpoint run: + +``` +curl --location 'localhost:8000/api/chat/request' \ +--header 'Content-Type: application/json' \ +--data '{ "messages": [{ "role": "user", "content": "Hello" }] }' +``` + +You can start editing the API endpoints by modifying `app/api/routers/chat.py`. The endpoints auto-update as you save the file. You can delete the endpoint you're not using. + +Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API. + +The API allows CORS for all origins to simplify development. You can change this behavior by setting the `ENVIRONMENT` environment variable to `prod`: + +``` +ENVIRONMENT=prod python main.py +``` + +## Using Docker + +1. Build an image for the FastAPI app: + +``` +docker build -t . +``` + +2. Generate embeddings: + +Parse the data and generate the vector embeddings if the `./data` folder exists - otherwise, skip this step: + +``` +docker run \ + --rm \ + -v $(pwd)/.env:/app/.env \ # Use ENV variables and configuration from your file-system + -v $(pwd)/config:/app/config \ + -v $(pwd)/data:/app/data \ # Use your local folder to read the data + -v $(pwd)/storage:/app/storage \ # Use your file system to store the vector database + \ + poetry run generate +``` + +3. Start the API: + +``` +docker run \ + -v $(pwd)/.env:/app/.env \ # Use ENV variables and configuration from your file-system + -v $(pwd)/config:/app/config \ + -v $(pwd)/storage:/app/storage \ # Use your file system to store gea vector database + -p 8000:8000 \ + +``` + +## Learn More + +To learn more about LlamaIndex, take a look at the following resources: + +- [LlamaIndex Documentation](https://docs.llamaindex.ai) - learn about LlamaIndex. + +You can check out [the LlamaIndex GitHub repository](https://github.com/run-llama/llama_index) - your feedback and contributions are welcome! diff --git a/backend/app/__init__.py b/backend/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/api/routers/__init__.py b/backend/app/api/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/api/routers/chat.py b/backend/app/api/routers/chat.py new file mode 100644 index 0000000..cb7036d --- /dev/null +++ b/backend/app/api/routers/chat.py @@ -0,0 +1,148 @@ +import logging +import os +from typing import List + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status +from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore +from llama_index.core.llms import MessageRole +from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters + +from app.api.routers.events import EventCallbackHandler +from app.api.routers.models import ( + ChatConfig, + ChatData, + Message, + Result, + SourceNodes, +) +from app.api.routers.vercel_response import VercelStreamResponse +from app.api.services.llama_cloud import LLamaCloudFileService +from app.engine import get_chat_engine + +chat_router = r = APIRouter() + +logger = logging.getLogger("uvicorn") + + +def process_response_nodes( + nodes: List[NodeWithScore], + background_tasks: BackgroundTasks, +): + """ + Start background tasks on the source nodes if needed. + """ + files_to_download = SourceNodes.get_download_files(nodes) + for file in files_to_download: + background_tasks.add_task( + LLamaCloudFileService.download_llamacloud_pipeline_file, file + ) + + +# streaming endpoint - delete if not needed +@r.post("") +async def chat( + request: Request, + data: ChatData, + background_tasks: BackgroundTasks, + chat_engine: BaseChatEngine = Depends(get_chat_engine), +): + try: + last_message_content = data.get_last_message_content() + messages = data.get_history_messages() + + doc_ids = data.get_chat_document_ids() + filters = generate_filters(doc_ids) + params = data.data or {} + logger.info("Creating chat engine with filters", filters.dict()) + chat_engine = get_chat_engine(filters=filters, params=params) + + event_handler = EventCallbackHandler() + chat_engine.callback_manager.handlers.append(event_handler) # type: ignore + + response = await chat_engine.astream_chat(last_message_content, messages) + process_response_nodes(response.source_nodes, background_tasks) + + return VercelStreamResponse(request, event_handler, response, data) + except Exception as e: + logger.exception("Error in chat engine", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error in chat engine: {e}", + ) from e + + +def generate_filters(doc_ids): + if len(doc_ids) > 0: + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="private", + value=["true"], + operator="nin", # type: ignore + ), + MetadataFilter( + key="doc_id", + value=doc_ids, + operator="in", # type: ignore + ), + ], + condition="or", # type: ignore + ) + else: + filters = MetadataFilters( + # Use the "NIN" - "not in" operator to include all public documents (don't have the private key set) + filters=[ + MetadataFilter( + key="private", + value=["true"], + operator="nin", # type: ignore + ), + ] + ) + + return filters + + +# non-streaming endpoint - delete if not needed +@r.post("/request") +async def chat_request( + data: ChatData, + chat_engine: BaseChatEngine = Depends(get_chat_engine), +) -> Result: + last_message_content = data.get_last_message_content() + messages = data.get_history_messages() + + response = await chat_engine.achat(last_message_content, messages) + return Result( + result=Message(role=MessageRole.ASSISTANT, content=response.response), + nodes=SourceNodes.from_source_nodes(response.source_nodes), + ) + + +@r.get("/config") +async def chat_config() -> ChatConfig: + starter_questions = None + conversation_starters = os.getenv("CONVERSATION_STARTERS") + if conversation_starters and conversation_starters.strip(): + starter_questions = conversation_starters.strip().split("\n") + return ChatConfig(starter_questions=starter_questions) + + +@r.get("/config/llamacloud") +async def chat_llama_cloud_config(): + projects = LLamaCloudFileService.get_all_projects_with_pipelines() + pipeline = os.getenv("LLAMA_CLOUD_INDEX_NAME") + project = os.getenv("LLAMA_CLOUD_PROJECT_NAME") + pipeline_config = ( + pipeline + and project + and { + "pipeline": pipeline, + "project": project, + } + or None + ) + return { + "projects": projects, + "pipeline": pipeline_config, + } diff --git a/backend/app/api/routers/events.py b/backend/app/api/routers/events.py new file mode 100644 index 0000000..94cc585 --- /dev/null +++ b/backend/app/api/routers/events.py @@ -0,0 +1,149 @@ +import json +import asyncio +import logging +from typing import AsyncGenerator, Dict, Any, List, Optional +from llama_index.core.callbacks.base import BaseCallbackHandler +from llama_index.core.callbacks.schema import CBEventType +from llama_index.core.tools.types import ToolOutput +from pydantic import BaseModel + + +logger = logging.getLogger(__name__) + + +class CallbackEvent(BaseModel): + event_type: CBEventType + payload: Optional[Dict[str, Any]] = None + event_id: str = "" + + def get_retrieval_message(self) -> dict | None: + if self.payload: + nodes = self.payload.get("nodes") + if nodes: + msg = f"Retrieved {len(nodes)} sources to use as context for the query" + else: + msg = f"Retrieving context for query: '{self.payload.get('query_str')}'" + return { + "type": "events", + "data": {"title": msg}, + } + else: + return None + + def get_tool_message(self) -> dict | None: + func_call_args = self.payload.get("function_call") + if func_call_args is not None and "tool" in self.payload: + tool = self.payload.get("tool") + return { + "type": "events", + "data": { + "title": f"Calling tool: {tool.name} with inputs: {func_call_args}", + }, + } + + def _is_output_serializable(self, output: Any) -> bool: + try: + json.dumps(output) + return True + except TypeError: + return False + + def get_agent_tool_response(self) -> dict | None: + response = self.payload.get("response") + if response is not None: + sources = response.sources + for source in sources: + # Return the tool response here to include the toolCall information + if isinstance(source, ToolOutput): + if self._is_output_serializable(source.raw_output): + output = source.raw_output + else: + output = source.content + + return { + "type": "tools", + "data": { + "toolOutput": { + "output": output, + "isError": source.is_error, + }, + "toolCall": { + "id": None, # There is no tool id in the ToolOutput + "name": source.tool_name, + "input": source.raw_input, + }, + }, + } + + def to_response(self): + try: + match self.event_type: + case "retrieve": + return self.get_retrieval_message() + case "function_call": + return self.get_tool_message() + case "agent_step": + return self.get_agent_tool_response() + case _: + return None + except Exception as e: + logger.error(f"Error in converting event to response: {e}") + return None + + +class EventCallbackHandler(BaseCallbackHandler): + _aqueue: asyncio.Queue + is_done: bool = False + + def __init__( + self, + ): + """Initialize the base callback handler.""" + ignored_events = [ + CBEventType.CHUNKING, + CBEventType.NODE_PARSING, + CBEventType.EMBEDDING, + CBEventType.LLM, + CBEventType.TEMPLATING, + ] + super().__init__(ignored_events, ignored_events) + self._aqueue = asyncio.Queue() + + def on_event_start( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> str: + event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) + if event.to_response() is not None: + self._aqueue.put_nowait(event) + + def on_event_end( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> None: + event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) + if event.to_response() is not None: + self._aqueue.put_nowait(event) + + def start_trace(self, trace_id: Optional[str] = None) -> None: + """No-op.""" + + def end_trace( + self, + trace_id: Optional[str] = None, + trace_map: Optional[Dict[str, List[str]]] = None, + ) -> None: + """No-op.""" + + async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]: + while not self._aqueue.empty() or not self.is_done: + try: + yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1) + except asyncio.TimeoutError: + pass diff --git a/backend/app/api/routers/models.py b/backend/app/api/routers/models.py new file mode 100644 index 0000000..c9ea1ad --- /dev/null +++ b/backend/app/api/routers/models.py @@ -0,0 +1,252 @@ +import logging +import os +from typing import Any, Dict, List, Literal, Optional, Set + +from llama_index.core.llms import ChatMessage, MessageRole +from llama_index.core.schema import NodeWithScore +from pydantic import BaseModel, Field, validator +from pydantic.alias_generators import to_camel + +logger = logging.getLogger("uvicorn") + + +class FileContent(BaseModel): + type: Literal["text", "ref"] + # If the file is pure text then the value is be a string + # otherwise, it's a list of document IDs + value: str | List[str] + + +class File(BaseModel): + id: str + content: FileContent + filename: str + filesize: int + filetype: str + + +class AnnotationFileData(BaseModel): + files: List[File] = Field( + default=[], + description="List of files", + ) + + class Config: + json_schema_extra = { + "example": { + "csvFiles": [ + { + "content": "Name, Age\nAlice, 25\nBob, 30", + "filename": "example.csv", + "filesize": 123, + "id": "123", + "type": "text/csv", + } + ] + } + } + alias_generator = to_camel + + +class Annotation(BaseModel): + type: str + data: AnnotationFileData | List[str] + + def to_content(self) -> str | None: + if self.type == "document_file": + # We only support generating context content for CSV files for now + csv_files = [file for file in self.data.files if file.filetype == "csv"] + if len(csv_files) > 0: + return "Use data from following CSV raw content\n" + "\n".join( + [f"```csv\n{csv_file.content.value}\n```" for csv_file in csv_files] + ) + else: + logger.warning( + f"The annotation {self.type} is not supported for generating context content" + ) + return None + + +class Message(BaseModel): + role: MessageRole + content: str + annotations: List[Annotation] | None = None + + +class ChatData(BaseModel): + messages: List[Message] + data: Any = None + + class Config: + json_schema_extra = { + "example": { + "messages": [ + { + "role": "user", + "content": "What standards for letters exist?", + } + ] + } + } + + @validator("messages") + def messages_must_not_be_empty(cls, v): + if len(v) == 0: + raise ValueError("Messages must not be empty") + return v + + def get_last_message_content(self) -> str: + """ + Get the content of the last message along with the data content if available. + Fallback to use data content from previous messages + """ + if len(self.messages) == 0: + raise ValueError("There is not any message in the chat") + last_message = self.messages[-1] + message_content = last_message.content + for message in reversed(self.messages): + if message.role == MessageRole.USER and message.annotations is not None: + annotation_contents = filter( + None, + [annotation.to_content() for annotation in message.annotations], + ) + if not annotation_contents: + continue + annotation_text = "\n".join(annotation_contents) + message_content = f"{message_content}\n{annotation_text}" + break + return message_content + + def get_history_messages(self) -> List[ChatMessage]: + """ + Get the history messages + """ + return [ + ChatMessage(role=message.role, content=message.content) + for message in self.messages[:-1] + ] + + def is_last_message_from_user(self) -> bool: + return self.messages[-1].role == MessageRole.USER + + def get_chat_document_ids(self) -> List[str]: + """ + Get the document IDs from the chat messages + """ + document_ids: List[str] = [] + for message in self.messages: + if message.role == MessageRole.USER and message.annotations is not None: + for annotation in message.annotations: + if ( + annotation.type == "document_file" + and annotation.data.files is not None + ): + for fi in annotation.data.files: + if fi.content.type == "ref": + document_ids += fi.content.value + return list(set(document_ids)) + + +class LlamaCloudFile(BaseModel): + file_name: str + pipeline_id: str + + def __eq__(self, other): + if not isinstance(other, LlamaCloudFile): + return NotImplemented + return ( + self.file_name == other.file_name and self.pipeline_id == other.pipeline_id + ) + + def __hash__(self): + return hash((self.file_name, self.pipeline_id)) + + +class SourceNodes(BaseModel): + id: str + metadata: Dict[str, Any] + score: Optional[float] + text: str + url: Optional[str] + + @classmethod + def from_source_node(cls, source_node: NodeWithScore): + metadata = source_node.node.metadata + url = cls.get_url_from_metadata(metadata) + + return cls( + id=source_node.node.node_id, + metadata=metadata, + score=source_node.score, + text=source_node.node.text, # type: ignore + url=url, + ) + + @classmethod + def get_url_from_metadata(cls, metadata: Dict[str, Any]) -> str: + url_prefix = os.getenv("FILESERVER_URL_PREFIX") + if not url_prefix: + logger.warning( + "Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server" + ) + file_name = metadata.get("file_name") + if file_name and url_prefix: + # file_name exists and file server is configured + pipeline_id = metadata.get("pipeline_id") + if pipeline_id and metadata.get("private") is None: + # file is from LlamaCloud and was not ingested locally + file_name = f"{pipeline_id}${file_name}" + return f"{url_prefix}/output/llamacloud/{file_name}" + is_private = metadata.get("private", "false") == "true" + if is_private: + return f"{url_prefix}/output/uploaded/{file_name}" + return f"{url_prefix}/data/{file_name}" + else: + # fallback to URL in metadata (e.g. for websites) + return metadata.get("URL") + + @classmethod + def from_source_nodes(cls, source_nodes: List[NodeWithScore]): + return [cls.from_source_node(node) for node in source_nodes] + + @staticmethod + def get_download_files(nodes: List[NodeWithScore]) -> Set[LlamaCloudFile]: + source_nodes = SourceNodes.from_source_nodes(nodes) + llama_cloud_files = [ + LlamaCloudFile( + file_name=node.metadata.get("file_name"), + pipeline_id=node.metadata.get("pipeline_id"), + ) + for node in source_nodes + if ( + node.metadata.get("private") + is None # Only download files are from LlamaCloud and were not ingested locally + and node.metadata.get("pipeline_id") is not None + and node.metadata.get("file_name") is not None + ) + ] + # Remove duplicates and return + return set(llama_cloud_files) + + +class Result(BaseModel): + result: Message + nodes: List[SourceNodes] + + +class ChatConfig(BaseModel): + starter_questions: Optional[List[str]] = Field( + default=None, + description="List of starter questions", + serialization_alias="starterQuestions", + ) + + class Config: + json_schema_extra = { + "example": { + "starterQuestions": [ + "What standards for letters exist?", + "What are the requirements for a letter to be considered a letter?", + ] + } + } diff --git a/backend/app/api/routers/upload.py b/backend/app/api/routers/upload.py new file mode 100644 index 0000000..94f3ce7 --- /dev/null +++ b/backend/app/api/routers/upload.py @@ -0,0 +1,25 @@ +import logging +from typing import List + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from app.api.services.file import PrivateFileService + +file_upload_router = r = APIRouter() + +logger = logging.getLogger("uvicorn") + + +class FileUploadRequest(BaseModel): + base64: str + + +@r.post("") +def upload_file(request: FileUploadRequest) -> List[str]: + try: + logger.info("Processing file") + return PrivateFileService.process_file(request.base64) + except Exception as e: + logger.error(f"Error processing file: {e}", exc_info=True) + raise HTTPException(status_code=500, detail="Error processing file") diff --git a/backend/app/api/routers/vercel_response.py b/backend/app/api/routers/vercel_response.py new file mode 100644 index 0000000..0222a14 --- /dev/null +++ b/backend/app/api/routers/vercel_response.py @@ -0,0 +1,109 @@ +import json + +from aiostream import stream +from fastapi import Request +from fastapi.responses import StreamingResponse +from llama_index.core.chat_engine.types import StreamingAgentChatResponse + +from app.api.routers.events import EventCallbackHandler +from app.api.routers.models import ChatData, Message, SourceNodes +from app.api.services.suggestion import NextQuestionSuggestion + + +class VercelStreamResponse(StreamingResponse): + """ + Class to convert the response from the chat engine to the streaming format expected by Vercel + """ + + TEXT_PREFIX = "0:" + DATA_PREFIX = "8:" + + @classmethod + def convert_text(cls, token: str): + # Escape newlines and double quotes to avoid breaking the stream + token = json.dumps(token) + return f"{cls.TEXT_PREFIX}{token}\n" + + @classmethod + def convert_data(cls, data: dict): + data_str = json.dumps(data) + return f"{cls.DATA_PREFIX}[{data_str}]\n" + + def __init__( + self, + request: Request, + event_handler: EventCallbackHandler, + response: StreamingAgentChatResponse, + chat_data: ChatData, + ): + content = VercelStreamResponse.content_generator( + request, event_handler, response, chat_data + ) + super().__init__(content=content) + + @classmethod + async def content_generator( + cls, + request: Request, + event_handler: EventCallbackHandler, + response: StreamingAgentChatResponse, + chat_data: ChatData, + ): + # Yield the text response + async def _chat_response_generator(): + final_response = "" + async for token in response.async_response_gen(): + final_response += token + yield VercelStreamResponse.convert_text(token) + + # Generate questions that user might interested to + conversation = chat_data.messages + [ + Message(role="assistant", content=final_response) + ] + questions = await NextQuestionSuggestion.suggest_next_questions( + conversation + ) + if len(questions) > 0: + yield VercelStreamResponse.convert_data( + { + "type": "suggested_questions", + "data": questions, + } + ) + + # the text_generator is the leading stream, once it's finished, also finish the event stream + event_handler.is_done = True + + # Yield the source nodes + yield cls.convert_data( + { + "type": "sources", + "data": { + "nodes": [ + SourceNodes.from_source_node(node).dict() + for node in response.source_nodes + ] + }, + } + ) + + # Yield the events from the event handler + async def _event_generator(): + async for event in event_handler.async_event_gen(): + event_response = event.to_response() + if event_response is not None: + yield VercelStreamResponse.convert_data(event_response) + + combine = stream.merge(_chat_response_generator(), _event_generator()) + is_stream_started = False + async with combine.stream() as streamer: + async for output in streamer: + if not is_stream_started: + is_stream_started = True + # Stream a blank message to start the stream + yield VercelStreamResponse.convert_text("") + + yield output + + if await request.is_disconnected(): + break diff --git a/backend/app/api/services/file.py b/backend/app/api/services/file.py new file mode 100644 index 0000000..a478570 --- /dev/null +++ b/backend/app/api/services/file.py @@ -0,0 +1,113 @@ +import base64 +import mimetypes +import os +from pathlib import Path +from typing import Dict, List +from uuid import uuid4 + +from app.engine.index import get_index +from llama_index.core import VectorStoreIndex +from llama_index.core.ingestion import IngestionPipeline +from llama_index.core.readers.file.base import ( + _try_loading_included_file_formats as get_file_loaders_map, +) +from llama_index.core.readers.file.base import ( + default_file_metadata_func, +) +from llama_index.core.schema import Document +from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex +from llama_index.readers.file import FlatReader + + +def get_llamaparse_parser(): + from app.engine.loaders import load_configs + from app.engine.loaders.file import FileLoaderConfig, llama_parse_parser + + config = load_configs() + file_loader_config = FileLoaderConfig(**config["file"]) + if file_loader_config.use_llama_parse: + return llama_parse_parser() + else: + return None + + +def default_file_loaders_map(): + default_loaders = get_file_loaders_map() + default_loaders[".txt"] = FlatReader + return default_loaders + + +class PrivateFileService: + PRIVATE_STORE_PATH = "output/uploaded" + + @staticmethod + def preprocess_base64_file(base64_content: str) -> tuple: + header, data = base64_content.split(",", 1) + mime_type = header.split(";")[0].split(":", 1)[1] + extension = mimetypes.guess_extension(mime_type) + # File data as bytes + return base64.b64decode(data), extension + + @staticmethod + def store_and_parse_file(file_data, extension) -> List[Document]: + # Store file to the private directory + os.makedirs(PrivateFileService.PRIVATE_STORE_PATH, exist_ok=True) + + # random file name + file_name = f"{uuid4().hex}{extension}" + file_path = Path(os.path.join(PrivateFileService.PRIVATE_STORE_PATH, file_name)) + + # write file + with open(file_path, "wb") as f: + f.write(file_data) + + # Load file to documents + # If LlamaParse is enabled, use it to parse the file + # Otherwise, use the default file loaders + reader = get_llamaparse_parser() + if reader is None: + reader_cls = default_file_loaders_map().get(extension) + if reader_cls is None: + raise ValueError(f"File extension {extension} is not supported") + reader = reader_cls() + documents = reader.load_data(file_path) + # Add custom metadata + for doc in documents: + doc.metadata["file_name"] = file_name + doc.metadata["private"] = "true" + return documents + + @staticmethod + def process_file(base64_content: str) -> List[str]: + file_data, extension = PrivateFileService.preprocess_base64_file(base64_content) + documents = PrivateFileService.store_and_parse_file(file_data, extension) + + # Only process nodes, no store the index + pipeline = IngestionPipeline() + nodes = pipeline.run(documents=documents) + + # Add the nodes to the index and persist it + current_index = get_index() + + # Insert the documents into the index + if isinstance(current_index, LlamaCloudIndex): + # LlamaCloudIndex is a managed index so we don't need to process the nodes + # just insert the documents + for doc in documents: + current_index.insert(doc) + else: + # Only process nodes, no store the index + pipeline = IngestionPipeline() + nodes = pipeline.run(documents=documents) + + # Add the nodes to the index and persist it + if current_index is None: + current_index = VectorStoreIndex(nodes=nodes) + else: + current_index.insert_nodes(nodes=nodes) + current_index.storage_context.persist( + persist_dir=os.environ.get("STORAGE_DIR", "storage") + ) + + # Return the document ids + return [doc.doc_id for doc in documents] diff --git a/backend/app/api/services/llama_cloud.py b/backend/app/api/services/llama_cloud.py new file mode 100644 index 0000000..852ae7c --- /dev/null +++ b/backend/app/api/services/llama_cloud.py @@ -0,0 +1,114 @@ +import logging +import os +from typing import Any, Dict, List, Optional + +import requests +from app.api.routers.models import LlamaCloudFile + +logger = logging.getLogger("uvicorn") + + +class LLamaCloudFileService: + LLAMA_CLOUD_URL = "https://cloud.llamaindex.ai/api/v1" + LOCAL_STORE_PATH = "output/llamacloud" + + DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}" + + @classmethod + def get_all_projects(cls) -> List[Dict[str, Any]]: + url = f"{cls.LLAMA_CLOUD_URL}/projects" + return cls._make_request(url) + + @classmethod + def get_all_pipelines(cls) -> List[Dict[str, Any]]: + url = f"{cls.LLAMA_CLOUD_URL}/pipelines" + return cls._make_request(url) + + @classmethod + def get_all_projects_with_pipelines(cls) -> List[Dict[str, Any]]: + try: + projects = cls.get_all_projects() + pipelines = cls.get_all_pipelines() + return [ + { + **project, + "pipelines": [p for p in pipelines if p["project_id"] == project["id"]], + } + for project in projects + ] + except Exception as error: + logger.error(f"Error listing projects and pipelines: {error}") + return [] + + @classmethod + def _get_files(cls, pipeline_id: str) -> List[Dict[str, Any]]: + url = f"{cls.LLAMA_CLOUD_URL}/pipelines/{pipeline_id}/files" + return cls._make_request(url) + + @classmethod + def _get_file_detail(cls, project_id: str, file_id: str) -> Dict[str, Any]: + url = f"{cls.LLAMA_CLOUD_URL}/files/{file_id}/content?project_id={project_id}" + return cls._make_request(url) + + @classmethod + def _download_file(cls, url: str, local_file_path: str): + logger.info(f"Downloading file to {local_file_path}") + # Create directory if it doesn't exist + os.makedirs(cls.LOCAL_STORE_PATH, exist_ok=True) + # Download the file + with requests.get(url, stream=True) as r: + r.raise_for_status() + with open(local_file_path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + logger.info("File downloaded successfully") + + @classmethod + def download_llamacloud_pipeline_file( + cls, + file: LlamaCloudFile, + force_download: bool = False, + ): + file_name = file.file_name + pipeline_id = file.pipeline_id + + # Check is the file already exists + downloaded_file_path = cls.get_file_path(file_name, pipeline_id) + if os.path.exists(downloaded_file_path) and not force_download: + logger.debug(f"File {file_name} already exists in local storage") + return + try: + logger.info(f"Downloading file {file_name} for pipeline {pipeline_id}") + files = cls._get_files(pipeline_id) + if not files or not isinstance(files, list): + raise Exception("No files found in LlamaCloud") + for file_entry in files: + if file_entry["name"] == file_name: + file_id = file_entry["file_id"] + project_id = file_entry["project_id"] + file_detail = cls._get_file_detail(project_id, file_id) + cls._download_file(file_detail["url"], downloaded_file_path) + break + except Exception as error: + logger.info(f"Error fetching file from LlamaCloud: {error}") + + @classmethod + def get_file_name(cls, name: str, pipeline_id: str) -> str: + return cls.DOWNLOAD_FILE_NAME_TPL.format(pipeline_id=pipeline_id, filename=name) + + @classmethod + def get_file_path(cls, name: str, pipeline_id: str) -> str: + return os.path.join(cls.LOCAL_STORE_PATH, cls.get_file_name(name, pipeline_id)) + + @staticmethod + def _make_request( + url: str, data=None, headers: Optional[Dict] = None, method: str = "get" + ): + if headers is None: + headers = { + "Accept": "application/json", + "Authorization": f'Bearer {os.getenv("LLAMA_CLOUD_API_KEY")}', + } + response = requests.request(method, url, headers=headers, data=data) + response.raise_for_status() + return response.json() diff --git a/backend/app/api/services/suggestion.py b/backend/app/api/services/suggestion.py new file mode 100644 index 0000000..406b0ae --- /dev/null +++ b/backend/app/api/services/suggestion.py @@ -0,0 +1,48 @@ +from typing import List + +from app.api.routers.models import Message +from llama_index.core.prompts import PromptTemplate +from llama_index.core.settings import Settings +from pydantic import BaseModel + +NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate( + "You're a helpful assistant! Your task is to suggest the next question that user might ask. " + "\nHere is the conversation history" + "\n---------------------\n{conversation}\n---------------------" + "Given the conversation history, please give me $number_of_questions questions that you might ask next!" +) +N_QUESTION_TO_GENERATE = 3 + + +class NextQuestions(BaseModel): + """A list of questions that user might ask next""" + + questions: List[str] + + +class NextQuestionSuggestion: + @staticmethod + async def suggest_next_questions( + messages: List[Message], + number_of_questions: int = N_QUESTION_TO_GENERATE, + ) -> List[str]: + # Reduce the cost by only using the last two messages + last_user_message = None + last_assistant_message = None + for message in reversed(messages): + if message.role == "user": + last_user_message = f"User: {message.content}" + elif message.role == "assistant": + last_assistant_message = f"Assistant: {message.content}" + if last_user_message and last_assistant_message: + break + conversation: str = f"{last_user_message}\n{last_assistant_message}" + + output: NextQuestions = await Settings.llm.astructured_predict( + NextQuestions, + prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT, + conversation=conversation, + nun_questions=number_of_questions, + ) + + return output.questions diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py new file mode 100644 index 0000000..fb8d410 --- /dev/null +++ b/backend/app/engine/__init__.py @@ -0,0 +1,31 @@ +import os +from llama_index.core.settings import Settings +from llama_index.core.agent import AgentRunner +from llama_index.core.tools.query_engine import QueryEngineTool +from app.engine.tools import ToolFactory +from app.engine.index import get_index + + +def get_chat_engine(filters=None, params=None): + system_prompt = os.getenv("SYSTEM_PROMPT") + top_k = os.getenv("TOP_K", "3") + tools = [] + + # Add query tool if index exists + index = get_index() + if index is not None: + query_engine = index.as_query_engine( + similarity_top_k=int(top_k), filters=filters + ) + query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) + tools.append(query_engine_tool) + + # Add additional tools + tools += ToolFactory.from_env() + + return AgentRunner.from_llm( + llm=Settings.llm, + tools=tools, + system_prompt=system_prompt, + verbose=True, + ) diff --git a/backend/app/engine/generate.py b/backend/app/engine/generate.py new file mode 100644 index 0000000..8bcf606 --- /dev/null +++ b/backend/app/engine/generate.py @@ -0,0 +1,51 @@ +from dotenv import load_dotenv + +load_dotenv() + +import os +import logging +from app.settings import init_settings +from app.engine.loaders import get_documents +from llama_index.indices.managed.llama_cloud import LlamaCloudIndex + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger() + + +def generate_datasource(): + init_settings() + logger.info("Generate index for the provided data") + + name = os.getenv("LLAMA_CLOUD_INDEX_NAME") + project_name = os.getenv("LLAMA_CLOUD_PROJECT_NAME") + api_key = os.getenv("LLAMA_CLOUD_API_KEY") + base_url = os.getenv("LLAMA_CLOUD_BASE_URL") + organization_id = os.getenv("LLAMA_CLOUD_ORGANIZATION_ID") + + if name is None or project_name is None or api_key is None: + raise ValueError( + "Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY" + " to your environment variables or config them in .env file" + ) + + documents = get_documents() + + # Set private=false to mark the document as public (required for filtering) + for doc in documents: + doc.metadata["private"] = "false" + + LlamaCloudIndex.from_documents( + documents=documents, + name=name, + project_name=project_name, + api_key=api_key, + base_url=base_url, + organization_id=organization_id + ) + + logger.info("Finished generating the index") + + +if __name__ == "__main__": + generate_datasource() diff --git a/backend/app/engine/index.py b/backend/app/engine/index.py new file mode 100644 index 0000000..e54e8ca --- /dev/null +++ b/backend/app/engine/index.py @@ -0,0 +1,31 @@ +import logging +import os +from llama_index.indices.managed.llama_cloud import LlamaCloudIndex + + +logger = logging.getLogger("uvicorn") + +def get_index(params=None): + configParams = params or {} + pipelineConfig = configParams.get("llamaCloudPipeline", {}) + name = pipelineConfig.get("pipeline", os.getenv("LLAMA_CLOUD_INDEX_NAME")) + project_name = pipelineConfig.get("project", os.getenv("LLAMA_CLOUD_PROJECT_NAME")) + api_key = os.getenv("LLAMA_CLOUD_API_KEY") + base_url = os.getenv("LLAMA_CLOUD_BASE_URL") + organization_id = os.getenv("LLAMA_CLOUD_ORGANIZATION_ID") + + if name is None or project_name is None or api_key is None: + raise ValueError( + "Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY" + " to your environment variables or config them in .env file" + ) + + index = LlamaCloudIndex( + name=name, + project_name=project_name, + api_key=api_key, + base_url=base_url, + organization_id=organization_id + ) + + return index diff --git a/backend/app/engine/loaders/__init__.py b/backend/app/engine/loaders/__init__.py new file mode 100644 index 0000000..4a278a4 --- /dev/null +++ b/backend/app/engine/loaders/__init__.py @@ -0,0 +1,37 @@ +import logging + +import yaml +from app.engine.loaders.db import DBLoaderConfig, get_db_documents +from app.engine.loaders.file import FileLoaderConfig, get_file_documents +from app.engine.loaders.web import WebLoaderConfig, get_web_documents + +logger = logging.getLogger(__name__) + + +def load_configs(): + with open("config/loaders.yaml") as f: + configs = yaml.safe_load(f) + return configs + + +def get_documents(): + documents = [] + config = load_configs() + for loader_type, loader_config in config.items(): + logger.info( + f"Loading documents from loader: {loader_type}, config: {loader_config}" + ) + match loader_type: + case "file": + document = get_file_documents(FileLoaderConfig(**loader_config)) + case "web": + document = get_web_documents(WebLoaderConfig(**loader_config)) + case "db": + document = get_db_documents( + configs=[DBLoaderConfig(**cfg) for cfg in loader_config] + ) + case _: + raise ValueError(f"Invalid loader type: {loader_type}") + documents.extend(document) + + return documents diff --git a/backend/app/engine/loaders/db.py b/backend/app/engine/loaders/db.py new file mode 100644 index 0000000..d5c9ffd --- /dev/null +++ b/backend/app/engine/loaders/db.py @@ -0,0 +1,26 @@ +import os +import logging +from typing import List +from pydantic import BaseModel, validator +from llama_index.core.indices.vector_store import VectorStoreIndex + +logger = logging.getLogger(__name__) + + +class DBLoaderConfig(BaseModel): + uri: str + queries: List[str] + + +def get_db_documents(configs: list[DBLoaderConfig]): + from llama_index.readers.database import DatabaseReader + + docs = [] + for entry in configs: + loader = DatabaseReader(uri=entry.uri) + for query in entry.queries: + logger.info(f"Loading data from database with query: {query}") + documents = loader.load_data(query=query) + docs.extend(documents) + + return documents diff --git a/backend/app/engine/loaders/file.py b/backend/app/engine/loaders/file.py new file mode 100644 index 0000000..4dea4f8 --- /dev/null +++ b/backend/app/engine/loaders/file.py @@ -0,0 +1,79 @@ +import os +import logging +from typing import Dict +from llama_parse import LlamaParse +from pydantic import BaseModel, validator + +logger = logging.getLogger(__name__) + + +class FileLoaderConfig(BaseModel): + data_dir: str = "data" + use_llama_parse: bool = False + + @validator("data_dir") + def data_dir_must_exist(cls, v): + if not os.path.isdir(v): + raise ValueError(f"Directory '{v}' does not exist") + return v + + +def llama_parse_parser(): + if os.getenv("LLAMA_CLOUD_API_KEY") is None: + raise ValueError( + "LLAMA_CLOUD_API_KEY environment variable is not set. " + "Please set it in .env file or in your shell environment then run again!" + ) + parser = LlamaParse( + result_type="markdown", + verbose=True, + language="en", + ignore_errors=False, + ) + return parser + + +def llama_parse_extractor() -> Dict[str, LlamaParse]: + from llama_parse.utils import SUPPORTED_FILE_TYPES + + parser = llama_parse_parser() + return {file_type: parser for file_type in SUPPORTED_FILE_TYPES} + + +def get_file_documents(config: FileLoaderConfig): + from llama_index.core.readers import SimpleDirectoryReader + + try: + file_extractor = None + if config.use_llama_parse: + # LlamaParse is async first, + # so we need to use nest_asyncio to run it in sync mode + import nest_asyncio + + nest_asyncio.apply() + + file_extractor = llama_parse_extractor() + reader = SimpleDirectoryReader( + config.data_dir, + recursive=True, + filename_as_id=True, + raise_on_error=True, + file_extractor=file_extractor, + ) + return reader.load_data() + except Exception as e: + import sys + import traceback + + # Catch the error if the data dir is empty + # and return as empty document list + _, _, exc_traceback = sys.exc_info() + function_name = traceback.extract_tb(exc_traceback)[-1].name + if function_name == "_add_files": + logger.warning( + f"Failed to load file documents, error message: {e} . Return as empty document list." + ) + return [] + else: + # Raise the error if it is not the case of empty data dir + raise e diff --git a/backend/app/engine/loaders/web.py b/backend/app/engine/loaders/web.py new file mode 100644 index 0000000..563e51b --- /dev/null +++ b/backend/app/engine/loaders/web.py @@ -0,0 +1,36 @@ +import os +import json +from pydantic import BaseModel, Field + + +class CrawlUrl(BaseModel): + base_url: str + prefix: str + max_depth: int = Field(default=1, ge=0) + + +class WebLoaderConfig(BaseModel): + driver_arguments: list[str] = Field(default=None) + urls: list[CrawlUrl] + + +def get_web_documents(config: WebLoaderConfig): + from llama_index.readers.web import WholeSiteReader + from selenium import webdriver + from selenium.webdriver.chrome.options import Options + + options = Options() + driver_arguments = config.driver_arguments or [] + for arg in driver_arguments: + options.add_argument(arg) + + docs = [] + for url in config.urls: + scraper = WholeSiteReader( + prefix=url.prefix, + max_depth=url.max_depth, + driver=webdriver.Chrome(options=options), + ) + docs.extend(scraper.load_data(url.base_url)) + + return docs diff --git a/backend/app/engine/tools/__init__.py b/backend/app/engine/tools/__init__.py new file mode 100644 index 0000000..111bee5 --- /dev/null +++ b/backend/app/engine/tools/__init__.py @@ -0,0 +1,56 @@ +import os +import yaml +import json +import importlib +from cachetools import cached, LRUCache +from llama_index.core.tools.tool_spec.base import BaseToolSpec +from llama_index.core.tools.function_tool import FunctionTool + + +class ToolType: + LLAMAHUB = "llamahub" + LOCAL = "local" + + +class ToolFactory: + + TOOL_SOURCE_PACKAGE_MAP = { + ToolType.LLAMAHUB: "llama_index.tools", + ToolType.LOCAL: "app.engine.tools", + } + + def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]: + source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type] + try: + if "ToolSpec" in tool_name: + tool_package, tool_cls_name = tool_name.split(".") + module_name = f"{source_package}.{tool_package}" + module = importlib.import_module(module_name) + tool_class = getattr(module, tool_cls_name) + tool_spec: BaseToolSpec = tool_class(**config) + return tool_spec.to_tool_list() + else: + module = importlib.import_module(f"{source_package}.{tool_name}") + tools = module.get_tools(**config) + if not all(isinstance(tool, FunctionTool) for tool in tools): + raise ValueError( + f"The module {module} does not contain valid tools" + ) + return tools + except ImportError as e: + raise ValueError(f"Failed to import tool {tool_name}: {e}") + except AttributeError as e: + raise ValueError(f"Failed to load tool {tool_name}: {e}") + + @staticmethod + def from_env() -> list[FunctionTool]: + tools = [] + if os.path.exists("config/tools.yaml"): + with open("config/tools.yaml", "r") as f: + tool_configs = yaml.safe_load(f) + for tool_type, config_entries in tool_configs.items(): + for tool_name, config in config_entries.items(): + tools.extend( + ToolFactory.load_tools(tool_type, tool_name, config) + ) + return tools diff --git a/backend/app/engine/tools/duckduckgo.py b/backend/app/engine/tools/duckduckgo.py new file mode 100644 index 0000000..b63612a --- /dev/null +++ b/backend/app/engine/tools/duckduckgo.py @@ -0,0 +1,36 @@ +from llama_index.core.tools.function_tool import FunctionTool + + +def duckduckgo_search( + query: str, + region: str = "wt-wt", + max_results: int = 10, +): + """ + Use this function to search for any query in DuckDuckGo. + Args: + query (str): The query to search in DuckDuckGo. + region Optional(str): The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc... + max_results Optional(int): The maximum number of results to be returned. Default is 10. + """ + try: + from duckduckgo_search import DDGS + except ImportError: + raise ImportError( + "duckduckgo_search package is required to use this function." + "Please install it by running: `poetry add duckduckgo_search` or `pip install duckduckgo_search`" + ) + + params = { + "keywords": query, + "region": region, + "max_results": max_results, + } + results = [] + with DDGS() as ddg: + results = list(ddg.text(**params)) + return results + + +def get_tools(**kwargs): + return [FunctionTool.from_defaults(duckduckgo_search)] diff --git a/backend/app/engine/tools/img_gen.py b/backend/app/engine/tools/img_gen.py new file mode 100644 index 0000000..966e95d --- /dev/null +++ b/backend/app/engine/tools/img_gen.py @@ -0,0 +1,108 @@ +import os +import uuid +import logging +import requests +from typing import Optional +from pydantic import BaseModel, Field +from llama_index.core.tools import FunctionTool + +logger = logging.getLogger(__name__) + + +class ImageGeneratorToolOutput(BaseModel): + is_success: bool = Field( + ..., + description="Whether the image generation was successful.", + ) + image_url: Optional[str] = Field( + None, + description="The URL of the generated image.", + ) + error_message: Optional[str] = Field( + None, + description="The error message if the image generation failed.", + ) + + +class ImageGeneratorTool: + _IMG_OUTPUT_FORMAT = "webp" + _IMG_OUTPUT_DIR = "output/tool" + _IMG_GEN_API = "https://api.stability.ai/v2beta/stable-image/generate/core" + + def __init__(self, api_key: str = None): + if not api_key: + api_key = os.getenv("STABILITY_API_KEY") + self._api_key = api_key + self.fileserver_url_prefix = os.getenv("FILESERVER_URL_PREFIX") + if self._api_key is None: + raise ValueError( + "STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys" + ) + if self.fileserver_url_prefix is None: + raise ValueError("FILESERVER_URL_PREFIX is required.") + + def _prepare_output_dir(self): + """ + Create the output directory if it doesn't exist + """ + if not os.path.exists(self._IMG_OUTPUT_DIR): + os.makedirs(self._IMG_OUTPUT_DIR, exist_ok=True) + + def _save_image(self, image_data: bytes): + self._prepare_output_dir() + filename = f"{uuid.uuid4()}.{self._IMG_OUTPUT_FORMAT}" + output_path = os.path.join(self._IMG_OUTPUT_DIR, filename) + with open(output_path, "wb") as f: + f.write(image_data) + url = f"{os.getenv('FILESERVER_URL_PREFIX')}/{self._IMG_OUTPUT_DIR}/{filename}" + logger.info(f"Saved image to {output_path}.\nURL: {url}") + return url + + def _call_stability_api(self, prompt: str): + headers = { + "authorization": f"Bearer {self._api_key}", + "accept": "image/*", + } + data = { + "prompt": prompt, + "output_format": self._IMG_OUTPUT_FORMAT, + } + + response = requests.post( + self._IMG_GEN_API, + headers=headers, + files={"none": ""}, + data=data, + ) + response.raise_for_status() + + return response + + def generate_image(self, prompt: str) -> ImageGeneratorToolOutput: + """ + Use this tool to generate an image based on the prompt. + Args: + prompt (str): The prompt to generate the image from. + """ + + try: + # Call the Stability API + response = self._call_stability_api(prompt) + + # Save the image and get the URL + image_url = self._save_image(response.content) + + return ImageGeneratorToolOutput( + is_success=True, + image_url=image_url, + ) + except Exception as e: + logger.exception(e, exc_info=True) + return ImageGeneratorToolOutput( + is_success=False, + error_message=str(e), + ) + + +def get_tools(**kwargs): + return [FunctionTool.from_defaults(ImageGeneratorTool(**kwargs).generate_image)] diff --git a/backend/app/engine/tools/interpreter.py b/backend/app/engine/tools/interpreter.py new file mode 100644 index 0000000..1d2c02c --- /dev/null +++ b/backend/app/engine/tools/interpreter.py @@ -0,0 +1,143 @@ +import os +import logging +import base64 +import uuid +from pydantic import BaseModel +from typing import List, Tuple, Dict, Optional +from llama_index.core.tools import FunctionTool +from e2b_code_interpreter import CodeInterpreter +from e2b_code_interpreter.models import Logs + + +logger = logging.getLogger(__name__) + + +class InterpreterExtraResult(BaseModel): + type: str + content: Optional[str] = None + filename: Optional[str] = None + url: Optional[str] = None + + +class E2BToolOutput(BaseModel): + is_error: bool + logs: Logs + results: List[InterpreterExtraResult] = [] + + +class E2BCodeInterpreter: + + output_dir = "output/tool" + + def __init__(self, api_key: str = None): + if api_key is None: + api_key = os.getenv("E2B_API_KEY") + filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX") + if not api_key: + raise ValueError( + "E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key" + ) + if not filesever_url_prefix: + raise ValueError( + "FILESERVER_URL_PREFIX is required to display file output from sandbox" + ) + + self.filesever_url_prefix = filesever_url_prefix + self.interpreter = CodeInterpreter(api_key=api_key) + + def __del__(self): + self.interpreter.close() + + def get_output_path(self, filename: str) -> str: + # if output directory doesn't exist, create it + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir, exist_ok=True) + return os.path.join(self.output_dir, filename) + + def save_to_disk(self, base64_data: str, ext: str) -> Dict: + filename = f"{uuid.uuid4()}.{ext}" # generate a unique filename + buffer = base64.b64decode(base64_data) + output_path = self.get_output_path(filename) + + try: + with open(output_path, "wb") as file: + file.write(buffer) + except IOError as e: + logger.error(f"Failed to write to file {output_path}: {str(e)}") + raise e + + logger.info(f"Saved file to {output_path}") + + return { + "outputPath": output_path, + "filename": filename, + } + + def get_file_url(self, filename: str) -> str: + return f"{self.filesever_url_prefix}/{self.output_dir}/{filename}" + + def parse_result(self, result) -> List[InterpreterExtraResult]: + """ + The result could include multiple formats (e.g. png, svg, etc.) but encoded in base64 + We save each result to disk and return saved file metadata (extension, filename, url) + """ + if not result: + return [] + + output = [] + + try: + formats = result.formats() + results = [result[format] for format in formats] + + for ext, data in zip(formats, results): + match ext: + case "png" | "svg" | "jpeg" | "pdf": + result = self.save_to_disk(data, ext) + filename = result["filename"] + output.append( + InterpreterExtraResult( + type=ext, + filename=filename, + url=self.get_file_url(filename), + ) + ) + case _: + output.append( + InterpreterExtraResult( + type=ext, + content=data, + ) + ) + except Exception as error: + logger.exception(error, exc_info=True) + logger.error("Error when parsing output from E2b interpreter tool", error) + + return output + + def interpret(self, code: str) -> E2BToolOutput: + """ + Execute python code in a Jupyter notebook cell, the toll will return result, stdout, stderr, display_data, and error. + + Parameters: + code (str): The python code to be executed in a single cell. + """ + logger.info( + f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}" + ) + exec = self.interpreter.notebook.exec_cell(code) + + if exec.error: + logger.error("Error when executing code", exec.error) + output = E2BToolOutput(is_error=True, logs=exec.logs, results=[]) + else: + if len(exec.results) == 0: + output = E2BToolOutput(is_error=False, logs=exec.logs, results=[]) + else: + results = self.parse_result(exec.results[0]) + output = E2BToolOutput(is_error=False, logs=exec.logs, results=results) + return output + + +def get_tools(**kwargs): + return [FunctionTool.from_defaults(E2BCodeInterpreter(**kwargs).interpret)] diff --git a/backend/app/engine/tools/openapi_action.py b/backend/app/engine/tools/openapi_action.py new file mode 100644 index 0000000..c19187d --- /dev/null +++ b/backend/app/engine/tools/openapi_action.py @@ -0,0 +1,78 @@ +from typing import Dict, List, Tuple +from llama_index.tools.openapi import OpenAPIToolSpec +from llama_index.tools.requests import RequestsToolSpec + + +class OpenAPIActionToolSpec(OpenAPIToolSpec, RequestsToolSpec): + """ + A combination of OpenAPI and Requests tool specs that can parse OpenAPI specs and make requests. + + openapi_uri: str: The file path or URL to the OpenAPI spec. + domain_headers: dict: Whitelist domains and the headers to use. + """ + + spec_functions = OpenAPIToolSpec.spec_functions + RequestsToolSpec.spec_functions + # Cached parsed specs by URI + _specs: Dict[str, Tuple[Dict, List[str]]] = {} + + def __init__(self, openapi_uri: str, domain_headers: dict = None, **kwargs): + if domain_headers is None: + domain_headers = {} + if openapi_uri not in self._specs: + openapi_spec, servers = self._load_openapi_spec(openapi_uri) + self._specs[openapi_uri] = (openapi_spec, servers) + else: + openapi_spec, servers = self._specs[openapi_uri] + + # Add the servers to the domain headers if they are not already present + for server in servers: + if server not in domain_headers: + domain_headers[server] = {} + + OpenAPIToolSpec.__init__(self, spec=openapi_spec) + RequestsToolSpec.__init__(self, domain_headers) + + @staticmethod + def _load_openapi_spec(uri: str) -> Tuple[Dict, List[str]]: + """ + Load an OpenAPI spec from a URI. + + Args: + uri (str): A file path or URL to the OpenAPI spec. + + Returns: + List[Document]: A list of Document objects. + """ + import yaml + from urllib.parse import urlparse + + if uri.startswith("http"): + import requests + + response = requests.get(uri) + if response.status_code != 200: + raise ValueError( + "Could not initialize OpenAPIActionToolSpec: " + f"Failed to load OpenAPI spec from {uri}, status code: {response.status_code}" + ) + spec = yaml.safe_load(response.text) + elif uri.startswith("file"): + filepath = urlparse(uri).path + with open(filepath, "r") as file: + spec = yaml.safe_load(file) + else: + raise ValueError( + "Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI URI provided. " + "Only HTTP and file path are supported." + ) + # Add the servers to the whitelist + try: + servers = [ + urlparse(server["url"]).netloc for server in spec.get("servers", []) + ] + except KeyError as e: + raise ValueError( + "Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI spec provided. " + "Could not get `servers` from the spec." + ) from e + return spec, servers diff --git a/backend/app/engine/tools/weather.py b/backend/app/engine/tools/weather.py new file mode 100644 index 0000000..c8b6f1b --- /dev/null +++ b/backend/app/engine/tools/weather.py @@ -0,0 +1,73 @@ +"""Open Meteo weather map tool spec.""" + +import logging +import requests +import pytz +from llama_index.core.tools import FunctionTool + +logger = logging.getLogger(__name__) + + +class OpenMeteoWeather: + geo_api = "https://geocoding-api.open-meteo.com/v1" + weather_api = "https://api.open-meteo.com/v1" + + @classmethod + def _get_geo_location(cls, location: str) -> dict: + """Get geo location from location name.""" + params = {"name": location, "count": 10, "language": "en", "format": "json"} + response = requests.get(f"{cls.geo_api}/search", params=params) + if response.status_code != 200: + raise Exception(f"Failed to fetch geo location: {response.status_code}") + else: + data = response.json() + result = data["results"][0] + geo_location = { + "id": result["id"], + "name": result["name"], + "latitude": result["latitude"], + "longitude": result["longitude"], + } + return geo_location + + @classmethod + def get_weather_information(cls, location: str) -> dict: + """Use this function to get the weather of any given location. + Note that the weather code should follow WMO Weather interpretation codes (WW): + 0: Clear sky + 1, 2, 3: Mainly clear, partly cloudy, and overcast + 45, 48: Fog and depositing rime fog + 51, 53, 55: Drizzle: Light, moderate, and dense intensity + 56, 57: Freezing Drizzle: Light and dense intensity + 61, 63, 65: Rain: Slight, moderate and heavy intensity + 66, 67: Freezing Rain: Light and heavy intensity + 71, 73, 75: Snow fall: Slight, moderate, and heavy intensity + 77: Snow grains + 80, 81, 82: Rain showers: Slight, moderate, and violent + 85, 86: Snow showers slight and heavy + 95: Thunderstorm: Slight or moderate + 96, 99: Thunderstorm with slight and heavy hail + """ + logger.info( + f"Calling open-meteo api to get weather information of location: {location}" + ) + geo_location = cls._get_geo_location(location) + timezone = pytz.timezone("UTC").zone + params = { + "latitude": geo_location["latitude"], + "longitude": geo_location["longitude"], + "current": "temperature_2m,weather_code", + "hourly": "temperature_2m,weather_code", + "daily": "weather_code", + "timezone": timezone, + } + response = requests.get(f"{cls.weather_api}/forecast", params=params) + if response.status_code != 200: + raise Exception( + f"Failed to fetch weather information: {response.status_code}" + ) + return response.json() + + +def get_tools(**kwargs): + return [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)] diff --git a/backend/app/llmhub.py b/backend/app/llmhub.py new file mode 100644 index 0000000..69e0e32 --- /dev/null +++ b/backend/app/llmhub.py @@ -0,0 +1,61 @@ +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.core.settings import Settings +from typing import Dict +import os + +DEFAULT_MODEL = "gpt-3.5-turbo" +DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large" + +class TSIEmbedding(OpenAIEmbedding): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._query_engine = self._text_engine = self.model_name + +def llm_config_from_env() -> Dict: + from llama_index.core.constants import DEFAULT_TEMPERATURE + + model = os.getenv("MODEL", DEFAULT_MODEL) + temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE) + max_tokens = os.getenv("LLM_MAX_TOKENS") + api_key = os.getenv("T_SYSTEMS_LLMHUB_API_KEY") + api_base = os.getenv("T_SYSTEMS_LLMHUB_BASE_URL") + + config = { + "model": model, + "api_key": api_key, + "api_base": api_base, + "temperature": float(temperature), + "max_tokens": int(max_tokens) if max_tokens is not None else None, + } + return config + + +def embedding_config_from_env() -> Dict: + from llama_index.core.constants import DEFAULT_EMBEDDING_DIM + + model = os.getenv("EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL) + dimension = os.getenv("EMBEDDING_DIM", DEFAULT_EMBEDDING_DIM) + api_key = os.getenv("T_SYSTEMS_LLMHUB_API_KEY") + api_base = os.getenv("T_SYSTEMS_LLMHUB_BASE_URL") + + config = { + "model_name": model, + "dimension": int(dimension) if dimension is not None else None, + "api_key": api_key, + "api_base": api_base, + } + return config + +def init_llmhub(): + from llama_index.llms.openai_like import OpenAILike + + llm_configs = llm_config_from_env() + embedding_configs = embedding_config_from_env() + + Settings.embed_model = TSIEmbedding(**embedding_configs) + Settings.llm = OpenAILike( + **llm_configs, + is_chat_model=True, + is_function_calling_model=False, + context_window=4096, + ) \ No newline at end of file diff --git a/backend/app/observability.py b/backend/app/observability.py new file mode 100644 index 0000000..28019c3 --- /dev/null +++ b/backend/app/observability.py @@ -0,0 +1,2 @@ +def init_observability(): + pass diff --git a/backend/app/settings.py b/backend/app/settings.py new file mode 100644 index 0000000..b723bf3 --- /dev/null +++ b/backend/app/settings.py @@ -0,0 +1,172 @@ +import os +from typing import Dict + +from llama_index.core.settings import Settings + + +def init_settings(): + model_provider = os.getenv("MODEL_PROVIDER") + match model_provider: + case "openai": + init_openai() + case "groq": + init_groq() + case "ollama": + init_ollama() + case "anthropic": + init_anthropic() + case "gemini": + init_gemini() + case "mistral": + init_mistral() + case "azure-openai": + init_azure_openai() + case "t-systems": + from .llmhub import init_llmhub + + init_llmhub() + case _: + raise ValueError(f"Invalid model provider: {model_provider}") + + Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024")) + Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20")) + + +def init_ollama(): + from llama_index.embeddings.ollama import OllamaEmbedding + from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama + + base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434" + request_timeout = float( + os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT) + ) + Settings.embed_model = OllamaEmbedding( + base_url=base_url, + model_name=os.getenv("EMBEDDING_MODEL"), + ) + Settings.llm = Ollama( + base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout + ) + + +def init_openai(): + from llama_index.core.constants import DEFAULT_TEMPERATURE + from llama_index.embeddings.openai import OpenAIEmbedding + from llama_index.llms.openai import OpenAI + + max_tokens = os.getenv("LLM_MAX_TOKENS") + config = { + "model": os.getenv("MODEL"), + "temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)), + "max_tokens": int(max_tokens) if max_tokens is not None else None, + } + Settings.llm = OpenAI(**config) + + dimensions = os.getenv("EMBEDDING_DIM") + config = { + "model": os.getenv("EMBEDDING_MODEL"), + "dimensions": int(dimensions) if dimensions is not None else None, + } + Settings.embed_model = OpenAIEmbedding(**config) + + +def init_azure_openai(): + from llama_index.core.constants import DEFAULT_TEMPERATURE + from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding + from llama_index.llms.azure_openai import AzureOpenAI + + llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"] + embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"] + max_tokens = os.getenv("LLM_MAX_TOKENS") + temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE) + dimensions = os.getenv("EMBEDDING_DIM") + + azure_config = { + "api_key": os.environ["AZURE_OPENAI_KEY"], + "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"], + "api_version": os.getenv("AZURE_OPENAI_API_VERSION") + or os.getenv("OPENAI_API_VERSION"), + } + + Settings.llm = AzureOpenAI( + model=os.getenv("MODEL"), + max_tokens=int(max_tokens) if max_tokens is not None else None, + temperature=float(temperature), + deployment_name=llm_deployment, + **azure_config, + ) + + Settings.embed_model = AzureOpenAIEmbedding( + model=os.getenv("EMBEDDING_MODEL"), + dimensions=int(dimensions) if dimensions is not None else None, + deployment_name=embedding_deployment, + **azure_config, + ) + + +def init_fastembed(): + """ + Use Qdrant Fastembed as the local embedding provider. + """ + from llama_index.embeddings.fastembed import FastEmbedEmbedding + + embed_model_map: Dict[str, str] = { + # Small and multilingual + "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", + # Large and multilingual + "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501 + } + + # This will download the model automatically if it is not already downloaded + Settings.embed_model = FastEmbedEmbedding( + model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")] + ) + + +def init_groq(): + from llama_index.llms.groq import Groq + + model_map: Dict[str, str] = { + "llama3-8b": "llama3-8b-8192", + "llama3-70b": "llama3-70b-8192", + "mixtral-8x7b": "mixtral-8x7b-32768", + } + + Settings.llm = Groq(model=model_map[os.getenv("MODEL")]) + # Groq does not provide embeddings, so we use FastEmbed instead + init_fastembed() + + +def init_anthropic(): + from llama_index.llms.anthropic import Anthropic + + model_map: Dict[str, str] = { + "claude-3-opus": "claude-3-opus-20240229", + "claude-3-sonnet": "claude-3-sonnet-20240229", + "claude-3-haiku": "claude-3-haiku-20240307", + "claude-2.1": "claude-2.1", + "claude-instant-1.2": "claude-instant-1.2", + } + + Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")]) + # Anthropic does not provide embeddings, so we use FastEmbed instead + init_fastembed() + + +def init_gemini(): + from llama_index.embeddings.gemini import GeminiEmbedding + from llama_index.llms.gemini import Gemini + + model_name = f"models/{os.getenv('MODEL')}" + embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}" + + Settings.llm = Gemini(model=model_name) + Settings.embed_model = GeminiEmbedding(model_name=embed_model_name) + + +def init_mistral(): + from llama_index.embeddings.mistralai import MistralAIEmbedding + from llama_index.llms.mistralai import MistralAI + + Settings.llm = MistralAI(model=os.getenv("MODEL")) + Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL")) diff --git a/backend/config/loaders.yaml b/backend/config/loaders.yaml new file mode 100644 index 0000000..d746c61 --- /dev/null +++ b/backend/config/loaders.yaml @@ -0,0 +1,10 @@ +file: + # use_llama_parse: Use LlamaParse if `true`. Needs a `LLAMA_CLOUD_API_KEY` from https://cloud.llamaindex.ai set as environment variable + use_llama_parse: true +db: + # The configuration for the database loader, only supports MySQL and PostgreSQL databases for now. + # uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db + # query: The query to fetch data from the database. E.g.: SELECT * FROM table + - uri: mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1 + queries: + - SELECT * FROM mytable diff --git a/backend/config/tools.yaml b/backend/config/tools.yaml new file mode 100644 index 0000000..df5690c --- /dev/null +++ b/backend/config/tools.yaml @@ -0,0 +1,4 @@ +local: + weather: {} + interpreter: {} +llamahub: {} diff --git a/backend/data/101.pdf b/backend/data/101.pdf new file mode 100644 index 0000000..ae5acff Binary files /dev/null and b/backend/data/101.pdf differ diff --git a/backend/main.py b/backend/main.py new file mode 100644 index 0000000..a72745e --- /dev/null +++ b/backend/main.py @@ -0,0 +1,64 @@ +from dotenv import load_dotenv + +load_dotenv() + +import logging +import os +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from app.api.routers.chat import chat_router +from app.api.routers.upload import file_upload_router +from app.settings import init_settings +from app.observability import init_observability +from fastapi.staticfiles import StaticFiles + + +app = FastAPI() + +init_settings() +init_observability() + +environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set +logger = logging.getLogger("uvicorn") + +if environment == "dev": + logger.warning("Running in development mode - allowing CORS for all origins") + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Redirect to documentation page when accessing base URL + @app.get("/") + async def redirect_to_docs(): + return RedirectResponse(url="/docs") + + +def mount_static_files(directory, path): + if os.path.exists(directory): + for dir, _, _ in os.walk(directory): + relative_path = os.path.relpath(dir, directory) + mount_path = path if relative_path == "." else f"{path}/{relative_path}" + logger.info(f"Mounting static files '{dir}' at {mount_path}") + app.mount(mount_path, StaticFiles(directory=dir), name=f"{dir}-static") + + +# Mount the data files to serve the file viewer +mount_static_files("data", "/api/files/data") +# Mount the output files from tools +mount_static_files("output", "/api/files/output") + +app.include_router(chat_router, prefix="/api/chat") +app.include_router(file_upload_router, prefix="/api/chat/upload") + +if __name__ == "__main__": + app_host = os.getenv("APP_HOST", "0.0.0.0") + app_port = int(os.getenv("APP_PORT", "8000")) + reload = True if environment == "dev" else False + + uvicorn.run(app="main:app", host=app_host, port=app_port, reload=reload) diff --git a/backend/pyproject.toml b/backend/pyproject.toml new file mode 100644 index 0000000..21715b3 --- /dev/null +++ b/backend/pyproject.toml @@ -0,0 +1,48 @@ +[tool] +[tool.poetry] +name = "app" +version = "0.1.0" +description = "" +authors = [ "Marcus Schiesser " ] +readme = "README.md" + +[tool.poetry.scripts] +generate = "app.engine.generate:generate_datasource" + +[tool.poetry.dependencies] +python = "^3.11,<3.12" +fastapi = "^0.109.1" +python-dotenv = "^1.0.0" +aiostream = "^0.5.2" +llama-index = "0.10.58" +cachetools = "^5.3.3" + +[tool.poetry.dependencies.uvicorn] +extras = [ "standard" ] +version = "^0.23.2" + +[tool.poetry.dependencies.llama-index-readers-database] +version = "^0.1.3" + +[tool.poetry.dependencies.pymysql] +version = "^1.1.0" +extras = [ "rsa" ] + +[tool.poetry.dependencies.psycopg2] +version = "^2.9.9" + +[tool.poetry.dependencies.llama-index-indices-managed-llama-cloud] +version = "^0.2.7" + +[tool.poetry.dependencies.docx2txt] +version = "^0.8" + +[tool.poetry.dependencies.e2b_code_interpreter] +version = "0.0.7" + +[tool.poetry.dependencies.llama-index-agent-openai] +version = "0.2.6" + +[build-system] +requires = [ "poetry-core" ] +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/frontend/.env b/frontend/.env new file mode 100644 index 0000000..faf27f1 --- /dev/null +++ b/frontend/.env @@ -0,0 +1,6 @@ +# The backend API for chat endpoint. +NEXT_PUBLIC_CHAT_API=http://localhost:8000/api/chat + +# Let's the user change indexes in LlamaCloud projects +NEXT_PUBLIC_USE_LLAMACLOUD=true + diff --git a/frontend/.eslintrc.json b/frontend/.eslintrc.json new file mode 100644 index 0000000..e96fdb8 --- /dev/null +++ b/frontend/.eslintrc.json @@ -0,0 +1,7 @@ +{ + "extends": ["next/core-web-vitals", "prettier"], + "rules": { + "max-params": ["error", 4], + "prefer-const": "error" + } +} diff --git a/frontend/.gitignore b/frontend/.gitignore new file mode 100644 index 0000000..b7ee22a --- /dev/null +++ b/frontend/.gitignore @@ -0,0 +1,37 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# next.js +/.next/ +/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# local env files +.env*.local + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts + +output/ diff --git a/frontend/Dockerfile b/frontend/Dockerfile new file mode 100644 index 0000000..5c738ab --- /dev/null +++ b/frontend/Dockerfile @@ -0,0 +1,16 @@ +FROM node:20-alpine as build + +WORKDIR /app + +# Install dependencies +COPY package.json package-lock.* ./ +RUN npm install + +# Build the application +COPY . . +RUN npm run build + +# ==================================== +FROM build as release + +CMD ["npm", "run", "start"] \ No newline at end of file diff --git a/frontend/README.md b/frontend/README.md new file mode 100644 index 0000000..d2eb1eb --- /dev/null +++ b/frontend/README.md @@ -0,0 +1,71 @@ +This is a [LlamaIndex](https://www.llamaindex.ai/) project using [Next.js](https://nextjs.org/) bootstrapped with [`create-llama`](https://github.com/run-llama/LlamaIndexTS/tree/main/packages/create-llama). + +## Getting Started + +First, install the dependencies: + +``` +npm install +``` + +Second, generate the embeddings of the documents in the `./data` directory (if this folder exists - otherwise, skip this step): + +``` +npm run generate +``` + +Third, run the development server: + +``` +npm run dev +``` + +Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. + +You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file. + +This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font. + +## Using Docker + +1. Build an image for the Next.js app: + +``` +docker build -t . +``` + +2. Generate embeddings: + +Parse the data and generate the vector embeddings if the `./data` folder exists - otherwise, skip this step: + +``` +docker run \ + --rm \ + -v $(pwd)/.env:/app/.env \ # Use ENV variables and configuration from your file-system + -v $(pwd)/config:/app/config \ + -v $(pwd)/data:/app/data \ + -v $(pwd)/cache:/app/cache \ # Use your file system to store the vector database + \ + npm run generate +``` + +3. Start the app: + +``` +docker run \ + --rm \ + -v $(pwd)/.env:/app/.env \ # Use ENV variables and configuration from your file-system + -v $(pwd)/config:/app/config \ + -v $(pwd)/cache:/app/cache \ # Use your file system to store gea vector database + -p 3000:3000 \ + +``` + +## Learn More + +To learn more about LlamaIndex, take a look at the following resources: + +- [LlamaIndex Documentation](https://docs.llamaindex.ai) - learn about LlamaIndex (Python features). +- [LlamaIndexTS Documentation](https://ts.llamaindex.ai) - learn about LlamaIndex (Typescript features). + +You can check out [the LlamaIndexTS GitHub repository](https://github.com/run-llama/LlamaIndexTS) - your feedback and contributions are welcome! diff --git a/frontend/app/components/chat-section.tsx b/frontend/app/components/chat-section.tsx new file mode 100644 index 0000000..b33a6cf --- /dev/null +++ b/frontend/app/components/chat-section.tsx @@ -0,0 +1,51 @@ +"use client"; + +import { useChat } from "ai/react"; +import { ChatInput, ChatMessages } from "./ui/chat"; +import { useClientConfig } from "./ui/chat/hooks/use-config"; + +export default function ChatSection() { + const { backend } = useClientConfig(); + const { + messages, + input, + isLoading, + handleSubmit, + handleInputChange, + reload, + stop, + append, + setInput, + } = useChat({ + api: `${backend}/api/chat`, + headers: { + "Content-Type": "application/json", // using JSON because of vercel/ai 2.2.26 + }, + onError: (error: unknown) => { + if (!(error instanceof Error)) throw error; + const message = JSON.parse(error.message); + alert(message.detail); + }, + }); + + return ( +
+ + +
+ ); +} diff --git a/frontend/app/components/header.tsx b/frontend/app/components/header.tsx new file mode 100644 index 0000000..f02ce73 --- /dev/null +++ b/frontend/app/components/header.tsx @@ -0,0 +1,28 @@ +import Image from "next/image"; + +export default function Header() { + return ( +
+

+ Get started by editing  + app/page.tsx +

+ +
+ ); +} diff --git a/frontend/app/components/ui/README.md b/frontend/app/components/ui/README.md new file mode 100644 index 0000000..ebfcf48 --- /dev/null +++ b/frontend/app/components/ui/README.md @@ -0,0 +1 @@ +Using the chat component from https://github.com/marcusschiesser/ui (based on https://ui.shadcn.com/) diff --git a/frontend/app/components/ui/button.tsx b/frontend/app/components/ui/button.tsx new file mode 100644 index 0000000..662b040 --- /dev/null +++ b/frontend/app/components/ui/button.tsx @@ -0,0 +1,56 @@ +import { Slot } from "@radix-ui/react-slot"; +import { cva, type VariantProps } from "class-variance-authority"; +import * as React from "react"; + +import { cn } from "./lib/utils"; + +const buttonVariants = cva( + "inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50", + { + variants: { + variant: { + default: "bg-primary text-primary-foreground hover:bg-primary/90", + destructive: + "bg-destructive text-destructive-foreground hover:bg-destructive/90", + outline: + "border border-input bg-background hover:bg-accent hover:text-accent-foreground", + secondary: + "bg-secondary text-secondary-foreground hover:bg-secondary/80", + ghost: "hover:bg-accent hover:text-accent-foreground", + link: "text-primary underline-offset-4 hover:underline", + }, + size: { + default: "h-10 px-4 py-2", + sm: "h-9 rounded-md px-3", + lg: "h-11 rounded-md px-8", + icon: "h-10 w-10", + }, + }, + defaultVariants: { + variant: "default", + size: "default", + }, + }, +); + +export interface ButtonProps + extends React.ButtonHTMLAttributes, + VariantProps { + asChild?: boolean; +} + +const Button = React.forwardRef( + ({ className, variant, size, asChild = false, ...props }, ref) => { + const Comp = asChild ? Slot : "button"; + return ( + + ); + }, +); +Button.displayName = "Button"; + +export { Button, buttonVariants }; diff --git a/frontend/app/components/ui/chat/chat-actions.tsx b/frontend/app/components/ui/chat/chat-actions.tsx new file mode 100644 index 0000000..151ef61 --- /dev/null +++ b/frontend/app/components/ui/chat/chat-actions.tsx @@ -0,0 +1,28 @@ +import { PauseCircle, RefreshCw } from "lucide-react"; + +import { Button } from "../button"; +import { ChatHandler } from "./chat.interface"; + +export default function ChatActions( + props: Pick & { + showReload?: boolean; + showStop?: boolean; + }, +) { + return ( +
+ {props.showStop && ( + + )} + {props.showReload && ( + + )} +
+ ); +} diff --git a/frontend/app/components/ui/chat/chat-input.tsx b/frontend/app/components/ui/chat/chat-input.tsx new file mode 100644 index 0000000..4c58296 --- /dev/null +++ b/frontend/app/components/ui/chat/chat-input.tsx @@ -0,0 +1,127 @@ +import { JSONValue } from "ai"; +import { useState } from "react"; +import { Button } from "../button"; +import { DocumentPreview } from "../document-preview"; +import FileUploader from "../file-uploader"; +import { Input } from "../input"; +import UploadImagePreview from "../upload-image-preview"; +import { ChatHandler } from "./chat.interface"; +import { useFile } from "./hooks/use-file"; +import { LlamaCloudSelector } from "./widgets/LlamaCloudSelector"; + +const ALLOWED_EXTENSIONS = ["png", "jpg", "jpeg", "csv", "pdf", "txt", "docx"]; + +export default function ChatInput( + props: Pick< + ChatHandler, + | "isLoading" + | "input" + | "onFileUpload" + | "onFileError" + | "handleSubmit" + | "handleInputChange" + | "messages" + | "setInput" + | "append" + > & { + requestParams?: any; + }, +) { + const { + imageUrl, + setImageUrl, + uploadFile, + files, + removeDoc, + reset, + getAnnotations, + } = useFile(); + const [requestData, setRequestData] = useState(); + + // default submit function does not handle including annotations in the message + // so we need to use append function to submit new message with annotations + const handleSubmitWithAnnotations = ( + e: React.FormEvent, + annotations: JSONValue[] | undefined, + ) => { + e.preventDefault(); + props.append!( + { + content: props.input, + role: "user", + createdAt: new Date(), + annotations, + }, + { data: requestData }, + ); + props.setInput!(""); + }; + + const onSubmit = (e: React.FormEvent) => { + const annotations = getAnnotations(); + if (annotations.length) { + handleSubmitWithAnnotations(e, annotations); + return reset(); + } + props.handleSubmit(e, { data: requestData }); + }; + + const handleUploadFile = async (file: File) => { + if (imageUrl || files.length > 0) { + alert("You can only upload one file at a time."); + return; + } + try { + await uploadFile(file, props.requestParams); + props.onFileUpload?.(file); + } catch (error: any) { + props.onFileError?.(error.message); + } + }; + + return ( +
+ {imageUrl && ( + setImageUrl(null)} /> + )} + {files.length > 0 && ( +
+ {files.map((file) => ( + removeDoc(file)} + /> + ))} +
+ )} +
+ + + {process.env.NEXT_PUBLIC_USE_LLAMACLOUD === "true" && ( + + )} + +
+ + ); +} diff --git a/frontend/app/components/ui/chat/chat-message/chat-avatar.tsx b/frontend/app/components/ui/chat/chat-message/chat-avatar.tsx new file mode 100644 index 0000000..ce04e30 --- /dev/null +++ b/frontend/app/components/ui/chat/chat-message/chat-avatar.tsx @@ -0,0 +1,25 @@ +import { User2 } from "lucide-react"; +import Image from "next/image"; + +export default function ChatAvatar({ role }: { role: string }) { + if (role === "user") { + return ( +
+ +
+ ); + } + + return ( +
+ Llama Logo +
+ ); +} diff --git a/frontend/app/components/ui/chat/chat-message/chat-events.tsx b/frontend/app/components/ui/chat/chat-message/chat-events.tsx new file mode 100644 index 0000000..3dfad75 --- /dev/null +++ b/frontend/app/components/ui/chat/chat-message/chat-events.tsx @@ -0,0 +1,50 @@ +import { ChevronDown, ChevronRight, Loader2 } from "lucide-react"; +import { useState } from "react"; +import { Button } from "../../button"; +import { + Collapsible, + CollapsibleContent, + CollapsibleTrigger, +} from "../../collapsible"; +import { EventData } from "../index"; + +export function ChatEvents({ + data, + isLoading, +}: { + data: EventData[]; + isLoading: boolean; +}) { + const [isOpen, setIsOpen] = useState(false); + + const buttonLabel = isOpen ? "Hide events" : "Show events"; + + const EventIcon = isOpen ? ( + + ) : ( + + ); + + return ( +
+ + + + + +
+ {data.map((eventItem, index) => ( +
+ {eventItem.title} +
+ ))} +
+
+
+
+ ); +} diff --git a/frontend/app/components/ui/chat/chat-message/chat-files.tsx b/frontend/app/components/ui/chat/chat-message/chat-files.tsx new file mode 100644 index 0000000..5139c54 --- /dev/null +++ b/frontend/app/components/ui/chat/chat-message/chat-files.tsx @@ -0,0 +1,13 @@ +import { DocumentPreview } from "../../document-preview"; +import { DocumentFileData } from "../index"; + +export function ChatFiles({ data }: { data: DocumentFileData }) { + if (!data.files.length) return null; + return ( +
+ {data.files.map((file) => ( + + ))} +
+ ); +} diff --git a/frontend/app/components/ui/chat/chat-message/chat-image.tsx b/frontend/app/components/ui/chat/chat-message/chat-image.tsx new file mode 100644 index 0000000..2de28c3 --- /dev/null +++ b/frontend/app/components/ui/chat/chat-message/chat-image.tsx @@ -0,0 +1,17 @@ +import Image from "next/image"; +import { type ImageData } from "../index"; + +export function ChatImage({ data }: { data: ImageData }) { + return ( +
+ +
+ ); +} diff --git a/frontend/app/components/ui/chat/chat-message/chat-sources.tsx b/frontend/app/components/ui/chat/chat-message/chat-sources.tsx new file mode 100644 index 0000000..1d4ccb6 --- /dev/null +++ b/frontend/app/components/ui/chat/chat-message/chat-sources.tsx @@ -0,0 +1,123 @@ +import { Check, Copy } from "lucide-react"; +import { useMemo } from "react"; +import { Button } from "../../button"; +import { + HoverCard, + HoverCardContent, + HoverCardTrigger, +} from "../../hover-card"; +import { useCopyToClipboard } from "../hooks/use-copy-to-clipboard"; +import { SourceData } from "../index"; +import PdfDialog from "../widgets/PdfDialog"; + +const SCORE_THRESHOLD = 0.3; + +function SourceNumberButton({ index }: { index: number }) { + return ( +
+ {index + 1} +
+ ); +} + +type NodeInfo = { + id: string; + url?: string; +}; + +export function ChatSources({ data }: { data: SourceData }) { + const sources: NodeInfo[] = useMemo(() => { + // aggregate nodes by url or file_path (get the highest one by score) + const nodesByPath: { [path: string]: NodeInfo } = {}; + + data.nodes + .filter((node) => (node.score ?? 1) > SCORE_THRESHOLD) + .sort((a, b) => (b.score ?? 1) - (a.score ?? 1)) + .forEach((node) => { + const nodeInfo = { + id: node.id, + url: node.url, + }; + const key = nodeInfo.url ?? nodeInfo.id; // use id as key for UNKNOWN type + if (!nodesByPath[key]) { + nodesByPath[key] = nodeInfo; + } + }); + + return Object.values(nodesByPath); + }, [data.nodes]); + + if (sources.length === 0) return null; + + return ( +
+ Sources: +
+ {sources.map((nodeInfo: NodeInfo, index: number) => { + if (nodeInfo.url?.endsWith(".pdf")) { + return ( + } + /> + ); + } + return ( +
+ + + + + + + + +
+ ); + })} +
+
+ ); +} + +function NodeInfo({ nodeInfo }: { nodeInfo: NodeInfo }) { + const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 }); + + if (nodeInfo.url) { + // this is a node generated by the web loader or file loader, + // add a link to view its URL and a button to copy the URL to the clipboard + return ( +
+ + {nodeInfo.url} + + +
+ ); + } + + // node generated by unknown loader, implement renderer by analyzing logged out metadata + return ( +

+ Sorry, unknown node type. Please add a new renderer in the NodeInfo + component. +

+ ); +} diff --git a/frontend/app/components/ui/chat/chat-message/chat-suggestedQuestions.tsx b/frontend/app/components/ui/chat/chat-message/chat-suggestedQuestions.tsx new file mode 100644 index 0000000..ea662e4 --- /dev/null +++ b/frontend/app/components/ui/chat/chat-message/chat-suggestedQuestions.tsx @@ -0,0 +1,32 @@ +import { useState } from "react"; +import { ChatHandler, SuggestedQuestionsData } from ".."; + +export function SuggestedQuestions({ + questions, + append, +}: { + questions: SuggestedQuestionsData; + append: Pick["append"]; +}) { + const [showQuestions, setShowQuestions] = useState(questions.length > 0); + + return ( + showQuestions && + append !== undefined && ( + + ) + ); +} diff --git a/frontend/app/components/ui/chat/chat-message/chat-tools.tsx b/frontend/app/components/ui/chat/chat-message/chat-tools.tsx new file mode 100644 index 0000000..202f982 --- /dev/null +++ b/frontend/app/components/ui/chat/chat-message/chat-tools.tsx @@ -0,0 +1,26 @@ +import { ToolData } from "../index"; +import { WeatherCard, WeatherData } from "../widgets/WeatherCard"; + +// TODO: If needed, add displaying more tool outputs here +export default function ChatTools({ data }: { data: ToolData }) { + if (!data) return null; + const { toolCall, toolOutput } = data; + + if (toolOutput.isError) { + return ( +
+ There was an error when calling the tool {toolCall.name} with input:{" "} +
+ {JSON.stringify(toolCall.input)} +
+ ); + } + + switch (toolCall.name) { + case "get_weather_information": + const weatherData = toolOutput.output as unknown as WeatherData; + return ; + default: + return null; + } +} diff --git a/frontend/app/components/ui/chat/chat-message/codeblock.tsx b/frontend/app/components/ui/chat/chat-message/codeblock.tsx new file mode 100644 index 0000000..e71a408 --- /dev/null +++ b/frontend/app/components/ui/chat/chat-message/codeblock.tsx @@ -0,0 +1,139 @@ +"use client"; + +import { Check, Copy, Download } from "lucide-react"; +import { FC, memo } from "react"; +import { Prism, SyntaxHighlighterProps } from "react-syntax-highlighter"; +import { coldarkDark } from "react-syntax-highlighter/dist/cjs/styles/prism"; + +import { Button } from "../../button"; +import { useCopyToClipboard } from "../hooks/use-copy-to-clipboard"; + +// TODO: Remove this when @type/react-syntax-highlighter is updated +const SyntaxHighlighter = Prism as unknown as FC; + +interface Props { + language: string; + value: string; +} + +interface languageMap { + [key: string]: string | undefined; +} + +export const programmingLanguages: languageMap = { + javascript: ".js", + python: ".py", + java: ".java", + c: ".c", + cpp: ".cpp", + "c++": ".cpp", + "c#": ".cs", + ruby: ".rb", + php: ".php", + swift: ".swift", + "objective-c": ".m", + kotlin: ".kt", + typescript: ".ts", + go: ".go", + perl: ".pl", + rust: ".rs", + scala: ".scala", + haskell: ".hs", + lua: ".lua", + shell: ".sh", + sql: ".sql", + html: ".html", + css: ".css", + // add more file extensions here, make sure the key is same as language prop in CodeBlock.tsx component +}; + +export const generateRandomString = (length: number, lowercase = false) => { + const chars = "ABCDEFGHJKLMNPQRSTUVWXY3456789"; // excluding similar looking characters like Z, 2, I, 1, O, 0 + let result = ""; + for (let i = 0; i < length; i++) { + result += chars.charAt(Math.floor(Math.random() * chars.length)); + } + return lowercase ? result.toLowerCase() : result; +}; + +const CodeBlock: FC = memo(({ language, value }) => { + const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 2000 }); + + const downloadAsFile = () => { + if (typeof window === "undefined") { + return; + } + const fileExtension = programmingLanguages[language] || ".file"; + const suggestedFileName = `file-${generateRandomString( + 3, + true, + )}${fileExtension}`; + const fileName = window.prompt("Enter file name" || "", suggestedFileName); + + if (!fileName) { + // User pressed cancel on prompt. + return; + } + + const blob = new Blob([value], { type: "text/plain" }); + const url = URL.createObjectURL(blob); + const link = document.createElement("a"); + link.download = fileName; + link.href = url; + link.style.display = "none"; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + URL.revokeObjectURL(url); + }; + + const onCopy = () => { + if (isCopied) return; + copyToClipboard(value); + }; + + return ( +
+
+ {language} +
+ + +
+
+ + {value} + +
+ ); +}); +CodeBlock.displayName = "CodeBlock"; + +export { CodeBlock }; diff --git a/frontend/app/components/ui/chat/chat-message/index.tsx b/frontend/app/components/ui/chat/chat-message/index.tsx new file mode 100644 index 0000000..e71903e --- /dev/null +++ b/frontend/app/components/ui/chat/chat-message/index.tsx @@ -0,0 +1,156 @@ +import { Check, Copy } from "lucide-react"; + +import { Message } from "ai"; +import { Fragment } from "react"; +import { Button } from "../../button"; +import { useCopyToClipboard } from "../hooks/use-copy-to-clipboard"; +import { + ChatHandler, + DocumentFileData, + EventData, + ImageData, + MessageAnnotation, + MessageAnnotationType, + SourceData, + SuggestedQuestionsData, + ToolData, + getAnnotationData, +} from "../index"; +import ChatAvatar from "./chat-avatar"; +import { ChatEvents } from "./chat-events"; +import { ChatFiles } from "./chat-files"; +import { ChatImage } from "./chat-image"; +import { ChatSources } from "./chat-sources"; +import { SuggestedQuestions } from "./chat-suggestedQuestions"; +import ChatTools from "./chat-tools"; +import Markdown from "./markdown"; + +type ContentDisplayConfig = { + order: number; + component: JSX.Element | null; +}; + +function ChatMessageContent({ + message, + isLoading, + append, +}: { + message: Message; + isLoading: boolean; + append: Pick["append"]; +}) { + const annotations = message.annotations as MessageAnnotation[] | undefined; + if (!annotations?.length) return ; + + const imageData = getAnnotationData( + annotations, + MessageAnnotationType.IMAGE, + ); + const contentFileData = getAnnotationData( + annotations, + MessageAnnotationType.DOCUMENT_FILE, + ); + const eventData = getAnnotationData( + annotations, + MessageAnnotationType.EVENTS, + ); + const sourceData = getAnnotationData( + annotations, + MessageAnnotationType.SOURCES, + ); + const toolData = getAnnotationData( + annotations, + MessageAnnotationType.TOOLS, + ); + const suggestedQuestionsData = getAnnotationData( + annotations, + MessageAnnotationType.SUGGESTED_QUESTIONS, + ); + + const contents: ContentDisplayConfig[] = [ + { + order: 1, + component: imageData[0] ? : null, + }, + { + order: -3, + component: + eventData.length > 0 ? ( + + ) : null, + }, + { + order: 2, + component: contentFileData[0] ? ( + + ) : null, + }, + { + order: -1, + component: toolData[0] ? : null, + }, + { + order: 0, + component: , + }, + { + order: 3, + component: sourceData[0] ? : null, + }, + { + order: 4, + component: suggestedQuestionsData[0] ? ( + + ) : null, + }, + ]; + + return ( +
+ {contents + .sort((a, b) => a.order - b.order) + .map((content, index) => ( + {content.component} + ))} +
+ ); +} + +export default function ChatMessage({ + chatMessage, + isLoading, + append, +}: { + chatMessage: Message; + isLoading: boolean; + append: Pick["append"]; +}) { + const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 2000 }); + return ( +
+ +
+ + +
+
+ ); +} diff --git a/frontend/app/components/ui/chat/chat-message/markdown.tsx b/frontend/app/components/ui/chat/chat-message/markdown.tsx new file mode 100644 index 0000000..79791b4 --- /dev/null +++ b/frontend/app/components/ui/chat/chat-message/markdown.tsx @@ -0,0 +1,88 @@ +import "katex/dist/katex.min.css"; +import { FC, memo } from "react"; +import ReactMarkdown, { Options } from "react-markdown"; +import rehypeKatex from "rehype-katex"; +import remarkGfm from "remark-gfm"; +import remarkMath from "remark-math"; + +import { CodeBlock } from "./codeblock"; + +const MemoizedReactMarkdown: FC = memo( + ReactMarkdown, + (prevProps, nextProps) => + prevProps.children === nextProps.children && + prevProps.className === nextProps.className, +); + +const preprocessLaTeX = (content: string) => { + // Replace block-level LaTeX delimiters \[ \] with $$ $$ + const blockProcessedContent = content.replace( + /\\\[([\s\S]*?)\\\]/g, + (_, equation) => `$$${equation}$$`, + ); + // Replace inline LaTeX delimiters \( \) with $ $ + const inlineProcessedContent = blockProcessedContent.replace( + /\\\[([\s\S]*?)\\\]/g, + (_, equation) => `$${equation}$`, + ); + return inlineProcessedContent; +}; + +const preprocessMedia = (content: string) => { + // Remove `sandbox:` from the beginning of the URL + // to fix OpenAI's models issue appending `sandbox:` to the relative URL + return content.replace(/(sandbox|attachment|snt):/g, ""); +}; + +const preprocessContent = (content: string) => { + return preprocessMedia(preprocessLaTeX(content)); +}; + +export default function Markdown({ content }: { content: string }) { + const processedContent = preprocessContent(content); + + return ( + {children}

; + }, + code({ node, inline, className, children, ...props }) { + if (children.length) { + if (children[0] == "▍") { + return ( + + ); + } + + children[0] = (children[0] as string).replace("`▍`", "▍"); + } + + const match = /language-(\w+)/.exec(className || ""); + + if (inline) { + return ( + + {children} + + ); + } + + return ( + + ); + }, + }} + > + {processedContent} +
+ ); +} diff --git a/frontend/app/components/ui/chat/chat-messages.tsx b/frontend/app/components/ui/chat/chat-messages.tsx new file mode 100644 index 0000000..e0afd8b --- /dev/null +++ b/frontend/app/components/ui/chat/chat-messages.tsx @@ -0,0 +1,95 @@ +import { Loader2 } from "lucide-react"; +import { useEffect, useRef } from "react"; + +import { Button } from "../button"; +import ChatActions from "./chat-actions"; +import ChatMessage from "./chat-message"; +import { ChatHandler } from "./chat.interface"; +import { useClientConfig } from "./hooks/use-config"; + +export default function ChatMessages( + props: Pick< + ChatHandler, + "messages" | "isLoading" | "reload" | "stop" | "append" + >, +) { + const { starterQuestions } = useClientConfig(); + const scrollableChatContainerRef = useRef(null); + const messageLength = props.messages.length; + const lastMessage = props.messages[messageLength - 1]; + + const scrollToBottom = () => { + if (scrollableChatContainerRef.current) { + scrollableChatContainerRef.current.scrollTop = + scrollableChatContainerRef.current.scrollHeight; + } + }; + + const isLastMessageFromAssistant = + messageLength > 0 && lastMessage?.role !== "user"; + const showReload = + props.reload && !props.isLoading && isLastMessageFromAssistant; + const showStop = props.stop && props.isLoading; + + // `isPending` indicate + // that stream response is not yet received from the server, + // so we show a loading indicator to give a better UX. + const isPending = props.isLoading && !isLastMessageFromAssistant; + + useEffect(() => { + scrollToBottom(); + }, [messageLength, lastMessage]); + + return ( +
+
+ {props.messages.map((m, i) => { + const isLoadingMessage = i === messageLength - 1 && props.isLoading; + return ( + + ); + })} + {isPending && ( +
+ +
+ )} +
+ {(showReload || showStop) && ( +
+ +
+ )} + {!messageLength && starterQuestions?.length && props.append && ( +
+
+ {starterQuestions.map((question, i) => ( + + ))} +
+
+ )} +
+ ); +} diff --git a/frontend/app/components/ui/chat/chat.interface.ts b/frontend/app/components/ui/chat/chat.interface.ts new file mode 100644 index 0000000..6b74d4f --- /dev/null +++ b/frontend/app/components/ui/chat/chat.interface.ts @@ -0,0 +1,25 @@ +import { Message } from "ai"; + +export interface ChatHandler { + messages: Message[]; + input: string; + isLoading: boolean; + handleSubmit: ( + e: React.FormEvent, + ops?: { + data?: any; + }, + ) => void; + handleInputChange: (e: React.ChangeEvent) => void; + reload?: () => void; + stop?: () => void; + onFileUpload?: (file: File) => Promise; + onFileError?: (errMsg: string) => void; + setInput?: (input: string) => void; + append?: ( + message: Message | Omit, + ops?: { + data: any; + }, + ) => Promise; +} diff --git a/frontend/app/components/ui/chat/hooks/use-config.ts b/frontend/app/components/ui/chat/hooks/use-config.ts new file mode 100644 index 0000000..05de32a --- /dev/null +++ b/frontend/app/components/ui/chat/hooks/use-config.ts @@ -0,0 +1,31 @@ +"use client"; + +import { useEffect, useMemo, useState } from "react"; + +export interface ChatConfig { + backend?: string; + starterQuestions?: string[]; +} + +export function useClientConfig(): ChatConfig { + const chatAPI = process.env.NEXT_PUBLIC_CHAT_API; + const [config, setConfig] = useState(); + + const backendOrigin = useMemo(() => { + return chatAPI ? new URL(chatAPI).origin : ""; + }, [chatAPI]); + + const configAPI = `${backendOrigin}/api/chat/config`; + + useEffect(() => { + fetch(configAPI) + .then((response) => response.json()) + .then((data) => setConfig({ ...data, chatAPI })) + .catch((error) => console.error("Error fetching config", error)); + }, [chatAPI, configAPI]); + + return { + backend: backendOrigin, + starterQuestions: config?.starterQuestions, + }; +} diff --git a/frontend/app/components/ui/chat/hooks/use-copy-to-clipboard.tsx b/frontend/app/components/ui/chat/hooks/use-copy-to-clipboard.tsx new file mode 100644 index 0000000..e011d69 --- /dev/null +++ b/frontend/app/components/ui/chat/hooks/use-copy-to-clipboard.tsx @@ -0,0 +1,33 @@ +"use client"; + +import * as React from "react"; + +export interface useCopyToClipboardProps { + timeout?: number; +} + +export function useCopyToClipboard({ + timeout = 2000, +}: useCopyToClipboardProps) { + const [isCopied, setIsCopied] = React.useState(false); + + const copyToClipboard = (value: string) => { + if (typeof window === "undefined" || !navigator.clipboard?.writeText) { + return; + } + + if (!value) { + return; + } + + navigator.clipboard.writeText(value).then(() => { + setIsCopied(true); + + setTimeout(() => { + setIsCopied(false); + }, timeout); + }); + }; + + return { isCopied, copyToClipboard }; +} diff --git a/frontend/app/components/ui/chat/hooks/use-file.ts b/frontend/app/components/ui/chat/hooks/use-file.ts new file mode 100644 index 0000000..2c2c34b --- /dev/null +++ b/frontend/app/components/ui/chat/hooks/use-file.ts @@ -0,0 +1,153 @@ +"use client"; + +import { JSONValue } from "llamaindex"; +import { useState } from "react"; +import { v4 as uuidv4 } from "uuid"; +import { + DocumentFile, + DocumentFileType, + MessageAnnotation, + MessageAnnotationType, +} from ".."; +import { useClientConfig } from "./use-config"; + +const docMineTypeMap: Record = { + "text/csv": "csv", + "application/pdf": "pdf", + "text/plain": "txt", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": + "docx", +}; + +export function useFile() { + const { backend } = useClientConfig(); + const [imageUrl, setImageUrl] = useState(null); + const [files, setFiles] = useState([]); + + const docEqual = (a: DocumentFile, b: DocumentFile) => { + if (a.id === b.id) return true; + if (a.filename === b.filename && a.filesize === b.filesize) return true; + return false; + }; + + const addDoc = (file: DocumentFile) => { + const existedFile = files.find((f) => docEqual(f, file)); + if (!existedFile) { + setFiles((prev) => [...prev, file]); + return true; + } + return false; + }; + + const removeDoc = (file: DocumentFile) => { + setFiles((prev) => prev.filter((f) => f.id !== file.id)); + }; + + const reset = () => { + imageUrl && setImageUrl(null); + files.length && setFiles([]); + }; + + const uploadContent = async ( + base64: string, + requestParams: any = {}, + ): Promise => { + const uploadAPI = `${backend}/api/chat/upload`; + const response = await fetch(uploadAPI, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + base64, + ...requestParams, + }), + }); + if (!response.ok) throw new Error("Failed to upload document."); + return await response.json(); + }; + + const getAnnotations = () => { + const annotations: MessageAnnotation[] = []; + if (imageUrl) { + annotations.push({ + type: MessageAnnotationType.IMAGE, + data: { url: imageUrl }, + }); + } + if (files.length > 0) { + annotations.push({ + type: MessageAnnotationType.DOCUMENT_FILE, + data: { files }, + }); + } + return annotations as JSONValue[]; + }; + + const readContent = async (input: { + file: File; + asUrl?: boolean; + }): Promise => { + const { file, asUrl } = input; + const content = await new Promise((resolve, reject) => { + const reader = new FileReader(); + if (asUrl) { + reader.readAsDataURL(file); + } else { + reader.readAsText(file); + } + reader.onload = () => resolve(reader.result as string); + reader.onerror = (error) => reject(error); + }); + return content; + }; + + const uploadFile = async (file: File, requestParams: any = {}) => { + if (file.type.startsWith("image/")) { + const base64 = await readContent({ file, asUrl: true }); + return setImageUrl(base64); + } + + const filetype = docMineTypeMap[file.type]; + if (!filetype) throw new Error("Unsupported document type."); + const newDoc: Omit = { + id: uuidv4(), + filetype, + filename: file.name, + filesize: file.size, + }; + switch (file.type) { + case "text/csv": { + const content = await readContent({ file }); + return addDoc({ + ...newDoc, + content: { + type: "text", + value: content, + }, + }); + } + default: { + const base64 = await readContent({ file, asUrl: true }); + const ids = await uploadContent(base64, requestParams); + return addDoc({ + ...newDoc, + content: { + type: "ref", + value: ids, + }, + }); + } + } + }; + + return { + imageUrl, + setImageUrl, + files, + removeDoc, + reset, + getAnnotations, + uploadFile, + }; +} diff --git a/frontend/app/components/ui/chat/index.ts b/frontend/app/components/ui/chat/index.ts new file mode 100644 index 0000000..dcfc9cd --- /dev/null +++ b/frontend/app/components/ui/chat/index.ts @@ -0,0 +1,91 @@ +import { JSONValue } from "ai"; +import ChatInput from "./chat-input"; +import ChatMessages from "./chat-messages"; + +export { type ChatHandler } from "./chat.interface"; +export { ChatInput, ChatMessages }; + +export enum MessageAnnotationType { + IMAGE = "image", + DOCUMENT_FILE = "document_file", + SOURCES = "sources", + EVENTS = "events", + TOOLS = "tools", + SUGGESTED_QUESTIONS = "suggested_questions", +} + +export type ImageData = { + url: string; +}; + +export type DocumentFileType = "csv" | "pdf" | "txt" | "docx"; + +export type DocumentFileContent = { + type: "ref" | "text"; + value: string[] | string; +}; + +export type DocumentFile = { + id: string; + filename: string; + filesize: number; + filetype: DocumentFileType; + content: DocumentFileContent; +}; + +export type DocumentFileData = { + files: DocumentFile[]; +}; + +export type SourceNode = { + id: string; + metadata: Record; + score?: number; + text: string; + url?: string; +}; + +export type SourceData = { + nodes: SourceNode[]; +}; + +export type EventData = { + title: string; + isCollapsed: boolean; +}; + +export type ToolData = { + toolCall: { + id: string; + name: string; + input: { + [key: string]: JSONValue; + }; + }; + toolOutput: { + output: JSONValue; + isError: boolean; + }; +}; + +export type SuggestedQuestionsData = string[]; + +export type AnnotationData = + | ImageData + | DocumentFileData + | SourceData + | EventData + | ToolData + | SuggestedQuestionsData; + +export type MessageAnnotation = { + type: MessageAnnotationType; + data: AnnotationData; +}; + +export function getAnnotationData( + annotations: MessageAnnotation[], + type: MessageAnnotationType, +): T[] { + return annotations.filter((a) => a.type === type).map((a) => a.data as T); +} diff --git a/frontend/app/components/ui/chat/widgets/LlamaCloudSelector.tsx b/frontend/app/components/ui/chat/widgets/LlamaCloudSelector.tsx new file mode 100644 index 0000000..aa995c9 --- /dev/null +++ b/frontend/app/components/ui/chat/widgets/LlamaCloudSelector.tsx @@ -0,0 +1,151 @@ +import { Loader2 } from "lucide-react"; +import { useEffect, useState } from "react"; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectTrigger, + SelectValue, +} from "../../select"; +import { useClientConfig } from "../hooks/use-config"; + +type LLamaCloudPipeline = { + id: string; + name: string; +}; + +type LLamaCloudProject = { + id: string; + organization_id: string; + name: string; + is_default: boolean; + pipelines: Array; +}; + +type PipelineConfig = { + project: string; // project name + pipeline: string; // pipeline name +}; + +type LlamaCloudConfig = { + projects?: LLamaCloudProject[]; + pipeline?: PipelineConfig; +}; + +export interface LlamaCloudSelectorProps { + setRequestData: React.Dispatch; +} + +export function LlamaCloudSelector({ + setRequestData, +}: LlamaCloudSelectorProps) { + const { backend } = useClientConfig(); + const [config, setConfig] = useState(); + + useEffect(() => { + if (process.env.NEXT_PUBLIC_USE_LLAMACLOUD === "true" && !config) { + fetch(`${backend}/api/chat/config/llamacloud`) + .then((response) => response.json()) + .then((data) => { + setConfig(data); + setRequestData({ + llamaCloudPipeline: data.pipeline, + }); + }) + .catch((error) => console.error("Error fetching config", error)); + } + }, [backend, config, setRequestData]); + + const setPipeline = (pipelineConfig?: PipelineConfig) => { + setConfig((prevConfig: any) => ({ + ...prevConfig, + pipeline: pipelineConfig, + })); + setRequestData((prevData: any) => { + if (!prevData) return { llamaCloudPipeline: pipelineConfig }; + return { + ...prevData, + llamaCloudPipeline: pipelineConfig, + }; + }); + }; + + const handlePipelineSelect = async (value: string) => { + setPipeline(JSON.parse(value) as PipelineConfig); + }; + + if (!config) { + return ( +
+ +
+ ); + } + if (!isValid(config)) { + return ( +

+ Invalid LlamaCloud configuration. Check console logs. +

+ ); + } + const { projects, pipeline } = config; + + return ( + + ); +} + +function isValid(config: LlamaCloudConfig): boolean { + const { projects, pipeline } = config; + if (!projects?.length) return false; + if (!pipeline) return false; + const matchedProject = projects.find( + (project: LLamaCloudProject) => project.name === pipeline.project, + ); + if (!matchedProject) { + console.error( + `LlamaCloud project ${pipeline.project} not found. Check LLAMA_CLOUD_PROJECT_NAME variable`, + ); + return false; + } + const pipelineExists = matchedProject.pipelines.some( + (p) => p.name === pipeline.pipeline, + ); + if (!pipelineExists) { + console.error( + `LlamaCloud pipeline ${pipeline.pipeline} not found. Check LLAMA_CLOUD_INDEX_NAME variable`, + ); + return false; + } + return true; +} diff --git a/frontend/app/components/ui/chat/widgets/PdfDialog.tsx b/frontend/app/components/ui/chat/widgets/PdfDialog.tsx new file mode 100644 index 0000000..8dafffc --- /dev/null +++ b/frontend/app/components/ui/chat/widgets/PdfDialog.tsx @@ -0,0 +1,67 @@ +import dynamic from "next/dynamic"; +import { Button } from "../../button"; +import { + Drawer, + DrawerClose, + DrawerContent, + DrawerDescription, + DrawerHeader, + DrawerTitle, + DrawerTrigger, +} from "../../drawer"; + +export interface PdfDialogProps { + documentId: string; + url: string; + trigger: React.ReactNode; +} + +// Dynamic imports for client-side rendering only +const PDFViewer = dynamic( + () => import("@llamaindex/pdf-viewer").then((module) => module.PDFViewer), + { ssr: false }, +); + +const PdfFocusProvider = dynamic( + () => + import("@llamaindex/pdf-viewer").then((module) => module.PdfFocusProvider), + { ssr: false }, +); + +export default function PdfDialog(props: PdfDialogProps) { + return ( + + {props.trigger} + + +
+ PDF Content + + File URL:{" "} + + {props.url} + + +
+ + + +
+
+ + + +
+
+
+ ); +} diff --git a/frontend/app/components/ui/chat/widgets/WeatherCard.tsx b/frontend/app/components/ui/chat/widgets/WeatherCard.tsx new file mode 100644 index 0000000..f2115ae --- /dev/null +++ b/frontend/app/components/ui/chat/widgets/WeatherCard.tsx @@ -0,0 +1,213 @@ +export interface WeatherData { + latitude: number; + longitude: number; + generationtime_ms: number; + utc_offset_seconds: number; + timezone: string; + timezone_abbreviation: string; + elevation: number; + current_units: { + time: string; + interval: string; + temperature_2m: string; + weather_code: string; + }; + current: { + time: string; + interval: number; + temperature_2m: number; + weather_code: number; + }; + hourly_units: { + time: string; + temperature_2m: string; + weather_code: string; + }; + hourly: { + time: string[]; + temperature_2m: number[]; + weather_code: number[]; + }; + daily_units: { + time: string; + weather_code: string; + }; + daily: { + time: string[]; + weather_code: number[]; + }; +} + +// Follow WMO Weather interpretation codes (WW) +const weatherCodeDisplayMap: Record< + string, + { + icon: JSX.Element; + status: string; + } +> = { + "0": { + icon: ☀️, + status: "Clear sky", + }, + "1": { + icon: 🌤️, + status: "Mainly clear", + }, + "2": { + icon: ☁️, + status: "Partly cloudy", + }, + "3": { + icon: ☁️, + status: "Overcast", + }, + "45": { + icon: 🌫️, + status: "Fog", + }, + "48": { + icon: 🌫️, + status: "Depositing rime fog", + }, + "51": { + icon: 🌧️, + status: "Drizzle", + }, + "53": { + icon: 🌧️, + status: "Drizzle", + }, + "55": { + icon: 🌧️, + status: "Drizzle", + }, + "56": { + icon: 🌧️, + status: "Freezing Drizzle", + }, + "57": { + icon: 🌧️, + status: "Freezing Drizzle", + }, + "61": { + icon: 🌧️, + status: "Rain", + }, + "63": { + icon: 🌧️, + status: "Rain", + }, + "65": { + icon: 🌧️, + status: "Rain", + }, + "66": { + icon: 🌧️, + status: "Freezing Rain", + }, + "67": { + icon: 🌧️, + status: "Freezing Rain", + }, + "71": { + icon: ❄️, + status: "Snow fall", + }, + "73": { + icon: ❄️, + status: "Snow fall", + }, + "75": { + icon: ❄️, + status: "Snow fall", + }, + "77": { + icon: ❄️, + status: "Snow grains", + }, + "80": { + icon: 🌧️, + status: "Rain showers", + }, + "81": { + icon: 🌧️, + status: "Rain showers", + }, + "82": { + icon: 🌧️, + status: "Rain showers", + }, + "85": { + icon: ❄️, + status: "Snow showers", + }, + "86": { + icon: ❄️, + status: "Snow showers", + }, + "95": { + icon: ⛈️, + status: "Thunderstorm", + }, + "96": { + icon: ⛈️, + status: "Thunderstorm", + }, + "99": { + icon: ⛈️, + status: "Thunderstorm", + }, +}; + +const displayDay = (time: string) => { + return new Date(time).toLocaleDateString("en-US", { + weekday: "long", + }); +}; + +export function WeatherCard({ data }: { data: WeatherData }) { + const currentDayString = new Date(data.current.time).toLocaleDateString( + "en-US", + { + weekday: "long", + month: "long", + day: "numeric", + }, + ); + + return ( +
+
+
+
{currentDayString}
+
+ + {data.current.temperature_2m} {data.current_units.temperature_2m} + + {weatherCodeDisplayMap[data.current.weather_code].icon} +
+
+ + {weatherCodeDisplayMap[data.current.weather_code].status} + +
+
+ {data.daily.time.map((time, index) => { + if (index === 0) return null; // skip the current day + return ( +
+ {displayDay(time)} +
+ {weatherCodeDisplayMap[data.daily.weather_code[index]].icon} +
+ + {weatherCodeDisplayMap[data.daily.weather_code[index]].status} + +
+ ); + })} +
+
+ ); +} diff --git a/frontend/app/components/ui/collapsible.tsx b/frontend/app/components/ui/collapsible.tsx new file mode 100644 index 0000000..1fe76f5 --- /dev/null +++ b/frontend/app/components/ui/collapsible.tsx @@ -0,0 +1,11 @@ +"use client"; + +import * as CollapsiblePrimitive from "@radix-ui/react-collapsible"; + +const Collapsible = CollapsiblePrimitive.Root; + +const CollapsibleTrigger = CollapsiblePrimitive.CollapsibleTrigger; + +const CollapsibleContent = CollapsiblePrimitive.CollapsibleContent; + +export { Collapsible, CollapsibleContent, CollapsibleTrigger }; diff --git a/frontend/app/components/ui/document-preview.tsx b/frontend/app/components/ui/document-preview.tsx new file mode 100644 index 0000000..eb9d6d9 --- /dev/null +++ b/frontend/app/components/ui/document-preview.tsx @@ -0,0 +1,119 @@ +import { XCircleIcon } from "lucide-react"; +import Image from "next/image"; +import DocxIcon from "../ui/icons/docx.svg"; +import PdfIcon from "../ui/icons/pdf.svg"; +import SheetIcon from "../ui/icons/sheet.svg"; +import TxtIcon from "../ui/icons/txt.svg"; +import { Button } from "./button"; +import { DocumentFile, DocumentFileType } from "./chat"; +import { + Drawer, + DrawerClose, + DrawerContent, + DrawerDescription, + DrawerHeader, + DrawerTitle, + DrawerTrigger, +} from "./drawer"; +import { cn } from "./lib/utils"; + +export interface DocumentPreviewProps { + file: DocumentFile; + onRemove?: () => void; +} + +export function DocumentPreview(props: DocumentPreviewProps) { + const { filename, filesize, content, filetype } = props.file; + + if (content.type === "ref") { + return ( +
+ +
+ ); + } + + return ( + + +
+ +
+
+ + +
+ {filetype.toUpperCase()} Raw Content + + {filename} ({inKB(filesize)} KB) + +
+ + + +
+
+ {content.type === "text" && ( +
+              {content.value as string}
+            
+ )} +
+
+
+ ); +} + +const FileIcon: Record = { + csv: SheetIcon, + pdf: PdfIcon, + docx: DocxIcon, + txt: TxtIcon, +}; + +function PreviewCard(props: DocumentPreviewProps) { + const { onRemove, file } = props; + return ( +
+
+
+ Icon +
+
+
+ {file.filename} ({inKB(file.filesize)} KB) +
+
+ {file.filetype.toUpperCase()} File +
+
+
+ {onRemove && ( +
+ +
+ )} +
+ ); +} + +function inKB(size: number) { + return Math.round((size / 1024) * 10) / 10; +} diff --git a/frontend/app/components/ui/drawer.tsx b/frontend/app/components/ui/drawer.tsx new file mode 100644 index 0000000..bf733c8 --- /dev/null +++ b/frontend/app/components/ui/drawer.tsx @@ -0,0 +1,118 @@ +"use client"; + +import * as React from "react"; +import { Drawer as DrawerPrimitive } from "vaul"; + +import { cn } from "./lib/utils"; + +const Drawer = ({ + shouldScaleBackground = true, + ...props +}: React.ComponentProps) => ( + +); +Drawer.displayName = "Drawer"; + +const DrawerTrigger = DrawerPrimitive.Trigger; + +const DrawerPortal = DrawerPrimitive.Portal; + +const DrawerClose = DrawerPrimitive.Close; + +const DrawerOverlay = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +DrawerOverlay.displayName = DrawerPrimitive.Overlay.displayName; + +const DrawerContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + + +
+ {children} + + +)); +DrawerContent.displayName = "DrawerContent"; + +const DrawerHeader = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+); +DrawerHeader.displayName = "DrawerHeader"; + +const DrawerFooter = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+); +DrawerFooter.displayName = "DrawerFooter"; + +const DrawerTitle = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +DrawerTitle.displayName = DrawerPrimitive.Title.displayName; + +const DrawerDescription = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +DrawerDescription.displayName = DrawerPrimitive.Description.displayName; + +export { + Drawer, + DrawerClose, + DrawerContent, + DrawerDescription, + DrawerFooter, + DrawerHeader, + DrawerOverlay, + DrawerPortal, + DrawerTitle, + DrawerTrigger, +}; diff --git a/frontend/app/components/ui/file-uploader.tsx b/frontend/app/components/ui/file-uploader.tsx new file mode 100644 index 0000000..e42a267 --- /dev/null +++ b/frontend/app/components/ui/file-uploader.tsx @@ -0,0 +1,105 @@ +"use client"; + +import { Loader2, Paperclip } from "lucide-react"; +import { ChangeEvent, useState } from "react"; +import { buttonVariants } from "./button"; +import { cn } from "./lib/utils"; + +export interface FileUploaderProps { + config?: { + inputId?: string; + fileSizeLimit?: number; + allowedExtensions?: string[]; + checkExtension?: (extension: string) => string | null; + disabled: boolean; + }; + onFileUpload: (file: File) => Promise; + onFileError?: (errMsg: string) => void; +} + +const DEFAULT_INPUT_ID = "fileInput"; +const DEFAULT_FILE_SIZE_LIMIT = 1024 * 1024 * 50; // 50 MB + +export default function FileUploader({ + config, + onFileUpload, + onFileError, +}: FileUploaderProps) { + const [uploading, setUploading] = useState(false); + + const inputId = config?.inputId || DEFAULT_INPUT_ID; + const fileSizeLimit = config?.fileSizeLimit || DEFAULT_FILE_SIZE_LIMIT; + const allowedExtensions = config?.allowedExtensions; + const defaultCheckExtension = (extension: string) => { + if (allowedExtensions && !allowedExtensions.includes(extension)) { + return `Invalid file type. Please select a file with one of these formats: ${allowedExtensions!.join( + ",", + )}`; + } + return null; + }; + const checkExtension = config?.checkExtension ?? defaultCheckExtension; + + const isFileSizeExceeded = (file: File) => { + return file.size > fileSizeLimit; + }; + + const resetInput = () => { + const fileInput = document.getElementById(inputId) as HTMLInputElement; + fileInput.value = ""; + }; + + const onFileChange = async (e: ChangeEvent) => { + const file = e.target.files?.[0]; + if (!file) return; + + setUploading(true); + await handleUpload(file); + resetInput(); + setUploading(false); + }; + + const handleUpload = async (file: File) => { + const onFileUploadError = onFileError || window.alert; + const fileExtension = file.name.split(".").pop() || ""; + const extensionFileError = checkExtension(fileExtension); + if (extensionFileError) { + return onFileUploadError(extensionFileError); + } + + if (isFileSizeExceeded(file)) { + return onFileUploadError( + `File size exceeded. Limit is ${fileSizeLimit / 1024 / 1024} MB`, + ); + } + + await onFileUpload(file); + }; + + return ( +
+ + +
+ ); +} diff --git a/frontend/app/components/ui/hover-card.tsx b/frontend/app/components/ui/hover-card.tsx new file mode 100644 index 0000000..e886235 --- /dev/null +++ b/frontend/app/components/ui/hover-card.tsx @@ -0,0 +1,29 @@ +"use client"; + +import * as HoverCardPrimitive from "@radix-ui/react-hover-card"; +import * as React from "react"; + +import { cn } from "./lib/utils"; + +const HoverCard = HoverCardPrimitive.Root; + +const HoverCardTrigger = HoverCardPrimitive.Trigger; + +const HoverCardContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, align = "center", sideOffset = 4, ...props }, ref) => ( + +)); +HoverCardContent.displayName = HoverCardPrimitive.Content.displayName; + +export { HoverCard, HoverCardContent, HoverCardTrigger }; diff --git a/frontend/app/components/ui/icons/docx.svg b/frontend/app/components/ui/icons/docx.svg new file mode 100644 index 0000000..4278239 --- /dev/null +++ b/frontend/app/components/ui/icons/docx.svg @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/frontend/app/components/ui/icons/pdf.svg b/frontend/app/components/ui/icons/pdf.svg new file mode 100644 index 0000000..f32146c --- /dev/null +++ b/frontend/app/components/ui/icons/pdf.svg @@ -0,0 +1,19 @@ + + + + + + + + + \ No newline at end of file diff --git a/frontend/app/components/ui/icons/sheet.svg b/frontend/app/components/ui/icons/sheet.svg new file mode 100644 index 0000000..65f1b0f --- /dev/null +++ b/frontend/app/components/ui/icons/sheet.svg @@ -0,0 +1,90 @@ + + + Sheets-icon + Created with Sketch. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/app/components/ui/icons/txt.svg b/frontend/app/components/ui/icons/txt.svg new file mode 100644 index 0000000..0afb11b --- /dev/null +++ b/frontend/app/components/ui/icons/txt.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/app/components/ui/input.tsx b/frontend/app/components/ui/input.tsx new file mode 100644 index 0000000..edfa129 --- /dev/null +++ b/frontend/app/components/ui/input.tsx @@ -0,0 +1,25 @@ +import * as React from "react"; + +import { cn } from "./lib/utils"; + +export interface InputProps + extends React.InputHTMLAttributes {} + +const Input = React.forwardRef( + ({ className, type, ...props }, ref) => { + return ( + + ); + }, +); +Input.displayName = "Input"; + +export { Input }; diff --git a/frontend/app/components/ui/lib/utils.ts b/frontend/app/components/ui/lib/utils.ts new file mode 100644 index 0000000..a5ef193 --- /dev/null +++ b/frontend/app/components/ui/lib/utils.ts @@ -0,0 +1,6 @@ +import { clsx, type ClassValue } from "clsx"; +import { twMerge } from "tailwind-merge"; + +export function cn(...inputs: ClassValue[]) { + return twMerge(clsx(inputs)); +} diff --git a/frontend/app/components/ui/select.tsx b/frontend/app/components/ui/select.tsx new file mode 100644 index 0000000..c01b068 --- /dev/null +++ b/frontend/app/components/ui/select.tsx @@ -0,0 +1,159 @@ +"use client"; + +import * as SelectPrimitive from "@radix-ui/react-select"; +import { Check, ChevronDown, ChevronUp } from "lucide-react"; +import * as React from "react"; +import { cn } from "./lib/utils"; + +const Select = SelectPrimitive.Root; + +const SelectGroup = SelectPrimitive.Group; + +const SelectValue = SelectPrimitive.Value; + +const SelectTrigger = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + span]:line-clamp-1", + className, + )} + {...props} + > + {children} + + + + +)); +SelectTrigger.displayName = SelectPrimitive.Trigger.displayName; + +const SelectScrollUpButton = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + +)); +SelectScrollUpButton.displayName = SelectPrimitive.ScrollUpButton.displayName; + +const SelectScrollDownButton = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + +)); +SelectScrollDownButton.displayName = + SelectPrimitive.ScrollDownButton.displayName; + +const SelectContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, position = "popper", ...props }, ref) => ( + + + + + {children} + + + + +)); +SelectContent.displayName = SelectPrimitive.Content.displayName; + +const SelectLabel = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +SelectLabel.displayName = SelectPrimitive.Label.displayName; + +const SelectItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + + + + + + + {children} + +)); +SelectItem.displayName = SelectPrimitive.Item.displayName; + +const SelectSeparator = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +SelectSeparator.displayName = SelectPrimitive.Separator.displayName; + +export { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectScrollDownButton, + SelectScrollUpButton, + SelectSeparator, + SelectTrigger, + SelectValue, +}; diff --git a/frontend/app/components/ui/upload-image-preview.tsx b/frontend/app/components/ui/upload-image-preview.tsx new file mode 100644 index 0000000..55ef6e9 --- /dev/null +++ b/frontend/app/components/ui/upload-image-preview.tsx @@ -0,0 +1,32 @@ +import { XCircleIcon } from "lucide-react"; +import Image from "next/image"; +import { cn } from "./lib/utils"; + +export default function UploadImagePreview({ + url, + onRemove, +}: { + url: string; + onRemove: () => void; +}) { + return ( +
+ Uploaded image + +
+ ); +} diff --git a/frontend/app/favicon.ico b/frontend/app/favicon.ico new file mode 100644 index 0000000..a1eaef6 Binary files /dev/null and b/frontend/app/favicon.ico differ diff --git a/frontend/app/globals.css b/frontend/app/globals.css new file mode 100644 index 0000000..0c2b9bd --- /dev/null +++ b/frontend/app/globals.css @@ -0,0 +1,97 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +@layer base { + :root { + --background: 0 0% 100%; + --foreground: 222.2 47.4% 11.2%; + + --muted: 210 40% 96.1%; + --muted-foreground: 215.4 16.3% 46.9%; + + --popover: 0 0% 100%; + --popover-foreground: 222.2 47.4% 11.2%; + + --border: 214.3 31.8% 91.4%; + --input: 214.3 31.8% 91.4%; + + --card: 0 0% 100%; + --card-foreground: 222.2 47.4% 11.2%; + + --primary: 222.2 47.4% 11.2%; + --primary-foreground: 210 40% 98%; + + --secondary: 210 40% 96.1%; + --secondary-foreground: 222.2 47.4% 11.2%; + + --accent: 210 40% 96.1%; + --accent-foreground: 222.2 47.4% 11.2%; + + --destructive: 0 100% 50%; + --destructive-foreground: 210 40% 98%; + + --ring: 215 20.2% 65.1%; + + --radius: 0.5rem; + } + + .dark { + --background: 224 71% 4%; + --foreground: 213 31% 91%; + + --muted: 223 47% 11%; + --muted-foreground: 215.4 16.3% 56.9%; + + --accent: 216 34% 17%; + --accent-foreground: 210 40% 98%; + + --popover: 224 71% 4%; + --popover-foreground: 215 20.2% 65.1%; + + --border: 216 34% 17%; + --input: 216 34% 17%; + + --card: 224 71% 4%; + --card-foreground: 213 31% 91%; + + --primary: 210 40% 98%; + --primary-foreground: 222.2 47.4% 1.2%; + + --secondary: 222.2 47.4% 11.2%; + --secondary-foreground: 210 40% 98%; + + --destructive: 0 63% 31%; + --destructive-foreground: 210 40% 98%; + + --ring: 216 34% 17%; + + --radius: 0.5rem; + } +} + +@layer base { + * { + @apply border-border; + } + html { + @apply h-full; + } + body { + @apply bg-background text-foreground h-full; + font-feature-settings: + "rlig" 1, + "calt" 1; + } + .background-gradient { + background-color: #fff; + background-image: radial-gradient( + at 21% 11%, + rgba(186, 186, 233, 0.53) 0, + transparent 50% + ), + radial-gradient(at 85% 0, hsla(46, 57%, 78%, 0.52) 0, transparent 50%), + radial-gradient(at 91% 36%, rgba(194, 213, 255, 0.68) 0, transparent 50%), + radial-gradient(at 8% 40%, rgba(251, 218, 239, 0.46) 0, transparent 50%); + } +} diff --git a/frontend/app/layout.tsx b/frontend/app/layout.tsx new file mode 100644 index 0000000..8f7cab9 --- /dev/null +++ b/frontend/app/layout.tsx @@ -0,0 +1,23 @@ +import type { Metadata } from "next"; +import { Inter } from "next/font/google"; +import "./globals.css"; +import "./markdown.css"; + +const inter = Inter({ subsets: ["latin"] }); + +export const metadata: Metadata = { + title: "Create Llama App", + description: "Generated by create-llama", +}; + +export default function RootLayout({ + children, +}: { + children: React.ReactNode; +}) { + return ( + + {children} + + ); +} diff --git a/frontend/app/markdown.css b/frontend/app/markdown.css new file mode 100644 index 0000000..a843eeb --- /dev/null +++ b/frontend/app/markdown.css @@ -0,0 +1,23 @@ +/* Custom CSS for chat message markdown */ +.custom-markdown ul { + list-style-type: disc; + margin-left: 20px; +} + +.custom-markdown ol { + list-style-type: decimal; + margin-left: 20px; +} + +.custom-markdown li { + margin-bottom: 5px; +} + +.custom-markdown ol ol { + list-style: lower-alpha; +} + +.custom-markdown ul ul, +.custom-markdown ol ol { + margin-left: 20px; +} diff --git a/frontend/app/observability/index.ts b/frontend/app/observability/index.ts new file mode 100644 index 0000000..2e4ce2b --- /dev/null +++ b/frontend/app/observability/index.ts @@ -0,0 +1 @@ +export const initObservability = () => {}; diff --git a/frontend/app/page.tsx b/frontend/app/page.tsx new file mode 100644 index 0000000..04d4302 --- /dev/null +++ b/frontend/app/page.tsx @@ -0,0 +1,15 @@ +import Header from "@/app/components/header"; +import ChatSection from "./components/chat-section"; + +export default function Home() { + return ( +
+
+
+
+ +
+
+
+ ); +} diff --git a/frontend/config/tools.json b/frontend/config/tools.json new file mode 100644 index 0000000..298e248 --- /dev/null +++ b/frontend/config/tools.json @@ -0,0 +1,7 @@ +{ + "local": { + "weather": {}, + "interpreter": {} + }, + "llamahub": {} +} \ No newline at end of file diff --git a/frontend/next.config.json b/frontend/next.config.json new file mode 100644 index 0000000..018bd38 --- /dev/null +++ b/frontend/next.config.json @@ -0,0 +1,16 @@ +{ + "experimental": { + "outputFileTracingIncludes": { + "/*": [ + "./cache/**/*" + ], + "/api/**/*": [ + "./node_modules/**/*.wasm" + ] + } + }, + "output": "export", + "images": { + "unoptimized": true + } +} diff --git a/frontend/next.config.mjs b/frontend/next.config.mjs new file mode 100644 index 0000000..64bdff2 --- /dev/null +++ b/frontend/next.config.mjs @@ -0,0 +1,10 @@ +/** @type {import('next').NextConfig} */ +import fs from "fs"; +import withLlamaIndex from "llamaindex/next"; +import webpack from "./webpack.config.mjs"; + +const nextConfig = JSON.parse(fs.readFileSync("./next.config.json", "utf-8")); +nextConfig.webpack = webpack; + +// use withLlamaIndex to add necessary modifications for llamaindex library +export default withLlamaIndex(nextConfig); diff --git a/frontend/package.json b/frontend/package.json new file mode 100644 index 0000000..e3a26e4 --- /dev/null +++ b/frontend/package.json @@ -0,0 +1,65 @@ +{ + "name": "testapp", + "version": "0.1.0", + "scripts": { + "format": "prettier --ignore-unknown --cache --check .", + "format:write": "prettier --ignore-unknown --write .", + "dev": "next dev", + "build": "next build", + "start": "next start", + "lint": "next lint", + "generate": "tsx app\\api\\chat\\engine\\generate.ts" + }, + "dependencies": { + "@apidevtools/swagger-parser": "^10.1.0", + "@e2b/code-interpreter": "^0.0.5", + "@llamaindex/pdf-viewer": "^1.1.3", + "@radix-ui/react-collapsible": "^1.0.3", + "@radix-ui/react-hover-card": "^1.0.7", + "@radix-ui/react-select": "^2.1.1", + "@radix-ui/react-slot": "^1.0.2", + "ai": "^3.0.21", + "ajv": "^8.12.0", + "class-variance-authority": "^0.7.0", + "clsx": "^2.1.1", + "dotenv": "^16.3.1", + "duck-duck-scrape": "^2.2.5", + "formdata-node": "^6.0.3", + "got": "^14.4.1", + "llamaindex": "0.5.12", + "lucide-react": "^0.294.0", + "next": "^14.2.4", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "react-markdown": "^8.0.7", + "react-syntax-highlighter": "^15.5.0", + "rehype-katex": "^7.0.0", + "remark": "^14.0.3", + "remark-code-import": "^1.2.0", + "remark-gfm": "^3.0.1", + "remark-math": "^5.1.1", + "supports-color": "^8.1.1", + "tailwind-merge": "^2.1.0", + "tiktoken": "^1.0.15", + "uuid": "^9.0.1", + "vaul": "^0.9.1" + }, + "devDependencies": { + "@types/node": "^20.10.3", + "@types/react": "^18.2.42", + "@types/react-dom": "^18.2.17", + "@types/react-syntax-highlighter": "^15.5.11", + "@types/uuid": "^9.0.8", + "autoprefixer": "^10.4.16", + "cross-env": "^7.0.3", + "eslint": "^8.55.0", + "eslint-config-next": "^14.2.4", + "eslint-config-prettier": "^8.10.0", + "postcss": "^8.4.32", + "prettier": "^3.2.5", + "prettier-plugin-organize-imports": "^3.2.4", + "tailwindcss": "^3.3.6", + "tsx": "^4.7.2", + "typescript": "^5.3.2" + } +} diff --git a/frontend/postcss.config.js b/frontend/postcss.config.js new file mode 100644 index 0000000..12a703d --- /dev/null +++ b/frontend/postcss.config.js @@ -0,0 +1,6 @@ +module.exports = { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +}; diff --git a/frontend/prettier.config.js b/frontend/prettier.config.js new file mode 100644 index 0000000..1fe03c6 --- /dev/null +++ b/frontend/prettier.config.js @@ -0,0 +1,3 @@ +module.exports = { + plugins: ["prettier-plugin-organize-imports"], +}; diff --git a/frontend/public/llama.png b/frontend/public/llama.png new file mode 100644 index 0000000..d4efba3 Binary files /dev/null and b/frontend/public/llama.png differ diff --git a/frontend/tailwind.config.ts b/frontend/tailwind.config.ts new file mode 100644 index 0000000..aa5580a --- /dev/null +++ b/frontend/tailwind.config.ts @@ -0,0 +1,78 @@ +import type { Config } from "tailwindcss"; +import { fontFamily } from "tailwindcss/defaultTheme"; + +const config: Config = { + darkMode: ["class"], + content: ["app/**/*.{ts,tsx}", "components/**/*.{ts,tsx}"], + theme: { + container: { + center: true, + padding: "2rem", + screens: { + "2xl": "1400px", + }, + }, + extend: { + colors: { + border: "hsl(var(--border))", + input: "hsl(var(--input))", + ring: "hsl(var(--ring))", + background: "hsl(var(--background))", + foreground: "hsl(var(--foreground))", + primary: { + DEFAULT: "hsl(var(--primary))", + foreground: "hsl(var(--primary-foreground))", + }, + secondary: { + DEFAULT: "hsl(var(--secondary))", + foreground: "hsl(var(--secondary-foreground))", + }, + destructive: { + DEFAULT: "hsl(var(--destructive) / )", + foreground: "hsl(var(--destructive-foreground) / )", + }, + muted: { + DEFAULT: "hsl(var(--muted))", + foreground: "hsl(var(--muted-foreground))", + }, + accent: { + DEFAULT: "hsl(var(--accent))", + foreground: "hsl(var(--accent-foreground))", + }, + popover: { + DEFAULT: "hsl(var(--popover))", + foreground: "hsl(var(--popover-foreground))", + }, + card: { + DEFAULT: "hsl(var(--card))", + foreground: "hsl(var(--card-foreground))", + }, + }, + borderRadius: { + xl: `calc(var(--radius) + 4px)`, + lg: `var(--radius)`, + md: `calc(var(--radius) - 2px)`, + sm: "calc(var(--radius) - 4px)", + }, + fontFamily: { + sans: ["var(--font-sans)", ...fontFamily.sans], + }, + keyframes: { + "accordion-down": { + from: { height: "0" }, + to: { height: "var(--radix-accordion-content-height)" }, + }, + "accordion-up": { + from: { height: "var(--radix-accordion-content-height)" }, + to: { height: "0" }, + }, + }, + animation: { + "accordion-down": "accordion-down 0.2s ease-out", + "accordion-up": "accordion-up 0.2s ease-out", + }, + }, + }, + plugins: [], +}; +export default config; diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json new file mode 100644 index 0000000..e7ff90f --- /dev/null +++ b/frontend/tsconfig.json @@ -0,0 +1,26 @@ +{ + "compilerOptions": { + "lib": ["dom", "dom.iterable", "esnext"], + "allowJs": true, + "skipLibCheck": true, + "strict": true, + "noEmit": true, + "esModuleInterop": true, + "module": "esnext", + "moduleResolution": "bundler", + "resolveJsonModule": true, + "isolatedModules": true, + "jsx": "preserve", + "incremental": true, + "plugins": [ + { + "name": "next" + } + ], + "paths": { + "@/*": ["./*"] + } + }, + "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], + "exclude": ["node_modules"] +} diff --git a/frontend/webpack.config.mjs b/frontend/webpack.config.mjs new file mode 100644 index 0000000..29decaf --- /dev/null +++ b/frontend/webpack.config.mjs @@ -0,0 +1,8 @@ +// webpack config must be a function in NextJS that is used to patch the default webpack config provided by NextJS, see https://nextjs.org/docs/pages/api-reference/next-config-js/webpack +export default function webpack(config) { + config.resolve.fallback = { + aws4: false, + }; + + return config; +}