diff --git a/onmt/transforms/misc.py b/onmt/transforms/misc.py index 3885eb97f0..875438255b 100644 --- a/onmt/transforms/misc.py +++ b/onmt/transforms/misc.py @@ -1,6 +1,8 @@ from onmt.utils.logging import logger from onmt.transforms import register_transform from .transform import Transform +import random +from onmt.constants import ModelTask, SubwordMarker @register_transform(name='filtertoolong') @@ -41,6 +43,81 @@ def _repr_args(self): ) +@register_transform(name="joiner_dropout") +class JoinerDropoutTransform(Transform): + """Disjoin joiner with probability dropout""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + """Avalilable options relate to this Transform.""" + group = parser.add_argument_group("Transform/JoinerDropout") + group.add( + "--src_joiner_dropout", + "-src_joiner_dropout", + type=float, + default=0.0, + help="Source dropout probability.", + ) + group.add( + "--tgt_joiner_dropout", + "-tgt_joiner_dropout", + type=float, + default=0.0, + help="Target dropout probability.", + ) + + def _parse_opts(self): + self.src_joiner_dropout = self.opts.src_joiner_dropout + self.tgt_joiner_dropout = self.opts.tgt_joiner_dropout + self.model_task = getattr(self.opts, "model_task", None) + + def dropout_separate_joiner(self, seq, side="src"): + out_seq = [] + dropout = ( + self.src_joiner_dropout + if side == "src" + else self.tgt_joiner_dropout + ) + for elem in seq: + if len(elem) > 1 and elem.startswith(SubwordMarker.JOINER): + if random.random() < dropout: + out_seq.append(SubwordMarker.JOINER) + elem = elem[1:] + if len(elem) > 1 and elem.endswith(SubwordMarker.JOINER): + if random.random() < dropout: + out_seq.append(elem[:-1]) + elem = elem[-1:] + out_seq.append(elem) + + return out_seq + + def apply(self, example, is_train=False, stats=None, **kwargs): + """Return None if too long else return as is.""" + if not is_train: + return example + else: + src_out = self.dropout_separate_joiner(example["src"], "src") + example["src"] = src_out + if self.model_task == ModelTask.LANGUAGE_MODEL: + example["tgt"] = src_out + else: + tgt_out = self.dropout_separate_joiner(example["tgt"], "tgt") + example["tgt"] = tgt_out + return example + + def _repr_args(self): + """Return str represent key arguments for class.""" + return "{}={}, {}={}".format( + "src_joiner_dropout", + self.src_joiner_dropout, + "tgt_joiner_dropout", + self.tgt_joiner_dropout, + ) + + @register_transform(name='prefix') class PrefixTransform(Transform): """Add Prefix to src (& tgt) sentence.""" diff --git a/onmt/transforms/tokenize.py b/onmt/transforms/tokenize.py index 4f343e476c..9254a2c26d 100644 --- a/onmt/transforms/tokenize.py +++ b/onmt/transforms/tokenize.py @@ -2,6 +2,7 @@ from onmt.utils.logging import logger from onmt.transforms import register_transform from .transform import Transform +from onmt.constants import ModelTask class TokenizerTransform(Transform): @@ -90,6 +91,7 @@ def _parse_opts(self): self.tgt_subword_vocab = self.opts.tgt_subword_vocab self.src_vocab_threshold = self.opts.src_vocab_threshold self.tgt_vocab_threshold = self.opts.tgt_vocab_threshold + self.model_task = getattr(self.opts, "model_task", None) def _repr_args(self): """Return str represent key arguments for TokenizerTransform.""" @@ -169,7 +171,10 @@ def _tokenize(self, tokens, side='src', is_train=False): def apply(self, example, is_train=False, stats=None, **kwargs): """Apply sentencepiece subword encode to src & tgt.""" src_out = self._tokenize(example['src'], 'src', is_train) - tgt_out = self._tokenize(example['tgt'], 'tgt', is_train) + if self.model_task == ModelTask.LANGUAGE_MODEL: + tgt_out = src_out + else: + tgt_out = self._tokenize(example['tgt'], 'tgt', is_train) if stats is not None: n_words = len(example['src']) + len(example['tgt']) n_subwords = len(src_out) + len(tgt_out) @@ -243,7 +248,10 @@ def _tokenize(self, tokens, side='src', is_train=False): def apply(self, example, is_train=False, stats=None, **kwargs): """Apply bpe subword encode to src & tgt.""" src_out = self._tokenize(example['src'], 'src', is_train) - tgt_out = self._tokenize(example['tgt'], 'tgt', is_train) + if self.model_task == ModelTask.LANGUAGE_MODEL: + tgt_out = src_out + else: + tgt_out = self._tokenize(example['tgt'], 'tgt', is_train) if stats is not None: n_words = len(example['src']) + len(example['tgt']) n_subwords = len(src_out) + len(tgt_out) @@ -327,7 +335,7 @@ def get_specials(cls, opts): tgt_specials.update(_case_specials) return (set(), set()) - def _get_subword_kwargs(self, side='src'): + def _get_subword_kwargs(self, side='src', is_train=False): """Return a dict containing kwargs relate to `side` subwords.""" subword_type = self.tgt_subword_type if side == 'tgt' \ else self.src_subword_type @@ -338,6 +346,10 @@ def _get_subword_kwargs(self, side='src'): subword_alpha = self.tgt_subword_alpha if side == 'tgt' \ else self.src_subword_alpha kwopts = dict() + if not is_train: + # disable random aspects during validation + subword_alpha = 0 + subword_nbest = 1 if subword_type == 'bpe': kwopts['bpe_model_path'] = subword_model kwopts['bpe_dropout'] = subword_alpha @@ -360,42 +372,65 @@ def warm_up(self, vocabs=None): """Initialize Tokenizer models.""" super().warm_up(None) import pyonmttok - src_subword_kwargs = self._get_subword_kwargs(side='src') + + src_subword_kwargs = self._get_subword_kwargs( + side="src", is_train=True + ) + valid_src_subword_kwargs = self._get_subword_kwargs( + side="src", is_train=False + ) src_tokenizer = pyonmttok.Tokenizer( **src_subword_kwargs, **self.src_other_kwargs ) - tgt_subword_kwargs = self._get_subword_kwargs(side='tgt') - _diff_vocab = ( - src_subword_kwargs.get('vocabulary_path', '') != - tgt_subword_kwargs.get('vocabulary_path', '') or - src_subword_kwargs.get('vocabulary_threshold', 0) != - tgt_subword_kwargs.get('vocabulary_threshold', 0)) + valid_src_tokenizer = pyonmttok.Tokenizer( + **valid_src_subword_kwargs, **self.src_other_kwargs + ) + tgt_subword_kwargs = self._get_subword_kwargs( + side="tgt", is_train=True + ) + _diff_vocab = src_subword_kwargs.get( + "vocabulary_path", "" + ) != tgt_subword_kwargs.get( + "vocabulary_path", "" + ) or src_subword_kwargs.get( + "vocabulary_threshold", 0 + ) != tgt_subword_kwargs.get( + "vocabulary_threshold", 0 + ) if self.share_vocab and not _diff_vocab: self.load_models = { - 'src': src_tokenizer, - 'tgt': src_tokenizer + "src": {"train": src_tokenizer, "valid": valid_src_tokenizer}, + "tgt": {"train": src_tokenizer, "valid": valid_src_tokenizer}, } else: - tgt_subword_kwargs = self._get_subword_kwargs(side='tgt') tgt_tokenizer = pyonmttok.Tokenizer( **tgt_subword_kwargs, **self.tgt_other_kwargs ) + valid_tgt_subword_kwargs = self._get_subword_kwargs( + side="tgt", is_train=False + ) + valid_tgt_tokenizer = pyonmttok.Tokenizer( + **valid_tgt_subword_kwargs, **self.tgt_other_kwargs + ) self.load_models = { - 'src': src_tokenizer, - 'tgt': tgt_tokenizer + "src": {"train": src_tokenizer, "valid": valid_src_tokenizer}, + "tgt": {"train": tgt_tokenizer, "valid": valid_tgt_tokenizer}, } def _tokenize(self, tokens, side='src', is_train=False): """Do OpenNMT Tokenizer's tokenize.""" - tokenizer = self.load_models[side] + tokenizer = self.load_models[side]['train' if is_train else 'valid'] sentence = ' '.join(tokens) segmented, _ = tokenizer.tokenize(sentence) return segmented def apply(self, example, is_train=False, stats=None, **kwargs): """Apply OpenNMT Tokenizer to src & tgt.""" - src_out = self._tokenize(example['src'], 'src') - tgt_out = self._tokenize(example['tgt'], 'tgt') + src_out = self._tokenize(example['src'], 'src', is_train) + if self.model_task == ModelTask.LANGUAGE_MODEL: + tgt_out = src_out + else: + tgt_out = self._tokenize(example['tgt'], 'tgt', is_train) if stats is not None: n_words = len(example['src']) + len(example['tgt']) n_subwords = len(src_out) + len(tgt_out)