Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔓 Remove lm_head check in AutoModelForCausalLMWithValueHead #2398

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions tests/test_modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,6 @@ def test_generate(self, model_name):
# Just check if the generation works
_ = model.generate(input_ids, generation_config=generation_config)

def test_raise_error_not_causallm(self):
# Test with a model without a LM head
model_id = "trl-internal-testing/tiny-GPT2LMHeadModel"
# This should raise a ValueError
with self.assertRaises(ValueError):
pretrained_model = AutoModelForCausalLM.from_pretrained(model_id)
_ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer)

def test_transformers_bf16_kwargs(self):
r"""
Test if the transformers kwargs are correctly passed
Expand All @@ -283,10 +275,11 @@ def test_transformers_bf16_kwargs(self):
for model_name in self.all_model_names:
trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16)

lm_head_namings = self.trl_model_class.lm_head_namings
lm_head_namings = ["lm_head", "embed_out", "output_layer"]

self.assertTrue(
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings),
"Can't test the model because it doesn't have any of the expected lm_head namings",
)

for lm_head_naming in lm_head_namings:
Expand Down
9 changes: 0 additions & 9 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
Class attributes:
- **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This
should be set to `transformers.AutoModelForCausalLM` for this class.
- **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the
wrapped model. This is set to `("lm_head", "embed_out", "output_layer")` for this class but can be changed
for other models in the future
- **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported
by the `ValueHead` class. Currently, the supported args are:
- **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the
Expand All @@ -86,7 +83,6 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
"""

transformers_parent_class = AutoModelForCausalLM
lm_head_namings = ["lm_head", "embed_out", "output_layer"]
supported_args = (
"summary_dropout_prob",
"v_head_initializer_range",
Expand All @@ -106,12 +102,7 @@ def __init__(self, pretrained_model, **kwargs):
"""
super().__init__(pretrained_model, **kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)

if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
raise ValueError("The model does not have a language model head, please use a model that has one.")

self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)

self._init_weights(**v_head_kwargs)

def _init_weights(self, **kwargs):
Expand Down