diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 6c312b7a5a3f..2ffb4ae41b33 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -186,9 +186,9 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) # 4. Prepare for GQA - query_idx = torch.tensor(query.size(3), device=query.device) - key_idx = torch.tensor(key.size(3), device=key.device) - value_idx = torch.tensor(value.size(3), device=value.device) + query_idx = query.size(3) + key_idx = key.size(3) + value_idx = value.size(3) key = key.repeat_interleave(query_idx // key_idx, dim=3) value = value.repeat_interleave(query_idx // value_idx, dim=3)