From 0ed81fd75f66d4fc6413e1e9a6985860cf151568 Mon Sep 17 00:00:00 2001 From: Tobias Domhan Date: Mon, 18 Dec 2017 17:36:59 +0100 Subject: [PATCH] Sharded data iterator. (#241) * Sharded data iterator. * Added remaining sockeye/*.py files to typechecked files (#242) * Tests to see we get the right number of batches. * Improved log message about vocabs a little bit * Factored validation iter creation into separate function * Covering prepare data in the system tests. * Writing a data version. --- CHANGELOG.md | 20 +- setup.py | 9 +- sockeye/arguments.py | 227 +++-- sockeye/constants.py | 15 + sockeye/convolution.py | 2 +- sockeye/data_io.py | 1533 ++++++++++++++++++++++-------- sockeye/decoder.py | 5 +- sockeye/encoder.py | 5 +- sockeye/inference.py | 28 +- sockeye/model.py | 12 +- sockeye/prepare_data.py | 76 ++ sockeye/train.py | 179 +++- sockeye/training.py | 27 +- sockeye/utils.py | 108 ++- sockeye/vocab.py | 75 +- test/common.py | 67 +- test/system/test_seq_copy_sys.py | 23 +- test/unit/test_arguments.py | 97 +- test/unit/test_checkpoint.py | 82 -- test/unit/test_data_io.py | 520 +++++++++- test/unit/test_translate.py | 3 +- test/unit/test_utils.py | 40 + typechecked-files | 7 +- 23 files changed, 2354 insertions(+), 806 deletions(-) create mode 100644 sockeye/prepare_data.py delete mode 100644 test/unit/test_checkpoint.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 528d2e3f3..fb255f6b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,17 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [1.15.6] +### Added + - New CLI `sockeye.prepare_data` for preprocessing the training data only once before training, + potentially splitting large datasets into shards. At training time only one shard is loaded into memory at a time, + limiting the maximum memory usage. + +### Changed + - Instead of using the ```--source``` and ```--target``` arguments ```sockeye.train``` now accepts a + ```--prepared-data``` argument pointing to the folder containing the preprocessed and sharded data. Using the raw + training data is still possible and now consumes less memory. + ## [1.15.5] ### Added - Optionally apply query, key and value projections to the source and target hidden vectors in the CNN model @@ -33,8 +44,13 @@ Each version section may have have subsections for: _Added_, _Changed_, _Removed ## [1.15.0] ### Added -- Added support for Swish-1 (SiLU) activation to transformer models ([Ramachandran et al. 2017: Searching for Activation Functions](https://arxiv.org/pdf/1710.05941.pdf), [Elfwing et al. 2017: Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning](https://arxiv.org/pdf/1702.03118.pdf)). Use `--transformer-activation-type swish1`. -- Added support for GELU activation to transformer models ([Hendrycks and Gimpel 2016: Bridging Nonlinearities and Stochastic Regularizers with Gaussian Error Linear Units](https://arxiv.org/pdf/1606.08415.pdf). Use `--transformer-activation-type gelu`. +- Added support for Swish-1 (SiLU) activation to transformer models +([Ramachandran et al. 2017: Searching for Activation Functions](https://arxiv.org/pdf/1710.05941.pdf), +[Elfwing et al. 2017: Sigmoid-Weighted Linear Units for Neural Network Function Approximation +in Reinforcement Learning](https://arxiv.org/pdf/1702.03118.pdf)). Use `--transformer-activation-type swish1`. +- Added support for GELU activation to transformer models ([Hendrycks and Gimpel 2016: Bridging Nonlinearities and +Stochastic Regularizers with Gaussian Error Linear Units](https://arxiv.org/pdf/1606.08415.pdf). +Use `--transformer-activation-type gelu`. ## [1.14.3] ### Changed diff --git a/setup.py b/setup.py index 931900b85..fa5041c0f 100644 --- a/setup.py +++ b/setup.py @@ -110,15 +110,16 @@ def get_requirements(filename): entry_points={ 'console_scripts': [ - 'sockeye-train = sockeye.train:main', - 'sockeye-translate = sockeye.translate:main', 'sockeye-average = sockeye.average:main', 'sockeye-embeddings = sockeye.embeddings:main', 'sockeye-evaluate = sockeye.evaluate:main', - 'sockeye-vocab = sockeye.vocab:main', + 'sockeye-extract-parameters = sockeye.extract_parameters:main', 'sockeye-lexicon = sockeye.lexicon:main', - 'sockeye-extract = sockeye.extract_parameters:main', 'sockeye-init-embed = sockeye.init_embedding:main' + 'sockeye-prepare-data = sockeye.prepare_data:main' + 'sockeye-train = sockeye.train:main', + 'sockeye-translate = sockeye.translate:main', + 'sockeye-vocab = sockeye.vocab:main' ], }, diff --git a/sockeye/arguments.py b/sockeye/arguments.py index 9b9ff84c5..39ce6cb76 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -41,6 +41,22 @@ def check_regular_file(value_to_check): return check_regular_file +def regular_folder() -> Callable: + """ + Returns a method that can be used in argument parsing to check the argument is a directory. + + :return: A method that can be used as a type in argparse. + """ + + def check_regular_directory(value_to_check): + value_to_check = str(value_to_check) + if not os.path.isdir(value_to_check): + raise argparse.ArgumentTypeError("must be a directory.") + return value_to_check + + return check_regular_directory + + def int_greater_or_equal(threshold: int) -> Callable: """ Returns a method that can be used in argument parsing to check that the argument is greater or equal to `threshold`. @@ -235,62 +251,117 @@ def add_logging_args(params): help='Suppress console logging.') -def add_io_args(params): - data_params = params.add_argument_group("Data & I/O") +def add_training_data_args(params, required=False): + params.add_argument(C.TRAINING_ARG_SOURCE, '-s', + required=required, + type=regular_file(), + help='Source side of parallel training data.') + params.add_argument(C.TRAINING_ARG_TARGET, '-t', + required=required, + type=regular_file(), + help='Target side of parallel training data.') + + +def add_validation_data_params(params): + params.add_argument('--validation-source', '-vs', + required=True, + type=regular_file(), + help='Source side of validation data.') + params.add_argument('--validation-target', '-vt', + required=True, + type=regular_file(), + help='Target side of validation data.') + + +def add_prepared_data_args(params): + params.add_argument(C.TRAINING_ARG_PREPARED_DATA, '-d', + type=regular_folder(), + help='Prepared training data directory created through python -m sockeye.prepare_data.') + + +def add_monitoring_args(params): + params.add_argument('--use-tensorboard', + action='store_true', + help='Track metrics through tensorboard. Requires installed tensorboard.') + + params.add_argument('--monitor-pattern', + default=None, + type=str, + help="Pattern to match outputs/weights/gradients to monitor. '.*' monitors everything. " + "Default: %(default)s.") + + params.add_argument('--monitor-stat-func', + default=C.STAT_FUNC_DEFAULT, + choices=list(C.MONITOR_STAT_FUNCS.keys()), + help="Statistics function to run on monitored outputs/weights/gradients. " + "Default: %(default)s.") + + +def add_training_output_args(params): + params.add_argument('--output', '-o', + required=True, + help='Folder where model & training results are written to.') + params.add_argument('--overwrite-output', + action='store_true', + help='Delete all contents of the model directory if it already exists.') + + +def add_training_io_args(params): + params = params.add_argument_group("Data & I/O") + + # Unfortunately we must set --source/--target to not required as we either accept these parameters + # or --prepared-data which can not easily be encoded in argparse. + add_training_data_args(params, required=False) + add_prepared_data_args(params) + add_validation_data_params(params) + add_bucketing_args(params) + add_vocab_args(params) + add_training_output_args(params) + add_monitoring_args(params) - data_params.add_argument('--source', '-s', - required=True, - type=regular_file(), - help='Source side of parallel training data.') - data_params.add_argument('--target', '-t', - required=True, - type=regular_file(), - help='Target side of parallel training data.') - data_params.add_argument('--limit', - default=None, - type=int, - help="Maximum number of training sequences to read. Default: %(default)s.") - - data_params.add_argument('--validation-source', '-vs', - required=True, - type=regular_file(), - help='Source side of validation data.') - data_params.add_argument('--validation-target', '-vt', - required=True, - type=regular_file(), - help='Target side of validation data.') - data_params.add_argument('--output', '-o', - required=True, - help='Folder where model & training results are written to.') - data_params.add_argument('--overwrite-output', - action='store_true', - help='Delete all contents of the model directory if it already exists.') - - data_params.add_argument('--source-vocab', - required=False, - default=None, - help='Existing source vocabulary (JSON)') - data_params.add_argument('--target-vocab', - required=False, - default=None, - help='Existing target vocabulary (JSON)') - - data_params.add_argument('--use-tensorboard', - action='store_true', - help='Track metrics through tensorboard. Requires installed tensorboard.') - - data_params.add_argument('--monitor-pattern', - default=None, - type=str, - help="Pattern to match outputs/weights/gradients to monitor. '.*' monitors everything. " - "Default: %(default)s.") +def add_bucketing_args(params): + params.add_argument('--no-bucketing', + action='store_true', + help='Disable bucketing: always unroll the graph to --max-seq-len. Default: %(default)s.') - data_params.add_argument('--monitor-stat-func', - default=C.STAT_FUNC_DEFAULT, - choices=list(C.MONITOR_STAT_FUNCS.keys()), - help="Statistics function to run on monitored outputs/weights/gradients. " - "Default: %(default)s.") + params.add_argument('--bucket-width', + type=int_greater_or_equal(1), + default=10, + help='Width of buckets in tokens. Default: %(default)s.') + + params.add_argument('--max-seq-len', + type=multiple_values(num_values=2, greater_or_equal=1), + default=(100, 100), + help='Maximum sequence length in tokens. Note that the target side will be extended by ' + 'the (beginning of sentence) token, increasing the effective target length. ' + 'Use "x:x" to specify separate values for src&tgt. Default: %(default)s.') + + +def add_prepare_data_cli_args(params): + params = params.add_argument_group("Data preparation.") + add_training_data_args(params, required=True) + add_vocab_args(params) + add_bucketing_args(params) + + params.add_argument('--num-samples-per-shard', + default=1000000, + help='The approximate number of samples per shard. Default: %(default)s.') + + params.add_argument('--min-num-shards', + default=1, + type=int_greater_or_equal(1), + help='The minimum number of shards to use, even if they would not ' + 'reach the desired number of samples per shard. Default: %(default)s.') + + params.add_argument('--seed', + type=int, + default=13, + help='Random seed used that makes shard assignments deterministic. Default: %(default)s.') + + params.add_argument('--output', '-o', + required=True, + help='Folder where the prepared and possibly sharded data is written to.') def add_device_args(params): @@ -319,16 +390,29 @@ def add_device_args(params): 'write permissions.') -def add_vocab_args(model_params): - model_params.add_argument('--num-words', - type=multiple_values(num_values=2, greater_or_equal=0), - default=(50000, 50000), - help='Maximum vocabulary size. Use "x:x" to specify separate values for src&tgt. ' - 'Default: %(default)s.') - model_params.add_argument('--word-min-count', - type=multiple_values(num_values=2, greater_or_equal=1), - default=(1, 1), - help='Minimum frequency of words to be included in vocabularies. Default: %(default)s.') +def add_vocab_args(params): + params.add_argument('--source-vocab', + required=False, + default=None, + help='Existing source vocabulary (JSON).') + params.add_argument('--target-vocab', + required=False, + default=None, + help='Existing target vocabulary (JSON).') + params.add_argument(C.VOCAB_ARG_SHARED_VOCAB, + action='store_true', + default=False, + help='Share source and target vocabulary. ' + 'Will be automatically turned on when using weight tying. Default: %(default)s.') + params.add_argument('--num-words', + type=multiple_values(num_values=2, greater_or_equal=0), + default=(50000, 50000), + help='Maximum vocabulary size. Use "x:x" to specify separate values for src&tgt. ' + 'Default: %(default)s.') + params.add_argument('--word-min-count', + type=multiple_values(num_values=2, greater_or_equal=1), + default=(1, 1), + help='Minimum frequency of words to be included in vocabularies. Default: %(default)s.') def add_model_parameters(params): @@ -344,8 +428,6 @@ def add_model_parameters(params): help="Allow misssing parameters when initializing model parameters from file. " "Default: %(default)s.") - add_vocab_args(model_params) - model_params.add_argument('--encoder', choices=C.ENCODERS, default=C.RNN_NAME, @@ -539,12 +621,6 @@ def add_model_parameters(params): help='The type of weight tying. source embeddings=src, target embeddings=trg, ' 'target softmax weight matrix=softmax. Default: %(default)s.') - model_params.add_argument('--max-seq-len', - type=multiple_values(num_values=2, greater_or_equal=1), - default=(100, 100), - help='Maximum sequence length in tokens. ' - 'Use "x:x" to specify separate values for src&tgt. Default: %(default)s.') - model_params.add_argument('--layer-normalization', action="store_true", help="Adds layer normalization before non-linear activations. " "This includes MLP attention, RNN decoder state initialization, " @@ -576,13 +652,6 @@ def add_training_args(params): type=str, default='replicate', help=argparse.SUPPRESS) - train_params.add_argument('--no-bucketing', - action='store_true', - help='Disable bucketing: always unroll to the max_len.') - train_params.add_argument('--bucket-width', - type=int_greater_or_equal(1), - default=10, - help='Width of buckets in tokens. Default: %(default)s.') train_params.add_argument('--loss', default=C.CROSS_ENTROPY, @@ -824,7 +893,7 @@ def add_training_args(params): def add_train_cli_args(params): - add_io_args(params) + add_training_io_args(params) add_model_parameters(params) add_training_args(params) add_device_args(params) diff --git a/sockeye/constants.py b/sockeye/constants.py index 5591c397f..dc860399f 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -222,6 +222,12 @@ "keep_last_params"] # Other argument constants +TRAINING_ARG_SOURCE = "--source" +TRAINING_ARG_TARGET = "--target" +TRAINING_ARG_PREPARED_DATA = "--prepared-data" + +VOCAB_ARG_SHARED_VOCAB = "--shared-vocab" + INFERENCE_ARG_INPUT_LONG = "--input" INFERENCE_ARG_INPUT_SHORT = "-i" INFERENCE_ARG_OUTPUT_LONG = "--output" @@ -316,3 +322,12 @@ LARGE_POSITIVE_VALUE = 99999999. LARGE_NEGATIVE_VALUE = -LARGE_POSITIVE_VALUE + +# data sharding +SHARD_NAME = "shard.%05d" +SHARD_SOURCE = SHARD_NAME + ".source" +SHARD_TARGET = SHARD_NAME + ".target" +DATA_CONFIG = "data.config" +PREPARED_DATA_VERSION_FILE = "data.version" +PREPARED_DATA_VERSION = 1 + diff --git a/sockeye/convolution.py b/sockeye/convolution.py index db6289698..3b6492a1f 100644 --- a/sockeye/convolution.py +++ b/sockeye/convolution.py @@ -35,7 +35,7 @@ def __init__(self, kernel_width: int, num_hidden: int, act_type: str = C.GLU, - weight_normalization: bool = False): + weight_normalization: bool = False) -> None: super().__init__() self.kernel_width = kernel_width self.num_hidden = num_hidden diff --git a/sockeye/data_io.py b/sockeye/data_io.py index 2741dd9bd..f49151ee6 100644 --- a/sockeye/data_io.py +++ b/sockeye/data_io.py @@ -15,20 +15,23 @@ Implements data iterators and I/O related functions for sequence-to-sequence models. """ import bisect -import gzip import logging -import math +import os import pickle import random +from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Dict, Iterator, Iterable, List, NamedTuple, Optional, Tuple +from contextlib import ExitStack +from typing import Any, cast, Dict, Iterator, Iterable, List, Optional, Sized, Tuple +import math import mxnet as mx import numpy as np -from sockeye.utils import check_condition from . import config from . import constants as C +from . import vocab +from .utils import check_condition, smart_open, get_tokens, OnlineMeanAndVariance logger = logging.getLogger(__name__) @@ -67,10 +70,10 @@ def define_parallel_buckets(max_seq_len_source: int, target_step_size = bucket_width if length_ratio >= 1.0: # target side is longer -> scale source - source_step_size = max(1, int(bucket_width / length_ratio)) + source_step_size = max(1, int(round(bucket_width / length_ratio))) else: # source side is longer, -> scale target - target_step_size = max(1, int(bucket_width * length_ratio)) + target_step_size = max(1, int(round(bucket_width * length_ratio))) source_buckets = define_buckets(max_seq_len_source, step=source_step_size) target_buckets = define_buckets(max_seq_len_target, step=target_step_size) # Extra buckets @@ -83,7 +86,9 @@ def define_parallel_buckets(max_seq_len_source: int, target_buckets = [max(2, b) for b in target_buckets] parallel_buckets = list(zip(source_buckets, target_buckets)) # deduplicate for return - return list(OrderedDict.fromkeys(parallel_buckets)) + buckets = list(OrderedDict.fromkeys(parallel_buckets)) + buckets.sort() + return buckets def get_bucket(seq_len: int, buckets: List[int]) -> Optional[int]: @@ -100,25 +105,597 @@ def get_bucket(seq_len: int, buckets: List[int]) -> Optional[int]: return buckets[bucket_idx] -def length_statistics(source_sentences: Iterable[List[Any]], - target_sentences: Iterable[List[Any]]) -> Tuple[float, float]: +class BucketBatchSize: + """ + :param bucket: The corresponding bucket. + :param batch_size: Number of sequences in each batch. + :param average_words_per_batch: Approximate number of non-padding tokens in each batch. + """ + + def __init__(self, bucket: Tuple[int, int], batch_size: int, average_words_per_batch: float) -> None: + self.bucket = bucket + self.batch_size = batch_size + self.average_words_per_batch = average_words_per_batch + + +def define_bucket_batch_sizes(buckets: List[Tuple[int, int]], + batch_size: int, + batch_by_words: bool, + batch_num_devices: int, + data_target_average_len: List[Optional[float]]) -> List[BucketBatchSize]: + """ + Computes bucket-specific batch sizes (sentences, average_words). + + If sentence-based batching: number of sentences is the same for each batch, determines the + number of words. Hence all batch sizes for each bucket are equal. + + If word-based batching: number of sentences for each batch is set to the multiple of number + of devices that produces the number of words closest to the target batch size. Average + target sentence length (non-padding symbols) is used for word number calculations. + + :param buckets: Bucket list. + :param batch_size: Batch size. + :param batch_by_words: Batch by words. + :param batch_num_devices: Number of devices. + :param data_target_average_len: Optional average target length for each bucket. + """ + check_condition(len(data_target_average_len) == len(buckets), + "Must provide None or average target length for each bucket") + data_target_average_len = list(data_target_average_len) + bucket_batch_sizes = [] # type: List[BucketBatchSize] + largest_total_num_words = 0 + for buck_idx, bucket in enumerate(buckets): + # Target/label length with padding + padded_seq_len = bucket[1] + # Average target/label length excluding padding + if data_target_average_len[buck_idx] is None: + data_target_average_len[buck_idx] = padded_seq_len + average_seq_len = data_target_average_len[buck_idx] + + # Word-based: num words determines num sentences + # Sentence-based: num sentences determines num words + if batch_by_words: + check_condition(padded_seq_len <= batch_size, "Word batch size must cover sequence lengths for all" + " buckets: (%d > %d)" % (padded_seq_len, batch_size)) + # Multiple of number of devices (int) closest to target number of words, assuming each sentence is of + # average length + batch_size_seq = batch_num_devices * round((batch_size / average_seq_len) / batch_num_devices) + batch_size_word = batch_size_seq * average_seq_len + else: + batch_size_seq = batch_size + batch_size_word = batch_size_seq * average_seq_len + bucket_batch_sizes.append(BucketBatchSize(bucket, batch_size_seq, batch_size_word)) + # Track largest number of word samples in a batch + largest_total_num_words = max(largest_total_num_words, batch_size_seq * max(*bucket)) + + # Final step: guarantee that largest bucket by sequence length also has largest total batch size. + # When batching by sentences, this will already be the case. + if batch_by_words: + padded_seq_len = max(*buckets[-1]) + average_seq_len = data_target_average_len[-1] + while bucket_batch_sizes[-1].batch_size * padded_seq_len < largest_total_num_words: + bucket_batch_sizes[-1] = BucketBatchSize( + bucket_batch_sizes[-1].bucket, + bucket_batch_sizes[-1].batch_size + batch_num_devices, + bucket_batch_sizes[-1].average_words_per_batch + batch_num_devices * average_seq_len) + return bucket_batch_sizes + + +def calculate_length_statistics(source_sentences: Iterable[List[Any]], + target_sentences: Iterable[List[Any]], + max_seq_len_source: int, + max_seq_len_target: int) -> 'LengthStatistics': """ Returns mean and standard deviation of target-to-source length ratios of parallel corpus. :param source_sentences: Source sentences. :param target_sentences: Target sentences. - :return: Mean and standard deviation of length ratios. + :param max_seq_len_source: Maximum source sequence length. + :param max_seq_len_target: Maximum target sequence length. + :return: The number of sentences as well as the mean and standard deviation of target to source length ratios. + """ + mean_and_variance = OnlineMeanAndVariance() + + for target, source in zip(target_sentences, source_sentences): + source_len = len(source) + target_len = len(target) + if source_len > max_seq_len_source or target_len > max_seq_len_target: + continue + + length_ratio = target_len / source_len + mean_and_variance.update(length_ratio) + + num_sents = mean_and_variance.count + mean = mean_and_variance.mean + std = math.sqrt(mean_and_variance.variance) + return LengthStatistics(num_sents, mean, std) + + +def analyze_sequence_lengths(source: str, + target: str, + vocab_source: Dict[str, int], + vocab_target: Dict[str, int], + max_seq_len_source: int, + max_seq_len_target: int) -> 'LengthStatistics': + train_source_sentences = SequenceReader(source, vocab_source, add_bos=False) + # Length statistics are calculated on the raw sentences without special tokens, such as the BOS, as these can + # have a a large impact on the length ratios, especially with lots of short sequences. + train_target_sentences = SequenceReader(target, vocab_target, add_bos=False) + + length_statistics = calculate_length_statistics(train_source_sentences, train_target_sentences, + max_seq_len_source, + # Take into account the BOS symbol that is added later + max_seq_len_target - 1) + check_condition(train_source_sentences.is_done() and train_target_sentences.is_done(), + "Different number of lines in source and target data.") + + logger.info("%d sequences of maximum length (%d, %d) in '%s' and '%s'.", + length_statistics.num_sents, max_seq_len_source, max_seq_len_target, source, target) + logger.info("Mean training target/source length ratio: %.2f (+-%.2f)", + length_statistics.length_ratio_mean, + length_statistics.length_ratio_std) + return length_statistics + + +class DataStatisticsAccumulator: + + def __init__(self, + buckets: List[Tuple[int, int]], + vocab_source: Dict[str, int], + vocab_target: Dict[str, int], + length_ratio_mean: float, + length_ratio_std: float) -> None: + self.buckets = buckets + num_buckets = len(buckets) + self.length_ratio_mean = length_ratio_mean + self.length_ratio_std = length_ratio_std + self.unk_id_source = vocab_source[C.UNK_SYMBOL] + self.unk_id_target = vocab_target[C.UNK_SYMBOL] + self.size_vocab_source = len(vocab_source) + self.size_vocab_target = len(vocab_target) + self.num_sents = 0 + self.num_discarded = 0 + self.num_tokens_source = 0 + self.num_tokens_target = 0 + self.num_unks_source = 0 + self.num_unks_target = 0 + self.max_observed_len_source = 0 + self.max_observed_len_target = 0 + self._mean_len_target_per_bucket = [OnlineMeanAndVariance() for _ in range(num_buckets)] + + def sequence_pair(self, + source: List[int], + target: List[int], + bucket_idx: Optional[int]): + source_len = len(source) + target_len = len(target) + + if bucket_idx is None: + self.num_discarded += 1 + return + + self._mean_len_target_per_bucket[bucket_idx].update(target_len) + + self.num_sents += 1 + self.num_tokens_source += source_len + self.num_tokens_target += target_len + self.max_observed_len_source = max(source_len, self.max_observed_len_source) + self.max_observed_len_target = max(target_len, self.max_observed_len_target) + + self.num_unks_source += source.count(self.unk_id_source) + self.num_unks_target += target.count(self.unk_id_target) + + @property + def mean_len_target_per_bucket(self) -> List[Optional[float]]: + return [mean_and_variance.mean if mean_and_variance.count > 0 else None + for mean_and_variance in self._mean_len_target_per_bucket] + + @property + def statistics(self): + num_sents_per_bucket = [mean_and_variance.count for mean_and_variance in self._mean_len_target_per_bucket] + return DataStatistics(num_sents=self.num_sents, + num_discarded=self.num_discarded, + num_tokens_source=self.num_tokens_source, + num_tokens_target=self.num_tokens_target, + num_unks_source=self.num_unks_source, + num_unks_target=self.num_unks_target, + max_observed_len_source=self.max_observed_len_source, + max_observed_len_target=self.max_observed_len_target, + size_vocab_source=self.size_vocab_source, + size_vocab_target=self.size_vocab_target, + length_ratio_mean=self.length_ratio_mean, + length_ratio_std=self.length_ratio_std, + buckets=self.buckets, + num_sents_per_bucket=num_sents_per_bucket, + mean_len_target_per_bucket=self.mean_len_target_per_bucket) + + +def shard_data(source_fname: str, + target_fname: str, + vocab_source: Dict[str, int], + vocab_target: Dict[str, int], + num_shards: int, + buckets: List[Tuple[int, int]], + length_ratio_mean: float, + length_ratio_std: float, + output_prefix: str) -> Tuple[List[Tuple[str, str, 'DataStatistics']], 'DataStatistics']: + """ + Assign int-coded source/target sentence pairs to shards at random. + + :param source_fname: The file name of the source file. + :param target_fname: The file name of the target file. + :param vocab_source: Source vocabulary. + :param vocab_target: Target vocabulary. + :param num_shards: The total number of shards. + :param buckets: Bucket list. + :param length_ratio_mean: Mean length ratio. + :param length_ratio_std: Standard deviation of length ratios. + :param output_prefix: The prefix under which the shard files will be created. + :return: Tuple of source, target file names and statistics for each shard as well as global statistics. + """ + os.makedirs(output_prefix, exist_ok=True) + source_shard_fnames = [os.path.join(output_prefix, C.SHARD_SOURCE % i) + for i in range(num_shards)] # type: List[str] + target_shard_fnames = [os.path.join(output_prefix, C.SHARD_TARGET % i) + for i in range(num_shards)] # type: List[str] + + data_stats_accumulator = DataStatisticsAccumulator(buckets, vocab_source, vocab_target, + length_ratio_mean, length_ratio_std) + per_shard_stat_accumulators = [DataStatisticsAccumulator(buckets, vocab_source, vocab_target, length_ratio_mean, + length_ratio_std) for shard_idx in range(num_shards)] + + with ExitStack() as exit_stack: + source_shards = [] + target_shards = [] + # create shard files: + for fname in source_shard_fnames: + source_shards.append(exit_stack.enter_context(smart_open(fname, mode="wt"))) + for fname in target_shard_fnames: + target_shards.append(exit_stack.enter_context(smart_open(fname, mode="wt"))) + shards = list(zip(source_shards, target_shards, per_shard_stat_accumulators)) + + source_iter = SequenceReader(source_fname, vocab_source, add_bos=False) + target_iter = SequenceReader(target_fname, vocab_target, add_bos=True) + + random_shard_iter = iter(lambda: random.choice(shards), None) + + for source, target, (source_shard, target_shard, shard_stats) in zip(source_iter, target_iter, + random_shard_iter): + source_len = len(source) + target_len = len(target) + + buck_idx, buck = get_parallel_bucket(buckets, source_len, target_len) + data_stats_accumulator.sequence_pair(source, target, buck_idx) + shard_stats.sequence_pair(source, target, buck_idx) + + if buck is None: + continue + + source_shard.write(ids2strids(source) + "\n") + target_shard.write(ids2strids(target) + "\n") + + per_shard_stats = [shard_stat_accumulator.statistics for shard_stat_accumulator in per_shard_stat_accumulators] + + return list(zip(source_shard_fnames, target_shard_fnames, per_shard_stats)), data_stats_accumulator.statistics + + +class RawParallelDatasetLoader: + """ + Loads a data set of variable-length parallel source/target sequences into buckets of NDArrays. + + :param buckets: Bucket list. + :param eos_id: End-of-sentence id. + :param pad_id: Padding id. + :param eos_id: Unknown id. + :param dtype: Data type. """ - length_ratios = np.array([len(t)/float(len(s)) for t, s in zip(target_sentences, source_sentences)]) - mean = np.asscalar(np.mean(length_ratios)) - std = np.asscalar(np.std(length_ratios)) - return mean, std + + def __init__(self, + buckets: List[Tuple[int, int]], + eos_id: int, + pad_id: int, + dtype: str = 'float32') -> None: + self.buckets = buckets + self.eos_id = eos_id + self.pad_id = pad_id + self.dtype = dtype + + def load(self, + source_sentences: Iterable[List[Any]], + target_sentences: Iterable[List[Any]], + num_samples_per_bucket: List[int]) -> 'ParallelDataSet': + + assert len(num_samples_per_bucket) == len(self.buckets) + + data_source = [np.full((num_samples, source_len), self.pad_id, dtype=self.dtype) + for (source_len, target_len), num_samples in zip(self.buckets, num_samples_per_bucket)] + data_target = [np.full((num_samples, target_len), self.pad_id, dtype=self.dtype) + for (source_len, target_len), num_samples in zip(self.buckets, num_samples_per_bucket)] + data_label = [np.full((num_samples, target_len), self.pad_id, dtype=self.dtype) + for (source_len, target_len), num_samples in zip(self.buckets, num_samples_per_bucket)] + + bucket_sample_index = [0 for buck in self.buckets] + + # track amount of padding introduced through bucketing + num_tokens_source = 0 + num_tokens_target = 0 + num_pad_source = 0 + num_pad_target = 0 + + # Bucket sentences as padded np arrays + for source, target in zip(source_sentences, target_sentences): + source_len = len(source) + target_len = len(target) + buck_index, buck = get_parallel_bucket(self.buckets, source_len, target_len) + if buck is None: + continue # skip this sentence pair + + num_tokens_source += buck[0] + num_tokens_target += buck[1] + num_pad_source += buck[0] - source_len + num_pad_target += buck[1] - target_len + + sample_index = bucket_sample_index[buck_index] + data_source[buck_index][sample_index, :source_len] = source + data_target[buck_index][sample_index, :target_len] = target + # NOTE(fhieber): while this is wasteful w.r.t memory, we need to explicitly create the label sequence + # with the EOS symbol here sentence-wise and not per-batch due to variable sequence length within a batch. + # Once MXNet allows item assignments given a list of indices (probably MXNet 1.0): e.g a[[0,1,5,2]] = x, + # we can try again to compute the label sequence on the fly in next(). + data_label[buck_index][sample_index, :target_len] = target[1:] + [self.eos_id] + + bucket_sample_index[buck_index] += 1 + + for i in range(len(data_source)): + data_source[i] = mx.nd.array(data_source[i], dtype=self.dtype) + data_target[i] = mx.nd.array(data_target[i], dtype=self.dtype) + data_label[i] = mx.nd.array(data_label[i], dtype=self.dtype) + + if num_tokens_source > 0 and num_tokens_target > 0: + logger.info("Created bucketed parallel data set. Introduced padding: source=%.1f%% target=%.1f%%)", + num_pad_source / num_tokens_source * 100, + num_pad_target / num_tokens_target * 100) + + return ParallelDataSet(data_source, data_target, data_label) + + +def get_num_shards(num_samples: int, samples_per_shard: int, min_num_shards: int) -> int: + """ + Returns the number of shards. + + :param num_samples: Number of training data samples. + :param samples_per_shard: Samples per shard. + :param min_num_shards: Minimum number of shards. + :return: Number of shards. + """ + return max(int(math.ceil(num_samples / samples_per_shard)), min_num_shards) + + +def prepare_data(source: str, target: str, + vocab_source: Dict[str, int], vocab_target: Dict[str, int], + vocab_source_path: Optional[str], vocab_target_path: Optional[str], + shared_vocab: bool, + max_seq_len_source: int, + max_seq_len_target: int, + bucketing: bool, + bucket_width: int, + samples_per_shard: int, + min_num_shards: int, + output_prefix: str, + keep_tmp_shard_files: bool = False): + logger.info("Preparing data.") + + # write vocabularies + vocab.vocab_to_json(vocab_source, os.path.join(output_prefix, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX) + vocab.vocab_to_json(vocab_target, os.path.join(output_prefix, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX) + + # Pass 1: get target/source length ratios. + length_statistics = analyze_sequence_lengths(source, target, vocab_source, vocab_target, + max_seq_len_source, max_seq_len_target) + + # define buckets + buckets = define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width, + length_statistics.length_ratio_mean) if bucketing else [ + (max_seq_len_source, max_seq_len_target)] + logger.info("Buckets: %s", buckets) + + # Pass 2: Randomly assign data to data shards + # no pre-processing yet, just write the sentences to different files + num_shards = get_num_shards(length_statistics.num_sents, samples_per_shard, min_num_shards) + logger.info("%d samples will be split into %d shard(s) (requested samples/shard=%d, min_num_shards=%d)." + % (length_statistics.num_sents, num_shards, samples_per_shard, min_num_shards)) + shards, data_statistics = shard_data(source_fname=source, + target_fname=target, + vocab_source=vocab_source, + vocab_target=vocab_target, + num_shards=num_shards, + buckets=buckets, + length_ratio_mean=length_statistics.length_ratio_mean, + length_ratio_std=length_statistics.length_ratio_std, + output_prefix=output_prefix) + data_statistics.log() + + data_loader = RawParallelDatasetLoader(buckets=buckets, + eos_id=vocab_target[C.EOS_SYMBOL], + pad_id=C.PAD_ID) + + # 3. convert each shard to serialized ndarrays + for shard_idx, (shard_source, shard_target, shard_stats) in enumerate(shards): + source_sentences = SequenceReader(shard_source, vocab=None) + target_sentences = SequenceReader(shard_target, vocab=None) + dataset = data_loader.load(source_sentences, target_sentences, shard_stats.num_sents_per_bucket) + shard_fname = os.path.join(output_prefix, C.SHARD_NAME % shard_idx) + shard_stats.log() + logger.info("Writing '%s'", shard_fname) + dataset.save(shard_fname) + + if not keep_tmp_shard_files: + os.remove(shard_source) + os.remove(shard_target) + + config_data = DataConfig(source=os.path.abspath(source), + target=os.path.abspath(target), + vocab_source=vocab_source_path, + vocab_target=vocab_target_path, + shared_vocab=shared_vocab, + num_shards=num_shards, + data_statistics=data_statistics, + max_seq_len_source=max_seq_len_source, + max_seq_len_target=max_seq_len_target) + data_config_fname = os.path.join(output_prefix, C.DATA_CONFIG) + logger.info("Writing data config to '%s'", data_config_fname) + config_data.save(data_config_fname) + + version_file = os.path.join(output_prefix, C.PREPARED_DATA_VERSION_FILE) + + with open(version_file, "w") as version_out: + version_out.write(str(C.PREPARED_DATA_VERSION)) + + +def get_data_statistics(source_sentences: Iterable[List[int]], + target_sentences: Iterable[List[int]], + buckets: List[Tuple[int, int]], + length_ratio_mean: float, + length_ratio_std: float, + vocab_source: vocab.Vocab, + vocab_target: vocab.Vocab) -> 'DataStatistics': + data_stats_accumulator = DataStatisticsAccumulator(buckets, vocab_source, vocab_target, + length_ratio_mean, length_ratio_std) + + for source, target in zip(source_sentences, target_sentences): + buck_idx, buck = get_parallel_bucket(buckets, len(source), len(target)) + data_stats_accumulator.sequence_pair(source, target, buck_idx) + + return data_stats_accumulator.statistics + + +def get_validation_data_iter(data_loader: RawParallelDatasetLoader, + validation_source: str, + validation_target: str, + buckets: List[Tuple[int, int]], + bucket_batch_sizes: List[BucketBatchSize], + vocab_source: vocab.Vocab, + vocab_target: vocab.Vocab, + max_seq_len_source: int, + max_seq_len_target: int, + batch_size: int, + fill_up: str) -> 'ParallelSampleIter': + """ + Returns a ParallelSampleIter for the validation data. + """ + logger.info("=================================") + logger.info("Creating validation data iterator") + logger.info("=================================") + validation_length_statistics = analyze_sequence_lengths(validation_source, validation_target, + vocab_source, vocab_target, + max_seq_len_source, max_seq_len_target) + + validation_source_sentences = SequenceReader(validation_source, vocab_source, add_bos=False, limit=None) + validation_target_sentences = SequenceReader(validation_target, vocab_target, add_bos=True, limit=None) + + validation_data_statistics = get_data_statistics(validation_source_sentences, + validation_target_sentences, + buckets, + validation_length_statistics.length_ratio_mean, + validation_length_statistics.length_ratio_std, + vocab_source, vocab_target) + + validation_data_statistics.log(bucket_batch_sizes) + + validation_data = data_loader.load(validation_source_sentences, + validation_target_sentences, + validation_data_statistics.num_sents_per_bucket).fill_up(bucket_batch_sizes, + fill_up) + + return ParallelSampleIter(data=validation_data, + buckets=buckets, + batch_size=batch_size, + bucket_batch_sizes=bucket_batch_sizes) + + +def get_prepared_data_iters(prepared_data_dir: str, + validation_source: str, validation_target: str, + shared_vocab: bool, + batch_size: int, + batch_by_words: bool, + batch_num_devices: int, + fill_up: str) -> Tuple['BaseParallelSampleIter', + 'BaseParallelSampleIter', + 'DataConfig', vocab.Vocab, vocab.Vocab]: + logger.info("===============================") + logger.info("Creating training data iterator") + logger.info("===============================") + + version_file = os.path.join(prepared_data_dir, C.PREPARED_DATA_VERSION_FILE) + with open(version_file) as version_in: + version = int(version_in.read()) + check_condition(version == C.PREPARED_DATA_VERSION, + "The dataset %s was written in an old and incompatible format. Please rerun data " + "preparation with a current version of Sockeye." % prepared_data_dir) + config_file = os.path.join(prepared_data_dir, C.DATA_CONFIG) + check_condition(os.path.exists(config_file), + "Could not find data config %s. Are you sure %s is a directory created with " + "python -m sockeye.prepare_data?" % (config_file, prepared_data_dir)) + data_config = cast(DataConfig, DataConfig.load(config_file)) + shard_fnames = [os.path.join(prepared_data_dir, + C.SHARD_NAME % shard_idx) for shard_idx in range(data_config.num_shards)] + for shard_fname in shard_fnames: + check_condition(os.path.exists(shard_fname), "Shard %s does not exist." % shard_fname) + + source_vocab_fname = os.path.join(prepared_data_dir, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX + target_vocab_fname = os.path.join(prepared_data_dir, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX + check_condition(bool(source_vocab_fname), "Source vocabulary %s does not exist." % source_vocab_fname) + check_condition(bool(target_vocab_fname), "Target vocabulary %s does not exist." % target_vocab_fname) + + check_condition(shared_vocab == data_config.shared_vocab, "Shared config needed (e.g. for weight tying), but " + "data was prepared without a shared vocab. Use %s when " + "preparing the data." % C.VOCAB_ARG_SHARED_VOCAB) + + vocab_source = vocab.vocab_from_json(source_vocab_fname) + vocab_target = vocab.vocab_from_json(target_vocab_fname) + + buckets = data_config.data_statistics.buckets + max_seq_len_source = data_config.max_seq_len_source + max_seq_len_target = data_config.max_seq_len_target + + bucket_batch_sizes = define_bucket_batch_sizes(buckets, + batch_size, + batch_by_words, + batch_num_devices, + data_config.data_statistics.average_len_target_per_bucket) + + data_config.data_statistics.log(bucket_batch_sizes) + + train_iter = ShardedParallelSampleIter(shard_fnames, + buckets, + batch_size, + bucket_batch_sizes, + fill_up) + + data_loader = RawParallelDatasetLoader(buckets=buckets, + eos_id=vocab_target[C.EOS_SYMBOL], + pad_id=C.PAD_ID) + + validation_iter = get_validation_data_iter(data_loader=data_loader, + validation_source=validation_source, + validation_target=validation_target, + buckets=buckets, + bucket_batch_sizes=bucket_batch_sizes, + vocab_source=vocab_source, + vocab_target=vocab_target, + max_seq_len_source=max_seq_len_source, + max_seq_len_target=max_seq_len_target, + batch_size=batch_size, + fill_up=fill_up) + + return train_iter, validation_iter, data_config, vocab_source, vocab_target def get_training_data_iters(source: str, target: str, validation_source: str, validation_target: str, - vocab_source: Dict[str, int], vocab_target: Dict[str, int], + vocab_source: vocab.Vocab, vocab_target: vocab.Vocab, vocab_source_path: Optional[str], vocab_target_path: Optional[str], + shared_vocab: bool, batch_size: int, batch_by_words: bool, batch_num_devices: int, @@ -126,10 +703,9 @@ def get_training_data_iters(source: str, target: str, max_seq_len_source: int, max_seq_len_target: int, bucketing: bool, - bucket_width: int, - sequence_limit: Optional[int] = None) -> Tuple['ParallelBucketSentenceIter', - 'ParallelBucketSentenceIter', - 'DataConfig']: + bucket_width: int) -> Tuple['BaseParallelSampleIter', + 'BaseParallelSampleIter', + 'DataConfig']: """ Returns data iterators for training and validation data. @@ -141,6 +717,7 @@ def get_training_data_iters(source: str, target: str, :param vocab_target: Target vocabulary. :param vocab_source_path: Path to source vocabulary. :param vocab_target_path: Path to target vocabulary. + :param shared_vocab: Whether the vocabularies are shared. :param batch_size: Batch size. :param batch_by_words: Size batches by words rather than sentences. :param batch_num_devices: Number of devices batches will be parallelized across. @@ -149,117 +726,173 @@ def get_training_data_iters(source: str, target: str, :param max_seq_len_target: Maximum target sequence length. :param bucketing: Whether to use bucketing. :param bucket_width: Size of buckets. - :param sequence_limit: Maximum number of training sequences to read. :return: Tuple of (training data iterator, validation data iterator, data config). """ - logger.info("Creating train data iterator") - # streams id-coded sentences from disk - train_source_sentences = SentenceReader(source, vocab_source, add_bos=False, limit=sequence_limit) - train_target_sentences = SentenceReader(target, vocab_target, add_bos=True, limit=sequence_limit) - - # reads the id-coded sentences from disk once - lr_mean, lr_std = length_statistics(train_source_sentences, train_target_sentences) - check_condition(train_source_sentences.is_done() and train_target_sentences.is_done(), - "Different number of lines in source and target data.") - logger.info("%d source sentences in '%s'", train_source_sentences.count, source) - logger.info("%d target sentences in '%s'", train_target_sentences.count, target) - logger.info("Mean training target/source length ratio: %.2f (+-%.2f)", lr_mean, lr_std) - + logger.info("===============================") + logger.info("Creating training data iterator") + logger.info("===============================") + # Pass 1: get target/source length ratios. + length_statistics = analyze_sequence_lengths(source, target, vocab_source, vocab_target, + max_seq_len_source, max_seq_len_target) # define buckets - buckets = define_parallel_buckets(max_seq_len_source, - max_seq_len_target, - bucket_width, - lr_mean) if bucketing else [ + buckets = define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width, + length_statistics.length_ratio_mean) if bucketing else [ (max_seq_len_source, max_seq_len_target)] - train_iter = ParallelBucketSentenceIter(train_source_sentences, - train_target_sentences, - buckets, - batch_size, - batch_by_words, - batch_num_devices, - vocab_target[C.EOS_SYMBOL], - C.PAD_ID, - vocab_target[C.UNK_SYMBOL], - bucket_batch_sizes=None, - fill_up=fill_up) + source_sentences = SequenceReader(source, vocab_source, add_bos=False) + target_sentences = SequenceReader(target, vocab_target, add_bos=True) + + # 2. pass: Get data statistics + data_statistics = get_data_statistics(source_sentences, target_sentences, buckets, + length_statistics.length_ratio_mean, length_statistics.length_ratio_std, + vocab_source, vocab_target) + + bucket_batch_sizes = define_bucket_batch_sizes(buckets, + batch_size, + batch_by_words, + batch_num_devices, + data_statistics.average_len_target_per_bucket) + + data_statistics.log(bucket_batch_sizes) + + data_loader = RawParallelDatasetLoader(buckets=buckets, + eos_id=vocab_target[C.EOS_SYMBOL], + pad_id=C.PAD_ID) + + training_data = data_loader.load(source_sentences, target_sentences, + data_statistics.num_sents_per_bucket).fill_up(bucket_batch_sizes, fill_up) + + config_data = DataConfig(source=source, + target=target, + vocab_source=vocab_source_path, + vocab_target=vocab_target_path, + shared_vocab=shared_vocab, + num_shards=1, + data_statistics=data_statistics, + max_seq_len_source=max_seq_len_source, + max_seq_len_target=max_seq_len_target) + + train_iter = ParallelSampleIter(training_data, + buckets, + batch_size, + bucket_batch_sizes) + + validation_iter = get_validation_data_iter(data_loader=data_loader, + validation_source=validation_source, + validation_target=validation_target, + buckets=buckets, + bucket_batch_sizes=bucket_batch_sizes, + vocab_source=vocab_source, + vocab_target=vocab_target, + max_seq_len_source=max_seq_len_source, + max_seq_len_target=max_seq_len_target, + batch_size=batch_size, + fill_up=fill_up) + + return train_iter, validation_iter, config_data + + +class LengthStatistics(config.Config): + + def __init__(self, + num_sents: int, + length_ratio_mean: float, + length_ratio_std: float) -> None: + super().__init__() + self.num_sents = num_sents + self.length_ratio_mean = length_ratio_mean + self.length_ratio_std = length_ratio_std - logger.info("Creating validation data iterator") - val_source_sentences = SentenceReader(validation_source, vocab_source, add_bos=False, limit=None) - val_target_sentences = SentenceReader(validation_target, vocab_target, add_bos=True, limit=None) - - val_iter = ParallelBucketSentenceIter(val_source_sentences, - val_target_sentences, - buckets, - batch_size, - batch_by_words, - batch_num_devices, - vocab_target[C.EOS_SYMBOL], - C.PAD_ID, - vocab_target[C.UNK_SYMBOL], - bucket_batch_sizes=train_iter.bucket_batch_sizes, - fill_up=fill_up) - - check_condition(val_source_sentences.is_done() and val_target_sentences.is_done(), - "Different number of lines in source and target validation data.") - logger.info("%d validation source sentences in '%s'", val_source_sentences.count, source) - logger.info("%d validation target sentences in '%s'", val_target_sentences.count, target) - - config_data = DataConfig(source, target, - validation_source, validation_target, - vocab_source_path, vocab_target_path, - lr_mean, lr_std, train_iter.max_observed_source_len, train_iter.max_observed_target_len) - - return train_iter, val_iter, config_data + +class DataStatistics(config.Config): + + def __init__(self, + num_sents: int, + num_discarded, + num_tokens_source, + num_tokens_target, + num_unks_source, + num_unks_target, + max_observed_len_source, + max_observed_len_target, + size_vocab_source, + size_vocab_target, + length_ratio_mean, + length_ratio_std, + buckets: List[Tuple[int, int]], + num_sents_per_bucket: List[int], + mean_len_target_per_bucket: List[Optional[float]]) -> None: + super().__init__() + self.num_sents = num_sents + self.num_discarded = num_discarded + self.num_tokens_source = num_tokens_source + self.num_tokens_target = num_tokens_target + self.num_unks_source = num_unks_source + self.num_unks_target = num_unks_target + self.max_observed_len_source = max_observed_len_source + self.max_observed_len_target = max_observed_len_target + self.size_vocab_source = size_vocab_source + self.size_vocab_target = size_vocab_target + self.length_ratio_mean = length_ratio_mean + self.length_ratio_std = length_ratio_std + self.buckets = buckets + self.num_sents_per_bucket = num_sents_per_bucket + self.average_len_target_per_bucket = mean_len_target_per_bucket + + def log(self, bucket_batch_sizes: Optional[List[BucketBatchSize]] = None): + logger.info("Tokens: source %d target %d", self.num_tokens_source, self.num_tokens_target) + if self.num_tokens_source > 0 and self.num_tokens_target > 0: + logger.info("Vocabulary coverage: source %.0f%% target %.0f%%", + (1 - self.num_unks_source / self.num_tokens_source) * 100, + (1 - self.num_unks_target / self.num_tokens_target) * 100) + logger.info("%d sequences across %d buckets", self.num_sents, len(self.num_sents_per_bucket)) + logger.info("%d sequences did not fit into buckets and were discarded", self.num_discarded) + if bucket_batch_sizes is not None: + describe_data_and_buckets(self, bucket_batch_sizes) + + +def describe_data_and_buckets(data_statistics: DataStatistics, bucket_batch_sizes: List[BucketBatchSize]): + """ + Describes statistics across buckets + """ + check_condition(len(bucket_batch_sizes) == len(data_statistics.buckets), + "Number of bucket batch sizes (%d) does not match number of buckets in statistics (%d)." + % (len(bucket_batch_sizes), len(data_statistics.buckets))) + for bucket_batch_size, num_seq in zip(bucket_batch_sizes, data_statistics.num_sents_per_bucket): + if num_seq > 0: + logger.info("Bucket %s: %d samples in %d batches of %d, ~%.1f tokens/batch.", + bucket_batch_size.bucket, + num_seq, + math.ceil(num_seq / bucket_batch_size.batch_size), + bucket_batch_size.batch_size, + bucket_batch_size.average_words_per_batch) class DataConfig(config.Config): """ Stores data paths from training. """ + def __init__(self, source: str, target: str, - validation_source: str, - validation_target: str, vocab_source: Optional[str], vocab_target: Optional[str], - length_ratio_mean: float = C.TARGET_MAX_LENGTH_FACTOR, - length_ratio_std: float = 0.0, - max_observed_source_seq_len: Optional[int] = None, - max_observed_target_seq_len: Optional[int] = None) -> None: + shared_vocab: bool, + num_shards: int, + data_statistics: DataStatistics, + max_seq_len_source: int, + max_seq_len_target: int) -> None: super().__init__() self.source = source self.target = target - self.validation_source = validation_source - self.validation_target = validation_target self.vocab_source = vocab_source self.vocab_target = vocab_target - self.length_ratio_mean = length_ratio_mean - self.length_ratio_std = length_ratio_std - self.max_observed_source_seq_len = max_observed_source_seq_len - self.max_observed_target_seq_len = max_observed_target_seq_len - - -def smart_open(filename: str, mode: str = "rt", ftype: str = "auto", errors:str = 'replace'): - """ - Returns a file descriptor for filename with UTF-8 encoding. - If mode is "rt", file is opened read-only. - If ftype is "auto", uses gzip iff filename endswith .gz. - If ftype is {"gzip","gz"}, uses gzip. - - Note: encoding error handling defaults to "replace" - - :param filename: The filename to open. - :param mode: Reader mode. - :param ftype: File type. If 'auto' checks filename suffix for gz to try gzip.open - :param errors: Encoding error handling during reading. Defaults to 'replace' - :return: File descriptor - """ - if ftype == 'gzip' or ftype == 'gz' or (ftype == 'auto' and filename.endswith(".gz")): - return gzip.open(filename, mode=mode, encoding='utf-8', errors=errors) - else: - return open(filename, mode=mode, encoding='utf-8', errors=errors) + self.shared_vocab = shared_vocab + self.num_shards = num_shards + self.data_statistics = data_statistics + self.max_seq_len_source = max_seq_len_source + self.max_seq_len_target = max_seq_len_target def read_content(path: str, limit: Optional[int] = None) -> Iterator[List[str]]: @@ -277,57 +910,76 @@ def read_content(path: str, limit: Optional[int] = None) -> Iterator[List[str]]: yield list(get_tokens(line)) -def get_tokens(line: str) -> Iterator[str]: +def tokens2ids(tokens: Iterable[str], vocab: Dict[str, int]) -> List[int]: """ - Yields tokens from input string. + Returns sequence of integer ids given a sequence of tokens and vocab. - :param line: Input string. - :return: Iterator over tokens. + :param tokens: List of string tokens. + :param vocab: Vocabulary (containing UNK symbol). + :return: List of word ids. """ - for token in line.rstrip().split(): - if len(token) > 0: - yield token + return [vocab.get(w, vocab[C.UNK_SYMBOL]) for w in tokens] -def tokens2ids(tokens: Iterable[str], vocab: Dict[str, int]) -> List[int]: +def strids2ids(tokens: Iterable[str]) -> List[int]: """ - Returns sequence of ids given a sequence of tokens and vocab. + Returns sequence of integer ids given a sequence of string ids. - :param tokens: List of tokens. - :param vocab: Vocabulary (containing UNK symbol). + + :param tokens: List of integer tokens. :return: List of word ids. """ - return [vocab.get(w, vocab[C.UNK_SYMBOL]) for w in tokens] + return list(map(int, tokens)) -class SentenceReader(Iterator): +def ids2strids(ids: Iterable[int]) -> str: """ - Reads sentences from path and creates word id sentences. - Streams from disk, instead of loading all sentences into memory. + Returns a string representation of a sequence of integers. + + :param ids: Sequence of integers. + :return: String sequence + """ + return " ".join(map(str, ids)) + + +class SequenceReader(Iterator): + """ + Reads sequence samples from path and creates integer id sequences. + Streams from disk, instead of loading all samples into memory. + If vocab is None, the sequences in path are assumed to be integers coded as strings. :param path: Path to read data from. - :param vocab: Vocabulary mapping. + :param vocab: Optional mapping from strings to integer ids. :param add_bos: Whether to add Beginning-Of-Sentence (BOS) symbol. :param limit: Read limit. """ - def __init__(self, path: str, vocab: Dict[str, int], add_bos: bool = False, limit: Optional[int] = None) -> None: + def __init__(self, + path: str, + vocab: Optional[Dict[str, int]], + add_bos: bool = False, + limit: Optional[int] = None) -> None: self.path = path self.vocab = vocab + self.bos_id = None + if vocab is not None: + assert C.UNK_SYMBOL in vocab + assert vocab[C.PAD_SYMBOL] == C.PAD_ID + assert C.BOS_SYMBOL in vocab + assert C.EOS_SYMBOL in vocab + self.bos_id = vocab[C.BOS_SYMBOL] + else: + check_condition(not add_bos, "Adding a BOS symbol requires a vocabulary") self.add_bos = add_bos self.limit = limit - assert C.UNK_SYMBOL in vocab - assert C.UNK_SYMBOL in vocab - assert vocab[C.PAD_SYMBOL] == C.PAD_ID - assert C.BOS_SYMBOL in vocab - assert C.EOS_SYMBOL in vocab + self._iter = None # type: Optional[Iterator] self._iterated_once = False self.count = 0 self._next = None def __iter__(self): - assert self._next is None, "Can not iterate multiple times simultaneously." + check_condition(self._next is None, "Can not iterate multiple times simultaneously.") self._iter = read_content(self.path, self.limit) self._next = next(self._iter, None) return self @@ -336,11 +988,15 @@ def __next__(self): if self._next is None: raise StopIteration - sentence_tokens = self._next - sentence = tokens2ids(sentence_tokens, self.vocab) - check_condition(bool(sentence), "Empty sentence in file %s" % self.path) - if self.add_bos: - sentence.insert(0, self.vocab[C.BOS_SYMBOL]) + tokens = self._next + if self.vocab is not None: + sequence = tokens2ids(tokens, self.vocab) + else: + sequence = strids2ids(tokens) + check_condition(bool(sequence), "Empty sequence in file %s" % self.path) + + if vocab is not None and self.add_bos: + sequence.insert(0, self.vocab[C.BOS_SYMBOL]) if not self._iterated_once: self.count += 1 @@ -352,7 +1008,7 @@ def __next__(self): if not self._iterated_once: self._iterated_once = True - return sentence + return sequence def is_done(self): return self._iterated_once and self._next is None @@ -388,90 +1044,179 @@ def get_parallel_bucket(buckets: List[Tuple[int, int]], return bucket -BucketBatchSize = NamedTuple("BucketBatchSize", [ - ("batch_size", int), - ("average_words_per_batch", float) -]) -""" -:param batch_size: Number of sentences in each batch. -:param average_words_per_batch: Approximate number of non-padding tokens in each batch. -""" +class ParallelDataSet(Sized): + """ + Bucketed parallel data set with labels + """ + + def __init__(self, + source: List[mx.nd.array], + target: List[mx.nd.array], + label: List[mx.nd.array]) -> None: + check_condition(len(source) == len(target) == len(label), + "Number of buckets for source/target/label must match.") + self.source = source + self.target = target + self.label = label + + def __len__(self) -> int: + return len(self.source) + + def get_bucket_counts(self): + return [len(self.source[buck_idx]) for buck_idx in range(len(self))] + + def save(self, fname: str): + """ + Saves the dataset to a binary .npy file. + """ + mx.nd.save(fname, self.source + self.target + self.label) + + @staticmethod + def load(fname: str) -> 'ParallelDataSet': + """ + Loads a dataset from a binary .npy file. + """ + data = mx.nd.load(fname) + n = len(data) // 3 + source = data[:n] + target = data[n:2 * n] + label = data[2 * n:] + assert len(source) == len(target) == len(label) + return ParallelDataSet(source, target, label) + + def fill_up(self, + bucket_batch_sizes: List[BucketBatchSize], + fill_up: str, + seed: int = 42) -> 'ParallelDataSet': + """ + Returns a new dataset with buckets filled up using the specified fill-up strategy. + + :param bucket_batch_sizes: Bucket batch sizes. + :param fill_up: Fill-up strategy. + :param seed: The random seed used for sampling sentences to fill up. + :return: New dataset with buckets filled up to the next multiple of batch size + """ + source = list(self.source) + target = list(self.target) + label = list(self.label) + + rs = np.random.RandomState(seed) + + for bucket_idx in range(len(self)): + bucket = bucket_batch_sizes[bucket_idx].bucket + bucket_batch_size = bucket_batch_sizes[bucket_idx].batch_size + bucket_source = self.source[bucket_idx] + bucket_target = self.target[bucket_idx] + bucket_label = self.label[bucket_idx] + num_samples = bucket_source.shape[0] + + if num_samples % bucket_batch_size != 0: + if fill_up == 'replicate': + rest = bucket_batch_size - num_samples % bucket_batch_size + logger.info("Replicating %d random samples from %d samples in bucket %s " + "to size it to multiple of %d", + rest, num_samples, bucket, bucket_batch_size) + random_indices = mx.nd.array(rs.randint(num_samples, size=rest)) + source[bucket_idx] = mx.nd.concat(bucket_source, bucket_source.take(random_indices), dim=0) + target[bucket_idx] = mx.nd.concat(bucket_target, bucket_target.take(random_indices), dim=0) + label[bucket_idx] = mx.nd.concat(bucket_label, bucket_label.take(random_indices), dim=0) + else: + raise NotImplementedError('Unknown fill-up strategy') + + return ParallelDataSet(source, target, label) + + def permute(self, permutations: List[mx.nd.NDArray]) -> 'ParallelDataSet': + assert len(self) == len(permutations) + source = [] + target = [] + label = [] + for buck_idx in range(len(self)): + num_samples = self.source[buck_idx].shape[0] + if num_samples: # not empty bucket + permutation = permutations[buck_idx] + source.append(self.source[buck_idx].take(permutation)) + target.append(self.target[buck_idx].take(permutation)) + label.append(self.label[buck_idx].take(permutation)) + else: + source.append(self.source[buck_idx]) + target.append(self.target[buck_idx]) + label.append(self.label[buck_idx]) + return ParallelDataSet(source, target, label) -# TODO: consider more memory-efficient batch creation (load from disk on demand) -# TODO: consider using HDF5 format for language data -class ParallelBucketSentenceIter(mx.io.DataIter): +def get_permutations(bucket_counts: List[int]) -> Tuple[List[mx.nd.NDArray], List[mx.nd.NDArray]]: """ - A bucketing parallel sentence iterator. - Data is read into NDArrays for the buckets defined in buckets. - Randomly shuffles the data after every call to reset(). - Data is stored in NDArrays for each epoch for fast indexing during iteration. + Returns the indices of a random permutation for each bucket and the corresponding inverse permutations that can + restore the original order of the data if applied to the permuted data. - :param source_sentences: Iterable of source sentences (integer-coded). - :param target_sentences: Iterable of target sentences (integer-coded). - :param buckets: List of buckets. - :param batch_size: Batch_size of generated data batches. - Incomplete batches are discarded if fill_up == None, or filled up according to the fill_up strategy. - :param batch_by_words: Size batches by words rather than sentences. - :param batch_num_devices: Number of devices batches will be parallelized across. - :param fill_up: If not None, fill up bucket data to a multiple of batch_size to avoid discarding incomplete batches. - for each bucket. If set to 'replicate', sample examples from the bucket and use them to fill up. - :param eos_id: Word id for end-of-sentence. - :param pad_id: Word id for padding symbols. - :param unk_id: Word id for unknown symbols. - :param bucket_batch_sizes: Pre-computed bucket batch sizes (used to keep iterators consistent for train/validation). - :param dtype: Data type of generated NDArrays. + :param bucket_counts: The number of elements per bucket. + :return: For each bucket a permutation and inverse permutation is returned. + """ + data_permutations = [] # type: List[mx.nd.NDArray] + inverse_data_permutations = [] # type: List[mx.nd.NDArray] + for num_samples in bucket_counts: + if num_samples == 0: + num_samples = 1 + # new random order: + data_permutation = np.random.permutation(num_samples) + inverse_data_permutation = np.empty(num_samples, np.int32) + inverse_data_permutation[data_permutation] = np.arange(num_samples) + inverse_data_permutation = mx.nd.array(inverse_data_permutation) + data_permutation = mx.nd.array(data_permutation) + + data_permutations.append(data_permutation) + inverse_data_permutations.append(inverse_data_permutation) + return data_permutations, inverse_data_permutations + + +def get_batch_indices(data: ParallelDataSet, + bucket_batch_sizes: List[BucketBatchSize]) -> List[Tuple[int, int]]: + """ + Returns a list of index tuples that index into the bucket and the start index inside a bucket given + the batch size for a bucket. These indices are valid for the given dataset. + + :param data: Data to create indices for. + :param bucket_batch_sizes: Bucket batch sizes. + :return: List of 2d indices. + """ + # create index tuples (i,j) into buckets: i := bucket index ; j := row index of bucket array + idxs = [] # type: List[Tuple[int, int]] + for buck_idx, buck in enumerate(data.source): + bucket = bucket_batch_sizes[buck_idx].bucket + batch_size = bucket_batch_sizes[buck_idx].batch_size + num_samples = data.source[buck_idx].shape[0] + rest = num_samples % batch_size + if rest > 0: + logger.info("Ignoring %d samples from bucket %s with %d samples due to incomplete batch", + rest, bucket, num_samples) + idxs.extend([(buck_idx, j) for j in range(0, num_samples - batch_size + 1, batch_size)]) + return idxs + + +class BaseParallelSampleIter(mx.io.DataIter, ABC): + """ + Base parallel sample iterator. """ def __init__(self, - source_sentences: Iterable[List[int]], - target_sentences: Iterable[List[int]], - buckets: List[Tuple[int, int]], - batch_size: int, - batch_by_words: bool, - batch_num_devices: int, - eos_id: int, - pad_id: int, - unk_id: int, - bucket_batch_sizes: Optional[List[BucketBatchSize]] = None, - fill_up: Optional[str] = None, - source_data_name=C.SOURCE_NAME, - target_data_name=C.TARGET_NAME, - label_name=C.TARGET_LABEL_NAME, + buckets, + batch_size, + bucket_batch_sizes, + source_data_name, + target_data_name, + label_name, dtype='float32') -> None: - super(ParallelBucketSentenceIter, self).__init__() + super().__init__(batch_size=batch_size) self.buckets = list(buckets) - self.buckets.sort() self.default_bucket_key = get_default_bucket_key(self.buckets) - self.batch_size = batch_size - self.batch_by_words = batch_by_words - self.batch_num_devices = batch_num_devices - self.eos_id = eos_id - self.pad_id = pad_id - self.unk_id = unk_id - self.dtype = dtype + self.bucket_batch_sizes = bucket_batch_sizes + self.source_data_name = source_data_name self.target_data_name = target_data_name self.label_name = label_name - self.fill_up = fill_up - - self.data_source = [[] for _ in self.buckets] # type: ignore - self.data_target = [[] for _ in self.buckets] # type: ignore - self.data_label = [[] for _ in self.buckets] # type: ignore - self.data_target_average_len = [0 for _ in self.buckets] - - # Per-bucket batch sizes (num seq, num word) - # If not None, populated as part of assigning to buckets - self.bucket_batch_sizes = bucket_batch_sizes - # assign sentence pairs to buckets - self.max_observed_source_len = 0 - self.max_observed_target_len = 0 - self._assign_to_buckets(source_sentences, target_sentences) - - # convert to single numpy array for each bucket - self._convert_to_array() + self.dtype = dtype # "Staging area" that needs to fit any size batch we're using by total number of elements. # When computing per-bucket batch sizes, we guarantee that the default bucket will have the @@ -495,211 +1240,166 @@ def __init__(self, self.data_names = [self.source_data_name, self.target_data_name] self.label_names = [self.label_name] - # create index tuples (i,j) into buckets: i := bucket index ; j := row index of bucket array - self.idx = [] # type: List[Tuple[int, int]] - for i, buck in enumerate(self.data_source): - batch_size_seq = self.bucket_batch_sizes[i].batch_size - rest = len(buck) % batch_size_seq - if rest > 0: - logger.info("Discarding %d samples from bucket %s due to incomplete batch", rest, self.buckets[i]) - idxs = [(i, j) for j in range(0, len(buck) - batch_size_seq + 1, batch_size_seq)] - self.idx.extend(idxs) - self.curr_idx = 0 - - # holds NDArrays - self.indices = [] # type: List[List[int]] - self.nd_source = [] # type: List[mx.ndarray] - self.nd_target = [] # type: List[mx.ndarray] - self.nd_label = [] # type: List[mx.ndarray] + @abstractmethod + def reset(self): + pass - self.reset() + @abstractmethod + def iter_next(self) -> bool: + pass - def _assign_to_buckets(self, source_sentences, target_sentences): - ndiscard = 0 - tokens_source = 0 - tokens_target = 0 - num_of_unks_source = 0 - num_of_unks_target = 0 + @abstractmethod + def next(self) -> mx.io.DataBatch: + pass - # Bucket sentences as padded np arrays - for source, target in zip(source_sentences, target_sentences): - source_len = len(source) - target_len = len(target) - buck_idx, buck = get_parallel_bucket(self.buckets, source_len, target_len) - if buck is None: - ndiscard += 1 - continue # skip this sentence pair + @abstractmethod + def save_state(self, fname: str): + pass - tokens_source += source_len - tokens_target += target_len - if source_len > self.max_observed_source_len: - self.max_observed_source_len = source_len - if target_len > self.max_observed_target_len: - self.max_observed_target_len = target_len + @abstractmethod + def load_state(self, fname: str): + pass - num_of_unks_source += source.count(self.unk_id) - num_of_unks_target += target.count(self.unk_id) - buff_source = np.full((buck[0],), self.pad_id, dtype=self.dtype) - buff_target = np.full((buck[1],), self.pad_id, dtype=self.dtype) - # NOTE(fhieber): while this is wasteful w.r.t memory, we need to explicitly create the label sequence - # with the EOS symbol here sentence-wise and not per-batch due to variable sequence length within a batch. - # Once MXNet allows item assignments given a list of indices (probably MXNet 0.13): e.g a[[0,1,5,2]] = x, - # we can try again to compute the label sequence on the fly in next(). - buff_label = np.full((buck[1],), self.pad_id, dtype=self.dtype) - buff_source[:source_len] = source - buff_target[:target_len] = target - buff_label[:len(target)] = target[1:] + [self.eos_id] - self.data_source[buck_idx].append(buff_source) - self.data_target[buck_idx].append(buff_target) - self.data_label[buck_idx].append(buff_label) - self.data_target_average_len[buck_idx] += target_len - - # Average number of non-padding elements in target sequence per bucket - for buck_idx, buck in enumerate(self.buckets): - # Case of empty bucket -> use default padded length - if self.data_target_average_len[buck_idx] == 0: - self.data_target_average_len[buck_idx] = buck[1] - else: - self.data_target_average_len[buck_idx] /= len(self.data_target[buck_idx]) - - # We now have sufficient information to populate bucket batch sizes - self._populate_bucket_batch_sizes() - - logger.info("Source words: %d", tokens_source) - logger.info("Target words: %d", tokens_target) - logger.info("Vocab coverage source: %.0f%%", (1 - num_of_unks_source / tokens_source) * 100) - logger.info("Vocab coverage target: %.0f%%", (1 - num_of_unks_target / tokens_target) * 100) - logger.info("Total: %d samples in %d buckets", sum(len(b) for b in self.data_source), len(self.buckets)) - nsamples = 0 - for bkt, buck, batch_size_seq, average_seq_len in zip(self.buckets, - self.data_source, - (bbs.batch_size for bbs in self.bucket_batch_sizes), - self.data_target_average_len): - logger.info("Bucket of %s : %d samples in %d batches of %d, approx %0.1f words/batch", - bkt, - len(buck), - math.ceil(len(buck) / batch_size_seq), - batch_size_seq, - batch_size_seq * average_seq_len) - nsamples += len(buck) - check_condition(nsamples > 0, "0 data points available in the data iterator. " - "%d data points have been discarded because they " - "didn't fit into any bucket. Consider increasing " - "--max-seq-len to fit your data." % ndiscard) - logger.info("%d sentence pairs out of buckets", ndiscard) - logger.info("fill up mode: %s", self.fill_up) - logger.info("") - - def _populate_bucket_batch_sizes(self): - """ - Compute bucket-specific batch sizes (sentences, average_words) and default bucket batch - size. +class ShardedParallelSampleIter(BaseParallelSampleIter): + """ + Goes through the data one shard at a time. The memory consumption is limited by the memory consumption of the + largest shard. The order in which shards are traversed is changed with each reset. + """ + + def __init__(self, + shards_fnames: List[str], + buckets, + batch_size, + bucket_batch_sizes, + fill_up: str, + source_data_name=C.SOURCE_NAME, + target_data_name=C.TARGET_NAME, + label_name=C.TARGET_LABEL_NAME, + dtype='float32') -> None: + super().__init__(buckets, batch_size, bucket_batch_sizes, + source_data_name, target_data_name, label_name, dtype) + assert len(shards_fnames) > 0 + self.shards_fnames = list(shards_fnames) + self.shard_index = -1 + self.fill_up = fill_up - If sentence-based batching: number of sentences is the same for each batch, determines the - number of words. + self.reset() - If word-based batching: number of sentences for each batch is set to the multiple of number - of devices that produces the number of words closest to the target batch size. Average - target sentence length (non-padding symbols) is used for word number calculations. + def _load_shard(self): + shard_fname = self.shards_fnames[self.shard_index] + logger.info("Loading shard %s.", shard_fname) + dataset = ParallelDataSet.load(self.shards_fnames[self.shard_index]).fill_up(self.bucket_batch_sizes, + self.fill_up, + seed=self.shard_index) + self.shard_iter = ParallelSampleIter(dataset, + self.buckets, + self.batch_size, + self.bucket_batch_sizes) - Sets: self.bucket_batch_sizes - """ - # Pre-defined bucket batch sizes - if self.bucket_batch_sizes is not None: - return - # Otherwise compute here - self.bucket_batch_sizes = [None for _ in self.buckets] - largest_total_batch_size = 0 - for buck_idx, bucket_shape in enumerate(self.buckets): - # Target/label length with padding - padded_seq_len = bucket_shape[1] - # Average target/label length excluding padding - average_seq_len = self.data_target_average_len[buck_idx] - # Word-based: num words determines num sentences - # Sentence-based: num sentences determines num words - if self.batch_by_words: - check_condition(padded_seq_len <= self.batch_size, "Word batch size must cover sequence lengths for all" - " buckets: (%d > %d)" % (padded_seq_len, self.batch_size)) - # Multiple of number of devices (int) closest to target number of words, assuming each sentence is of - # average length - batch_size_seq = self.batch_num_devices * round((self.batch_size / average_seq_len) - / self.batch_num_devices) - batch_size_word = batch_size_seq * average_seq_len + def reset(self): + if len(self.shards_fnames) > 1: + logger.info("Shuffling the shards.") + # Making sure to not repeat a shard: + if self.shard_index < 0: + current_shard_fname = "" else: - batch_size_seq = self.batch_size - batch_size_word = batch_size_seq * average_seq_len - self.bucket_batch_sizes[buck_idx] = BucketBatchSize(batch_size_seq, batch_size_word) - # Track largest batch size by total elements - largest_total_batch_size = max(largest_total_batch_size, batch_size_seq * max(*bucket_shape)) - # Final step: guarantee that largest bucket by sequence length also has largest total batch size. - # When batching by sentences, this will already be the case. - if self.batch_by_words: - padded_seq_len = max(*self.buckets[-1]) - average_seq_len = self.data_target_average_len[-1] - while self.bucket_batch_sizes[-1].batch_size * padded_seq_len < largest_total_batch_size: - self.bucket_batch_sizes[-1] = BucketBatchSize( - self.bucket_batch_sizes[-1].batch_size + self.batch_num_devices, - self.bucket_batch_sizes[-1].average_words_per_batch + self.batch_num_devices * average_seq_len) - - def _convert_to_array(self): - for i in range(len(self.data_source)): - self.data_source[i] = np.asarray(self.data_source[i], dtype=self.dtype) - self.data_target[i] = np.asarray(self.data_target[i], dtype=self.dtype) - self.data_label[i] = np.asarray(self.data_label[i], dtype=self.dtype) - - n = len(self.data_source[i]) - batch_size_seq = self.bucket_batch_sizes[i].batch_size - if n % batch_size_seq != 0: - buck_shape = self.buckets[i] - rest = batch_size_seq - n % batch_size_seq - if self.fill_up == 'pad': - raise NotImplementedError - elif self.fill_up == 'replicate': - logger.info("Replicating %d random sentences from bucket %s to size it to multiple of %d", rest, - buck_shape, batch_size_seq) - random_indices = np.random.randint(self.data_source[i].shape[0], size=rest) - self.data_source[i] = np.concatenate((self.data_source[i], self.data_source[i][random_indices, :]), - axis=0) - self.data_target[i] = np.concatenate((self.data_target[i], self.data_target[i][random_indices, :]), - axis=0) - self.data_label[i] = np.concatenate((self.data_label[i], self.data_label[i][random_indices, :]), - axis=0) + current_shard_fname = self.shards_fnames[self.shard_index] + remaining_shards = [shard for shard in self.shards_fnames if shard != current_shard_fname] + next_shard_fname = random.choice(remaining_shards) + remaining_shards = [shard for shard in self.shards_fnames if shard != next_shard_fname] + random.shuffle(remaining_shards) + + self.shards_fnames = [next_shard_fname] + remaining_shards + + self.shard_index = 0 + self._load_shard() + else: + if self.shard_index < 0: + self.shard_index = 0 + self._load_shard() + # We can just reset the shard_iter as we only have a single shard + self.shard_iter.reset() + + def iter_next(self) -> bool: + next_shard_index = self.shard_index + 1 + return self.shard_iter.iter_next() or next_shard_index < len(self.shards_fnames) + + def next(self) -> mx.io.DataBatch: + if not self.shard_iter.iter_next(): + if self.shard_index < len(self.shards_fnames) - 1: + self.shard_index += 1 + self._load_shard() + else: + raise StopIteration + return self.shard_iter.next() + + def save_state(self, fname: str): + with open(fname, "wb") as fp: + pickle.dump(self.shards_fnames, fp) + pickle.dump(self.shard_index, fp) + self.shard_iter.save_state(fname + ".sharditer") + + def load_state(self, fname: str): + with open(fname, "rb") as fp: + self.shards_fnames = pickle.load(fp) + self.shard_index = pickle.load(fp) + self._load_shard() + self.shard_iter.load_state(fname + ".sharditer") + + +class ParallelSampleIter(BaseParallelSampleIter): + """ + Data iterator on a bucketed ParallelDataSet. Shuffles data at every reset and supports saving and loading the + iterator state. + """ + + def __init__(self, + data: ParallelDataSet, + buckets, + batch_size, + bucket_batch_sizes, + source_data_name=C.SOURCE_NAME, + target_data_name=C.TARGET_NAME, + label_name=C.TARGET_LABEL_NAME, + dtype='float32') -> None: + super().__init__(buckets, batch_size, bucket_batch_sizes, + source_data_name, target_data_name, label_name, dtype) + + # create independent lists to be shuffled + self.data = ParallelDataSet(list(data.source), list(data.target), list(data.label)) + + # create index tuples (buck_idx, batch_start_pos) into buckets. These will be shuffled. + self.batch_indices = get_batch_indices(self.data, bucket_batch_sizes) + self.curr_batch_index = 0 + + self.inverse_data_permutations = [mx.nd.arange(0, max(1, self.data.source[i].shape[0])) + for i in range(len(self.data))] + self.data_permutations = [mx.nd.arange(0, max(1, self.data.source[i].shape[0])) + for i in range(len(self.data))] + + self.reset() def reset(self): """ Resets and reshuffles the data. """ - self.curr_idx = 0 - # shuffle indices - random.shuffle(self.idx) - - self.nd_source = [] - self.nd_target = [] - self.nd_label = [] - self.indices = [] - for i in range(len(self.data_source)): - # shuffle indices within each bucket - self.indices.append(np.random.permutation(len(self.data_source[i]))) - self._append_ndarrays(i, self.indices[-1]) - - def _append_ndarrays(self, bucket: int, shuffled_indices: np.array): - """ - Appends the actual data, selected by the given indices, to the NDArrays - of the appropriate bucket. Use when reshuffling the data. + self.curr_batch_index = 0 + # shuffle batch start indices + random.shuffle(self.batch_indices) - :param bucket: Current bucket. - :param shuffled_indices: Indices indicating which data to select. - """ - self.nd_source.append(mx.nd.array(self.data_source[bucket].take(shuffled_indices, axis=0), dtype=self.dtype)) - self.nd_target.append(mx.nd.array(self.data_target[bucket].take(shuffled_indices, axis=0), dtype=self.dtype)) - self.nd_label.append(mx.nd.array(self.data_label[bucket].take(shuffled_indices, axis=0), dtype=self.dtype)) + # restore + self.data = self.data.permute(self.inverse_data_permutations) + + self.data_permutations, self.inverse_data_permutations = get_permutations(self.data.get_bucket_counts()) + + self.data = self.data.permute(self.data_permutations) def iter_next(self) -> bool: """ True if iterator can return another batch """ - return self.curr_idx != len(self.idx) + return self.curr_batch_index != len(self.batch_indices) def next(self) -> mx.io.DataBatch: """ @@ -708,15 +1408,15 @@ def next(self) -> mx.io.DataBatch: if not self.iter_next(): raise StopIteration - i, j = self.idx[self.curr_idx] - self.curr_idx += 1 + i, j = self.batch_indices[self.curr_batch_index] + self.curr_batch_index += 1 - batch_size_seq = self.bucket_batch_sizes[i].batch_size - source = self.nd_source[i][j:j + batch_size_seq] - target = self.nd_target[i][j:j + batch_size_seq] + batch_size = self.bucket_batch_sizes[i].batch_size + source = self.data.source[i][j:j + batch_size] + target = self.data.target[i][j:j + batch_size] data = [source, target] - label = [self.nd_label[i][j:j + batch_size_seq]] + label = [self.data.label[i][j:j + batch_size]] provide_data = [mx.io.DataDesc(name=n, shape=x.shape, layout=C.BATCH_MAJOR) for n, x in zip(self.data_names, data)] @@ -737,9 +1437,10 @@ def save_state(self, fname: str): :param fname: File name to save the information to. """ with open(fname, "wb") as fp: - pickle.dump(self.idx, fp) - pickle.dump(self.curr_idx, fp) - np.save(fp, self.indices) + pickle.dump(self.batch_indices, fp) + pickle.dump(self.curr_batch_index, fp) + np.save(fp, [a.asnumpy() for a in self.inverse_data_permutations]) + np.save(fp, [a.asnumpy() for a in self.data_permutations]) def load_state(self, fname: str): """ @@ -747,19 +1448,31 @@ def load_state(self, fname: str): :param fname: File name to load the information from. """ + + # restore order + self.data = self.data.permute(self.inverse_data_permutations) + with open(fname, "rb") as fp: - self.idx = pickle.load(fp) - self.curr_idx = pickle.load(fp) - self.indices = np.load(fp) + self.batch_indices = pickle.load(fp) + self.curr_batch_index = pickle.load(fp) + inverse_data_permutations = np.load(fp) + data_permutations = np.load(fp) # Because of how checkpointing is done (pre-fetching the next batch in - # each iteration), curr_idx should be always >= 1 - assert self.curr_idx >= 1 + # each iteration), curr_idx should always be >= 1 + assert self.curr_batch_index >= 1 # Right after loading the iterator state, next() should be called - self.curr_idx -= 1 + self.curr_batch_index -= 1 + + # load previous permutations + self.inverse_data_permutations = [] + self.data_permutations = [] + + for bucket in range(len(self.data)): + inverse_permutation = mx.nd.array(inverse_data_permutations[bucket]) + self.inverse_data_permutations.append(inverse_permutation) + + permutation = mx.nd.array(data_permutations[bucket]) + self.data_permutations.append(permutation) - self.nd_source = [] - self.nd_target = [] - self.nd_label = [] - for i in range(len(self.data_source)): - self._append_ndarrays(i, self.indices[i]) + self.data = self.data.permute(self.data_permutations) diff --git a/sockeye/decoder.py b/sockeye/decoder.py index 778acbef6..3ba6771d0 100644 --- a/sockeye/decoder.py +++ b/sockeye/decoder.py @@ -16,7 +16,7 @@ """ import logging from abc import ABC, abstractmethod -from typing import Callable, Dict, List, NamedTuple, Tuple +from typing import Callable, Dict, List, NamedTuple, Tuple, Union from typing import Optional import mxnet as mx @@ -32,9 +32,10 @@ from . import utils logger = logging.getLogger(__name__) +DecoderConfig = Union['RecurrentDecoderConfig', transformer.TransformerConfig, 'ConvolutionalDecoderConfig'] -def get_decoder(config: Config) -> 'Decoder': +def get_decoder(config: DecoderConfig) -> 'Decoder': if isinstance(config, RecurrentDecoderConfig): return RecurrentDecoder(config=config, prefix=C.RNN_DECODER_PREFIX) elif isinstance(config, ConvolutionalDecoderConfig): diff --git a/sockeye/encoder.py b/sockeye/encoder.py index e76e9c1fa..b8e1372a0 100644 --- a/sockeye/encoder.py +++ b/sockeye/encoder.py @@ -18,7 +18,7 @@ import logging from math import ceil, floor from abc import ABC, abstractmethod -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import mxnet as mx @@ -30,9 +30,10 @@ from . import utils logger = logging.getLogger(__name__) +EncoderConfigs = Union['RecurrentEncoderConfig', transformer.TransformerConfig, 'ConvolutionalEncoderConfig'] -def get_encoder(config: Config): +def get_encoder(config: EncoderConfigs) -> 'Encoder': if isinstance(config, RecurrentEncoderConfig): return get_recurrent_encoder(config) elif isinstance(config, transformer.TransformerConfig): diff --git a/sockeye/inference.py b/sockeye/inference.py index 0185931b1..7d7161118 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -322,16 +322,16 @@ def run_decoder(self, @property def training_max_seq_len_source(self) -> int: """ The maximum sequence length on the source side during training. """ - if self.config.config_data.max_observed_source_seq_len is not None: - return self.config.config_data.max_observed_source_seq_len + if self.config.config_data.data_statistics.max_observed_len_source is not None: + return self.config.config_data.data_statistics.max_observed_len_source else: return self.config.max_seq_len_source @property def training_max_seq_len_target(self) -> int: """ The maximum sequence length on the target side during training. """ - if self.config.config_data.max_observed_target_seq_len is not None: - return self.config.config_data.max_observed_target_seq_len + if self.config.config_data.data_statistics.max_observed_len_target is not None: + return self.config.config_data.data_statistics.max_observed_len_target else: return self.config.max_seq_len_target @@ -347,11 +347,11 @@ def max_supported_seq_len_target(self) -> Optional[int]: @property def length_ratio_mean(self) -> float: - return self.config.config_data.length_ratio_mean + return self.config.config_data.data_statistics.length_ratio_mean @property def length_ratio_std(self) -> float: - return self.config.config_data.length_ratio_std + return self.config.config_data.data_statistics.length_ratio_std def load_models(context: mx.context.Context, @@ -397,10 +397,8 @@ def load_models(context: mx.context.Context, cache_output_layer_w_b=cache_output_layer_w_b) models.append(model) - utils.check_condition(all(set(vocab.items()) == set(source_vocabs[0].items()) for vocab in source_vocabs), - "Source vocabulary ids do not match") - utils.check_condition(all(set(vocab.items()) == set(target_vocabs[0].items()) for vocab in target_vocabs), - "Target vocabulary ids do not match") + utils.check_condition(vocab.are_identical(*source_vocabs), "Source vocabulary ids do not match") + utils.check_condition(vocab.are_identical(*target_vocabs), "Target vocabulary ids do not match") # set a common max_output length for all models. max_input_len, get_max_output_length = get_max_input_output_length(models, @@ -465,7 +463,15 @@ def get_max_input_output_length(models: List[InferenceModel], num_stds: int, max_input_len = training_max_seq_len_source def get_max_output_length(input_length: int): - return int(np.ceil(factor * input_length)) + """ + Returns the maximum output length for inference given the input length. + Explicitly includes space for BOS and EOS sentence symbols in the target sequence, because we assume + that the mean length ratio computed on the training data do not include these special symbols. + (see data_io.analyze_sequence_lengths) + """ + space_for_bos = 1 + space_for_eos = 1 + return int(np.ceil(factor * input_length)) + space_for_bos + space_for_eos return max_input_len, get_max_output_length diff --git a/sockeye/model.py b/sockeye/model.py index b7512b9b9..686236652 100644 --- a/sockeye/model.py +++ b/sockeye/model.py @@ -14,7 +14,7 @@ import copy import logging import os -from typing import Dict, List, Optional, Tuple +from typing import cast, Dict, Optional, Tuple import mxnet as mx @@ -56,8 +56,8 @@ def __init__(self, max_seq_len_target: int, vocab_source_size: int, vocab_target_size: int, - config_embed_source: Config, - config_embed_target: Config, + config_embed_source: encoder.EmbeddingConfig, + config_embed_target: encoder.EmbeddingConfig, config_encoder: Config, config_decoder: Config, config_loss: loss.LossConfig, @@ -130,7 +130,7 @@ def load_config(fname: str) -> ModelConfig: """ config = ModelConfig.load(fname) logger.info('ModelConfig loaded from "%s"', fname) - return config # type: ignore + return cast(ModelConfig, config) # type: ignore def save_params_to_file(self, fname: str): """ @@ -172,8 +172,6 @@ def _get_embed_weights(self) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, mx.sym.Symbo :return: Tuple of source and target parameter symbols. """ - assert isinstance(self.config.config_embed_source, encoder.EmbeddingConfig) - assert isinstance(self.config.config_embed_target, encoder.EmbeddingConfig) w_embed_source = mx.sym.Variable(C.SOURCE_EMBEDDING_PREFIX + "weight", shape=(self.config.config_embed_source.vocab_size, self.config.config_embed_source.num_embed)) @@ -211,8 +209,6 @@ def _build_model_components(self): # source & target embeddings embed_weight_source, embed_weight_target, out_weight_target = self._get_embed_weights() - assert isinstance(self.config.config_embed_source, encoder.EmbeddingConfig) - assert isinstance(self.config.config_embed_target, encoder.EmbeddingConfig) self.embedding_source = encoder.Embedding(self.config.config_embed_source, prefix=C.SOURCE_EMBEDDING_PREFIX, embed_weight=embed_weight_source) diff --git a/sockeye/prepare_data.py b/sockeye/prepare_data.py new file mode 100644 index 000000000..361f1d1d4 --- /dev/null +++ b/sockeye/prepare_data.py @@ -0,0 +1,76 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import argparse +import os + +from . import arguments +from . import constants as C +from . import data_io +from . import utils +from . import vocab +from .log import setup_main_logger + +logger = setup_main_logger(__name__, file_logging=False, console=True) + + +def main(): + params = argparse.ArgumentParser(description='Preprocesses and shards training data.') + arguments.add_prepare_data_cli_args(params) + args = params.parse_args() + + output_folder = os.path.abspath(args.output) + os.makedirs(output_folder, exist_ok=True) + global logger + logger = setup_main_logger(__name__, file_logging=True, path=os.path.join(output_folder, C.LOG_NAME)) + + utils.seedRNGs(args.seed) + + minimum_num_shards = args.min_num_shards + samples_per_shard = args.num_samples_per_shard + bucketing = not args.no_bucketing + bucket_width = args.bucket_width + + shared_vocab = args.shared_vocab + vocab_source_path = args.source_vocab + vocab_target_path = args.target_vocab + num_words_source, num_words_target = args.num_words + word_min_count_source, word_min_count_target = args.word_min_count + max_len_source, max_len_target = args.max_seq_len + + vocab_source, vocab_target = vocab.load_or_create_vocabs(source=args.source, + target=args.target, + source_vocab_path=args.source_vocab, + target_vocab_path=args.target_vocab, + shared_vocab=args.shared_vocab, + num_words_source=num_words_source, + word_min_count_source=word_min_count_source, + num_words_target=num_words_target, + word_min_count_target=word_min_count_target) + + data_io.prepare_data(args.source, args.target, + vocab_source, vocab_target, + vocab_source_path, vocab_target_path, + shared_vocab, + max_len_source, + max_len_target, + bucketing=bucketing, + bucket_width=bucket_width, + samples_per_shard=samples_per_shard, + min_num_shards=minimum_num_shards, + output_prefix=output_folder) + + +if __name__ == "__main__": + main() + diff --git a/sockeye/train.py b/sockeye/train.py index 97ccef6d6..ce041ca56 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -21,7 +21,7 @@ import shutil import sys from contextlib import ExitStack -from typing import Optional, Dict, List, Tuple +from typing import Optional, Dict, List, Tuple, cast import mxnet as mx @@ -54,16 +54,6 @@ def none_if_negative(val): return None if val < 0 else val -def _build_or_load_vocab(existing_vocab_path: Optional[str], data_paths: List[str], num_words: int, - word_min_count: int) -> Dict: - if existing_vocab_path is None: - vocabulary = vocab.build_from_paths(paths=data_paths, - num_words=num_words, - min_count=word_min_count) - else: - vocabulary = vocab.vocab_from_json(existing_vocab_path) - return vocabulary - def _list_to_tuple(v): """Convert v to a tuple if it is a list.""" @@ -101,7 +91,6 @@ def check_arg_compatibility(args: argparse.Namespace): % (args.transformer_model_size, args.num_embed[1])) - def check_resume(args: argparse.Namespace, output_folder: str) -> Tuple[bool, str]: """ Check if we should resume a broken training run. @@ -250,48 +239,129 @@ def load_or_create_vocabs(args: argparse.Namespace, resume_training: bool, outpu num_words_source, word_min_count_source) vocab_target = _build_or_load_vocab(args.target_vocab, [args.target], num_words_target, word_min_count_target) + return vocab_source, vocab_target - # write vocabularies - vocab.vocab_to_json(vocab_source, os.path.join(output_folder, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX) - vocab.vocab_to_json(vocab_target, os.path.join(output_folder, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX) - return vocab_source, vocab_target +def _build_or_load_vocab(existing_vocab_path: Optional[str], data_paths: List[str], + num_words: int, word_min_count: int) -> vocab.Vocab: + if existing_vocab_path is None: + vocabulary = vocab.build_from_paths(paths=data_paths, + num_words=num_words, + min_count=word_min_count) + else: + vocabulary = vocab.vocab_from_json(existing_vocab_path) + return vocabulary -def create_data_iters(args: argparse.Namespace, - vocab_source: Dict, - vocab_target: Dict) -> Tuple['data_io.ParallelBucketSentenceIter', - 'data_io.ParallelBucketSentenceIter', - 'data_io.DataConfig']: +def use_shared_vocab(args: argparse.Namespace) -> bool: + """ Determine whether the source and target vocabulary should be shared. """ + weight_tying = args.weight_tying + weight_tying_type = args.weight_tying_type + shared_vocab = args.shared_vocab + if weight_tying and C.WEIGHT_TYING_SRC in weight_tying_type and C.WEIGHT_TYING_TRG in weight_tying_type: + if not shared_vocab: + logger.info("A shared source/target vocabulary will be used as weight tying source/target weight tying " + "is enabled") + shared_vocab = True + return shared_vocab + + +def create_data_iters_and_vocab(args: argparse.Namespace, + shared_vocab: bool, + resume_training: bool, + output_folder: str) -> Tuple['data_io.BaseParallelSampleIter', + 'data_io.BaseParallelSampleIter', + 'data_io.DataConfig', Dict, Dict]: """ - Create the data iterators. + Create the data iterators and the vocabularies. :param args: Arguments as returned by argparse. - :param vocab_source: The source vocabulary. - :param vocab_target: The target vocabulary. - :return: The data iterators (train, validation, config_data). + :param shared_vocab: Whether to create a shared vocabulary. + :param resume_training: Whether to resume training. + :param output_folder: Output folder. + :return: The data iterators (train, validation, config_data) as well as the source and target vocabularies. """ + max_seq_len_source, max_seq_len_target = args.max_seq_len + num_words_source, num_words_target = args.num_words + word_min_count_source, word_min_count_target = args.word_min_count batch_num_devices = 1 if args.use_cpu else sum(-di if di < 0 else 1 for di in args.device_ids) - return data_io.get_training_data_iters(source=os.path.abspath(args.source), - target=os.path.abspath(args.target), - validation_source=os.path.abspath( - args.validation_source), - validation_target=os.path.abspath( - args.validation_target), - vocab_source=vocab_source, - vocab_target=vocab_target, - vocab_source_path=args.source_vocab, - vocab_target_path=args.target_vocab, - batch_size=args.batch_size, - batch_by_words=args.batch_type == C.BATCH_TYPE_WORD, - batch_num_devices=batch_num_devices, - fill_up=args.fill_up, - max_seq_len_source=max_seq_len_source, - max_seq_len_target=max_seq_len_target, - bucketing=not args.no_bucketing, - bucket_width=args.bucket_width, - sequence_limit=args.limit) + batch_by_words = args.batch_type == C.BATCH_TYPE_WORD + + either_raw_or_prepared_error_msg = "Either specify a raw training corpus with %s and %s or a preprocessed corpus " \ + "with %s." % (C.TRAINING_ARG_SOURCE, + C.TRAINING_ARG_TARGET, + C.TRAINING_ARG_PREPARED_DATA) + if args.prepared_data is not None: + utils.check_condition(args.source is None and args.target is None, either_raw_or_prepared_error_msg) + if not resume_training: + utils.check_condition(args.source_vocab is None and args.target_vocab is None, + "You are using a prepared data folder, which is tied to a vocabulary. " + "To change it you need to rerun data preparation with a different vocabulary.") + train_iter, validation_iter, data_config, vocab_source, vocab_target = data_io.get_prepared_data_iters( + prepared_data_dir=args.prepared_data, + validation_source=os.path.abspath(args.validation_source), + validation_target=os.path.abspath(args.validation_target), + shared_vocab=shared_vocab, + batch_size=args.batch_size, + batch_by_words=batch_by_words, + batch_num_devices=batch_num_devices, + fill_up=args.fill_up) + if resume_training: + # resuming training. Making sure the vocabs in the model and in the prepared data match up + model_vocab_source = vocab.vocab_from_json(os.path.join(output_folder, C.VOCAB_SRC_NAME + C.JSON_SUFFIX)) + model_vocab_target = vocab.vocab_from_json(os.path.join(output_folder, C.VOCAB_TRG_NAME + C.JSON_SUFFIX)) + utils.check_condition(vocab.are_identical(vocab_source, model_vocab_source), + "Prepared data and resumed model source vocabs do not match.") + utils.check_condition(vocab.are_identical(vocab_target, model_vocab_target), + "Prepared data and resumed model target vocabs do not match.") + + return train_iter, validation_iter, data_config, vocab_source, vocab_target + else: + utils.check_condition(args.prepared_data is None and args.source is not None and args.target is not None, + either_raw_or_prepared_error_msg) + + if resume_training: + # Load the existing vocab created when starting the training run. + vocab_source = vocab.vocab_from_json(os.path.join(output_folder, C.VOCAB_SRC_NAME + C.JSON_SUFFIX)) + vocab_target = vocab.vocab_from_json(os.path.join(output_folder, C.VOCAB_TRG_NAME + C.JSON_SUFFIX)) + + # Recover the vocabulary path from the existing config file: + orig_config = cast(model.ModelConfig, Config.load(os.path.join(output_folder, C.CONFIG_NAME))) + vocab_source_path = orig_config.config_data.vocab_source + vocab_target_path = orig_config.config_data.vocab_target + else: + # Load vocab: + vocab_source_path = args.source_vocab + vocab_target_path = args.target_vocab + vocab_source, vocab_target = vocab.load_or_create_vocabs(source=args.source, target=args.target, + source_vocab_path=vocab_source_path, + target_vocab_path=vocab_target_path, + shared_vocab=shared_vocab, + num_words_source=num_words_source, + num_words_target=num_words_target, + word_min_count_source=word_min_count_source, + word_min_count_target=word_min_count_target) + + train_iter, validation_iter, config_data = data_io.get_training_data_iters( + source=os.path.abspath(args.source), + target=os.path.abspath(args.target), + validation_source=os.path.abspath(args.validation_source), + validation_target=os.path.abspath(args.validation_target), + vocab_source=vocab_source, + vocab_target=vocab_target, + vocab_source_path=vocab_source_path, + vocab_target_path=vocab_target_path, + shared_vocab=shared_vocab, + batch_size=args.batch_size, + batch_by_words=batch_by_words, + batch_num_devices=batch_num_devices, + fill_up=args.fill_up, + max_seq_len_source=max_seq_len_source, + max_seq_len_target=max_seq_len_target, + bucketing=not args.no_bucketing, + bucket_width=args.bucket_width) + return train_iter, validation_iter, config_data, vocab_source, vocab_target def create_lr_scheduler(args: argparse.Namespace, resume_training: bool, @@ -557,7 +627,7 @@ def create_model_config(args: argparse.Namespace, def create_training_model(model_config: model.ModelConfig, args: argparse.Namespace, context: List[mx.Context], - train_iter: data_io.ParallelBucketSentenceIter, + train_iter: data_io.BaseParallelSampleIter, lr_scheduler_instance: lr_scheduler.LearningRateScheduler, resume_training: bool, training_state_dir: str) -> training.TrainingModel: @@ -638,7 +708,7 @@ def main(): arguments.add_train_cli_args(params) args = params.parse_args() - utils.seedRNGs(args) + utils.seedRNGs(args.seed) check_arg_compatibility(args) output_folder = os.path.abspath(args.output) @@ -654,11 +724,22 @@ def main(): with ExitStack() as exit_stack: context = determine_context(args, exit_stack) - vocab_source, vocab_target = load_or_create_vocabs(args, resume_training, output_folder) + + shared_vocab = use_shared_vocab(args) + + train_iter, eval_iter, config_data, vocab_source, vocab_target = create_data_iters_and_vocab( + args=args, + shared_vocab=shared_vocab, + resume_training=resume_training, + output_folder=output_folder) + + if not resume_training: + vocab.vocab_to_json(vocab_source, os.path.join(output_folder, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX) + vocab.vocab_to_json(vocab_target, os.path.join(output_folder, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX) + vocab_source_size = len(vocab_source) vocab_target_size = len(vocab_target) logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size, vocab_target_size) - train_iter, eval_iter, config_data = create_data_iters(args, vocab_source, vocab_target) lr_scheduler_instance = create_lr_scheduler(args, resume_training, training_state_dir) model_config = create_model_config(args, vocab_source_size, vocab_target_size, config_data) @@ -715,6 +796,8 @@ def main(): min_num_epochs=min_num_epochs, max_num_epochs=max_num_epochs, decode_and_evaluate=decode_and_evaluate, + decode_and_evaluate_fname_source=args.validation_source, + decode_and_evaluate_fname_target=args.validation_target, decode_and_evaluate_context=decode_and_evaluate_context, use_tensorboard=args.use_tensorboard, mxmonitor_pattern=args.monitor_pattern, diff --git a/sockeye/training.py b/sockeye/training.py index 6129cc03a..6468ac27b 100644 --- a/sockeye/training.py +++ b/sockeye/training.py @@ -83,7 +83,7 @@ class TrainingModel(model.SockeyeModel): def __init__(self, config: model.ModelConfig, context: List[mx.context.Context], - train_iter: data_io.ParallelBucketSentenceIter, + train_iter: data_io.BaseParallelSampleIter, bucketing: bool, lr_scheduler) -> None: super().__init__(config) @@ -94,11 +94,11 @@ def __init__(self, self.module = self._build_module(train_iter) self.training_monitor = None # type: Optional[callback.TrainingMonitor] - def _build_module(self, train_iter: data_io.ParallelBucketSentenceIter): + def _build_module(self, train_iter: data_io.BaseParallelSampleIter): """ Initializes model components, creates training symbol and module, and binds it. """ - utils.check_condition(train_iter.pad_id == C.PAD_ID == 0, "pad id should be 0") + #utils.check_condition(train_iter.pad_id == C.PAD_ID == 0, "pad id should be 0") source = mx.sym.Variable(C.SOURCE_NAME) source_length = utils.compute_lengths(source) target = mx.sym.Variable(C.TARGET_NAME) @@ -189,8 +189,8 @@ def create_eval_metric_composite(metric_names: List[AnyStr]) -> mx.metric.Compos return mx.metric.create(metrics) def fit(self, - train_iter: data_io.ParallelBucketSentenceIter, - val_iter: data_io.ParallelBucketSentenceIter, + train_iter: data_io.BaseParallelSampleIter, + val_iter: data_io.BaseParallelSampleIter, output_folder: str, max_params_files_to_keep: int, metrics: List[AnyStr], @@ -208,6 +208,8 @@ def fit(self, min_num_epochs: Optional[int] = None, max_num_epochs: Optional[int] = None, decode_and_evaluate: int = 0, + decode_and_evaluate_fname_source: Optional[str] = None, + decode_and_evaluate_fname_target: Optional[str] = None, decode_and_evaluate_context: Optional[mx.Context] = None, use_tensorboard: bool = False, mxmonitor_pattern: Optional[str] = None, @@ -237,6 +239,9 @@ def fit(self, :param max_num_epochs: Optional maximum number of epochs to train. :param decode_and_evaluate: Monitor BLEU during training (0: off, >=0: the number of sentences to decode for BLEU evaluation, -1: decode the full validation set.). + :param decode_and_evaluate_fname_source: Filename of source data to decode and evaluate. + :param decode_and_evaluate_fname_target: Filename of target data (references) to decode and evaluate. + :param decode_and_evaluate_context: Optional MXNet context for decode and evaluate. :param use_tensorboard: If True write tensorboard compatible logs for monitoring training and validation metrics. :param mxmonitor_pattern: Optional pattern to match to monitor weights/gradients/outputs @@ -267,8 +272,8 @@ def fit(self, self.module.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) cp_decoder = checkpoint_decoder.CheckpointDecoder(decode_and_evaluate_context, - self.config.config_data.validation_source, - self.config.config_data.validation_target, + decode_and_evaluate_fname_source, + decode_and_evaluate_fname_target, output_folder, sample_size=decode_and_evaluate) \ if decode_and_evaluate else None @@ -335,8 +340,8 @@ def _get_optimizer(self): return self._get_curr_module()._optimizer def _fit(self, - train_iter: data_io.ParallelBucketSentenceIter, - val_iter: data_io.ParallelBucketSentenceIter, + train_iter: data_io.BaseParallelSampleIter, + val_iter: data_io.BaseParallelSampleIter, output_folder: str, kvstore: str, max_params_files_to_keep: int, @@ -611,7 +616,7 @@ def _evaluate(self, training_state, val_iter, val_metric): return self.training_monitor.eval_end_callback(training_state.checkpoint, val_metric) def _checkpoint(self, training_state: _TrainingState, output_folder: str, - train_iter: data_io.ParallelBucketSentenceIter): + train_iter: data_io.BaseParallelSampleIter): """ Saves checkpoint. Note that the parameters are saved in _save_params. """ @@ -702,7 +707,7 @@ def load_optimizer_states(self, fname: str): """ self._get_curr_module().load_optimizer_states(fname) - def load_checkpoint(self, directory: str, train_iter: data_io.ParallelBucketSentenceIter) -> _TrainingState: + def load_checkpoint(self, directory: str, train_iter: data_io.BaseParallelSampleIter) -> _TrainingState: """ Loads the full training state from disk. This includes optimizer, random number generators and everything needed. Note that params diff --git a/sockeye/utils.py b/sockeye/utils.py index b191bb09b..33a576894 100644 --- a/sockeye/utils.py +++ b/sockeye/utils.py @@ -14,10 +14,9 @@ """ A set of utility methods. """ -import argparse -import collections import errno -import fcntl +import gzip +import itertools import logging import os import random @@ -25,10 +24,10 @@ import subprocess import sys import time -import itertools from contextlib import contextmanager, ExitStack -from typing import Mapping, NamedTuple, Any, List, Iterator, Iterable, Set, TextIO, Tuple, Dict, Optional +from typing import Mapping, Any, List, Iterator, Iterable, Set, Tuple, Dict, Optional, Union, IO +import fcntl import mxnet as mx import numpy as np @@ -95,15 +94,15 @@ def log_basic_info(args) -> None: logger.info("Arguments: %s", args) -def seedRNGs(args: argparse.Namespace) -> None: +def seedRNGs(seed: int) -> None: """ Seed the random number generators (Python, Numpy and MXNet) - :param args: Arguments as returned by argparse. + :param seed: The random seed. """ - np.random.seed(args.seed) - random.seed(args.seed) - mx.random.seed(args.seed) + np.random.seed(seed) + random.seed(seed) + mx.random.seed(seed) def check_condition(condition: bool, error_message: str): @@ -214,6 +213,35 @@ def update(self, labels, preds): self.num_inst += n +class OnlineMeanAndVariance: + def __init__(self) -> None: + self._count = 0 + self._mean = 0. + self._M2 = 0. + + def update(self, value: Union[float, int]) -> None: + self._count += 1 + delta = value - self._mean + self._mean += delta / self._count + delta2 = value - self._mean + self._M2 += delta * delta2 + + @property + def count(self) -> int: + return self._count + + @property + def mean(self) -> float: + return self._mean + + @property + def variance(self) -> float: + if self._count < 2: + return float('nan') + else: + return self._M2 / self._count + + def smallest_k(matrix: np.ndarray, k: int, only_first_row: bool = False) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]: """ @@ -262,6 +290,39 @@ def chunks(some_list: List, n: int) -> Iterable[List]: yield some_list[i:i + n] +def get_tokens(line: str) -> Iterator[str]: + """ + Yields tokens from input string. + + :param line: Input string. + :return: Iterator over tokens. + """ + for token in line.rstrip().split(): + if len(token) > 0: + yield token + + +def smart_open(filename: str, mode: str = "rt", ftype: str = "auto", errors: str = 'replace'): + """ + Returns a file descriptor for filename with UTF-8 encoding. + If mode is "rt", file is opened read-only. + If ftype is "auto", uses gzip iff filename endswith .gz. + If ftype is {"gzip","gz"}, uses gzip. + + Note: encoding error handling defaults to "replace" + + :param filename: The filename to open. + :param mode: Reader mode. + :param ftype: File type. If 'auto' checks filename suffix for gz to try gzip.open + :param errors: Encoding error handling during reading. Defaults to 'replace' + :return: File descriptor + """ + if ftype == 'gzip' or ftype == 'gz' or (ftype == 'auto' and filename.endswith(".gz")): + return gzip.open(filename, mode=mode, encoding='utf-8', errors=errors) + else: + return open(filename, mode=mode, encoding='utf-8', errors=errors) + + def plot_attention(attention_matrix: np.ndarray, source_tokens: List[str], target_tokens: List[str], filename: str): """ Uses matplotlib for creating a visualization of the attention matrix. @@ -504,7 +565,7 @@ def acquire_gpus(requested_device_ids: List[int], lock_dir: str = "/tmp", acquired_gpus = [] # type: List[int] any_failed = False for candidates in candidates_to_request: - gpu_id = exit_stack.enter_context(GpuFileLock(candidates=candidates, lock_dir=lock_dir)) + gpu_id = exit_stack.enter_context(GpuFileLock(candidates=candidates, lock_dir=lock_dir)) # type: ignore if gpu_id is not None: acquired_gpus.append(gpu_id) else: @@ -541,7 +602,7 @@ class GpuFileLock: def __init__(self, candidates: List[int], lock_dir: str) -> None: self.candidates = candidates self.lock_dir = lock_dir - self.lock_file = None # type: Optional[TextIO] + self.lock_file = None # type: Optional[IO[Any]] self.lock_file_path = None # type: Optional[str] self.gpu_id = None # type: Optional[int] self._acquired_lock = False @@ -584,25 +645,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): os.remove(self.lockfile_path) -def namedtuple_with_defaults(typename, field_names, default_values: Mapping[str, Any] = ()) -> NamedTuple: - """ - Create a named tuple with default values. - - :param typename: The name of the new type. - :param field_names: The fields the type will have. - :param default_values: A mapping from field names to default values. - :return: The new named tuple with default values. - """ - T = collections.namedtuple(typename, field_names) - T.__new__.__defaults__ = (None,) * len(T._fields) - if isinstance(default_values, collections.Mapping): - prototype = T(**default_values) - else: - prototype = T(*default_values) - T.__new__.__defaults__ = tuple(prototype) - return T - - def read_metrics_file(path: str) -> List[Dict[str, Any]]: """ Reads lines metrics file and returns list of mappings of key and values. @@ -662,7 +704,7 @@ class PrintValue(mx.operator.CustomOp): the system logger and 'print_grad=True' for printing information about the gradient (out_grad, i.e. "upper part" of the graph). """ - def __init__(self, print_name, print_grad: str, use_logger: str): + def __init__(self, print_name, print_grad: str, use_logger: str) -> None: super().__init__() self.print_name = print_name # Note that all the parameters are serialized as strings @@ -690,7 +732,7 @@ def backward(self, req, out_grad, in_data, out_data, in_grad, aux): @mx.operator.register("PrintValue") class PrintValueProp(mx.operator.CustomOpProp): - def __init__(self, print_name: str, print_grad: bool = False, use_logger: bool = False): + def __init__(self, print_name: str, print_grad: bool = False, use_logger: bool = False) -> None: super().__init__(need_top_grad=True) self.print_name = print_name self.print_grad = print_grad diff --git a/sockeye/vocab.py b/sockeye/vocab.py index 479a00002..5ffedefc6 100644 --- a/sockeye/vocab.py +++ b/sockeye/vocab.py @@ -19,18 +19,20 @@ from collections import Counter from contextlib import ExitStack from itertools import chain, islice -from typing import Dict, Iterable, List, Mapping +from typing import Dict, Iterable, List, Mapping, Optional, Tuple -from sockeye.data_io import get_tokens, smart_open from . import utils -from . import arguments from . import constants as C from . import log logger = logging.getLogger(__name__) -def build_from_paths(paths: List[str], num_words: int = 50000, min_count: int = 1) -> Dict[str, int]: +Vocab = Dict[str, int] +InverseVocab = Dict[int, str] + + +def build_from_paths(paths: List[str], num_words: int = 50000, min_count: int = 1) -> Vocab: """ Creates vocabulary from paths to a file in sentence-per-line format. A sentence is just a whitespace delimited list of tokens. Note that special symbols like the beginning of sentence (BOS) symbol will be added to the @@ -43,11 +45,11 @@ def build_from_paths(paths: List[str], num_words: int = 50000, min_count: int = """ with ExitStack() as stack: logger.info("Building vocabulary from dataset(s): %s", paths) - files = (stack.enter_context(smart_open(path)) for path in paths) + files = (stack.enter_context(utils.smart_open(path)) for path in paths) return build_vocab(chain(*files), num_words, min_count) -def build_vocab(data: Iterable[str], num_words: int = 50000, min_count: int = 1) -> Dict[str, int]: +def build_vocab(data: Iterable[str], num_words: int = 50000, min_count: int = 1) -> Vocab: """ Creates a vocabulary mapping from words to ids. Increasing integer ids are assigned by word frequency, using lexical sorting as a tie breaker. The only exception to this are special symbols such as the padding symbol @@ -59,19 +61,18 @@ def build_vocab(data: Iterable[str], num_words: int = 50000, min_count: int = 1) :return: Word-to-id mapping. """ vocab_symbols_set = set(C.VOCAB_SYMBOLS) - raw_vocab = Counter(token for line in data for token in get_tokens(line) + raw_vocab = Counter(token for line in data for token in utils.get_tokens(line) if token not in vocab_symbols_set) - logger.info("Initial vocabulary: %d types" % len(raw_vocab)) - # For words with the same count, they will be ordered reverse alphabetically. # Not an issue since we only care for consistency pruned_vocab = sorted(((c, w) for w, c in raw_vocab.items() if c >= min_count), reverse=True) - logger.info("Pruned vocabulary: %d types (min frequency %d)", len(pruned_vocab), min_count) vocab = islice((w for c, w in pruned_vocab), num_words) word_to_id = {word: idx for idx, word in enumerate(chain(C.VOCAB_SYMBOLS, vocab))} - logger.info("Final vocabulary: %d types (min frequency %d, top %d types)", + logger.info("Vocabulary: types: %d/%d/%d/%d (initial/min_pruned/max_pruned/+special) " + + "[min_frequency=%d, max_num_types=%d]", + len(raw_vocab), len(pruned_vocab), len(word_to_id) - len(C.VOCAB_SYMBOLS), len(word_to_id), min_count, num_words) # Important: pad symbol becomes index 0 @@ -103,7 +104,7 @@ def vocab_to_json(vocab: Mapping, path: str): logger.info('Vocabulary saved to "%s"', path) -def vocab_from_json_or_pickle(path) -> Dict: +def vocab_from_json_or_pickle(path) -> Vocab: """ Try loading the json version of the vocab and fall back to pickle for backwards compatibility. @@ -117,7 +118,7 @@ def vocab_from_json_or_pickle(path) -> Dict: return vocab_from_pickle(path) -def vocab_from_pickle(path: str) -> Dict: +def vocab_from_pickle(path: str) -> Vocab: """ Saves vocabulary in pickle format. @@ -130,7 +131,7 @@ def vocab_from_pickle(path: str) -> Dict: return vocab -def vocab_from_json(path: str, encoding: str = C.VOCAB_ENCODING) -> Dict: +def vocab_from_json(path: str, encoding: str = C.VOCAB_ENCODING) -> Vocab: """ Saves vocabulary in json format. @@ -143,7 +144,45 @@ def vocab_from_json(path: str, encoding: str = C.VOCAB_ENCODING) -> Dict: return vocab -def reverse_vocab(vocab: Mapping) -> Dict: +def load_or_create_vocab(data: str, vocab_path: Optional[str], + num_words: int, word_min_count: int): + return build_from_paths(paths=[data], + num_words=num_words, + min_count=word_min_count) if vocab_path is None else vocab_from_json(vocab_path) + + +def load_or_create_vocabs(source: str, target: str, source_vocab_path: Optional[str], target_vocab_path: Optional[str], + shared_vocab: bool, + num_words_source: int, word_min_count_source: int, + num_words_target: int, word_min_count_target: int) -> Tuple[Vocab, Vocab]: + if shared_vocab: + if source_vocab_path and target_vocab_path: + vocab_source = vocab_from_json(source_vocab_path) + vocab_target = vocab_from_json(target_vocab_path) + utils.check_condition(are_identical(vocab_source, vocab_target), + "Shared vocabulary requires identical source and target vocabularies. " + "The vocabularies in %s and %s are not identical." % (source_vocab_path, + target_vocab_path)) + elif source_vocab_path is None and target_vocab_path is None: + utils.check_condition(num_words_source == num_words_target, + "A shared vocabulary requires the number of source and target words to be the same.") + utils.check_condition(word_min_count_source == word_min_count_target, + "A shared vocabulary requires the minimum word count for source and target " + "to be the same.") + vocab_source = vocab_target = build_from_paths(paths=[source, target], + num_words=num_words_source, + min_count=word_min_count_source) + else: + vocab_path = source_vocab_path if source_vocab_path is not None else target_vocab_path + logger.info("Using %s as a shared source/target vocabulary." % vocab_path) + vocab_source = vocab_target = vocab_from_json(vocab_path) + else: + vocab_source = load_or_create_vocab(source, source_vocab_path, num_words_source, word_min_count_source) + vocab_target = load_or_create_vocab(target, target_vocab_path, num_words_target, word_min_count_target) + return vocab_source, vocab_target + + +def reverse_vocab(vocab: Mapping) -> InverseVocab: """ Returns value-to-key mapping from key-to-value-mapping. @@ -153,7 +192,13 @@ def reverse_vocab(vocab: Mapping) -> Dict: return {v: k for k, v in vocab.items()} +def are_identical(*vocabs: Vocab): + assert len(vocabs) > 0, "At least one vocabulary needed." + return all(set(vocab.items()) == set(vocabs[0].items()) for vocab in vocabs) + + def main(): + from . import arguments params = argparse.ArgumentParser(description='CLI to build source and target vocab(s).') arguments.add_build_vocab_args(params) args = params.parse_args() diff --git a/test/common.py b/test/common.py index 012c8d91c..b28e50c4e 100644 --- a/test/common.py +++ b/test/common.py @@ -11,10 +11,10 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. +import logging import os import random import sys -import logging from contextlib import contextmanager from tempfile import TemporaryDirectory from typing import Optional, Tuple @@ -27,12 +27,12 @@ import sockeye.constants as C import sockeye.evaluate import sockeye.lexicon +import sockeye.prepare_data import sockeye.train import sockeye.translate import sockeye.utils - -from sockeye.evaluate import raw_corpus_bleu from sockeye.chrf import corpus_chrf +from sockeye.evaluate import raw_corpus_bleu logger = logging.getLogger(__name__) @@ -164,6 +164,12 @@ def tmp_digits_dataset(prefix: str, _TRAIN_PARAMS_COMMON = "--use-cpu --max-seq-len {max_len} --source {train_source} --target {train_target}" \ " --validation-source {dev_source} --validation-target {dev_target} --output {model} {quiet}" +_PREPARE_DATA_COMMON = " --max-seq-len {max_len} --source {train_source} --target {train_target}" \ + " --output {output} {quiet}" + +_TRAIN_PARAMS_PREPARED_DATA_COMMON = "--use-cpu --max-seq-len {max_len} --prepared-data {prepared_data}" \ + " --validation-source {dev_source} --validation-target {dev_target} --output {model} {quiet}" + _TRANSLATE_PARAMS_COMMON = "--use-cpu --models {model} --input {input} --output {output} {quiet}" _TRANSLATE_PARAMS_RESTRICT = "--restrict-lexicon {json}" @@ -180,6 +186,7 @@ def run_train_translate(train_params: str, dev_target_path: str, test_source_path: str, test_target_path: str, + use_prepared_data: bool = False, max_seq_len: int = 10, restrict_lexicon: bool = False, work_dir: Optional[str] = None, @@ -207,20 +214,46 @@ def run_train_translate(train_params: str, else: quiet_arg = "" with TemporaryDirectory(dir=work_dir, prefix="test_train_translate.") as work_dir: - # Train model - model_path = os.path.join(work_dir, "model") - params = "{} {} {}".format(sockeye.train.__file__, - _TRAIN_PARAMS_COMMON.format(train_source=train_source_path, - train_target=train_target_path, - dev_source=dev_source_path, - dev_target=dev_target_path, - model=model_path, - max_len=max_seq_len, - quiet=quiet_arg), - train_params) - logger.info("Starting training with parameters %s.", train_params) - with patch.object(sys, "argv", params.split()): - sockeye.train.main() + # Optionally create prepared data directory + if use_prepared_data: + prepared_data_path = os.path.join(work_dir, "prepared_data") + params = "{} {}".format(sockeye.prepare_data.__file__, + _PREPARE_DATA_COMMON.format(train_source=train_source_path, + train_target=train_target_path, + output=prepared_data_path, + max_len=max_seq_len, + quiet=quiet_arg)) + logger.info("Creating prepared data folder.") + with patch.object(sys, "argv", params.split()): + sockeye.prepare_data.main() + # Train model + model_path = os.path.join(work_dir, "model") + params = "{} {} {}".format(sockeye.train.__file__, + _TRAIN_PARAMS_PREPARED_DATA_COMMON.format(prepared_data=prepared_data_path, + dev_source=dev_source_path, + dev_target=dev_target_path, + model=model_path, + max_len=max_seq_len, + quiet=quiet_arg), + train_params) + logger.info("Starting training with parameters %s.", train_params) + with patch.object(sys, "argv", params.split()): + sockeye.train.main() + else: + # Train model + model_path = os.path.join(work_dir, "model") + params = "{} {} {}".format(sockeye.train.__file__, + _TRAIN_PARAMS_COMMON.format(train_source=train_source_path, + train_target=train_target_path, + dev_source=dev_source_path, + dev_target=dev_target_path, + model=model_path, + max_len=max_seq_len, + quiet=quiet_arg), + train_params) + logger.info("Starting training with parameters %s.", train_params) + with patch.object(sys, "argv", params.split()): + sockeye.train.main() logger.info("Translating with parameters %s.", translate_params) # Translate corpus with the 1st params diff --git a/test/system/test_seq_copy_sys.py b/test/system/test_seq_copy_sys.py index ebfbac2b0..be5ae8476 100644 --- a/test/system/test_seq_copy_sys.py +++ b/test/system/test_seq_copy_sys.py @@ -29,7 +29,7 @@ _SEED_DEV = 17 -@pytest.mark.parametrize("name, train_params, translate_params, perplexity_thresh, bleu_thresh", [ +@pytest.mark.parametrize("name, train_params, translate_params, use_prepared_data, perplexity_thresh, bleu_thresh", [ ("Copy:lstm:lstm", "--encoder rnn --num-layers 1 --rnn-cell-type lstm --rnn-num-hidden 64 --num-embed 32 --rnn-attention-type mlp" " --rnn-attention-num-hidden 32 --batch-size 16 --loss cross-entropy --optimized-metric perplexity" @@ -37,6 +37,7 @@ " --rnn-dropout-states 0.0:0.1 --embed-dropout 0.1:0.0 --max-updates 4000 --weight-normalization" " --gradient-clipping-type norm --gradient-clipping-threshold 10", "--beam-size 5 ", + True, 1.02, 0.99), ("Copy:chunking", @@ -45,6 +46,7 @@ " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001" " --rnn-dropout-states 0.0:0.1 --embed-dropout 0.1:0.0 --max-updates 5000", "--beam-size 5 --max-input-len 4", + False, 1.01, 0.99), ("Copy:word-based-batching", @@ -53,6 +55,7 @@ " --optimized-metric perplexity --max-updates 5000 --checkpoint-frequency 1000 --optimizer adam " " --initial-learning-rate 0.001 --rnn-dropout-states 0.0:0.1 --embed-dropout 0.1:0.0 --layer-normalization", "--beam-size 5", + True, 1.01, 0.99), ("Copy:transformer:lstm", @@ -63,6 +66,7 @@ " --transformer-feed-forward-num-hidden 64 --transformer-activation-type gelu" " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 5", + False, 1.01, 0.99), ("Copy:lstm:transformer", @@ -73,6 +77,7 @@ " --transformer-feed-forward-num-hidden 64 --transformer-activation-type swish1" " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 5", + True, 1.01, 0.98), ("Copy:transformer:transformer", @@ -82,6 +87,7 @@ " --transformer-feed-forward-num-hidden 64 --num-embed 32" " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 1", + False, 1.01, 0.99), ("Copy:cnn:cnn", @@ -90,10 +96,11 @@ " --cnn-num-hidden 32 --cnn-positional-embedding-type fixed --cnn-project-qkv " " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 1", + True, 1.02, 0.98) ]) -def test_seq_copy(name, train_params, translate_params, perplexity_thresh, bleu_thresh): +def test_seq_copy(name, train_params, translate_params, use_prepared_data, perplexity_thresh, bleu_thresh): """Task: copy short sequences of digits""" with tmp_digits_dataset("test_seq_copy.", _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, _DEV_LINE_COUNT, _LINE_MAX_LENGTH, _TEST_LINE_COUNT, _TEST_LINE_COUNT_EMPTY, _TEST_MAX_LENGTH, @@ -108,6 +115,7 @@ def test_seq_copy(name, train_params, translate_params, perplexity_thresh, bleu_ data['validation_target'], data['test_source'], data['test_target'], + use_prepared_data=use_prepared_data, max_seq_len=_LINE_MAX_LENGTH + 1, restrict_lexicon=True, work_dir=data['work_dir']) @@ -118,12 +126,13 @@ def test_seq_copy(name, train_params, translate_params, perplexity_thresh, bleu_ assert bleu_restrict >= bleu_thresh -@pytest.mark.parametrize("name, train_params, translate_params, perplexity_thresh, bleu_thresh", [ +@pytest.mark.parametrize("name, train_params, translate_params, use_prepared_data, perplexity_thresh, bleu_thresh", [ ("Sort:lstm", "--encoder rnn --num-layers 1 --rnn-cell-type lstm --rnn-num-hidden 64 --num-embed 32 --rnn-attention-type mlp" " --rnn-attention-num-hidden 32 --batch-size 16 --loss cross-entropy --optimized-metric perplexity" " --max-updates 5000 --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 5", + True, 1.04, 0.98), ("Sort:word-based-batching", @@ -132,6 +141,7 @@ def test_seq_copy(name, train_params, translate_params, perplexity_thresh, bleu_ " --optimized-metric perplexity --max-updates 5000 --checkpoint-frequency 1000 --optimizer adam " " --initial-learning-rate 0.001 --rnn-dropout-states 0.0:0.1 --embed-dropout 0.1:0.0", "--beam-size 5", + False, 1.01, 0.99), ("Sort:transformer:lstm", @@ -142,6 +152,7 @@ def test_seq_copy(name, train_params, translate_params, perplexity_thresh, bleu_ " --transformer-feed-forward-num-hidden 64 --transformer-activation-type gelu" " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 5", + True, 1.02, 0.99), ("Sort:lstm:transformer", @@ -152,6 +163,7 @@ def test_seq_copy(name, train_params, translate_params, perplexity_thresh, bleu_ " --transformer-feed-forward-num-hidden 64 --transformer-activation-type swish1" " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 5", + False, 1.02, 0.99), ("Sort:transformer", @@ -161,6 +173,7 @@ def test_seq_copy(name, train_params, translate_params, perplexity_thresh, bleu_ " --transformer-feed-forward-num-hidden 64" " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 1", + True, 1.02, 0.99), ("Sort:cnn", @@ -169,10 +182,11 @@ def test_seq_copy(name, train_params, translate_params, perplexity_thresh, bleu_ " --cnn-num-hidden 32 --cnn-positional-embedding-type fixed" " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 1", + False, 1.07, 0.96) ]) -def test_seq_sort(name, train_params, translate_params, perplexity_thresh, bleu_thresh): +def test_seq_sort(name, train_params, translate_params, use_prepared_data, perplexity_thresh, bleu_thresh): """Task: sort short sequences of digits""" with tmp_digits_dataset("test_seq_sort.", _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, _DEV_LINE_COUNT, _LINE_MAX_LENGTH, _TEST_LINE_COUNT, _TEST_LINE_COUNT_EMPTY, _TEST_MAX_LENGTH, @@ -187,6 +201,7 @@ def test_seq_sort(name, train_params, translate_params, perplexity_thresh, bleu_ data['validation_target'], data['test_source'], data['test_target'], + use_prepared_data=use_prepared_data, max_seq_len=_LINE_MAX_LENGTH + 1, restrict_lexicon=True, work_dir=data['work_dir']) diff --git a/test/unit/test_arguments.py b/test/unit/test_arguments.py index 96249be0e..e4bdbf04b 100644 --- a/test/unit/test_arguments.py +++ b/test/unit/test_arguments.py @@ -21,41 +21,34 @@ from itertools import zip_longest +# note that while --prepared-data and --source/--target are mutually exclusive this is not the case at the CLI level @pytest.mark.parametrize("test_params, expected_params", [ # mandatory parameters - ('--source test_src --target test_tgt ' + ('--source test_src --target test_tgt --prepared-data prep_data ' '--validation-source test_validation_src --validation-target test_validation_tgt ' '--output test_output', - dict(source='test_src', target='test_tgt', limit=None, + dict(source='test_src', target='test_tgt', + prepared_data='prep_data', validation_source='test_validation_src', validation_target='test_validation_tgt', output='test_output', overwrite_output=False, - source_vocab=None, target_vocab=None, use_tensorboard=False, - monitor_pattern=None, monitor_stat_func='mx_default')), - - # all parameters - ('--source test_src --target test_tgt --limit 10 ' - '--validation-source test_validation_src --validation-target test_validation_tgt ' - '--output test_output ' - '--source-vocab test_src_vocab --target-vocab test_tgt_vocab ' - '--use-tensorboard --overwrite-output', - dict(source='test_src', target='test_tgt', limit=10, - validation_source='test_validation_src', validation_target='test_validation_tgt', - output='test_output', overwrite_output=True, - source_vocab='test_src_vocab', target_vocab='test_tgt_vocab', use_tensorboard=True, - monitor_pattern=None, monitor_stat_func='mx_default')), + source_vocab=None, target_vocab=None, shared_vocab=False, num_words=(50000, 50000), word_min_count=(1,1), + no_bucketing=False, bucket_width=10, max_seq_len=(100, 100), + monitor_pattern=None, monitor_stat_func='mx_default', use_tensorboard=False)), # short parameters - ('-s test_src -t test_tgt ' + ('-s test_src -t test_tgt -d prep_data ' '-vs test_validation_src -vt test_validation_tgt ' '-o test_output', - dict(source='test_src', target='test_tgt', limit=None, + dict(source='test_src', target='test_tgt', + prepared_data='prep_data', validation_source='test_validation_src', validation_target='test_validation_tgt', output='test_output', overwrite_output=False, - source_vocab=None, target_vocab=None, use_tensorboard=False, - monitor_pattern=None, monitor_stat_func='mx_default')) + source_vocab=None, target_vocab=None, shared_vocab=False, num_words=(50000, 50000), word_min_count=(1,1), + no_bucketing=False, bucket_width=10, max_seq_len=(100, 100), + monitor_pattern=None, monitor_stat_func='mx_default', use_tensorboard=False)) ]) def test_io_args(test_params, expected_params): - _test_args(test_params, expected_params, arguments.add_io_args) + _test_args(test_params, expected_params, arguments.add_training_io_args) @pytest.mark.parametrize("test_params, expected_params", [ @@ -77,8 +70,6 @@ def test_device_args(test_params, expected_params): @pytest.mark.parametrize("test_params, expected_params", [ ('', dict(params=None, allow_missing_params=False, - num_words=(50000, 50000), - word_min_count=(1, 1), num_layers=(1, 1), num_embed=(512, 512), rnn_attention_type='mlp', @@ -87,7 +78,6 @@ def test_device_args(test_params, expected_params): rnn_attention_coverage_num_hidden=1, weight_tying=False, weight_tying_type="trg_softmax", - max_seq_len=(100, 100), rnn_attention_mhdot_heads=None, transformer_attention_heads=8, transformer_feed_forward_num_hidden=2048, @@ -128,8 +118,6 @@ def test_model_parameters(test_params, expected_params): ('', dict(batch_size=64, batch_type="sentence", fill_up='replicate', - no_bucketing=False, - bucket_width=10, loss=C.CROSS_ENTROPY, label_smoothing=0.0, loss_normalization_type='valid', @@ -312,7 +300,28 @@ def test_tutorial_averaging_args(test_params, expected_params, expected_params_p _test_args_subset(test_params, expected_params, expected_params_present, arguments.add_average_args) -def _create_argument_values_that_must_be_files(params): +@pytest.mark.parametrize("test_params, expected_params", [ + ('--source test_src --target test_tgt --output prepared_data ', + dict(source='test_src', target='test_tgt', + source_vocab=None, + target_vocab=None, + shared_vocab=False, + num_words=(50000, 50000), + word_min_count=(1,1), + no_bucketing=False, + bucket_width=10, + max_seq_len=(100, 100), + min_num_shards=1, + num_samples_per_shard=1000000, + seed=13, + output='prepared_data' + )) +]) +def test_prepare_data_cli_args(test_params, expected_params): + _test_args(test_params, expected_params, arguments.add_prepare_data_cli_args) + + +def _create_argument_values_that_must_be_files_or_dirs(params): """ Loop over test_params and create temporary files for training/validation sources/targets. """ @@ -323,29 +332,39 @@ def grouper(iterable, n, fillvalue=None): return zip_longest(fillvalue=fillvalue, *args) params = params.split() - regular_files_params = {'-vs', '-vt', '-t', '-s', '--source', '--target', '--validation-source', '--validation-target'} + regular_files_params = {'-vs', '-vt', '-t', '-s', '--source', '--target', + '--validation-source', '--validation-target'} + folder_params = {'--prepared-data', '-d'} to_unlink = set() for arg, val in grouper(params, 2): if arg in regular_files_params and not os.path.isfile(val): - to_unlink.add((val, open(val, 'w'))) + open(val, 'w').close() + to_unlink.add(val) + if arg in folder_params: + os.mkdir(val) + to_unlink.add(val) return to_unlink -def _delete_argument_values_that_must_be_files(to_unlink): +def _delete_argument_values_that_must_be_files_or_dirs(to_unlink): """ - Close and delete previously created files. + Close and delete previously created files or directories. """ - for name, f in to_unlink: - f.close() - os.unlink(name) + for name in to_unlink: + if os.path.isfile(name): + os.unlink(name) + else: + os.rmdir(name) def _test_args(test_params, expected_params, args_func): test_parser = argparse.ArgumentParser() args_func(test_parser) - created = _create_argument_values_that_must_be_files(test_params) - parsed_params = test_parser.parse_args(test_params.split()) - _delete_argument_values_that_must_be_files(created) + created = _create_argument_values_that_must_be_files_or_dirs(test_params) + try: + parsed_params = test_parser.parse_args(test_params.split()) + finally: + _delete_argument_values_that_must_be_files_or_dirs(created) assert dict(vars(parsed_params)) == expected_params @@ -360,9 +379,9 @@ def _test_args_subset(test_params, expected_params, expected_params_present, arg """ test_parser = argparse.ArgumentParser() args_func(test_parser) - created = _create_argument_values_that_must_be_files(test_params) + created = _create_argument_values_that_must_be_files_or_dirs(test_params) parsed_params = dict(vars(test_parser.parse_args(test_params.split()))) - _delete_argument_values_that_must_be_files(created) + _delete_argument_values_that_must_be_files_or_dirs(created) parsed_params_subset = {k: v for k, v in parsed_params.items() if k in expected_params} assert parsed_params_subset == expected_params for expected_param_present in expected_params_present: diff --git a/test/unit/test_checkpoint.py b/test/unit/test_checkpoint.py deleted file mode 100644 index 412608d71..000000000 --- a/test/unit/test_checkpoint.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not -# use this file except in compliance with the License. A copy of the License -# is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import tempfile - -import numpy as np -import pytest - -import sockeye.data_io -from test.common import generate_random_sentence - - -def create_parallel_sentence_iter(source_sentences, target_sentences, max_len, batch_size, batch_by_words): - buckets = sockeye.data_io.define_parallel_buckets(max_len, max_len, 10) - batch_num_devices = 1 - eos = 0 - pad = 1 - unk = 2 - bucket_iterator = sockeye.data_io.ParallelBucketSentenceIter(source_sentences, - target_sentences, - buckets, - batch_size, - batch_by_words, - batch_num_devices, - eos, pad, unk) - return bucket_iterator - - -def data_batches_equal(db1, db2): - # We just compare the data, should probably be enough - equal = True - for data1, data2 in zip(db1.data, db2.data): - equal = equal and np.allclose(data1.asnumpy(), data2.asnumpy()) - return equal - - -@pytest.mark.parametrize("batch_size, batch_by_words", [ - (50, False), - (123, True), -]) -def test_parallel_sentence_iter(batch_size, batch_by_words): - # Create random sentences - vocab_size = 100 - max_len = 100 - source_sentences = [] - target_sentences = [] - for _ in range(1000): - source_sentences.append(generate_random_sentence(vocab_size, max_len)) - target_sentences.append(generate_random_sentence(vocab_size, max_len)) - - ori_iterator = create_parallel_sentence_iter(source_sentences, target_sentences, max_len, batch_size, batch_by_words) - ori_iterator.reset() # Random order - # Simulate some iterations - ori_iterator.next() - ori_iterator.next() - ori_iterator.next() - ori_iterator.next() - expected_output = ori_iterator.next() - # expected_output because the user is expected to call next() after loading - - # Save the state to disk - tmp_file = tempfile.NamedTemporaryFile() - ori_iterator.save_state(tmp_file.name) - - # Load the state in a new iterator - load_iterator = create_parallel_sentence_iter(source_sentences, target_sentences, max_len, batch_size, batch_by_words) - load_iterator.reset() # Random order - load_iterator.load_state(tmp_file.name) - - # Compare the outputs - loaded_output = load_iterator.next() - assert data_batches_equal(expected_output, loaded_output) diff --git a/test/unit/test_data_io.py b/test/unit/test_data_io.py index 4869503c3..ffa04bef4 100644 --- a/test/unit/test_data_io.py +++ b/test/unit/test_data_io.py @@ -11,14 +11,23 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. +import os +import random +from tempfile import TemporaryDirectory +from typing import Optional, List, Tuple + +import mxnet as mx import numpy as np import pytest from sockeye import constants as C from sockeye import data_io from sockeye import vocab +from sockeye.utils import SockeyeError, get_tokens, seedRNGs from test.common import tmp_digits_dataset +seedRNGs(12) + define_bucket_tests = [(50, 10, [10, 20, 30, 40, 50]), (50, 20, [20, 40, 50]), (50, 50, [50]), @@ -74,16 +83,6 @@ def test_get_bucket(buckets, length, expected_bucket): assert bucket == expected_bucket -get_tokens_tests = [("this is a line \n", ["this", "is", "a", "line"]), - (" a \tb \r \n", ["a", "b"])] - - -@pytest.mark.parametrize("line, expected_tokens", get_tokens_tests) -def test_get_tokens(line, expected_tokens): - tokens = list(data_io.get_tokens(line)) - assert tokens == expected_tokens - - tokens2ids_tests = [(["a", "b", "c"], {"a": 1, "b": 0, "c": 300, C.UNK_SYMBOL: 12}, [1, 0, 300]), (["a", "x", "c"], {"a": 1, "b": 0, "c": 300, C.UNK_SYMBOL: 12}, [1, 12, 300])] @@ -94,6 +93,238 @@ def test_tokens2ids(tokens, vocab, expected_ids): assert ids == expected_ids +@pytest.mark.parametrize("tokens, expected_ids", [(["1", "2", "3", "0"], [1, 2, 3, 0]), ([], [])]) +def test_strids2ids(tokens, expected_ids): + ids = data_io.strids2ids(tokens) + assert ids == expected_ids + + +@pytest.mark.parametrize("ids, expected_string", [([1, 2, 3, 0], "1 2 3 0"), ([], "")]) +def test_ids2strids(ids, expected_string): + string = data_io.ids2strids(ids) + assert string == expected_string + + +sequence_reader_tests = [(["1 2 3", "2", "14", "2 2 2"], False, False), + (["a b c", "c"], True, False), + (["a b c", "c"], True, True)] + + +@pytest.mark.parametrize("sequences, use_vocab, add_bos", sequence_reader_tests) +def test_sequence_reader(sequences, use_vocab, add_bos): + with TemporaryDirectory() as work_dir: + path = os.path.join(work_dir, 'input') + with open(path, 'w') as f: + for sequence in sequences: + f.write(sequence + "\n") + + vocabulary = vocab.build_vocab(sequences) if use_vocab else None + + reader = data_io.SequenceReader(path, vocab=vocabulary, add_bos=add_bos) + + read_sequences = [s for s in reader] + assert reader.is_done() + assert len(read_sequences) == reader.count + + if vocabulary is None: + with pytest.raises(SockeyeError) as e: + _ = data_io.SequenceReader(path, vocab=vocabulary, add_bos=True) + assert str(e.value) == "Adding a BOS symbol requires a vocabulary" + + expected_sequences = [data_io.strids2ids(get_tokens(s)) for s in sequences] + assert read_sequences == expected_sequences + else: + expected_sequences = [data_io.tokens2ids(get_tokens(s), vocabulary) for s in sequences] + if add_bos: + expected_sequences = [[vocabulary[C.BOS_SYMBOL]] + s for s in expected_sequences] + assert read_sequences == expected_sequences + + # check raise for multiple concurrent iters + _ = iter(reader) + with pytest.raises(SockeyeError) as e: + iter(reader) + assert str(e.value) == "Can not iterate multiple times simultaneously." + + +def test_sample_based_define_bucket_batch_sizes(): + batch_by_words = False + batch_size = 32 + max_seq_len = 100 + buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1.5) + bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets, + batch_size=batch_size, + batch_by_words=batch_by_words, + batch_num_devices=1, + data_target_average_len=[None] * len(buckets)) + for bbs in bucket_batch_sizes: + assert bbs.batch_size == batch_size + assert bbs.average_words_per_batch == bbs.bucket[1] * batch_size + + +def test_word_based_define_bucket_batch_sizes(): + batch_by_words = True + batch_num_devices = 1 + batch_size = 200 + max_seq_len = 100 + buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1.5) + bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets, + batch_size=batch_size, + batch_by_words=batch_by_words, + batch_num_devices=batch_num_devices, + data_target_average_len=[None] * len(buckets)) + # last bucket batch size is different + for bbs in bucket_batch_sizes[:-1]: + expected_batch_size = round((batch_size / bbs.bucket[1]) / batch_num_devices) + assert bbs.batch_size == expected_batch_size + expected_average_words_per_batch = expected_batch_size * bbs.bucket[1] + assert bbs.average_words_per_batch == expected_average_words_per_batch + + +def _get_random_bucketed_data(buckets: List[Tuple[int, int]], + min_count: int, + max_count: int, + bucket_counts: Optional[List[Optional[int]]] = None): + """ + Get random bucket data. + + :param buckets: The list of buckets. + :param min_count: The minimum number of samples that will be sampled if no exact count is given. + :param max_count: The maximum number of samples that will be sampled if no exact count is given. + :param bucket_counts: For each bucket an optional exact example count can be given. If it is not given it will be + sampled. + :return: The random source, target and label arrays. + """ + if bucket_counts is None: + bucket_counts = [None for _ in buckets] + bucket_counts = [random.randint(min_count, max_count) if given_count is None else given_count + for given_count in bucket_counts] + source = [mx.nd.array(np.random.randint(0, 10, (count, random.randint(1, bucket[0])))) for count, bucket in + zip(bucket_counts, buckets)] + target = [mx.nd.array(np.random.randint(0, 10, (count, random.randint(1, bucket[1])))) for count, bucket in + zip(bucket_counts, buckets)] + label = target + return source, target, label + + +def test_parallel_data_set(): + buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) + source, target, label = _get_random_bucketed_data(buckets, min_count=0, max_count=5) + + def check_equal(arrays1, arrays2): + assert len(arrays1) == len(arrays2) + for a1, a2 in zip(arrays1, arrays2): + assert np.array_equal(a1.asnumpy(), a2.asnumpy()) + + with TemporaryDirectory() as work_dir: + dataset = data_io.ParallelDataSet(source, target, label) + fname = os.path.join(work_dir, 'dataset') + dataset.save(fname) + dataset_loaded = data_io.ParallelDataSet.load(fname) + check_equal(dataset.source, dataset_loaded.source) + check_equal(dataset.target, dataset_loaded.target) + check_equal(dataset.label, dataset_loaded.label) + + +def test_parallel_data_set_fill_up(): + batch_size = 32 + buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) + bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, + batch_size, + batch_by_words=False, + batch_num_devices=1, + data_target_average_len=[None] * len(buckets)) + dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=1, max_count=5)) + + dataset_filled_up = dataset.fill_up(bucket_batch_sizes, 'replicate') + assert len(dataset_filled_up.source) == len(dataset.source) + assert len(dataset_filled_up.target) == len(dataset.target) + assert len(dataset_filled_up.label) == len(dataset.label) + for bidx in range(len(dataset)): + bucket_batch_size = bucket_batch_sizes[bidx].batch_size + assert dataset_filled_up.source[bidx].shape[0] == bucket_batch_size + assert dataset_filled_up.target[bidx].shape[0] == bucket_batch_size + assert dataset_filled_up.label[bidx].shape[0] == bucket_batch_size + + +def test_get_permutations(): + data = [list(range(3)), list(range(1)), list(range(7)), []] + bucket_counts = [len(d) for d in data] + + permutation, inverse_permutation = data_io.get_permutations(bucket_counts) + assert len(permutation) == len(inverse_permutation) == len(bucket_counts) == len(data) + + for d, p, pi in zip(data, permutation, inverse_permutation): + p = p.asnumpy().astype(np.int) + pi = pi.asnumpy().astype(np.int) + p_set = set(p) + pi_set = set(pi) + assert len(p_set) == len(p) + assert len(pi_set) == len(pi) + assert p_set - pi_set == set() + if d: + d = np.array(d) + assert (d[p][pi] == d).all() + else: + assert len(p_set) == 1 + + +def test_parallel_data_set_permute(): + batch_size = 5 + buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) + bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, + batch_size, + batch_by_words=False, + batch_num_devices=1, + data_target_average_len=[None] * len(buckets)) + dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5)).fill_up( + bucket_batch_sizes, 'replicate') + + permutations, inverse_permutations = data_io.get_permutations(dataset.get_bucket_counts()) + + assert len(permutations) == len(inverse_permutations) == len(dataset) + dataset_restored = dataset.permute(permutations).permute(inverse_permutations) + assert len(dataset) == len(dataset_restored) + for buck_idx in range(len(dataset)): + num_samples = dataset.source[buck_idx].shape[0] + if num_samples: + assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).asnumpy().all() + assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).asnumpy().all() + assert (dataset.label[buck_idx] == dataset_restored.label[buck_idx]).asnumpy().all() + else: + assert not dataset_restored.source[buck_idx] + assert not dataset_restored.target[buck_idx] + assert not dataset_restored.label[buck_idx] + + +def test_get_batch_indices(): + max_bucket_size = 50 + batch_size = 10 + buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) + bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, + batch_size, + batch_by_words=False, + batch_num_devices=1, + data_target_average_len=[None] * len(buckets)) + dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets=buckets, + min_count=1, + max_count=max_bucket_size)) + + indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) + + # check for valid indices + for buck_idx, start_pos in indices: + assert 0 <= buck_idx < len(dataset) + assert 0 <= start_pos < len(dataset.source[buck_idx]) - batch_size + 1 + + # check that all indices are used for a filled-up dataset + dataset = dataset.fill_up(bucket_batch_sizes, fill_up='replicate') + indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) + all_bucket_indices = set(list(range(len(dataset)))) + computed_bucket_indices = set([i for i, j in indices]) + + assert not all_bucket_indices - computed_bucket_indices + + @pytest.mark.parametrize("buckets, expected_default_bucket_key", [([(10, 10), (20, 20), (30, 30), (40, 40), (50, 50)], (50, 50)), ([(5, 10), (10, 20), (15, 30), (25, 50), (20, 40)], (25, 50))]) @@ -119,17 +350,19 @@ def test_get_parallel_bucket(buckets, source_length, target_length, expected_buc assert bucket == expected_bucket -@pytest.mark.parametrize("source, target, expected_mean, expected_std", +@pytest.mark.parametrize("source, target, expected_num_sents, expected_mean, expected_std", [([[1, 1, 1], [2, 2, 2], [3, 3, 3]], - [[1, 1, 1], [2, 2, 2], [3, 3, 3]], 1.0, 0.0), + [[1, 1, 1], [2, 2, 2], [3, 3, 3]], 3, 1.0, 0.0), ([[1, 1], [2, 2], [3, 3]], - [[1, 1, 1], [2, 2, 2], [3, 3, 3]], 1.5, 0.0), - ([[1, 1, 1], [2, 2]], - [[1, 1, 1], [2], [3, 3, 3]], 0.75, 0.25)]) -def test_length_statistics(source, target, expected_mean, expected_std): - mean, std = data_io.length_statistics(source, target) - assert np.isclose(mean, expected_mean) - assert np.isclose(std, expected_std) + [[1, 1, 1], [2, 2, 2], [3, 3, 3]], 3, 1.5, 0.0), + ([[1, 1, 1], [2, 2], [3, 3, 3, 3, 3, 3, 3]], + [[1, 1, 1], [2], [3, 3, 3]], 2, 0.75, 0.25)]) +def test_calculate_length_statistics(source, target, expected_num_sents, expected_mean, expected_std): + length_statistics = data_io.calculate_length_statistics(source, target, 5, 5) + assert len(source) == len(target) + assert length_statistics.num_sents == expected_num_sents + assert np.isclose(length_statistics.length_ratio_mean, expected_mean) + assert np.isclose(length_statistics.length_ratio_std, expected_std) def test_get_training_data_iters(): @@ -137,11 +370,11 @@ def test_get_training_data_iters(): train_max_length = 30 dev_line_count = 20 dev_max_length = 30 + expected_mean = 1.0 + expected_std = 0.0 test_line_count = 20 test_line_count_empty = 0 test_max_length = 30 - expected_mean = 1.1476392401276574 - expected_std = 0.2318455878853099 batch_size = 5 with tmp_digits_dataset("tmp_corpus", train_line_count, train_max_length, dev_line_count, dev_max_length, @@ -156,6 +389,7 @@ def test_get_training_data_iters(): vocab_target=vcb, vocab_source_path=None, vocab_target_path=None, + shared_vocab=True, batch_size=batch_size, batch_by_words=False, batch_num_devices=1, @@ -164,27 +398,23 @@ def test_get_training_data_iters(): max_seq_len_target=train_max_length, bucketing=True, bucket_width=10) + assert isinstance(train_iter, data_io.ParallelSampleIter) + assert isinstance(val_iter, data_io.ParallelSampleIter) + assert isinstance(config_data, data_io.DataConfig) assert config_data.source == data['source'] assert config_data.target == data['target'] - assert config_data.validation_source == data['validation_source'] - assert config_data.validation_target == data['validation_target'] assert config_data.vocab_source is None assert config_data.vocab_target is None - assert config_data.max_observed_source_seq_len == train_max_length - 1 - assert config_data.max_observed_target_seq_len == train_max_length - assert np.isclose(config_data.length_ratio_mean, expected_mean) - assert np.isclose(config_data.length_ratio_std, expected_std) + assert config_data.data_statistics.max_observed_len_source == train_max_length - 1 + assert config_data.data_statistics.max_observed_len_target == train_max_length + assert np.isclose(config_data.data_statistics.length_ratio_mean, expected_mean) + assert np.isclose(config_data.data_statistics.length_ratio_std, expected_std) assert train_iter.batch_size == batch_size assert val_iter.batch_size == batch_size assert train_iter.default_bucket_key == (train_max_length, train_max_length) assert val_iter.default_bucket_key == (dev_max_length, dev_max_length) - assert train_iter.max_observed_source_len == config_data.max_observed_source_seq_len - assert train_iter.max_observed_target_len == config_data.max_observed_target_seq_len - assert train_iter.pad_id == vcb[C.PAD_SYMBOL] assert train_iter.dtype == 'float32' - assert not train_iter.batch_by_words - assert train_iter.fill_up == 'replicate' # test some batches bos_id = vcb[C.BOS_SYMBOL] @@ -198,9 +428,7 @@ def test_get_training_data_iters(): source = batch.data[0].asnumpy() target = batch.data[1].asnumpy() label = batch.label[0].asnumpy() - assert source.shape[0] == batch_size - assert target.shape[0] == batch_size - assert label.shape[0] == batch_size + assert source.shape[0] == target.shape[0] == label.shape[0] == batch_size # target first symbol should be BOS assert np.array_equal(target[:, 0], expected_first_target_symbols) # label first symbol should be 2nd target symbol @@ -208,3 +436,225 @@ def test_get_training_data_iters(): # each label sequence contains one EOS symbol assert np.sum(label == vcb[C.EOS_SYMBOL]) == batch_size train_iter.reset() + + +def _data_batches_equal(db1, db2): + # We just compare the data, should probably be enough + equal = True + for data1, data2 in zip(db1.data, db2.data): + equal = equal and np.allclose(data1.asnumpy(), data2.asnumpy()) + return equal + + +def test_parallel_sample_iter(): + batch_size = 2 + buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) + # The first bucket is going to be empty: + bucket_counts = [0] + [None] * (len(buckets) - 1) + bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, + batch_size, + batch_by_words=False, + batch_num_devices=1, + data_target_average_len=[None] * len(buckets)) + + dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5, + bucket_counts=bucket_counts)) + it = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) + + with TemporaryDirectory() as work_dir: + # Test 1 + it.next() + expected_batch = it.next() + + fname = os.path.join(work_dir, "saved_iter") + it.save_state(fname) + + it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) + it_loaded.reset() + it_loaded.load_state(fname) + loaded_batch = it_loaded.next() + assert _data_batches_equal(expected_batch, loaded_batch) + + # Test 2 + it.reset() + expected_batch = it.next() + it.save_state(fname) + + it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) + it_loaded.reset() + it_loaded.load_state(fname) + + loaded_batch = it_loaded.next() + assert _data_batches_equal(expected_batch, loaded_batch) + + # Test 3 + it.reset() + expected_batch = it.next() + it.save_state(fname) + it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) + it_loaded.reset() + it_loaded.load_state(fname) + + loaded_batch = it_loaded.next() + assert _data_batches_equal(expected_batch, loaded_batch) + + while it.iter_next(): + it.next() + it_loaded.next() + assert not it_loaded.iter_next() + + +def test_sharded_parallel_sample_iter(): + batch_size = 2 + buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) + # The first bucket is going to be empty: + bucket_counts = [0] + [None] * (len(buckets) - 1) + bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, + batch_size, + batch_by_words=False, + batch_num_devices=1, + data_target_average_len=[None] * len(buckets)) + + dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5, + bucket_counts=bucket_counts)) + dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5, + bucket_counts=bucket_counts)) + + with TemporaryDirectory() as work_dir: + shard1_fname = os.path.join(work_dir, 'shard1') + shard2_fname = os.path.join(work_dir, 'shard2') + dataset1.save(shard1_fname) + dataset2.save(shard2_fname) + shard_fnames = [shard1_fname, shard2_fname] + + it = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') + + with TemporaryDirectory() as work_dir: + # Test 1 + it.next() + expected_batch = it.next() + + fname = os.path.join(work_dir, "saved_iter") + it.save_state(fname) + + it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, + 'replicate') + it_loaded.reset() + it_loaded.load_state(fname) + loaded_batch = it_loaded.next() + assert _data_batches_equal(expected_batch, loaded_batch) + + # Test 2 + it.reset() + expected_batch = it.next() + it.save_state(fname) + + it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, + 'replicate') + it_loaded.reset() + it_loaded.load_state(fname) + + loaded_batch = it_loaded.next() + assert _data_batches_equal(expected_batch, loaded_batch) + + # Test 3 + it.reset() + expected_batch = it.next() + it.save_state(fname) + it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, + 'replicate') + it_loaded.reset() + it_loaded.load_state(fname) + + loaded_batch = it_loaded.next() + assert _data_batches_equal(expected_batch, loaded_batch) + + while it.iter_next(): + it.next() + it_loaded.next() + assert not it_loaded.iter_next() + + +def test_sharded_parallel_sample_iter_num_batches(): + num_shards = 2 + batch_size = 2 + num_batches_per_bucket = 10 + buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) + bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets] + num_batches_per_shard = num_batches_per_bucket * len(buckets) + num_batches = num_shards * num_batches_per_shard + bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, + batch_size, + batch_by_words=False, + batch_num_devices=1, + data_target_average_len=[None] * len(buckets)) + + dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5, + bucket_counts=bucket_counts)) + dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5, + bucket_counts=bucket_counts)) + with TemporaryDirectory() as work_dir: + shard1_fname = os.path.join(work_dir, 'shard1') + shard2_fname = os.path.join(work_dir, 'shard2') + dataset1.save(shard1_fname) + dataset2.save(shard2_fname) + shard_fnames = [shard1_fname, shard2_fname] + + it = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, + 'replicate') + + num_batches_seen = 0 + while it.iter_next(): + it.next() + num_batches_seen += 1 + assert num_batches_seen == num_batches + + +def test_sharded_and_parallel_iter_same_num_batches(): + """ Tests that a sharded data iterator with just a single shard produces as many shards as an iterator directly + using the same dataset. """ + batch_size = 2 + num_batches_per_bucket = 10 + buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) + bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets] + num_batches = num_batches_per_bucket * len(buckets) + bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, + batch_size, + batch_by_words=False, + batch_num_devices=1, + data_target_average_len=[None] * len(buckets)) + + dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5, + bucket_counts=bucket_counts)) + + with TemporaryDirectory() as work_dir: + shard_fname = os.path.join(work_dir, 'shard1') + dataset.save(shard_fname) + shard_fnames = [shard_fname] + + it_sharded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, + 'replicate') + + it_parallel = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) + + num_batches_seen = 0 + while it_parallel.iter_next(): + assert it_sharded.iter_next() + it_parallel.next() + it_sharded.next() + num_batches_seen += 1 + assert num_batches_seen == num_batches + + print("Resetting...") + it_sharded.reset() + it_parallel.reset() + + num_batches_seen = 0 + while it_parallel.iter_next(): + assert it_sharded.iter_next() + it_parallel.next() + it_sharded.next() + + num_batches_seen += 1 + + assert num_batches_seen == num_batches diff --git a/test/unit/test_translate.py b/test/unit/test_translate.py index 3eb5c7030..b18474d8a 100644 --- a/test/unit/test_translate.py +++ b/test/unit/test_translate.py @@ -60,7 +60,8 @@ def test_translate_by_file(mock_file, mock_translator, mock_output_handler): def test_translate_by_stdin_chunk2(mock_translator, mock_output_handler): mock_translator.translate.return_value = ['', ''] mock_translator.batch_size = 1 - sockeye.translate.read_and_translate(translator=mock_translator, output_handler=mock_output_handler, + sockeye.translate.read_and_translate(translator=mock_translator, + output_handler=mock_output_handler, chunk_size=2) # Ensure that our translator has the correct input passed to it. diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index ac51c5a26..0335dc3a3 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -14,6 +14,7 @@ import os import tempfile +import math import mxnet as mx import numpy as np import pytest @@ -33,6 +34,7 @@ def test_chunks(some_list, expected): chunked_list = list(utils.chunks(some_list, chunk_size)) assert chunked_list == expected + def test_get_alignments(): attention_matrix = np.asarray([[0.1, 0.4, 0.5], [0.2, 0.8, 0.0], @@ -194,6 +196,44 @@ def test_check_version_checks_major(): assert "Given major version (%s) does not match major code version (%s)" % (version, __version__) == str(e.value) +@pytest.mark.parametrize("samples,expected_mean, expected_variance", + [ + ([1, 2], 1.5, 0.25), + ([4., 100., 12., -3, 1000, 1., -200], 130.57142857142858, 132975.38775510204), + ]) +def test_online_mean_and_variance(samples, expected_mean, expected_variance): + mean_and_variance = utils.OnlineMeanAndVariance() + for sample in samples: + mean_and_variance.update(sample) + + assert np.isclose(mean_and_variance.mean, expected_mean) + assert np.isclose(mean_and_variance.variance, expected_variance) + + +@pytest.mark.parametrize("samples,expected_mean", + [ + ([], 0.), + ([5.], 5.), + ]) +def test_online_mean_and_variance_nan(samples, expected_mean): + mean_and_variance = utils.OnlineMeanAndVariance() + for sample in samples: + mean_and_variance.update(sample) + + assert np.isclose(mean_and_variance.mean, expected_mean) + assert math.isnan(mean_and_variance.variance) + + +get_tokens_tests = [("this is a line \n", ["this", "is", "a", "line"]), + (" a \tb \r \n", ["a", "b"])] + + +@pytest.mark.parametrize("line, expected_tokens", get_tokens_tests) +def test_get_tokens(line, expected_tokens): + tokens = list(utils.get_tokens(line)) + assert tokens == expected_tokens + + def test_average_arrays(): n = 4 shape = (12, 14) diff --git a/typechecked-files b/typechecked-files index e293d594e..c0400234c 100644 --- a/typechecked-files +++ b/typechecked-files @@ -1,13 +1,13 @@ sockeye/__init__.py sockeye/arguments.py -sockeye/rnn_attention.py sockeye/average.py sockeye/callback.py sockeye/checkpoint_decoder.py +sockeye/chrf.py sockeye/config.py sockeye/constants.py +sockeye/convolution.py sockeye/coverage.py -sockeye/chrf.py sockeye/data_io.py sockeye/decoder.py sockeye/embeddings.py @@ -24,9 +24,12 @@ sockeye/lr_scheduler.py sockeye/model.py sockeye/optimizers.py sockeye/output_handler.py +sockeye/prepare_data.py sockeye/rnn.py +sockeye/rnn_attention.py sockeye/train.py sockeye/training.py sockeye/transformer.py sockeye/translate.py +sockeye/utils.py sockeye/vocab.py