Skip to content

Commit

Permalink
feat(document-search): Support for ingesting images (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer authored Nov 13, 2024
1 parent a2e9abc commit 6e46d4a
Show file tree
Hide file tree
Showing 10 changed files with 350 additions and 8 deletions.
Binary file added examples/document-search/images/bear.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/document-search/images/game.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/document-search/images/tree.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
98 changes: 98 additions & 0 deletions examples/document-search/multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Ragbits Document Search Example: Multimodal Embeddings
This example demonstrates how to use the `DocumentSearch` to index and search for images and text documents.
It employes the "multimodalembedding" from VertexAI. In order to use it, make sure that you are
logged in to Google Cloud (using the `gcloud auth login` command) and that you have the necessary permissions.
The script performs the following steps:
1. Create a list of example documents.
2. Initialize the `VertexAIMultimodelEmbeddings` class (which uses the VertexAI multimodal embeddings).
3. Initialize the `InMemoryVectorStore` class, which stores the embeddings for the duration of the script.
4. Initialize the `DocumentSearch` class with the embedder and the vector store.
5. Ingest the documents into the `DocumentSearch` instance.
6. List all embeddings in the vector store.
7. Search for documents using a query.
8. Print the search results.
To run the script, execute the following command:
```bash
uv run python examples/document-search/multimodal.py
```
"""

# /// script
# requires-python = ">=3.10"
# dependencies = [
# "ragbits-document-search",
# "ragbits-core[litellm]",
# ]
# ///
import asyncio
from pathlib import Path

from ragbits.core.embeddings.vertex_multimodal import VertexAIMultimodelEmbeddings
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.sources import LocalFileSource
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.dummy import DummyImageProvider

IMAGES_PATH = Path(__file__).parent / "images"


def jpg_example(file_name: str) -> DocumentMeta:
"""
Create a document from a JPG file in the images directory.
"""
return DocumentMeta(document_type=DocumentType.JPG, source=LocalFileSource(path=IMAGES_PATH / file_name))


documents = [
jpg_example("bear.jpg"),
jpg_example("game.jpg"),
jpg_example("tree.jpg"),
DocumentMeta.create_text_document_from_literal("A beautiful teady bear."),
DocumentMeta.create_text_document_from_literal("The constitution of the United States."),
]


async def main() -> None:
"""
Run the example.
"""
embedder = VertexAIMultimodelEmbeddings()
vector_store = InMemoryVectorStore()
router = DocumentProcessorRouter.from_config(
{
# For this example, we want to skip OCR and make sure
# that we test direct image embeddings.
DocumentType.JPG: DummyImageProvider(),
}
)

document_search = DocumentSearch(
embedder=embedder,
vector_store=vector_store,
document_processor_router=router,
)

await document_search.ingest(documents)

all_embeddings = await vector_store.list()
for embedding in all_embeddings:
print(f"Embedding: {embedding.metadata['document_meta']}")
print()

results = await document_search.search("Fluffy teady bear")
print("Results for 'Fluffy teady bear toy':")
for result in results:
document = await result.document_meta.fetch()
print(f"Type: {result.element_type}, Location: {document.local_path}, Text: {result.get_text_representation()}")


if __name__ == "__main__":
asyncio.run(main())
21 changes: 21 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,24 @@ async def embed_text(self, data: list[str]) -> list[list[float]]:
Returns:
List of embeddings for the given strings.
"""

def image_support(self) -> bool: # noqa: PLR6301
"""
Check if the model supports image embeddings.
Returns:
True if the model supports image embeddings, False otherwise.
"""
return False

async def embed_image(self, images: list[bytes]) -> list[list[float]]:
"""
Creates embeddings for the given images.
Args:
images: List of images to get embeddings for.
Returns:
List of embeddings for the given images.
"""
raise NotImplementedError("Image embeddings are not supported by this model.")
170 changes: 170 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/vertex_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import asyncio
import base64

try:
import litellm
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import VertexAIError
from litellm.main import VertexMultimodalEmbedding

HAS_LITELLM = True
except ImportError:
HAS_LITELLM = False

from ragbits.core.audit import trace
from ragbits.core.embeddings import Embeddings
from ragbits.core.embeddings.exceptions import (
EmbeddingResponseError,
EmbeddingStatusError,
)


class VertexAIMultimodelEmbeddings(Embeddings):
"""
Client for creating text embeddings using LiteLLM API.
"""

VERTEX_AI_PREFIX = "vertex_ai/"

def __init__(
self,
model: str = "multimodalembedding",
api_base: str | None = None,
api_key: str | None = None,
concurency: int = 10,
options: dict | None = None,
) -> None:
"""
Constructs the embedding client for multimodal VertexAI models.
Args:
model: One of the VertexAI multimodal models to be used. Default is "multimodalembedding".
api_base: The API endpoint you want to call the model with.
api_key: API key to be used. If not specified, an environment variable will be used.
concurency: The number of concurrent requests to make to the API.
options: Additional options to pass to the API.
Raises:
ImportError: If the 'litellm' extra requirements are not installed.
ValueError: If the chosen model is not supported by VertexAI multimodal embeddings.
"""
if not HAS_LITELLM:
raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM embeddings models")

super().__init__()
if model.startswith(self.VERTEX_AI_PREFIX):
model = model[len(self.VERTEX_AI_PREFIX) :]

self.model = model
self.api_base = api_base
self.api_key = api_key
self.concurency = concurency
self.options = options or {}

supported_models = VertexMultimodalEmbedding().SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
if model not in supported_models:
raise ValueError(f"Model {model} is not supported by VertexAI multimodal embeddings")

async def _embed(self, data: list[dict]) -> list[dict]:
"""
Creates embeddings for the given data. The format is defined in the VertexAI API:
https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-multimodal-embeddings
Args:
data: List of instances in the format expected by the VertexAI API.
Returns:
List of embeddings for the given VertexAI instances, each instance is a dictionary
in the format returned by the VertexAI API.
Raises:
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""
with trace(
data=data,
model=self.model,
api_base=self.api_base,
options=self.options,
) as outputs:
semaphore = asyncio.Semaphore(self.concurency)
try:
response = await asyncio.gather(
*[self._call_litellm(instance, semaphore) for instance in data],
)
except VertexAIError as exc:
raise EmbeddingStatusError(exc.message, exc.status_code) from exc

outputs.embeddings = []
for i, embedding in enumerate(response):
if embedding.data is None or not embedding.data:
raise EmbeddingResponseError(f"No embeddings returned for instance {i}")
outputs.embeddings.append(embedding.data[0])

return outputs.embeddings

async def _call_litellm(self, instance: dict, semaphore: asyncio.Semaphore) -> litellm.EmbeddingResponse:
"""
Calls the LiteLLM API to get embeddings for the given data.
Args:
instance: Single VertexAI instance to get embeddings for.
semaphore: Semaphore to limit the number of concurrent requests.
Returns:
List of embeddings for the given LiteLLM instances.
"""
async with semaphore:
response = await litellm.aembedding(
input=[instance],
model=f"{self.VERTEX_AI_PREFIX}{self.model}",
api_base=self.api_base,
api_key=self.api_key,
**self.options,
)

return response

async def embed_text(self, data: list[str]) -> list[list[float]]:
"""
Creates embeddings for the given strings.
Args:
data: List of strings to get embeddings for.
Returns:
List of embeddings for the given strings.
Raises:
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""
response = await self._embed([{"text": text} for text in data])
return [embedding["textEmbedding"] for embedding in response]

def image_support(self) -> bool: # noqa: PLR6301
"""
Check if the model supports image embeddings.
Returns:
True if the model supports image embeddings, False otherwise.
"""
return True

async def embed_image(self, images: list[bytes]) -> list[list[float]]:
"""
Creates embeddings for the given images.
Args:
images: List of images to get embeddings for.
Returns:
List of embeddings for the given images.
Raises:
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""
images_b64 = (base64.b64encode(image).decode() for image in images)
response = await self._embed([{"image": {"bytesBase64Encoded": image}} for image in images_b64])

return [embedding["imageEmbedding"] for embedding in response]
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Sequence
from typing import Any

Expand All @@ -8,7 +9,7 @@
from ragbits.core.vector_stores import VectorStore, get_vector_store
from ragbits.core.vector_stores.base import VectorStoreOptions
from ragbits.document_search.documents.document import Document, DocumentMeta
from ragbits.document_search.documents.element import Element
from ragbits.document_search.documents.element import Element, ImageElement
from ragbits.document_search.documents.sources import Source
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.base import BaseProvider
Expand Down Expand Up @@ -170,5 +171,22 @@ async def insert_elements(self, elements: list[Element]) -> None:
elements: The list of Elements to insert.
"""
vectors = await self.embedder.embed_text([element.get_text_for_embedding() for element in elements])

image_elements = [element for element in elements if isinstance(element, ImageElement)]
entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors, strict=False)]

if image_elements and self.embedder.image_support():
image_vectors = await self.embedder.embed_image([element.image_bytes for element in image_elements])
entries.extend(
[
element.to_vector_db_entry(vector)
for element, vector in zip(image_elements, image_vectors, strict=False)
]
)
elif image_elements:
warnings.warn(
f"Image elements are not supported by the embedder {self.embedder}. "
f"Skipping {len(image_elements)} image elements."
)

await self.vector_store.store(entries)
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def id(self) -> str:
"""
id_components = [
self.document_meta.id,
self.element_type,
self.get_text_for_embedding(),
self.get_text_representation(),
str(self.location),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
DocumentType,
TextDocument,
)
from ragbits.document_search.documents.element import Element, TextElement
from ragbits.document_search.documents.element import Element, ImageElement, TextElement
from ragbits.document_search.ingestion.providers.base import BaseProvider


Expand Down Expand Up @@ -31,3 +31,37 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]:
if isinstance(document, TextDocument):
return [TextElement(content=document.content, document_meta=document_meta)]
return []


class DummyImageProvider(BaseProvider):
"""
This is a simple provider that returns an ImageElement with the content of the image
and empty text metadata.
"""

SUPPORTED_DOCUMENT_TYPES = {DocumentType.JPG, DocumentType.PNG}

async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""
Process the image document.
Args:
document_meta: The document to process.
Returns:
List with a single ImageElement containing the content of the image.
"""
self.validate_document_type(document_meta.document_type)

document = await document_meta.fetch()
image_path = document.local_path
with open(image_path, "rb") as f:
image_bytes = f.read()
return [
ImageElement(
description="",
ocr_extracted_text="",
image_bytes=image_bytes,
document_meta=document_meta,
)
]
Loading

0 comments on commit 6e46d4a

Please sign in to comment.