Skip to content

[Models] Remove GPU-CPU sync when do_pan_and_scan=false in Gemma3 #19999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import nn
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from typing_extensions import NotRequired

import vllm.envs as envs
from vllm.config import VllmConfig
Expand Down Expand Up @@ -52,7 +53,7 @@ class Gemma3ImagePixelInputs(TypedDict):
over each image over each prompt in the batch.
"""

num_patches: torch.Tensor
num_patches: NotRequired[torch.Tensor]
"""Shape: `(batch_size * num_images)`"""


Expand Down Expand Up @@ -266,6 +267,12 @@ def _call_hf_processor(
mm_kwargs,
)

hf_processor = self.info.get_hf_processor(**mm_kwargs)
images_kwargs = self.info._resolve_image_kwargs(
hf_processor, {"do_pan_and_scan"})
if not images_kwargs["do_pan_and_scan"]:
return processed_outputs

# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
parsed_images = (self._get_data_parser().parse_mm_data({
Expand All @@ -276,7 +283,6 @@ def _call_hf_processor(
parsed_images.get_image_size(i)
for i in range(len(parsed_images))
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)

num_crops = [
self.info.get_num_crops(image_width=size.width,
Expand All @@ -293,6 +299,8 @@ def _get_mm_fields_config(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
if "num_crops" not in hf_inputs:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
num_crops = hf_inputs.get("num_crops", torch.empty(0))

return dict(
Expand Down Expand Up @@ -535,17 +543,22 @@ def _parse_and_validate_image_input(
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = flatten_bn(pixel_values, concat=True)
pixel_values = self._validate_pixel_values(pixel_values)

if num_crops is None:
return Gemma3ImagePixelInputs(type="pixel_values",
pixel_values=pixel_values)

if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")

pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)

return Gemma3ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
pixel_values=pixel_values,
num_patches=num_crops + 1,
)

Expand All @@ -563,14 +576,16 @@ def _process_image_input(
assert self.vision_tower is not None

pixel_values = image_input["pixel_values"]
num_patches = image_input["num_patches"]
num_patches = image_input.get("num_patches")

image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
image_embeds = self.multi_modal_projector(image_features)

if num_patches is None:
return image_embeds
return [
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
]
Expand Down