Skip to content

Commit

Permalink
Added SIGTERM listener sent by slurm to save model
Browse files Browse the repository at this point in the history
  • Loading branch information
William-N-Havard committed Jul 31, 2024
1 parent 11b72ca commit 3aa2bf9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
4 changes: 3 additions & 1 deletion fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from fairseq.models.ema import build_ema
from fairseq.nan_detector import NanDetector
from fairseq.optim import lr_scheduler
from fairseq.utils import safe_hasattr
from fairseq.utils import safe_hasattr, SigtermHandler

logger = logging.getLogger(__name__)

Expand All @@ -51,6 +51,8 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None):
)
cfg = convert_namespace_to_omegaconf(cfg)

self.sigterm = SigtermHandler()

self.cfg = cfg
self.task = task

Expand Down
11 changes: 11 additions & 0 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import warnings
from itertools import accumulate
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import signal

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -41,6 +42,16 @@
MANIFOLD_PATH_SEP = "|"


class SigtermHandler:
stop = False

def __init__(self):
signal.signal(signal.SIGTERM, self.exit_gracefully)

def exit_gracefully(self, signum, frame):
self.stop = True


class FileContentsAction(argparse.Action):
def __init__(self, option_strings, dest, nargs=None, **kwargs):
if nargs is not None:
Expand Down
9 changes: 8 additions & 1 deletion fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ def validate_and_save(
end_of_epoch: bool,
) -> Tuple[List[Optional[float]], bool]:
num_updates = trainer.get_num_updates()
sigstop = trainer.sigterm.stop

max_update = cfg.optimization.max_update or math.inf

# Stopping conditions (and an additional one based on validation loss later
Expand All @@ -393,6 +395,11 @@ def validate_and_save(
f"num_updates: {num_updates} >= max_update: {max_update}"
)

if sigstop:
logger.info(
f"Saving due to SIGTERM signal."
)

training_time_hours = trainer.cumulative_training_time() / (60 * 60)
if (
cfg.optimization.stop_time_hours > 0
Expand Down Expand Up @@ -438,7 +445,7 @@ def validate_and_save(
should_stop |= should_stop_early(cfg, valid_losses[0])

# Save checkpoint
if do_save or should_stop:
if do_save or should_stop or sigstop:
cp_path = checkpoint_utils.save_checkpoint(
cfg.checkpoint, trainer, epoch_itr, valid_losses[0]
)
Expand Down

0 comments on commit 3aa2bf9

Please sign in to comment.