diff --git a/imgdiet/core.py b/imgdiet/core.py index 16acce2..2e4c68a 100644 --- a/imgdiet/core.py +++ b/imgdiet/core.py @@ -50,65 +50,42 @@ def calculate_psnr( return 20.0 * math.log10(max_pixel / math.sqrt(mse)) -def measure_webp_quality_pil( +def measure_webp_pil( original_bgr: np.ndarray, pil_image: Image.Image, - quality: int + *, + quality: Optional[int] = None, + lossless: bool = False ) -> Tuple[float, int, bytes]: """ - Compresses the given PIL Image to WebP (quality-based), - returns (psnr, compressed_size, compressed_data). - """ - buffer = io.BytesIO() - icc_profile = pil_image.info.get("icc_profile") + Compresses the given PIL Image to WebP. + - If `lossless` is True, performs lossless compression. + - Otherwise, uses quality-based compression with the provided `quality`. - # Save with original mode (RGB or RGBA) - pil_image.save( - buffer, - format="WEBP", - quality=quality, - icc_profile=icc_profile, - exact=True - ) - data = buffer.getvalue() - size = len(data) - - buffer.seek(0) - # Open and convert to RGB only for PSNR calculation - compressed_pil = Image.open(buffer) - if compressed_pil.mode == 'RGBA': - compressed_pil = compressed_pil.convert('RGB') - compressed_bgr = np.array(compressed_pil)[:, :, ::-1] - - psnr_val = calculate_psnr(original_bgr, - compressed_bgr) - return psnr_val, size, data - - -def measure_webp_lossless_pil( - original_bgr: np.ndarray, - pil_image: Image.Image -) -> Tuple[float, int, bytes]: - """ - Compresses the given PIL Image in lossless WebP, - returns (psnr, compressed_size, compressed_data). + Returns a tuple (psnr, compressed_size, compressed_data). """ + if not lossless and quality is None: + raise ValueError("Either 'lossless' must be True or 'quality' must be provided for quality-based compression.") + buffer = io.BytesIO() icc_profile = pil_image.info.get("icc_profile") - # Save with original mode (RGB or RGBA) - pil_image.save( - buffer, - format="WEBP", - lossless=True, - icc_profile=icc_profile, - exact=True - ) + # Set up parameters for saving + save_kwargs = { + "format": "WEBP", + "icc_profile": icc_profile, + "exact": True, + } + if lossless: + save_kwargs["lossless"] = True + else: + save_kwargs["quality"] = quality + + pil_image.save(buffer, **save_kwargs) data = buffer.getvalue() size = len(data) buffer.seek(0) - # Open and convert to RGB only for PSNR calculation compressed_pil = Image.open(buffer) if compressed_pil.mode == 'RGBA': compressed_pil = compressed_pil.convert('RGB') @@ -134,7 +111,7 @@ def find_optimal_compression_binary_search( while left <= right: mid = (left + right) // 2 - psnr_val, size, _ = measure_webp_quality_pil(original_bgr, pil_image, mid) + psnr_val, size, _ = measure_webp_pil(original_bgr, pil_image, quality=mid, lossless=False) if psnr_val >= target_psnr: if size < best_size: @@ -246,7 +223,7 @@ def process_single_image( # Case 1: target_psnr == 0 => lossless if target_psnr == 0: try: - psnr_val, compressed_size, data = measure_webp_lossless_pil(original_bgr, pil_image) + psnr_val, compressed_size, data = measure_webp_pil(original_bgr, pil_image, lossless=True) if psnr_val == float("inf") and compressed_size < original_size: webp_path.parent.mkdir(parents=True, exist_ok=True) with open(webp_path, "wb") as f: @@ -283,10 +260,11 @@ def process_single_image( else: q = best_params["quality"] logger.info(f"Found best quality={q} for {img_path}") - psnr_val, compressed_size, _ = measure_webp_quality_pil(original_bgr, pil_image, q) + psnr_val, compressed_size, _ = measure_webp_pil(original_bgr, pil_image, quality=q, lossless=False) if compressed_size < original_size: webp_path.parent.mkdir(parents=True, exist_ok=True) with open(webp_path, "wb") as f: + icc_profile = pil_image.info.get("icc_profile", None) # quality 모드로 저장 (원본 모드 유지) pil_image.save( @@ -338,8 +316,9 @@ def save( valid_exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp", ".avif"} # Check extension is same with codec - if dst_path.suffix.lower() != f".{codec}": - raise ValueError(f"Codec and target extension are not same. {dst_path.suffix} != {codec}") + if dst_path.suffix and not dst_path.is_dir(): + if dst_path.suffix.lower() != f".{codec}": + raise ValueError(f"Codec and target extension are not same. {dst_path.suffix} != {codec}") # Add extension check and warning if dst_path.suffix and dst_path.suffix.lower() in valid_exts and dst_path.suffix.lower() not in ['.webp', '.avif']: