diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 16b1b91697..2810cc8d9f 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -28,7 +28,7 @@ from fairseq.models.ema import build_ema from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler -from fairseq.utils import safe_hasattr +from fairseq.utils import safe_hasattr, SigtermHandler logger = logging.getLogger(__name__) @@ -51,6 +51,8 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): ) cfg = convert_namespace_to_omegaconf(cfg) + self.sigterm = SigtermHandler() + self.cfg = cfg self.task = task diff --git a/fairseq/utils.py b/fairseq/utils.py index 4d4b350523..cb470e2f4b 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -14,6 +14,7 @@ import warnings from itertools import accumulate from typing import TYPE_CHECKING, Callable, Dict, List, Optional +import signal import torch import torch.nn.functional as F @@ -41,6 +42,16 @@ MANIFOLD_PATH_SEP = "|" +class SigtermHandler: + stop = False + + def __init__(self): + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, signum, frame): + self.stop = True + + class FileContentsAction(argparse.Action): def __init__(self, option_strings, dest, nargs=None, **kwargs): if nargs is not None: diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index f771bff654..1f67039594 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -381,6 +381,8 @@ def validate_and_save( end_of_epoch: bool, ) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() + sigstop = trainer.sigterm.stop + max_update = cfg.optimization.max_update or math.inf # Stopping conditions (and an additional one based on validation loss later @@ -393,6 +395,11 @@ def validate_and_save( f"num_updates: {num_updates} >= max_update: {max_update}" ) + if sigstop: + logger.info( + f"Saving due to SIGTERM signal." + ) + training_time_hours = trainer.cumulative_training_time() / (60 * 60) if ( cfg.optimization.stop_time_hours > 0 @@ -438,7 +445,7 @@ def validate_and_save( should_stop |= should_stop_early(cfg, valid_losses[0]) # Save checkpoint - if do_save or should_stop: + if do_save or should_stop or sigstop: cp_path = checkpoint_utils.save_checkpoint( cfg.checkpoint, trainer, epoch_itr, valid_losses[0] )