diff --git a/requirements/common.txt b/requirements/common.txt index 639abe51101..9a9ae1d9389 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -44,3 +44,4 @@ watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/others/logging_configuration.md scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu +pybase64 # fast base64 implementation diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index e673632d436..dce4c4c1cad 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import base64 from io import BytesIO from pathlib import Path +import pybase64 import torch from PIL import Image @@ -55,7 +55,7 @@ def load_bytes(self, data: bytes) -> Image.Image: return convert_image_mode(image, self.image_mode) def load_base64(self, media_type: str, data: str) -> Image.Image: - return self.load_bytes(base64.b64decode(data)) + return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> Image.Image: image = Image.open(filepath) @@ -75,7 +75,7 @@ def encode_base64( image.save(buffer, image_format) data = buffer.getvalue() - return base64.b64encode(data).decode('utf-8') + return pybase64.b64encode(data).decode('utf-8') class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): @@ -88,10 +88,10 @@ def load_bytes(self, data: bytes) -> torch.Tensor: return torch.load(buffer, weights_only=True) def load_base64(self, media_type: str, data: str) -> torch.Tensor: - return self.load_bytes(base64.b64decode(data)) + return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> torch.Tensor: return torch.load(filepath, weights_only=True) def encode_base64(self, media: torch.Tensor) -> str: - return base64.b64encode(media.numpy()).decode('utf-8') + return pybase64.b64encode(media.numpy()).decode('utf-8')