-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(document-search): Support for ingesting images (#172)
- Loading branch information
1 parent
a2e9abc
commit 6e46d4a
Showing
10 changed files
with
350 additions
and
8 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
""" | ||
Ragbits Document Search Example: Multimodal Embeddings | ||
This example demonstrates how to use the `DocumentSearch` to index and search for images and text documents. | ||
It employes the "multimodalembedding" from VertexAI. In order to use it, make sure that you are | ||
logged in to Google Cloud (using the `gcloud auth login` command) and that you have the necessary permissions. | ||
The script performs the following steps: | ||
1. Create a list of example documents. | ||
2. Initialize the `VertexAIMultimodelEmbeddings` class (which uses the VertexAI multimodal embeddings). | ||
3. Initialize the `InMemoryVectorStore` class, which stores the embeddings for the duration of the script. | ||
4. Initialize the `DocumentSearch` class with the embedder and the vector store. | ||
5. Ingest the documents into the `DocumentSearch` instance. | ||
6. List all embeddings in the vector store. | ||
7. Search for documents using a query. | ||
8. Print the search results. | ||
To run the script, execute the following command: | ||
```bash | ||
uv run python examples/document-search/multimodal.py | ||
``` | ||
""" | ||
|
||
# /// script | ||
# requires-python = ">=3.10" | ||
# dependencies = [ | ||
# "ragbits-document-search", | ||
# "ragbits-core[litellm]", | ||
# ] | ||
# /// | ||
import asyncio | ||
from pathlib import Path | ||
|
||
from ragbits.core.embeddings.vertex_multimodal import VertexAIMultimodelEmbeddings | ||
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore | ||
from ragbits.document_search import DocumentSearch | ||
from ragbits.document_search.documents.document import DocumentMeta, DocumentType | ||
from ragbits.document_search.documents.sources import LocalFileSource | ||
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter | ||
from ragbits.document_search.ingestion.providers.dummy import DummyImageProvider | ||
|
||
IMAGES_PATH = Path(__file__).parent / "images" | ||
|
||
|
||
def jpg_example(file_name: str) -> DocumentMeta: | ||
""" | ||
Create a document from a JPG file in the images directory. | ||
""" | ||
return DocumentMeta(document_type=DocumentType.JPG, source=LocalFileSource(path=IMAGES_PATH / file_name)) | ||
|
||
|
||
documents = [ | ||
jpg_example("bear.jpg"), | ||
jpg_example("game.jpg"), | ||
jpg_example("tree.jpg"), | ||
DocumentMeta.create_text_document_from_literal("A beautiful teady bear."), | ||
DocumentMeta.create_text_document_from_literal("The constitution of the United States."), | ||
] | ||
|
||
|
||
async def main() -> None: | ||
""" | ||
Run the example. | ||
""" | ||
embedder = VertexAIMultimodelEmbeddings() | ||
vector_store = InMemoryVectorStore() | ||
router = DocumentProcessorRouter.from_config( | ||
{ | ||
# For this example, we want to skip OCR and make sure | ||
# that we test direct image embeddings. | ||
DocumentType.JPG: DummyImageProvider(), | ||
} | ||
) | ||
|
||
document_search = DocumentSearch( | ||
embedder=embedder, | ||
vector_store=vector_store, | ||
document_processor_router=router, | ||
) | ||
|
||
await document_search.ingest(documents) | ||
|
||
all_embeddings = await vector_store.list() | ||
for embedding in all_embeddings: | ||
print(f"Embedding: {embedding.metadata['document_meta']}") | ||
print() | ||
|
||
results = await document_search.search("Fluffy teady bear") | ||
print("Results for 'Fluffy teady bear toy':") | ||
for result in results: | ||
document = await result.document_meta.fetch() | ||
print(f"Type: {result.element_type}, Location: {document.local_path}, Text: {result.get_text_representation()}") | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
170 changes: 170 additions & 0 deletions
170
packages/ragbits-core/src/ragbits/core/embeddings/vertex_multimodal.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import asyncio | ||
import base64 | ||
|
||
try: | ||
import litellm | ||
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import VertexAIError | ||
from litellm.main import VertexMultimodalEmbedding | ||
|
||
HAS_LITELLM = True | ||
except ImportError: | ||
HAS_LITELLM = False | ||
|
||
from ragbits.core.audit import trace | ||
from ragbits.core.embeddings import Embeddings | ||
from ragbits.core.embeddings.exceptions import ( | ||
EmbeddingResponseError, | ||
EmbeddingStatusError, | ||
) | ||
|
||
|
||
class VertexAIMultimodelEmbeddings(Embeddings): | ||
""" | ||
Client for creating text embeddings using LiteLLM API. | ||
""" | ||
|
||
VERTEX_AI_PREFIX = "vertex_ai/" | ||
|
||
def __init__( | ||
self, | ||
model: str = "multimodalembedding", | ||
api_base: str | None = None, | ||
api_key: str | None = None, | ||
concurency: int = 10, | ||
options: dict | None = None, | ||
) -> None: | ||
""" | ||
Constructs the embedding client for multimodal VertexAI models. | ||
Args: | ||
model: One of the VertexAI multimodal models to be used. Default is "multimodalembedding". | ||
api_base: The API endpoint you want to call the model with. | ||
api_key: API key to be used. If not specified, an environment variable will be used. | ||
concurency: The number of concurrent requests to make to the API. | ||
options: Additional options to pass to the API. | ||
Raises: | ||
ImportError: If the 'litellm' extra requirements are not installed. | ||
ValueError: If the chosen model is not supported by VertexAI multimodal embeddings. | ||
""" | ||
if not HAS_LITELLM: | ||
raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM embeddings models") | ||
|
||
super().__init__() | ||
if model.startswith(self.VERTEX_AI_PREFIX): | ||
model = model[len(self.VERTEX_AI_PREFIX) :] | ||
|
||
self.model = model | ||
self.api_base = api_base | ||
self.api_key = api_key | ||
self.concurency = concurency | ||
self.options = options or {} | ||
|
||
supported_models = VertexMultimodalEmbedding().SUPPORTED_MULTIMODAL_EMBEDDING_MODELS | ||
if model not in supported_models: | ||
raise ValueError(f"Model {model} is not supported by VertexAI multimodal embeddings") | ||
|
||
async def _embed(self, data: list[dict]) -> list[dict]: | ||
""" | ||
Creates embeddings for the given data. The format is defined in the VertexAI API: | ||
https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-multimodal-embeddings | ||
Args: | ||
data: List of instances in the format expected by the VertexAI API. | ||
Returns: | ||
List of embeddings for the given VertexAI instances, each instance is a dictionary | ||
in the format returned by the VertexAI API. | ||
Raises: | ||
EmbeddingStatusError: If the embedding API returns an error status code. | ||
EmbeddingResponseError: If the embedding API response is invalid. | ||
""" | ||
with trace( | ||
data=data, | ||
model=self.model, | ||
api_base=self.api_base, | ||
options=self.options, | ||
) as outputs: | ||
semaphore = asyncio.Semaphore(self.concurency) | ||
try: | ||
response = await asyncio.gather( | ||
*[self._call_litellm(instance, semaphore) for instance in data], | ||
) | ||
except VertexAIError as exc: | ||
raise EmbeddingStatusError(exc.message, exc.status_code) from exc | ||
|
||
outputs.embeddings = [] | ||
for i, embedding in enumerate(response): | ||
if embedding.data is None or not embedding.data: | ||
raise EmbeddingResponseError(f"No embeddings returned for instance {i}") | ||
outputs.embeddings.append(embedding.data[0]) | ||
|
||
return outputs.embeddings | ||
|
||
async def _call_litellm(self, instance: dict, semaphore: asyncio.Semaphore) -> litellm.EmbeddingResponse: | ||
""" | ||
Calls the LiteLLM API to get embeddings for the given data. | ||
Args: | ||
instance: Single VertexAI instance to get embeddings for. | ||
semaphore: Semaphore to limit the number of concurrent requests. | ||
Returns: | ||
List of embeddings for the given LiteLLM instances. | ||
""" | ||
async with semaphore: | ||
response = await litellm.aembedding( | ||
input=[instance], | ||
model=f"{self.VERTEX_AI_PREFIX}{self.model}", | ||
api_base=self.api_base, | ||
api_key=self.api_key, | ||
**self.options, | ||
) | ||
|
||
return response | ||
|
||
async def embed_text(self, data: list[str]) -> list[list[float]]: | ||
""" | ||
Creates embeddings for the given strings. | ||
Args: | ||
data: List of strings to get embeddings for. | ||
Returns: | ||
List of embeddings for the given strings. | ||
Raises: | ||
EmbeddingStatusError: If the embedding API returns an error status code. | ||
EmbeddingResponseError: If the embedding API response is invalid. | ||
""" | ||
response = await self._embed([{"text": text} for text in data]) | ||
return [embedding["textEmbedding"] for embedding in response] | ||
|
||
def image_support(self) -> bool: # noqa: PLR6301 | ||
""" | ||
Check if the model supports image embeddings. | ||
Returns: | ||
True if the model supports image embeddings, False otherwise. | ||
""" | ||
return True | ||
|
||
async def embed_image(self, images: list[bytes]) -> list[list[float]]: | ||
""" | ||
Creates embeddings for the given images. | ||
Args: | ||
images: List of images to get embeddings for. | ||
Returns: | ||
List of embeddings for the given images. | ||
Raises: | ||
EmbeddingStatusError: If the embedding API returns an error status code. | ||
EmbeddingResponseError: If the embedding API response is invalid. | ||
""" | ||
images_b64 = (base64.b64encode(image).decode() for image in images) | ||
response = await self._embed([{"image": {"bytesBase64Encoded": image}} for image in images_b64]) | ||
|
||
return [embedding["imageEmbedding"] for embedding in response] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.