From 3a847f65d70e4271e5163bc8532b9c788aef3e5c Mon Sep 17 00:00:00 2001 From: George Date: Thu, 12 Dec 2024 13:23:34 +0000 Subject: [PATCH] fix: fix pad to square for some rectangular images (#421) --- fastembed/image/transform/functional.py | 22 ++++++++++++++-------- fastembed/text/text_embedding.py | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/fastembed/image/transform/functional.py b/fastembed/image/transform/functional.py index 6b85e6ab..afefe4be 100644 --- a/fastembed/image/transform/functional.py +++ b/fastembed/image/transform/functional.py @@ -1,4 +1,4 @@ -from typing import Sized, Union, Optional +from typing import Sized, Union import numpy as np from PIL import Image @@ -127,18 +127,24 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]): def pad2square( image: Image.Image, size: int, - fill_color: Optional[Union[str, int, tuple[int, ...]]] = None, + fill_color: Union[str, int, tuple[int, ...]] = 0, ) -> Image.Image: height, width = image.height, image.width - # if the size is larger than the new canvas - if width > size or height > size: + left, right = 0, width + top, bottom = 0, height + + crop_required = False + if width > size: left = (width - size) // 2 - top = (height - size) // 2 right = left + size + crop_required = True + + if height > size: + top = (height - size) // 2 bottom = top + size - image = image.crop((left, top, right, bottom)) + crop_required = True - new_image = Image.new(mode="RGB", size=(size, size), color=fill_color or 0) - new_image.paste(image) + new_image = Image.new(mode="RGB", size=(size, size), color=fill_color) + new_image.paste(image.crop((left, top, right, bottom)) if crop_required else image) return new_image diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 81e8d6f0..960d68f7 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -78,7 +78,7 @@ def __init__( return raise ValueError( - f"Model {model_name} is not supported in TextEmbedding." + f"Model {model_name} is not supported in TextEmbedding. " "Please check the supported models using `TextEmbedding.list_supported_models()`" )