Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(weave): Add support for jpegs and pngs #3304

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion weave/trace/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from collections.abc import Iterable, Iterator
from collections.abc import Iterable, Iterator, MutableMapping
from concurrent.futures import ThreadPoolExecutor as _ThreadPoolExecutor
from contextvars import Context, copy_context
from functools import partial, wraps
Expand Down Expand Up @@ -168,6 +168,52 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return deco


class InvertableDict(MutableMapping):
"""A bijective mapping that behaves like a dict.

Invert the dict using the `inv` property.
"""

def __init__(self, *args, **kwargs):
self._forward = dict(*args, **kwargs)
self._backward = {}
for key, value in self._forward.items():
if value in self._backward:
raise ValueError(f"Duplicate value found: {value}")
self._backward[value] = key

def __getitem__(self, key):
return self._forward[key]

def __setitem__(self, key, value):
if key in self._forward:
del self._backward[self._forward[key]]
if value in self._backward:
raise ValueError(f"Duplicate value found: {value}")
self._forward[key] = value
self._backward[value] = key

def __delitem__(self, key):
value = self._forward.pop(key)
del self._backward[value]

def __iter__(self):
return iter(self._forward)

def __len__(self):
return len(self._forward)

def __repr__(self):
return repr(self._forward)

def __contains__(self, key):
return key in self._forward

@property
def inv(self):
return self._backward


# rename for cleaner export
ThreadPoolExecutor = ContextAwareThreadPoolExecutor
Thread = ContextAwareThread
Expand Down
34 changes: 28 additions & 6 deletions weave/type_handlers/Image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from weave.trace import object_preparers, serializer
from weave.trace.custom_objs import MemTraceFilesArtifact
from weave.trace.util import InvertableDict

try:
from PIL import Image
Expand All @@ -17,6 +18,14 @@

logger = logging.getLogger(__name__)

DEFAULT_FORMAT = "PNG"
PIL_FORMAT_TO_EXT = InvertableDict(
{
"JPEG": "jpg",
"PNG": "png",
}
)


class PILImagePreparer:
def should_prepare(self, obj: Any) -> bool:
Expand All @@ -35,6 +44,12 @@ def prepare(self, obj: Image.Image) -> None:


def save(obj: Image.Image, artifact: MemTraceFilesArtifact, name: str) -> None:
fmt = getattr(obj, "format", DEFAULT_FORMAT)
ext = PIL_FORMAT_TO_EXT.get(fmt)
if ext is None:
logger.warning(f"Unknown image format {fmt}, defaulting to {DEFAULT_FORMAT}")
ext = PIL_FORMAT_TO_EXT[DEFAULT_FORMAT]

# Note: I am purposely ignoring the `name` here and hard-coding the filename to "image.png".
# There is an extensive internal discussion here:
# https://weightsandbiases.slack.com/archives/C03BSTEBD7F/p1723670081582949
Expand All @@ -49,15 +64,22 @@ def save(obj: Image.Image, artifact: MemTraceFilesArtifact, name: str) -> None:
# using the same artifact. Moreover, since we package the deserialization logic with the
# object payload, we can always change the serialization logic later without breaking
# existing payloads.
with artifact.new_file("image.png", binary=True) as f:
obj.save(f, format="png") # type: ignore
fname = f"image.{ext}"
with artifact.new_file(fname, binary=True) as f:
obj.save(f, format=PIL_FORMAT_TO_EXT.inv[ext]) # type: ignore


def load(artifact: MemTraceFilesArtifact, name: str) -> Image.Image:
# Note: I am purposely ignoring the `name` here and hard-coding the filename. See comment
# on save.
path = artifact.path("image.png")
return Image.open(path)
for ext in PIL_FORMAT_TO_EXT.values():
# Note: I am purposely ignoring the `name` here and hard-coding the filename.
# See comment on save.
fname = f"image.{ext}"
path = artifact.path(fname)
try:
return Image.open(path)
except FileNotFoundError:
continue
raise FileNotFoundError(f"No image found in artifact {artifact}")


def register() -> None:
Expand Down
Loading