Skip to content

[generate] handle support for cache classes when num enc layers != num dec layers #40277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

gante
Copy link
Member

@gante gante commented Aug 19, 2025

What does this PR do?

I think this bug has been in since EncoderDecoderCache was added: we naively assumed that # encoder layers == # decoder layers, as we used config.num_hidden_layers (# encoder layers in encoder-decoder models) to instantiate both caches 👀

This PR adds the missing logic:

  • we can specify which part of the config we want to pull with get_text_config() (previously we could only isolate the decoder). We also update num_hidden_layers accordingly, so that the caches see the right number of layers.
  • we pull the right config (encoder vs decoder) before parameterizing EncoderDecoderCache in generate
  • (adds tests)

Fixes #40120


Example of failing checkpoint, taken from #40120 (needs num encoder layers != num decoder layers)

import torch
from transformers import AutoTokenizer, pipeline

torch_dtype = torch.float16
device_map = "cpu"
model_kwargs = dict(torch_dtype=torch_dtype, device_map=device_map)
model_id = "sshleifer/distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(model_id)
generator = pipeline(
    "summarization",
    model=model_id,
    tokenizer=tokenizer,
    model_kwargs=model_kwargs,
)
generation_config = generator.model.generation_config
generation_config.do_sample = True
generation_config.use_cache = True
generation_config.temperature = 1.0
generation_config.num_beams = 1
generation_config.max_new_tokens = 100
generation_config.min_new_tokens = 100
generation_config.top_p = 1.0
generation_config.cache_implementation="static"

prompt = "I like math"

output = generator(prompt , batch_size=1, generation_config=generation_config)  # the cache will have an incorrect number of layers on `main`, but it runs
output = generator(prompt , batch_size=1, generation_config=generation_config)  # crashes on `main`

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante
Copy link
Member Author

gante commented Aug 20, 2025

@Cyrilvallez CI green now 👌

(possibly there are a few more edge cases, but I don't think it's worth going the extra mile to find them all)

Copy link
Contributor

@manueldeprada manueldeprada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, only comment I have is that it would be clearer for me to name it cross_attention_cache instead of encoder_cache and same for self_aattention_cache and decoder_cache.

@gante
Copy link
Member Author

gante commented Aug 20, 2025

@manueldeprada like this? (see latest changes)

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: colqwen2, dia, qwen2_5_omni, t5gemma

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants