Skip to content

Commit

Permalink
Simplify immutable TrainingArgs fix using dataclasses.replace (#682)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen authored Aug 24, 2023
1 parent b095245 commit 1c27224
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import FrozenInstanceError
from dataclasses import FrozenInstanceError, replace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -133,14 +133,7 @@ def __init__(
try: # for bc before https://github.com/huggingface/transformers/pull/25435
args.remove_unused_columns = False
except FrozenInstanceError:
args_dict = args.to_dict()
args_dict["remove_unused_columns"] = False

new_args = TrainingArguments(
**args_dict,
)

args = new_args
args = replace(args, remove_unused_columns=False)
# warn users
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
Expand Down

0 comments on commit 1c27224

Please sign in to comment.