diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index 334221712..2be44fbc1 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -456,13 +456,16 @@ def _init_max_length(self) -> int: return self.config.max_length # Try to get the sequence length from the model config. + text_model_config = self.transformers_config.get_text_config() + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") for attr in seqlen_config_attrs: - if hasattr(self.transformers_config, attr): - return getattr(self.transformers_config, attr) + if hasattr(text_model_config, attr): + return getattr(text_model_config, attr) logger.warning( - "No max_length attribute found in the model config. Using the default max sequence length setting {2048}. It is recomended to set max_length through the model args" + "No max_length attribute found in the model config. Using the default max sequence length setting `2048`. " + "It is recommended to set max_length trough the model args: max_length=..." ) return 2048