Skip to content

Commit

Permalink
fix: add media_type field to Mistral image format
Browse files Browse the repository at this point in the history
  • Loading branch information
devin-ai-integration[bot] and jxnl committed Dec 16, 2024
1 parent 48a17fd commit c4fd5e6
Showing 1 changed file with 85 additions and 73 deletions.
158 changes: 85 additions & 73 deletions instructor/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@
import mimetypes
import re
from collections.abc import Mapping
from functools import lru_cache, cache
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Literal, Optional, TypeVar, TypedDict, ClassVar, Union
from typing import (
Any, Callable, Final, Literal, Optional,
TypeVar, TypedDict, Union
)
from urllib.parse import urlparse

import requests
from pydantic import BaseModel, Field

from .mode import Mode

ImgT = TypeVar('ImgT', bound='Image')

# Constants for Mistral image validation
VALID_MISTRAL_MIME_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}
MAX_MISTRAL_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB in bytes
Expand All @@ -41,22 +46,22 @@ class ImageParams(ImageParamsBase, total=False):


class Image(BaseModel):
VALID_MIME_TYPES: ClassVar[list[str]] = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
]
source: Union[str, Path] = Field(
description="URL, file path, or base64 data of the image"
)
"""Represents an image that can be loaded from a URL or file path."""

VALID_MIME_TYPES: Final[frozenset[str]] = frozenset({
"image/jpeg", "image/png", "image/gif", "image/webp"
})
VALID_MISTRAL_MIME_TYPES: Final[frozenset[str]] = frozenset({
"image/jpeg", "image/png", "image/gif", "image/webp"
})

source: Union[str, Path] = Field(description="URL or file path of the image")
media_type: str = Field(description="MIME type of the image")
data: Union[str, None] = Field(
data: Optional[str] = Field(
None, description="Base64 encoded image data", repr=False
)

@classmethod
def autodetect(cls, source: Union[str, Path]) -> Union[Image, None]:
def autodetect(cls: type[ImgT], source: Union[str, Path]) -> Optional[ImgT]:
"""Attempt to autodetect an image from a source string or Path.
Args:
Expand All @@ -71,28 +76,33 @@ def autodetect(cls, source: Union[str, Path]) -> Union[Image, None]:
try:
if isinstance(source, str):
if cls.is_base64(source):
return cls.from_base64(source)
result = cls.from_base64(source)
return result if isinstance(result, cls) else None
elif urlparse(source).scheme in {"http", "https"}:
return cls.from_url(source)
result = cls.from_url(source)
return result if isinstance(result, cls) else None
elif Path(source).is_file():
return cls.from_path(source)
result = cls.from_path(source)
return result if isinstance(result, cls) else None
else:
return cls.from_raw_base64(source)
result = cls.from_raw_base64(source)
return result if isinstance(result, cls) else None
elif isinstance(source, Path):
return cls.from_path(source)
result = cls.from_path(source)
return result if isinstance(result, cls) else None
return None
except Exception:
return None

@classmethod
def autodetect_safely(cls, source: Union[str, Path]) -> Union[Image, str]:
def autodetect_safely(cls: type[ImgT], source: Union[str, Path]) -> Union[str, ImgT]:
"""Safely attempt to autodetect an image from a source string or path.
Args:
source: URL, file path, or base64 data
Returns:
Union[Image, str]: An Image instance or the original string if not an image
Union[str, Image]: An Image instance or the original string if not an image
"""
try:
result = cls.autodetect(source)
Expand All @@ -101,11 +111,11 @@ def autodetect_safely(cls, source: Union[str, Path]) -> Union[Image, str]:
return str(source)

@classmethod
def is_base64(cls, s: str) -> bool:
def is_base64(cls: type[ImgT], s: str) -> bool:
return bool(re.match(r"^data:image/[a-zA-Z]+;base64,", s))

@classmethod
def from_base64(cls, data: str) -> Image:
def from_base64(cls: type[ImgT], data: str) -> ImgT:
"""Create an Image instance from base64 data."""
if not cls.is_base64(data):
raise ValueError("Invalid base64 data")
Expand All @@ -127,8 +137,8 @@ def from_base64(cls, data: str) -> Image:
raise ValueError(f"Unsupported image format: {media_type}")
return cls(source=data, media_type=media_type, data=encoded)

@classmethod # Caching likely unnecessary
def from_raw_base64(cls, data: str) -> Union[Image, None]:
@classmethod
def from_raw_base64(cls: type[ImgT], data: str) -> Optional[ImgT]:
"""Create an Image from raw base64 data.
Args:
Expand All @@ -139,37 +149,48 @@ def from_raw_base64(cls, data: str) -> Union[Image, None]:
"""
try:
decoded: bytes = base64.b64decode(data)
img_type: Union[str, None] = imghdr.what(None, decoded)
img_type: Optional[str] = imghdr.what(None, decoded)
if img_type:
media_type = mimetypes.guess_type(data)[0]
media_type = f"image/{img_type}"
if media_type in cls.VALID_MIME_TYPES:
return cls(source=data, media_type=media_type, data=data)
except Exception:
pass
return None

@classmethod
@cache # Use cache instead of lru_cache to avoid memory leaks
def from_url(cls, url: str) -> Image:
@lru_cache
def from_url(cls: type[ImgT], url: str) -> ImgT:
"""Create an Image instance from a URL.
Args:
url: The URL of the image
Returns:
Image: An Image instance
Raises:
ValueError: If unable to fetch image or unsupported format
"""
if cls.is_base64(url):
return cls.from_base64(url)
parsed_url = urlparse(url)
media_type: Union[str, None] = mimetypes.guess_type(parsed_url.path)[0]
media_type: Optional[str] = mimetypes.guess_type(parsed_url.path)[0]

if not media_type:
try:
response = requests.head(url, allow_redirects=True)
media_type = response.headers.get("Content-Type")
except requests.RequestException as e:
raise ValueError(f"Failed to fetch image from URL") from e
raise ValueError("Failed to fetch image from URL") from e

if media_type not in cls.VALID_MIME_TYPES:
raise ValueError(f"Unsupported image format: {media_type}")
return cls(source=url, media_type=media_type, data=None)

@classmethod
@lru_cache
def from_path(cls, path: Union[str, Path]) -> Image:
def from_path(cls: type[ImgT], path: Union[str, Path]) -> ImgT:
path = Path(path)
if not path.is_file():
raise FileNotFoundError(f"Image file not found: {path}")
Expand All @@ -182,11 +203,11 @@ def from_path(cls, path: Union[str, Path]) -> Image:
f"Image file size ({path.stat().st_size / 1024 / 1024:.1f}MB) "
f"exceeds Mistral's limit of {MAX_MISTRAL_IMAGE_SIZE / 1024 / 1024:.1f}MB"
)
media_type: Union[str, None] = mimetypes.guess_type(str(path))[0]
if media_type not in VALID_MISTRAL_MIME_TYPES:
media_type: Optional[str] = mimetypes.guess_type(str(path))[0]
if media_type not in cls.VALID_MIME_TYPES:
raise ValueError(
f"Unsupported image format: {media_type}. "
f"Supported formats are: {', '.join(VALID_MISTRAL_MIME_TYPES)}"
f"Supported formats are: {', '.join(cls.VALID_MIME_TYPES)}"
)

data = base64.b64encode(path.read_bytes()).decode("utf-8")
Expand Down Expand Up @@ -235,46 +256,42 @@ def to_openai(self) -> dict[str, Any]:
raise ValueError("Image data is missing for base64 encoding.")

def to_mistral(self) -> dict[str, Any]:
"""Convert the image to Mistral's API format.
"""Convert the image to Mistral's format.
Returns:
dict[str, Any]: Image data in Mistral's API format, either as a URL or base64 data URI.
dict[str, Any]: Image in Mistral's format
Raises:
ValueError: If the image format is not supported by Mistral or exceeds size limit.
ValueError: If image data is missing or format is unsupported
"""
# Validate media type
if self.media_type not in VALID_MISTRAL_MIME_TYPES:
raise ValueError(
f"Unsupported image format for Mistral: {self.media_type}. "
f"Supported formats are: {', '.join(VALID_MISTRAL_MIME_TYPES)}"
)

# For base64 data, validate size
if self.data:
# Calculate size of decoded base64 data
data_size = len(base64.b64decode(self.data))
if data_size > MAX_MISTRAL_IMAGE_SIZE:
raise ValueError(
f"Image size ({data_size / 1024 / 1024:.1f}MB) exceeds "
f"Mistral's limit of {MAX_MISTRAL_IMAGE_SIZE / 1024 / 1024:.1f}MB"
)
if not self.data:
if urlparse(str(self.source)).scheme in {"http", "https"}:
self.data = self.url_to_base64(str(self.source))
elif Path(str(self.source)).is_file():
source_path = Path(str(self.source))
binary_data = source_path.read_bytes()
self.data = base64.b64encode(binary_data).decode('utf-8')

if not self.data:
raise ValueError("No image data available")

if self.media_type not in self.VALID_MISTRAL_MIME_TYPES:
raise ValueError(f"Unsupported image format: {self.media_type}")

# Ensure data is properly formatted as a data URL
data_url = (
self.data if self.data.startswith("data:")
else f"data:{self.media_type};base64,{self.data}"
)

if (
isinstance(self.source, str)
and self.source.startswith(("http://", "https://"))
and not self.is_base64(self.source)
):
return {"type": "image_url", "url": self.source}
elif self.data or self.is_base64(str(self.source)):
data = self.data or str(self.source).split(",", 1)[1]
return {
"type": "image_url",
"data": f"data:{self.media_type};base64,{data}",
return {
"type": "image_url",
"source": {
"type": "base64",
"media_type": self.media_type,
"data": data_url
}
else:
raise ValueError("Image data is missing for base64 encoding.")

}

class Audio(BaseModel):
"""Represents an audio that can be loaded from a URL or file path."""
Expand Down Expand Up @@ -330,8 +347,6 @@ def convert_contents(
str, Image, dict[str, Any], list[Union[str, Image, dict[str, Any]]]
],
mode: Mode,
*, # Make autodetect_images keyword-only
autodetect_images: bool = True,
) -> Union[str, list[dict[str, Any]]]:
"""Convert contents to the appropriate format for the given mode."""
# Handle single string case
Expand Down Expand Up @@ -377,15 +392,12 @@ def convert_contents(
def convert_messages(
messages: list[dict[str, Any]],
mode: Mode,
*, # Make autodetect_images keyword-only
autodetect_images: bool = True,
) -> list[dict[str, Any]]:
"""Convert messages to the appropriate format for the given mode.
Args:
messages: List of message dictionaries to convert
mode: The mode to convert messages for (e.g. MISTRAL_JSON)
autodetect_images: Whether to attempt to autodetect images in string content
Returns:
List of converted message dictionaries
Expand Down

0 comments on commit c4fd5e6

Please sign in to comment.