Skip to content

[generate] handle support for cache classes when num enc layers != num dec layers #40277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
42 changes: 35 additions & 7 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,21 +1168,34 @@ def _get_non_default_generation_parameters(self) -> dict[str, Any]:

return non_default_generation_parameters

def get_text_config(self, decoder=False) -> "PretrainedConfig":
def get_text_config(self, decoder=None, encoder=None) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
Returns the text config related to the text input (encoder) or text output (decoder) of the model. The
`decoder` and `encoder` input arguments can be used to specify which end of the model we are interested in,
which is useful on models that have both text input and output modalities.

There are three possible outcomes of using this method:
1. On most models, it returns the original config instance itself.
2. On newer (2024+) composite models, it returns the text section of the config, which is nested under a set
of valid names.
3. On older (2023-) composite models, it discards decoder-only parameters when `encoder=True` and vice-versa.

Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
decoder (`Optional[bool]`, *optional*):
If set to `True`, then only search for decoder config names.
encoder (`Optional[bool]`, *optional*):
If set to `True`, then only search for encoder config names.
"""
return_both = decoder == encoder # both unset or both set -> search all possible names

decoder_possible_text_config_names = ("decoder", "generator", "text_config")
encoder_possible_text_config_names = ("text_encoder",)
if decoder:
if return_both:
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
elif decoder:
possible_text_config_names = decoder_possible_text_config_names
else:
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
possible_text_config_names = encoder_possible_text_config_names

valid_text_config_names = []
for text_config_name in possible_text_config_names:
Expand All @@ -1194,12 +1207,27 @@ def get_text_config(self, decoder=False) -> "PretrainedConfig":
if len(valid_text_config_names) > 1:
raise ValueError(
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
"case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, "
"e.g. `text_config = config.sub_config_name`"
)
elif len(valid_text_config_names) == 1:
config_to_return = getattr(self, valid_text_config_names[0])
else:
config_to_return = self

# handle legacy models with flat config structure, when we only want one of the configs
if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder:
config_to_return = copy.deepcopy(config_to_return)
prefix_to_discard = "encoder" if decoder else "decoder"
for key in config_to_return.to_dict():
if key.startswith(prefix_to_discard):
delattr(config_to_return, key)
# old encoder/decoder models may use "encoder_layers"/"decoder_layers" instead of "num_hidden_layers"
if decoder and hasattr(config_to_return, "decoder_layers"):
config_to_return.num_hidden_layers = config_to_return.decoder_layers
elif encoder and hasattr(config_to_return, "encoder_layers"):
config_to_return.num_hidden_layers = config_to_return.encoder_layers

return config_to_return

@classmethod
Expand Down
17 changes: 12 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,12 +1844,19 @@ def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len:
)

if need_new_cache:
cache_kwargs = {"config": self.config, "max_cache_len": max_cache_len, "offloading": offload_cache}
self._cache = StaticCache(**cache_kwargs)
self_attention_cache_kwargs = {
"config": self.config.get_text_config(decoder=True),
"max_cache_len": max_cache_len,
"offloading": offload_cache,
}
self._cache = StaticCache(**self_attention_cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
self._cache = EncoderDecoderCache(self._cache, StaticCache(**encoder_kwargs))
cross_attention_cache_kwargs = {
"config": self.config.get_text_config(encoder=True),
"max_cache_len": model_kwargs["encoder_outputs"][0].shape[1],
"offloading": offload_cache,
}
self._cache = EncoderDecoderCache(self._cache, StaticCache(**cross_attention_cache_kwargs))
else:
self._cache.reset()
return self._cache
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/colqwen2/configuration_colqwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def __init__(
self.initializer_range = initializer_range
super().__init__(**kwargs)

def get_text_config(self, decoder=False) -> PretrainedConfig:
return self.vlm_config.get_text_config(decoder=decoder)
def get_text_config(self, *args, **kwargs) -> PretrainedConfig:
return self.vlm_config.get_text_config(*args, **kwargs)


__all__ = ["ColQwen2Config"]
2 changes: 1 addition & 1 deletion src/transformers/models/dia/configuration_dia.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def __init__(
**kwargs,
)

def get_text_config(self, decoder=False):
def get_text_config(self, *args, **kwargs):
"""Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
return self.decoder_config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ def __init__(

super().__init__(**kwargs)

def get_text_config(self, decoder=False):
def get_text_config(self, *args, **kwargs):
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
Expand All @@ -1085,7 +1085,7 @@ def get_text_config(self, decoder=False):
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return self.thinker_config.get_text_config()
return self.thinker_config.get_text_config(*args, **kwargs)


__all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"]
4 changes: 2 additions & 2 deletions src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def __init__(

super().__init__(**kwargs)

def get_text_config(self, decoder=False):
def get_text_config(self, *args, **kwargs):
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
Expand All @@ -1120,7 +1120,7 @@ def get_text_config(self, decoder=False):
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return self.thinker_config.get_text_config()
return self.thinker_config.get_text_config(*args, **kwargs)


class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/t5gemma/configuration_t5gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,8 @@ def __setattr__(self, key, value):
setattr(self.decoder, key, value)
super().__setattr__(key, value)

def get_text_config(self, decoder=False):
def get_text_config(self, *args, **kwargs):
# Always return self, regardless of the decoder option.
del decoder
return self


Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/t5gemma/modular_t5gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,8 @@ def __setattr__(self, key, value):
setattr(self.decoder, key, value)
super().__setattr__(key, value)

def get_text_config(self, decoder=False):
def get_text_config(self, *args, **kwargs):
# Always return self, regardless of the decoder option.
del decoder
return self


Expand Down
34 changes: 33 additions & 1 deletion tests/utils/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from huggingface_hub import HfFolder
from requests.exceptions import HTTPError

from transformers import AutoConfig, BertConfig, GPT2Config
from transformers import AutoConfig, BertConfig, Florence2Config, GPT2Config
from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test

Expand Down Expand Up @@ -300,3 +300,35 @@ def test_loading_config_do_not_raise_future_warnings(self):
with warnings.catch_warnings():
warnings.simplefilter("error")
PretrainedConfig.from_pretrained("bert-base-uncased")

def test_get_text_config(self):
"""Tests the `get_text_config` method."""
# 1. model with only text input -> returns the original config instance
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
self.assertEqual(config.get_text_config(), config)
self.assertEqual(config.get_text_config(decoder=True), config)

# 2. composite model (VLM) -> returns the text component
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlavaForConditionalGeneration")
self.assertEqual(config.get_text_config(), config.text_config)
self.assertEqual(config.get_text_config(decoder=True), config.text_config)

# 3. ! corner case! : composite model whose sub-config is an old composite model (should behave as above)
config = Florence2Config()
self.assertEqual(config.get_text_config(), config.text_config)
self.assertEqual(config.get_text_config(decoder=True), config.text_config)

# 4. old composite model -> may remove components based on the `decoder` or `encoder` argument
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-bart")
self.assertEqual(config.get_text_config(), config)
# both encoder_layers and decoder_layers exist
self.assertTrue(getattr(config, "encoder_layers", None) is not None)
self.assertTrue(getattr(config, "decoder_layers", None) is not None)
decoder_config = config.get_text_config(decoder=True)
self.assertNotEqual(decoder_config, config)
self.assertEqual(decoder_config.num_hidden_layers, config.decoder_layers)
self.assertTrue(getattr(decoder_config, "encoder_layers", None) is None) # encoder_layers is removed
encoder_config = config.get_text_config(encoder=True)
self.assertNotEqual(encoder_config, config)
self.assertEqual(encoder_config.num_hidden_layers, config.encoder_layers)
self.assertTrue(getattr(encoder_config, "decoder_layers", None) is None) # decoder_layers is removed