From c79b5b044abe69b724541c8d36bca5c5c0b8ecf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 26 Nov 2024 15:07:58 +0000 Subject: [PATCH 1/3] Remove lm_head check in `AutoModelForCausalLMWithValueHead` --- tests/test_modeling_value_head.py | 5 +++-- trl/models/modeling_value_head.py | 9 --------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index ddc4eb850c..c412e7a1b9 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -283,10 +283,11 @@ def test_transformers_bf16_kwargs(self): for model_name in self.all_model_names: trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16) - lm_head_namings = self.trl_model_class.lm_head_namings + lm_head_namings = ["lm_head", "embed_out", "output_layer"] self.assertTrue( - any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) + any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), + "Can't test the model because it doesn't have any of the expected lm_head namings", ) for lm_head_naming in lm_head_namings: diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py index 0797794013..592879ae3e 100644 --- a/trl/models/modeling_value_head.py +++ b/trl/models/modeling_value_head.py @@ -69,9 +69,6 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): Class attributes: - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This should be set to `transformers.AutoModelForCausalLM` for this class. - - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the - wrapped model. This is set to `("lm_head", "embed_out", "output_layer")` for this class but can be changed - for other models in the future - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported by the `ValueHead` class. Currently, the supported args are: - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the @@ -86,7 +83,6 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): """ transformers_parent_class = AutoModelForCausalLM - lm_head_namings = ["lm_head", "embed_out", "output_layer"] supported_args = ( "summary_dropout_prob", "v_head_initializer_range", @@ -106,12 +102,7 @@ def __init__(self, pretrained_model, **kwargs): """ super().__init__(pretrained_model, **kwargs) v_head_kwargs, _, _ = self._split_kwargs(kwargs) - - if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings): - raise ValueError("The model does not have a language model head, please use a model that has one.") - self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) - self._init_weights(**v_head_kwargs) def _init_weights(self, **kwargs): From 24f84a54e36e5d15f905d3b3ff777184a8de58af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 26 Nov 2024 15:20:43 +0000 Subject: [PATCH 2/3] Style --- tests/test_modeling_value_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index c412e7a1b9..fa6edb4c51 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -287,7 +287,7 @@ def test_transformers_bf16_kwargs(self): self.assertTrue( any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), - "Can't test the model because it doesn't have any of the expected lm_head namings", + "Can't test the model because it doesn't have any of the expected lm_head namings", ) for lm_head_naming in lm_head_namings: From 81c544b477aec67a7bd58c22aa4a4b0f67cc056a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 26 Nov 2024 15:38:00 +0000 Subject: [PATCH 3/3] Remove test --- tests/test_modeling_value_head.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index fa6edb4c51..be4932e62f 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -265,14 +265,6 @@ def test_generate(self, model_name): # Just check if the generation works _ = model.generate(input_ids, generation_config=generation_config) - def test_raise_error_not_causallm(self): - # Test with a model without a LM head - model_id = "trl-internal-testing/tiny-GPT2LMHeadModel" - # This should raise a ValueError - with self.assertRaises(ValueError): - pretrained_model = AutoModelForCausalLM.from_pretrained(model_id) - _ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer) - def test_transformers_bf16_kwargs(self): r""" Test if the transformers kwargs are correctly passed