Skip to content

Commit

Permalink
Avoid circular import: move cleanup method to training.py (#932)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhieber authored Feb 5, 2021
1 parent f2d5f57 commit c3870e3
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 99 deletions.
11 changes: 6 additions & 5 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from . import constants as C
from . import data_io
from . import utils


class ConfigArgumentParser(argparse.ArgumentParser):
Expand Down Expand Up @@ -1085,26 +1084,28 @@ def add_training_args(params):
action="store_true",
help='In addition to keeping the last n params files, also keep params from checkpoint 0.')


train_params.add_argument('--cache-last-best-params',
required=False,
type=int,
default=0,
help='Cache the last n best params files, as distinct from the last n in sequence. Use 0 or negative to disable. Default: %(default)s')
help='Cache the last n best params files, as distinct from the last n in sequence. '
'Use 0 or negative to disable. Default: %(default)s')

train_params.add_argument('--cache-strategy',
required=False,
type=str,
default=C.AVERAGE_BEST,
choices=C.AVERAGE_CHOICES,
help='Strategy to use when deciding which are the "best" params files. Default: %(default)s')
help='Strategy to use when deciding which are the "best" params files. '
'Default: %(default)s')

train_params.add_argument('--cache-metric',
required=False,
type=str,
default=C.PERPLEXITY,
choices=C.METRICS,
help='Metric to use when deciding which are the "best" params files. Default: %(default)s')
help='Metric to use when deciding which are the "best" params files. '
'Default: %(default)s')

train_params.add_argument('--dry-run',
action='store_true',
Expand Down
6 changes: 3 additions & 3 deletions sockeye/average.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def find_checkpoints(model_path: str, size=4, strategy="best", metric: str = C.P
param_path = os.path.join(model_path, C.PARAMS_NAME)
points = [(value, checkpoint) for value, checkpoint in points if os.path.exists(param_path % checkpoint)]

if strategy == "best":
if strategy == C.AVERAGE_BEST:
# N best scoring points
top_n = strategy_best(points, size, maximize)

elif strategy == "last":
elif strategy == C.AVERAGE_LAST:
# N sequential points ending with overall best
top_n = strategy_last(points, size, maximize)

elif strategy == "lifespan":
elif strategy == C.AVERAGE_LIFESPAN:
# Track lifespan of every "new best" point
# Points dominated by a previous better point have lifespan 0
top_n = strategy_lifespan(points, size, maximize)
Expand Down
4 changes: 1 addition & 3 deletions sockeye/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
import sys
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List
from unittest.mock import patch

import numpy as np

import sockeye.average
import sockeye.checkpoint_decoder
import sockeye.constants as C
Expand Down
78 changes: 69 additions & 9 deletions sockeye/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
Code for training
"""
import glob
import logging
import os
import pickle
Expand All @@ -23,14 +24,13 @@
from collections import deque
from dataclasses import dataclass
from math import sqrt
from typing import Callable, Dict, List, Optional, Iterable, Tuple, Union

from pathlib import Path
from typing import Callable, Dict, List, Optional, Iterable, Tuple, Union, Set

import mxnet as mx
import numpy as np
from mxnet.contrib import amp

from . import average
from . import constants as C
from . import data_io
from . import horovod_mpi
Expand Down Expand Up @@ -565,9 +565,9 @@ def _save_params(self):
Saves model parameters at current checkpoint and optionally cleans up older parameter files to save disk space.
"""
self.model.save_parameters(self.current_params_fname)
utils.cleanup_params_files(self.config.output_dir, self.config.max_params_files_to_keep, self.state.checkpoint,
self.state.best_checkpoint, self.config.keep_initializations,
self.config.max_params_files_to_cache, self.config.cache_metric, self.config.cache_strategy)
cleanup_params_files(self.config.output_dir, self.config.max_params_files_to_keep, self.state.checkpoint,
self.state.best_checkpoint, self.config.keep_initializations,
self.config.max_params_files_to_cache, self.config.cache_metric, self.config.cache_strategy)

def _save_trainer_states(self, fname):
trainer_save_states_no_dump_optimizer(self.trainer, fname)
Expand Down Expand Up @@ -701,9 +701,9 @@ def _cleanup(self, keep_training_state=False):
"""
Cleans parameter files, training state directory and waits for remaining decoding processes.
"""
utils.cleanup_params_files(self.config.output_dir, self.config.max_params_files_to_keep,
self.state.checkpoint, self.state.best_checkpoint, self.config.keep_initializations,
self.config.max_params_files_to_cache, self.config.cache_metric, self.config.cache_strategy)
cleanup_params_files(self.config.output_dir, self.config.max_params_files_to_keep,
self.state.checkpoint, self.state.best_checkpoint, self.config.keep_initializations,
self.config.max_params_files_to_cache, self.config.cache_metric, self.config.cache_strategy)

if not keep_training_state:
if os.path.exists(self.training_state_dirname):
Expand Down Expand Up @@ -927,3 +927,63 @@ def trainer_save_states_no_dump_optimizer(trainer: mx.gluon.Trainer, fname: str)
else:
with open(fname, 'wb') as fout:
fout.write(trainer._updaters[0].get_states(dump_optimizer=False))


def cleanup_params_files(output_folder: str, max_to_keep: int, checkpoint: int, best_checkpoint: int, keep_first: bool,
max_params_files_to_cache: int, cache_metric: str, cache_strategy: str):
"""
Deletes oldest parameter files from a model folder.
:param output_folder: Folder where param files are located.
:param max_to_keep: Maximum number of files to keep, negative to keep all.
:param checkpoint: Current checkpoint (i.e. index of last params file created).
:param best_checkpoint: Best checkpoint. The parameter file corresponding to this checkpoint will not be deleted.
:param keep_first: Don't delete the first checkpoint.
:param max_params_files_to_cache: Maximum number of best param files to cache.
:param cache_metric: Metric to determine best param files.
:param cache_strategy: Strategy to select 'best' param files.
"""
if max_to_keep <= 0:
return

# make sure we keep N best params files from .metrics file according to strategy.
top_n: Set[int] = set()
metrics_path = os.path.join(output_folder, C.METRICS_NAME)

if max_params_files_to_cache > 0 and os.path.exists(metrics_path):
maximize = C.METRIC_MAXIMIZE[cache_metric]
points = utils.get_validation_metric_points(model_path=output_folder, metric=cache_metric)

if cache_strategy == C.AVERAGE_BEST:
# N best scoring points
top = average.strategy_best(points, max_params_files_to_cache, maximize)

elif cache_strategy == C.AVERAGE_LAST:
# N sequential points ending with overall best
top = average.strategy_last(points, max_params_files_to_cache, maximize)

elif cache_strategy == C.AVERAGE_LIFESPAN:
# Track lifespan of every "new best" point
# Points dominated by a previous better point have lifespan 0
top = average.strategy_lifespan(points, max_params_files_to_cache, maximize)
else:
raise RuntimeError("Unknown strategy, options are: %s" % C.AVERAGE_CHOICES)

top_n = set([x[1] for x in top])

# get rid of params files that are neither among the latest, nor among the best
existing_files = glob.glob(os.path.join(output_folder, C.PARAMS_PREFIX + "*"))
params_name_with_dir = os.path.join(output_folder, C.PARAMS_NAME)

for n in range(1 if keep_first else 0, max(1, checkpoint - max_to_keep + 1)):
if n != best_checkpoint:
param_fname_n = params_name_with_dir % n
if param_fname_n in existing_files and n not in top_n:
try:
os.remove(param_fname_n)
except FileNotFoundError:
# This can be occur on file systems with higher latency,
# such as distributed file systems. While repeated
# occurrences of this warning may indicate a problem, seeing
# one or two warnings during training is usually fine.
logger.warning('File has already been removed: %s', param_fname_n)
65 changes: 0 additions & 65 deletions sockeye/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""
import binascii
import errno
import glob
import gzip
import itertools
import logging
Expand All @@ -37,7 +36,6 @@
from . import __version__, constants as C
from . import horovod_mpi
from .log import log_sockeye_version, log_mxnet_version
from . import average

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -689,69 +687,6 @@ def metric_value_is_better(new: float, old: float, metric: str) -> bool:
return new < old


def cleanup_params_files(output_folder: str, max_to_keep: int, checkpoint: int, best_checkpoint: int, keep_first: bool,
max_params_files_to_cache: int, cache_metric: str, cache_strategy: str):
"""
Deletes oldest parameter files from a model folder.
:param output_folder: Folder where param files are located.
:param max_to_keep: Maximum number of files to keep, negative to keep all.
:param checkpoint: Current checkpoint (i.e. index of last params file created).
:param best_checkpoint: Best checkpoint. The parameter file corresponding to this checkpoint will not be deleted.
:param keep_first: Don't delete the first checkpoint.
"""
if max_to_keep <= 0:
return

#
# make sure we keep N best params files from .metrics file according to strategy.
#

top_n : Set[int] = set()
metrics_path = os.path.join(output_folder, C.METRICS_NAME)

if max_params_files_to_cache > 0 and os.path.exists(metrics_path):
maximize = C.METRIC_MAXIMIZE[cache_metric]
points = get_validation_metric_points(model_path=output_folder, metric=cache_metric)

if cache_strategy == C.AVERAGE_BEST:
# N best scoring points
top = average.strategy_best(points, max_params_files_to_cache, maximize)

elif cache_strategy == C.AVERAGE_LAST:
# N sequential points ending with overall best
top = average.strategy_last(points, max_params_files_to_cache, maximize)

elif cache_strategy == C.AVERAGE_LIFESPAN:
# Track lifespan of every "new best" point
# Points dominated by a previous better point have lifespan 0
top = average.strategy_lifespan(points, max_params_files_to_cache, maximize)
else:
raise RuntimeError("Unknown strategy, options are: %s" % (C.AVERAGE_CHOICES))

top_n = set([x[1] for x in top])

#
# get rid of params files that are neither among the latest, nor among the best
#

existing_files = glob.glob(os.path.join(output_folder, C.PARAMS_PREFIX + "*"))
params_name_with_dir = os.path.join(output_folder, C.PARAMS_NAME)

for n in range(1 if keep_first else 0, max(1, checkpoint - max_to_keep + 1)):
if n != best_checkpoint:
param_fname_n = params_name_with_dir % n
if param_fname_n in existing_files and n not in top_n:
try:
os.remove(param_fname_n)
except FileNotFoundError:
# This can be occur on file systems with higher latency,
# such as distributed file systems. While repeated
# occurrences of this warning may indicate a problem, seeing
# one or two warnings during training is usually fine.
logger.warning('File has already been removed: %s', param_fname_n)


def split(data: mx.nd.NDArray,
num_outputs: int,
axis: int = 1,
Expand Down
22 changes: 11 additions & 11 deletions test/unit/test_lexicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_topk_lexicon():
lex.create(input_lex_path, k)

# Test against known lexicon
expected = np.zeros((len(C.VOCAB_SYMBOLS) + len(vocab_list), k), dtype=np.int)
expected = np.zeros((len(C.VOCAB_SYMBOLS) + len(vocab_list), k), dtype=np.int32)
# a -> special + a b
expected[len(C.VOCAB_SYMBOLS), :2] = [len(C.VOCAB_SYMBOLS), len(C.VOCAB_SYMBOLS) + 1]
# b -> special + b
Expand All @@ -56,30 +56,30 @@ def test_topk_lexicon():
assert np.all(lex.lex == expected_sorted)

# Test lookup
trg_ids = lex.get_trg_ids(np.array([[vocab["a"], vocab["c"]]], dtype=np.int))
expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["a", "b"]], dtype=np.int)
trg_ids = lex.get_trg_ids(np.array([[vocab["a"], vocab["c"]]], dtype=np.int32))
expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["a", "b"]], dtype=np.int32)
assert np.all(trg_ids == expected)

trg_ids = lex.get_trg_ids(np.array([[vocab["b"]]], dtype=np.int))
expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["b"]], dtype=np.int)
trg_ids = lex.get_trg_ids(np.array([[vocab["b"]]], dtype=np.int32))
expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["b"]], dtype=np.int32)
assert np.all(trg_ids == expected)

trg_ids = lex.get_trg_ids(np.array([[vocab["c"]]], dtype=np.int))
expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS], dtype=np.int)
trg_ids = lex.get_trg_ids(np.array([[vocab["c"]]], dtype=np.int32))
expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS], dtype=np.int32)
assert np.all(trg_ids == expected)

# Test load with smaller k
small_k = k - 1
lex.load(json_lex_path, k=small_k)
assert lex.lex.shape[1] == small_k
trg_ids = lex.get_trg_ids(np.array([[vocab["a"]]], dtype=np.int))
expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["a"]], dtype=np.int)
trg_ids = lex.get_trg_ids(np.array([[vocab["a"]]], dtype=np.int32))
expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["a"]], dtype=np.int32)
assert np.all(trg_ids == expected)

# Test load with larger k
large_k = k + 1
lex.load(json_lex_path, k=large_k)
assert lex.lex.shape[1] == k
trg_ids = lex.get_trg_ids(np.array([[vocab["a"], vocab["c"]]], dtype=np.int))
expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["a", "b"]], dtype=np.int)
trg_ids = lex.get_trg_ids(np.array([[vocab["a"], vocab["c"]]], dtype=np.int32))
expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["a", "b"]], dtype=np.int32)
assert np.all(trg_ids == expected)
5 changes: 2 additions & 3 deletions test/unit/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
import sockeye.model
import sockeye.training
import sockeye.constants as C
import sockeye.utils


def test_cleanup_param_files():
with tempfile.TemporaryDirectory() as tmp_dir:
for n in itertools.chain(range(1, 20, 2), range(21, 41)):
# Create empty files
open(os.path.join(tmp_dir, C.PARAMS_NAME % n), "w").close()
sockeye.utils.cleanup_params_files(tmp_dir, 5, 40, 17, False, 8, "perplexity", "best")
sockeye.training.cleanup_params_files(tmp_dir, 5, 40, 17, False, 8, "perplexity", "best")

expectedSurviving = set([os.path.join(tmp_dir, C.PARAMS_NAME % n)
for n in [17, 36, 37, 38, 39, 40]])
Expand All @@ -44,7 +43,7 @@ def test_cleanup_param_files_keep_first():
for n in itertools.chain(range(0, 20, 2), range(21, 41)):
# Create empty files
open(os.path.join(tmp_dir, C.PARAMS_NAME % n), "w").close()
sockeye.utils.cleanup_params_files(tmp_dir, 5, 40, 16, True, 8, "perplexity", "best")
sockeye.training.cleanup_params_files(tmp_dir, 5, 40, 16, True, 8, "perplexity", "best")

expectedSurviving = set([os.path.join(tmp_dir, C.PARAMS_NAME % n)
for n in [0, 16, 36, 37, 38, 39, 40]])
Expand Down

0 comments on commit c3870e3

Please sign in to comment.