diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 30b1f3ca2611..f52bd780bcec 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -1168,21 +1168,34 @@ def _get_non_default_generation_parameters(self) -> dict[str, Any]: return non_default_generation_parameters - def get_text_config(self, decoder=False) -> "PretrainedConfig": + def get_text_config(self, decoder=None, encoder=None) -> "PretrainedConfig": """ - Returns the config that is meant to be used with text IO. On most models, it is the original config instance - itself. On specific composite models, it is under a set of valid names. + Returns the text config related to the text input (encoder) or text output (decoder) of the model. The + `decoder` and `encoder` input arguments can be used to specify which end of the model we are interested in, + which is useful on models that have both text input and output modalities. + + There are three possible outcomes of using this method: + 1. On most models, it returns the original config instance itself. + 2. On newer (2024+) composite models, it returns the text section of the config, which is nested under a set + of valid names. + 3. On older (2023-) composite models, it discards decoder-only parameters when `encoder=True` and vice-versa. Args: - decoder (`Optional[bool]`, *optional*, defaults to `False`): + decoder (`Optional[bool]`, *optional*): If set to `True`, then only search for decoder config names. + encoder (`Optional[bool]`, *optional*): + If set to `True`, then only search for encoder config names. """ + return_both = decoder == encoder # both unset or both set -> search all possible names + decoder_possible_text_config_names = ("decoder", "generator", "text_config") encoder_possible_text_config_names = ("text_encoder",) - if decoder: + if return_both: + possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names + elif decoder: possible_text_config_names = decoder_possible_text_config_names else: - possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names + possible_text_config_names = encoder_possible_text_config_names valid_text_config_names = [] for text_config_name in possible_text_config_names: @@ -1194,12 +1207,27 @@ def get_text_config(self, decoder=False) -> "PretrainedConfig": if len(valid_text_config_names) > 1: raise ValueError( f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this " - "case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly." + "case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, " + "e.g. `text_config = config.sub_config_name`" ) elif len(valid_text_config_names) == 1: config_to_return = getattr(self, valid_text_config_names[0]) else: config_to_return = self + + # handle legacy models with flat config structure, when we only want one of the configs + if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder: + config_to_return = copy.deepcopy(config_to_return) + prefix_to_discard = "encoder" if decoder else "decoder" + for key in config_to_return.to_dict(): + if key.startswith(prefix_to_discard): + delattr(config_to_return, key) + # old encoder/decoder models may use "encoder_layers"/"decoder_layers" instead of "num_hidden_layers" + if decoder and hasattr(config_to_return, "decoder_layers"): + config_to_return.num_hidden_layers = config_to_return.decoder_layers + elif encoder and hasattr(config_to_return, "encoder_layers"): + config_to_return.num_hidden_layers = config_to_return.encoder_layers + return config_to_return @classmethod diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9ba3e2a6d277..864d68d7b646 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1844,12 +1844,19 @@ def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len: ) if need_new_cache: - cache_kwargs = {"config": self.config, "max_cache_len": max_cache_len, "offloading": offload_cache} - self._cache = StaticCache(**cache_kwargs) + self_attention_cache_kwargs = { + "config": self.config.get_text_config(decoder=True), + "max_cache_len": max_cache_len, + "offloading": offload_cache, + } + self._cache = StaticCache(**self_attention_cache_kwargs) if requires_cross_attention_cache: - encoder_kwargs = cache_kwargs.copy() - encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1] - self._cache = EncoderDecoderCache(self._cache, StaticCache(**encoder_kwargs)) + cross_attention_cache_kwargs = { + "config": self.config.get_text_config(encoder=True), + "max_cache_len": model_kwargs["encoder_outputs"][0].shape[1], + "offloading": offload_cache, + } + self._cache = EncoderDecoderCache(self._cache, StaticCache(**cross_attention_cache_kwargs)) else: self._cache.reset() return self._cache diff --git a/src/transformers/models/colqwen2/configuration_colqwen2.py b/src/transformers/models/colqwen2/configuration_colqwen2.py index bab31fae74fe..d9a42df4c97e 100644 --- a/src/transformers/models/colqwen2/configuration_colqwen2.py +++ b/src/transformers/models/colqwen2/configuration_colqwen2.py @@ -87,8 +87,8 @@ def __init__( self.initializer_range = initializer_range super().__init__(**kwargs) - def get_text_config(self, decoder=False) -> PretrainedConfig: - return self.vlm_config.get_text_config(decoder=decoder) + def get_text_config(self, *args, **kwargs) -> PretrainedConfig: + return self.vlm_config.get_text_config(*args, **kwargs) __all__ = ["ColQwen2Config"] diff --git a/src/transformers/models/dia/configuration_dia.py b/src/transformers/models/dia/configuration_dia.py index 90ace73b3c96..d4dec60b3e48 100644 --- a/src/transformers/models/dia/configuration_dia.py +++ b/src/transformers/models/dia/configuration_dia.py @@ -368,7 +368,7 @@ def __init__( **kwargs, ) - def get_text_config(self, decoder=False): + def get_text_config(self, *args, **kwargs): """Defaulting to audio config as it's the decoder in this case which is usually the text backbone""" return self.decoder_config diff --git a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py index 552249f0ed2b..5df1b10a6528 100644 --- a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py @@ -1073,7 +1073,7 @@ def __init__( super().__init__(**kwargs) - def get_text_config(self, decoder=False): + def get_text_config(self, *args, **kwargs): """ Returns the config that is meant to be used with text IO. On most models, it is the original config instance itself. On specific composite models, it is under a set of valid names. @@ -1085,7 +1085,7 @@ def get_text_config(self, decoder=False): # Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model # except for Qwen yet. This has to be generalized if more deeply nested configs are # added. NOTE: currently method used only by vLLM - return self.thinker_config.get_text_config() + return self.thinker_config.get_text_config(*args, **kwargs) __all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"] diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index e65d0597f197..2ef432f7e171 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1108,7 +1108,7 @@ def __init__( super().__init__(**kwargs) - def get_text_config(self, decoder=False): + def get_text_config(self, *args, **kwargs): """ Returns the config that is meant to be used with text IO. On most models, it is the original config instance itself. On specific composite models, it is under a set of valid names. @@ -1120,7 +1120,7 @@ def get_text_config(self, decoder=False): # Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model # except for Qwen yet. This has to be generalized if more deeply nested configs are # added. NOTE: currently method used only by vLLM - return self.thinker_config.get_text_config() + return self.thinker_config.get_text_config(*args, **kwargs) class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel): diff --git a/src/transformers/models/t5gemma/configuration_t5gemma.py b/src/transformers/models/t5gemma/configuration_t5gemma.py index bc195d562f2c..86e367413ace 100644 --- a/src/transformers/models/t5gemma/configuration_t5gemma.py +++ b/src/transformers/models/t5gemma/configuration_t5gemma.py @@ -323,9 +323,8 @@ def __setattr__(self, key, value): setattr(self.decoder, key, value) super().__setattr__(key, value) - def get_text_config(self, decoder=False): + def get_text_config(self, *args, **kwargs): # Always return self, regardless of the decoder option. - del decoder return self diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index ebddaa5c4521..1d74fe8b33f6 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -213,9 +213,8 @@ def __setattr__(self, key, value): setattr(self.decoder, key, value) super().__setattr__(key, value) - def get_text_config(self, decoder=False): + def get_text_config(self, *args, **kwargs): # Always return self, regardless of the decoder option. - del decoder return self diff --git a/tests/utils/test_configuration_utils.py b/tests/utils/test_configuration_utils.py index 2bfd49399390..dac7669dd797 100644 --- a/tests/utils/test_configuration_utils.py +++ b/tests/utils/test_configuration_utils.py @@ -24,7 +24,7 @@ from huggingface_hub import HfFolder from requests.exceptions import HTTPError -from transformers import AutoConfig, BertConfig, GPT2Config +from transformers import AutoConfig, BertConfig, Florence2Config, GPT2Config from transformers.configuration_utils import PretrainedConfig from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test @@ -300,3 +300,35 @@ def test_loading_config_do_not_raise_future_warnings(self): with warnings.catch_warnings(): warnings.simplefilter("error") PretrainedConfig.from_pretrained("bert-base-uncased") + + def test_get_text_config(self): + """Tests the `get_text_config` method.""" + # 1. model with only text input -> returns the original config instance + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + self.assertEqual(config.get_text_config(), config) + self.assertEqual(config.get_text_config(decoder=True), config) + + # 2. composite model (VLM) -> returns the text component + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlavaForConditionalGeneration") + self.assertEqual(config.get_text_config(), config.text_config) + self.assertEqual(config.get_text_config(decoder=True), config.text_config) + + # 3. ! corner case! : composite model whose sub-config is an old composite model (should behave as above) + config = Florence2Config() + self.assertEqual(config.get_text_config(), config.text_config) + self.assertEqual(config.get_text_config(decoder=True), config.text_config) + + # 4. old composite model -> may remove components based on the `decoder` or `encoder` argument + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-bart") + self.assertEqual(config.get_text_config(), config) + # both encoder_layers and decoder_layers exist + self.assertTrue(getattr(config, "encoder_layers", None) is not None) + self.assertTrue(getattr(config, "decoder_layers", None) is not None) + decoder_config = config.get_text_config(decoder=True) + self.assertNotEqual(decoder_config, config) + self.assertEqual(decoder_config.num_hidden_layers, config.decoder_layers) + self.assertTrue(getattr(decoder_config, "encoder_layers", None) is None) # encoder_layers is removed + encoder_config = config.get_text_config(encoder=True) + self.assertNotEqual(encoder_config, config) + self.assertEqual(encoder_config.num_hidden_layers, config.encoder_layers) + self.assertTrue(getattr(encoder_config, "decoder_layers", None) is None) # decoder_layers is removed