diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 415a8dbdcf8..befa3482116 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -8,6 +8,7 @@ from torch import nn from transformers import BatchFeature, Gemma3Config, Gemma3Processor from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs +from typing_extensions import NotRequired import vllm.envs as envs from vllm.config import VllmConfig @@ -52,7 +53,7 @@ class Gemma3ImagePixelInputs(TypedDict): over each image over each prompt in the batch. """ - num_patches: torch.Tensor + num_patches: NotRequired[torch.Tensor] """Shape: `(batch_size * num_images)`""" @@ -266,6 +267,12 @@ def _call_hf_processor( mm_kwargs, ) + hf_processor = self.info.get_hf_processor(**mm_kwargs) + images_kwargs = self.info._resolve_image_kwargs( + hf_processor, {"do_pan_and_scan"}) + if not images_kwargs["do_pan_and_scan"]: + return processed_outputs + # HF processor pops the `num_crops` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: parsed_images = (self._get_data_parser().parse_mm_data({ @@ -276,7 +283,6 @@ def _call_hf_processor( parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] - hf_processor = self.info.get_hf_processor(**mm_kwargs) num_crops = [ self.info.get_num_crops(image_width=size.width, @@ -293,6 +299,8 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: + if "num_crops" not in hf_inputs: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) num_crops = hf_inputs.get("num_crops", torch.empty(0)) return dict( @@ -535,17 +543,22 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + pixel_values = flatten_bn(pixel_values, concat=True) + pixel_values = self._validate_pixel_values(pixel_values) + + if num_crops is None: + return Gemma3ImagePixelInputs(type="pixel_values", + pixel_values=pixel_values) if not isinstance(num_crops, (torch.Tensor, list)): raise ValueError("Incorrect type of num_crops. " f"Got type: {type(num_crops)}") - pixel_values = flatten_bn(pixel_values, concat=True) num_crops = flatten_bn(num_crops, concat=True) return Gemma3ImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values(pixel_values), + pixel_values=pixel_values, num_patches=num_crops + 1, ) @@ -563,7 +576,7 @@ def _process_image_input( assert self.vision_tower is not None pixel_values = image_input["pixel_values"] - num_patches = image_input["num_patches"] + num_patches = image_input.get("num_patches") image_features = self._image_pixels_to_features( self.vision_tower, @@ -571,6 +584,8 @@ def _process_image_input( ) image_embeds = self.multi_modal_projector(image_features) + if num_patches is None: + return image_embeds return [ e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist()) ]