Skip to content

[Model] Ultravox: Support Llama 4 and Gemma 3 backends #17818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
Expand Down
37 changes: 22 additions & 15 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@
merge_multimodal_embeddings,
merge_multimodal_embeddings_from_map)

_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
_MAX_ENCODER_BATCH_SIZE = 16


Expand Down Expand Up @@ -80,14 +78,15 @@ def get_hf_processor(
sampling_rate: Optional[int] = None,
**kwargs: object,
) -> ProcessorMixin:
config = self.ctx.model_config.hf_config
hf_processor = self.ctx.get_hf_processor(**kwargs)

# NOTE: Ultravox processing definition uses '<|eot_id|>' as the
# placeholder that will cause confusion with the actual end of turn
# token, thus we override placeholder with a reserved special
# token.
# token, thus we override placeholder with a reserved token.
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN
hf_processor.audio_replacement_token_id = config.audio_token_index

return hf_processor

def get_feature_extractor(
Expand Down Expand Up @@ -268,7 +267,7 @@ def __init__(self, config: UltravoxConfig):
else:
self.act = get_act_fn(config.projector_act)

dim_out = config.text_config.hidden_size
dim_out = config.text_hidden_size
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

# Ultravox v0.4.1 and below use layer_norm after the second linear layer
Expand Down Expand Up @@ -559,9 +558,13 @@ def get_input_embeddings(
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
Copy link
Contributor Author

@farzadab farzadab May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 When using V1, I noticed that the output was also completely garbled.

After debugging I noticed that when I tried printing input_ids here for the same sample (conditioned on len(input_ids)>1 to avoid decoding tokens), this is what I got:

# with VLLM_USE_V1=0
>>> t.decode([200000, 200005, 15651, 200006, 368, 4662, 583, 262, 19933, 43910, 26, 200008, 200005, 1556, 200006, 368, 4984, 290, 2182, 4097, 38, 7283, 201133, 200008, 200005, 140680, 200006, 368])
'<|begin_of_text|><|header_start|>system<|header_end|>\n\nYou are a helpful assistant.<|eot|><|header_start|>user<|header_end|>\n\nAnswer the following question: \n\n<|vision_reserved_special_token_1047|><|eot|><|header_start|>assistant<|header_end|>\n\n'

# with VLLM_USE_V1=1
>>> t.decode([24, 4984, 290, 2182, 4097, 38, 7283, 201133, 200008, 200005, 140680, 200006, 368])
',Answer the following question: \n\n<|vision_reserved_special_token_1047|><|eot|><|header_start|>assistant<|header_end|>\n\n'

The input_ids in the case of V1 seemed to be missing a part of the beginning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I got this issue at around v0.8.4 or 0.8.4. I'll try verifying it on v0.8.5.post1.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by upgrading to v0.9.1

if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
# The audio token index is not included in the embedding table
# We need to remove it before embedding lookup
safe_input_ids = input_ids.clone()
safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0
inputs_embeds = self.language_model.get_input_embeddings(
safe_input_ids)
if multimodal_embeddings is not None:

# TODO(ywang96): remove this block after v0 is deprecated.
if not envs.VLLM_USE_V1:
Expand All @@ -572,7 +575,7 @@ def get_input_embeddings(
else:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_AUDIO_PLACEHOLDER_TOKEN)
self.config.audio_token_index)
return inputs_embeds

def forward(self,
Expand Down Expand Up @@ -610,10 +613,14 @@ def forward(self,
multimodal_embeddings)
input_ids = None

hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
language_model = self.language_model
if hasattr(language_model, "language_model"):
language_model = language_model.language_model

hidden_states = language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def get_hf_text_config(config: PretrainedConfig):
# thinker_config.text_config.
return config.thinker_config.text_config

text_config = config.get_text_config()
text_config = config.get_text_config().get_text_config()

if text_config is not config:
# The code operates under the assumption that text_config should have
Expand Down
22 changes: 13 additions & 9 deletions vllm/transformers_utils/configs/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
"""

model_type = "ultravox"
audio_token = "<|audio|>"
is_composition = False

def __init__(
Expand Down Expand Up @@ -80,29 +81,32 @@ def __init__(
# Avoid circular import
from vllm.transformers_utils.config import get_config

self.text_config = get_config(text_model_id,
trust_remote_code=False)
text_config_obj = get_config(text_model_id,
trust_remote_code=False)
else:
text_config = text_config or {}
self.text_config = transformers.CONFIG_MAPPING[text_config.get(
text_config_obj = transformers.CONFIG_MAPPING[text_config.get(
"model_type", "llama")](**text_config)

inner_text_config = text_config_obj.get_text_config()

if audio_model_id is not None:
# Avoid circular import
from vllm.transformers_utils.config import get_config

self.audio_config = get_config(audio_model_id,
trust_remote_code=False)
audio_config = get_config(audio_model_id, trust_remote_code=False)
else:
audio_config = audio_config or {}
self.audio_config = transformers.CONFIG_MAPPING[audio_config.get(
audio_config = transformers.CONFIG_MAPPING[audio_config.get(
"model_type", "whisper")](**audio_config)

self.text_config = text_config_obj
self.audio_config = audio_config
self.text_model_lora_config = text_model_lora_config or {}
self.audio_model_lora_config = audio_model_lora_config or {}

self.vocab_size = self.text_config.vocab_size

self.initializer_range = self.text_config.initializer_range
self.vocab_size = inner_text_config.vocab_size
self.initializer_range = inner_text_config.initializer_range
self.text_hidden_size = inner_text_config.hidden_size

super().__init__(**kwargs)