@@ -96,38 +96,45 @@ def fix_valuehead_checkpoint(
96
96
97
97
98
98
class FixValueHeadModelCallback (TrainerCallback ):
99
+ r"""
100
+ A callback for fixing the checkpoint for valuehead models.
101
+ """
102
+
99
103
@override
100
104
def on_save (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
101
105
r"""
102
106
Event called after a checkpoint save.
103
107
"""
104
108
if args .should_save :
109
+ output_dir = os .path .join (args .output_dir , "{}-{}" .format (PREFIX_CHECKPOINT_DIR , state .global_step ))
105
110
fix_valuehead_checkpoint (
106
- model = kwargs .pop ("model" ),
107
- output_dir = os .path .join (args .output_dir , "{}-{}" .format (PREFIX_CHECKPOINT_DIR , state .global_step )),
108
- safe_serialization = args .save_safetensors ,
111
+ model = kwargs .pop ("model" ), output_dir = output_dir , safe_serialization = args .save_safetensors
109
112
)
110
113
111
114
112
115
class SaveProcessorCallback (TrainerCallback ):
116
+ r"""
117
+ A callback for saving the processor.
118
+ """
119
+
113
120
def __init__ (self , processor : "ProcessorMixin" ) -> None :
114
- r"""
115
- Initializes a callback for saving the processor.
116
- """
117
121
self .processor = processor
118
122
123
+ @override
124
+ def on_save (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
125
+ if args .should_save :
126
+ output_dir = os .path .join (args .output_dir , "{}-{}" .format (PREFIX_CHECKPOINT_DIR , state .global_step ))
127
+ getattr (self .processor , "image_processor" ).save_pretrained (output_dir )
128
+
119
129
@override
120
130
def on_train_end (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
121
- r"""
122
- Event called at the end of training.
123
- """
124
131
if args .should_save :
125
132
getattr (self .processor , "image_processor" ).save_pretrained (args .output_dir )
126
133
127
134
128
135
class PissaConvertCallback (TrainerCallback ):
129
136
r"""
130
- Initializes a callback for converting the PiSSA adapter to a normal one.
137
+ A callback for converting the PiSSA adapter to a normal one.
131
138
"""
132
139
133
140
@override
@@ -147,9 +154,6 @@ def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", contr
147
154
148
155
@override
149
156
def on_train_end (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
150
- r"""
151
- Event called at the end of training.
152
- """
153
157
if args .should_save :
154
158
model = kwargs .pop ("model" )
155
159
pissa_init_dir = os .path .join (args .output_dir , "pissa_init" )
@@ -177,21 +181,22 @@ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control
177
181
178
182
179
183
class LogCallback (TrainerCallback ):
184
+ r"""
185
+ A callback for logging training and evaluation status.
186
+ """
187
+
180
188
def __init__ (self ) -> None :
181
- r"""
182
- Initializes a callback for logging training and evaluation status.
183
- """
184
- """ Progress """
189
+ # Progress
185
190
self .start_time = 0
186
191
self .cur_steps = 0
187
192
self .max_steps = 0
188
193
self .elapsed_time = ""
189
194
self .remaining_time = ""
190
195
self .thread_pool : Optional ["ThreadPoolExecutor" ] = None
191
- """ Status """
196
+ # Status
192
197
self .aborted = False
193
198
self .do_train = False
194
- """ Web UI """
199
+ # Web UI
195
200
self .webui_mode = os .environ .get ("LLAMABOARD_ENABLED" , "0" ).lower () in ["true" , "1" ]
196
201
if self .webui_mode :
197
202
signal .signal (signal .SIGABRT , self ._set_abort )
@@ -233,9 +238,6 @@ def _close_thread_pool(self) -> None:
233
238
234
239
@override
235
240
def on_init_end (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
236
- r"""
237
- Event called at the end of the initialization of the `Trainer`.
238
- """
239
241
if (
240
242
args .should_save
241
243
and os .path .exists (os .path .join (args .output_dir , TRAINER_LOG ))
@@ -246,60 +248,39 @@ def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control:
246
248
247
249
@override
248
250
def on_train_begin (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
249
- r"""
250
- Event called at the beginning of training.
251
- """
252
251
if args .should_save :
253
252
self .do_train = True
254
253
self ._reset (max_steps = state .max_steps )
255
254
self ._create_thread_pool (output_dir = args .output_dir )
256
255
257
256
@override
258
257
def on_train_end (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
259
- r"""
260
- Event called at the end of training.
261
- """
262
258
self ._close_thread_pool ()
263
259
264
260
@override
265
261
def on_substep_end (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
266
- r"""
267
- Event called at the end of an substep during gradient accumulation.
268
- """
269
262
if self .aborted :
270
263
control .should_epoch_stop = True
271
264
control .should_training_stop = True
272
265
273
266
@override
274
267
def on_step_end (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
275
- r"""
276
- Event called at the end of a training step.
277
- """
278
268
if self .aborted :
279
269
control .should_epoch_stop = True
280
270
control .should_training_stop = True
281
271
282
272
@override
283
273
def on_evaluate (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
284
- r"""
285
- Event called after an evaluation phase.
286
- """
287
274
if not self .do_train :
288
275
self ._close_thread_pool ()
289
276
290
277
@override
291
278
def on_predict (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
292
- r"""
293
- Event called after a successful prediction.
294
- """
295
279
if not self .do_train :
296
280
self ._close_thread_pool ()
297
281
298
282
@override
299
283
def on_log (self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs ):
300
- r"""
301
- Event called after logging the last logs.
302
- """
303
284
if not args .should_save :
304
285
return
305
286
@@ -342,9 +323,6 @@ def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "Tra
342
323
def on_prediction_step (
343
324
self , args : "TrainingArguments" , state : "TrainerState" , control : "TrainerControl" , ** kwargs
344
325
):
345
- r"""
346
- Event called after a prediction step.
347
- """
348
326
if self .do_train :
349
327
return
350
328
0 commit comments