From 7e8c7c2a6ac1d28265f27c861fc63ef72c1f0fae Mon Sep 17 00:00:00 2001 From: anderleich Date: Thu, 9 Sep 2021 12:31:03 +0200 Subject: [PATCH] Source features support for V2.0 (#2090) --- .github/workflows/push.yml | 30 ++++++++++ data/data_features/src-test.feat0 | 1 + data/data_features/src-test.txt | 1 + data/data_features/src-train.feat0 | 3 + data/data_features/src-train.txt | 3 + data/data_features/src-val.feat0 | 1 + data/data_features/src-val.txt | 1 + data/data_features/tgt-train.txt | 3 + data/data_features/tgt-val.txt | 1 + data/features_data.yaml | 11 ++++ docs/source/FAQ.md | 70 +++++++++++++++++++++++ onmt/bin/build_vocab.py | 7 ++- onmt/bin/translate.py | 16 +++++- onmt/constants.py | 1 + onmt/inputters/corpus.py | 65 ++++++++++++++++----- onmt/inputters/dataset_base.py | 8 +-- onmt/inputters/fields.py | 13 +++-- onmt/inputters/inputter.py | 12 ++-- onmt/inputters/text_dataset.py | 76 ++++++++++++++++++------- onmt/opts.py | 8 +++ onmt/tests/pull_request_chk.sh | 46 +++++++++++++-- onmt/tests/test_subword_marker.py | 33 ++++++++++- onmt/tests/test_text_dataset.py | 26 ++++++++- onmt/tests/test_transform.py | 22 ++++++++ onmt/transforms/features.py | 90 ++++++++++++++++++++++++++++++ onmt/translate/translator.py | 8 ++- onmt/utils/alignment.py | 42 ++++++++++---- onmt/utils/parse.py | 27 ++++++++- 28 files changed, 549 insertions(+), 76 deletions(-) create mode 100644 data/data_features/src-test.feat0 create mode 100644 data/data_features/src-test.txt create mode 100644 data/data_features/src-train.feat0 create mode 100644 data/data_features/src-train.txt create mode 100644 data/data_features/src-val.feat0 create mode 100644 data/data_features/src-val.txt create mode 100644 data/data_features/tgt-train.txt create mode 100644 data/data_features/tgt-val.txt create mode 100644 data/features_data.yaml create mode 100644 onmt/transforms/features.py diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 9780df63fc..66d892efff 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -42,6 +42,16 @@ jobs: -src_vocab /tmp/onmt.vocab.src \ -tgt_vocab /tmp/onmt.vocab.tgt \ && rm -rf /tmp/sample + - name: Test vocabulary build with features + run: | + python onmt/bin/build_vocab.py \ + -config data/features_data.yaml \ + -save_data /tmp/onmt_feat \ + -src_vocab /tmp/onmt_feat.vocab.src \ + -tgt_vocab /tmp/onmt_feat.vocab.tgt \ + -src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \ + -n_sample -1 \ + && rm -rf /tmp/sample - name: Test field/transform dump run: | # The dumped fields are used later when testing tools @@ -169,6 +179,26 @@ jobs: -state_dim 256 \ -n_steps 10 \ -n_node 64 + - name: Testing training with features + run: | + python onmt/bin/train.py \ + -config data/features_data.yaml \ + -src_vocab /tmp/onmt_feat.vocab.src \ + -tgt_vocab /tmp/onmt_feat.vocab.tgt \ + -src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \ + -src_vocab_size 1000 -tgt_vocab_size 1000 \ + -rnn_size 2 -batch_size 10 \ + -word_vec_size 5 -rnn_size 10 \ + -report_every 5 -train_steps 10 \ + -save_model /tmp/onmt.model \ + -save_checkpoint_steps 10 + - name: Testing translation with features + run: | + python translate.py \ + -model /tmp/onmt.model_step_10.pt \ + -src data/data_features/src-test.txt \ + -src_feats "{'feat0': 'data/data_features/src-test.feat0'}" \ + -verbose - name: Test RNN translation run: | head data/src-test.txt > /tmp/src-test.txt diff --git a/data/data_features/src-test.feat0 b/data/data_features/src-test.feat0 new file mode 100644 index 0000000000..4ab4a9e651 --- /dev/null +++ b/data/data_features/src-test.feat0 @@ -0,0 +1 @@ +C B A B \ No newline at end of file diff --git a/data/data_features/src-test.txt b/data/data_features/src-test.txt new file mode 100644 index 0000000000..0cc723ce39 --- /dev/null +++ b/data/data_features/src-test.txt @@ -0,0 +1 @@ +she is a hard-working. \ No newline at end of file diff --git a/data/data_features/src-train.feat0 b/data/data_features/src-train.feat0 new file mode 100644 index 0000000000..7e189f2c33 --- /dev/null +++ b/data/data_features/src-train.feat0 @@ -0,0 +1,3 @@ +A A A A B A A A C +A B C D E +C B A B \ No newline at end of file diff --git a/data/data_features/src-train.txt b/data/data_features/src-train.txt new file mode 100644 index 0000000000..8a3ec35c2b --- /dev/null +++ b/data/data_features/src-train.txt @@ -0,0 +1,3 @@ +however, according to the logs, she is a hard-working. +however, according to the logs, +she is a hard-working. \ No newline at end of file diff --git a/data/data_features/src-val.feat0 b/data/data_features/src-val.feat0 new file mode 100644 index 0000000000..4ab4a9e651 --- /dev/null +++ b/data/data_features/src-val.feat0 @@ -0,0 +1 @@ +C B A B \ No newline at end of file diff --git a/data/data_features/src-val.txt b/data/data_features/src-val.txt new file mode 100644 index 0000000000..0cc723ce39 --- /dev/null +++ b/data/data_features/src-val.txt @@ -0,0 +1 @@ +she is a hard-working. \ No newline at end of file diff --git a/data/data_features/tgt-train.txt b/data/data_features/tgt-train.txt new file mode 100644 index 0000000000..8a3ec35c2b --- /dev/null +++ b/data/data_features/tgt-train.txt @@ -0,0 +1,3 @@ +however, according to the logs, she is a hard-working. +however, according to the logs, +she is a hard-working. \ No newline at end of file diff --git a/data/data_features/tgt-val.txt b/data/data_features/tgt-val.txt new file mode 100644 index 0000000000..0cc723ce39 --- /dev/null +++ b/data/data_features/tgt-val.txt @@ -0,0 +1 @@ +she is a hard-working. \ No newline at end of file diff --git a/data/features_data.yaml b/data/features_data.yaml new file mode 100644 index 0000000000..fa9b665f9c --- /dev/null +++ b/data/features_data.yaml @@ -0,0 +1,11 @@ +# Corpus opts: +data: + corpus_1: + path_src: data/data_features/src-train.txt + path_tgt: data/data_features/tgt-train.txt + src_feats: + feat0: data/data_features/src-train.feat0 + transforms: [filterfeats, inferfeats] + valid: + path_src: data/data_features/src-val.txt + path_tgt: data/data_features/tgt-val.txt diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index f40fada251..8f618f6c6e 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -477,3 +477,73 @@ Training options to perform vocabulary update are: * `-update_vocab`: set this option * `-reset_optim`: set the value to "states" * `-train_from`: checkpoint path + + +## How can I use source word features? + +Extra information can be added to the words in the source sentences by defining word features. + +Features should be defined in a separate file using blank spaces as a separator and with each row corresponding to a source sentence. An example of the input files: + +data.src +``` +however, according to the logs, she is hard-working. +``` + +feat0.txt +``` +A C C C C A A B +``` + +**Notes** +- Prior tokenization is not necessary, features will be inferred by using the `FeatInferTransform` transform. +- `FilterFeatsTransform` and `FeatInferTransform` are required in order to ensure the functionality. +- Not possible to do shared embeddings (at least with `feat_merge: concat` method) + +Sample config file: + +``` +data: + dummy: + path_src: data/train/data.src + path_tgt: data/train/data.tgt + src_feats: + feat_0: data/train/data.src.feat_0 + feat_1: data/train/data.src.feat_1 + transforms: [filterfeats, onmt_tokenize, inferfeats, filtertoolong] + weight: 1 + valid: + path_src: data/valid/data.src + path_tgt: data/valid/data.tgt + src_feats: + feat_0: data/valid/data.src.feat_0 + feat_1: data/valid/data.src.feat_1 + transforms: [filterfeats, onmt_tokenize, inferfeats] + +# # Vocab opts +src_vocab: exp/data.vocab.src +tgt_vocab: exp/data.vocab.tgt +src_feats_vocab: + feat_0: exp/data.vocab.feat_0 + feat_1: exp/data.vocab.feat_1 +feat_merge: "sum" + +``` + +During inference you can pass features by using the `--src_feats` argument. `src_feats` is expected to be a Python like dict, mapping feature name with its data file. + +``` +{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'} +``` + +**Important note!** During inference, input sentence is expected to be tokenized. Therefore feature inferring should be handled prior to running the translate command. Example: + +```bash +python translate.py -model model_step_10.pt -src ../data.txt.tok -output ../data.out --src_feats "{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}" +``` + +When using the Transformer architecture make sure the following options are appropriately set: + +- `src_word_vec_size` and `tgt_word_vec_size` or `word_vec_size` +- `feat_merge`: how to handle features vecs +- `feat_vec_size` and maybe `feat_vec_exponent` diff --git a/onmt/bin/build_vocab.py b/onmt/bin/build_vocab.py index e106d92180..ed510f09d2 100644 --- a/onmt/bin/build_vocab.py +++ b/onmt/bin/build_vocab.py @@ -32,11 +32,13 @@ def build_vocab_main(opts): transforms = make_transforms(opts, transforms_cls, fields) logger.info(f"Counter vocab from {opts.n_sample} samples.") - src_counter, tgt_counter = build_vocab( + src_counter, tgt_counter, src_feats_counter = build_vocab( opts, transforms, n_sample=opts.n_sample) logger.info(f"Counters src:{len(src_counter)}") logger.info(f"Counters tgt:{len(tgt_counter)}") + for feat_name, feat_counter in src_feats_counter.items(): + logger.info(f"Counters {feat_name}:{len(feat_counter)}") def save_counter(counter, save_path): check_path(save_path, exist_ok=opts.overwrite, log=logger.warning) @@ -52,6 +54,9 @@ def save_counter(counter, save_path): else: save_counter(src_counter, opts.src_vocab) save_counter(tgt_counter, opts.tgt_vocab) + + for k, v in src_feats_counter.items(): + save_counter(v, opts.src_feats_vocab[k]) def _get_parser(): diff --git a/onmt/bin/translate.py b/onmt/bin/translate.py index 0b5434f89a..4e3e126ae2 100755 --- a/onmt/bin/translate.py +++ b/onmt/bin/translate.py @@ -6,6 +6,7 @@ import onmt.opts as opts from onmt.utils.parse import ArgumentParser +from collections import defaultdict def translate(opt): @@ -15,12 +16,21 @@ def translate(opt): translator = build_translator(opt, logger=logger, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) - shard_pairs = zip(src_shards, tgt_shards) - - for i, (src_shard, tgt_shard) in enumerate(shard_pairs): + features_shards = [] + features_names = [] + for feat_name, feat_path in opt.src_feats.items(): + features_shards.append(split_corpus(feat_path, opt.shard_size)) + features_names.append(feat_name) + shard_pairs = zip(src_shards, tgt_shards, *features_shards) + + for i, (src_shard, tgt_shard, *features_shard) in enumerate(shard_pairs): + features_shard_ = defaultdict(list) + for j, x in enumerate(features_shard): + features_shard_[features_names[j]] = x logger.info("Translating shard %d." % i) translator.translate( src=src_shard, + src_feats=features_shard_, tgt=tgt_shard, batch_size=opt.batch_size, batch_type=opt.batch_type, diff --git a/onmt/constants.py b/onmt/constants.py index fb6afb0252..2d5864137b 100644 --- a/onmt/constants.py +++ b/onmt/constants.py @@ -22,6 +22,7 @@ class CorpusName(object): class SubwordMarker(object): SPACER = '▁' JOINER = '■' + CASE_MARKUP = ["⦅mrk_case_modifier_C⦆", "⦅mrk_begin_case_region_U⦆", "⦅mrk_end_case_region_U⦆"] class ModelTask(object): diff --git a/onmt/inputters/corpus.py b/onmt/inputters/corpus.py index c8a559f9f8..87da65139b 100644 --- a/onmt/inputters/corpus.py +++ b/onmt/inputters/corpus.py @@ -7,10 +7,11 @@ from torchtext.data import Dataset as TorchtextDataset, \ Example as TorchtextExample -from collections import Counter +from collections import Counter, defaultdict from contextlib import contextmanager import multiprocessing as mp +from collections import defaultdict @contextmanager @@ -70,10 +71,20 @@ def _process(item, is_train): example, is_train=is_train, corpus_name=cid) if maybe_example is None: return None - maybe_example['src'] = ' '.join(maybe_example['src']) - maybe_example['tgt'] = ' '.join(maybe_example['tgt']) + + maybe_example['src'] = {"src": ' '.join(maybe_example['src'])} + + # Make features part of src as in TextMultiField + # {'src': {'src': ..., 'feat1': ...., 'feat2': ....}} + if 'src_feats' in maybe_example: + for feat_name, feat_value in maybe_example['src_feats'].items(): + maybe_example['src'][feat_name] = ' '.join(feat_value) + del maybe_example["src_feats"] + + maybe_example['tgt'] = {"tgt": ' '.join(maybe_example['tgt'])} if 'align' in maybe_example: maybe_example['align'] = ' '.join(maybe_example['align']) + return maybe_example def _maybe_add_dynamic_dict(self, example, fields): @@ -107,12 +118,13 @@ def __call__(self, bucket): class ParallelCorpus(object): """A parallel corpus file pair that can be loaded to iterate.""" - def __init__(self, name, src, tgt, align=None): + def __init__(self, name, src, tgt, align=None, src_feats=None): """Initialize src & tgt side file path.""" self.id = name self.src = src self.tgt = tgt self.align = align + self.src_feats = src_feats def load(self, offset=0, stride=1): """ @@ -120,10 +132,18 @@ def load(self, offset=0, stride=1): `offset` and `stride` allow to iterate only on every `stride` example, starting from `offset`. """ + if self.src_feats: + features_names = [] + features_files = [] + for feat_name, feat_path in self.src_feats.items(): + features_names.append(feat_name) + features_files.append(open(feat_path, mode='rb')) + else: + features_files = [] with exfile_open(self.src, mode='rb') as fs,\ exfile_open(self.tgt, mode='rb') as ft,\ exfile_open(self.align, mode='rb') as fa: - for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)): + for i, (sline, tline, align, *features) in enumerate(zip(fs, ft, fa, *features_files)): if (i % stride) == offset: sline = sline.decode('utf-8') tline = tline.decode('utf-8') @@ -133,12 +153,18 @@ def load(self, offset=0, stride=1): } if align is not None: example['align'] = align.decode('utf-8') + if features: + example["src_feats"] = dict() + for j, feat in enumerate(features): + example["src_feats"][features_names[j]] = feat.decode("utf-8") yield example + for f in features_files: + f.close() def __str__(self): cls_name = type(self).__name__ - return '{}({}, {}, align={})'.format( - cls_name, self.src, self.tgt, self.align) + return '{}({}, {}, align={}, src_feats={})'.format( + cls_name, self.src, self.tgt, self.align, self.src_feats) def get_corpora(opts, is_train=False): @@ -150,14 +176,16 @@ def get_corpora(opts, is_train=False): corpus_id, corpus_dict["path_src"], corpus_dict["path_tgt"], - corpus_dict["path_align"]) + corpus_dict["path_align"], + corpus_dict["src_feats"]) else: if CorpusName.VALID in opts.data.keys(): corpora_dict[CorpusName.VALID] = ParallelCorpus( CorpusName.VALID, opts.data[CorpusName.VALID]["path_src"], opts.data[CorpusName.VALID]["path_tgt"], - opts.data[CorpusName.VALID]["path_align"]) + opts.data[CorpusName.VALID]["path_align"], + opts.data[CorpusName.VALID]["src_feats"]) else: return None return corpora_dict @@ -193,6 +221,9 @@ def _tokenize(self, stream): example['src'], example['tgt'] = src, tgt if 'align' in example: example['align'] = example['align'].strip('\n').split() + if 'src_feats' in example: + for k in example['src_feats'].keys(): + example['src_feats'][k] = example['src_feats'][k].strip('\n').split() yield example def _transform(self, stream): @@ -286,6 +317,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): """Build vocab on (strided) subpart of the data.""" sub_counter_src = Counter() sub_counter_tgt = Counter() + sub_counter_src_feats = defaultdict(Counter) datasets_iterables = build_corpora_iters( corpora, transforms, opts.data, skip_empty_level=opts.skip_empty_level, @@ -297,7 +329,10 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): if opts.dump_samples: build_sub_vocab.queues[c_name][offset].put("blank") continue - src_line, tgt_line = maybe_example['src'], maybe_example['tgt'] + src_line, tgt_line = maybe_example['src']['src'], maybe_example['tgt']['tgt'] + for feat_name, feat_line in maybe_example["src"].items(): + if feat_name != "src": + sub_counter_src_feats[feat_name].update(feat_line.split(' ')) sub_counter_src.update(src_line.split(' ')) sub_counter_tgt.update(tgt_line.split(' ')) if opts.dump_samples: @@ -309,7 +344,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): break if opts.dump_samples: build_sub_vocab.queues[c_name][offset].put("break") - return sub_counter_src, sub_counter_tgt + return sub_counter_src, sub_counter_tgt, sub_counter_src_feats def init_pool(queues): @@ -333,6 +368,7 @@ def build_vocab(opts, transforms, n_sample=3): corpora = get_corpora(opts, is_train=True) counter_src = Counter() counter_tgt = Counter() + counter_src_feats = defaultdict(Counter) from functools import partial queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size) for i in range(opts.num_threads)] @@ -349,13 +385,14 @@ def build_vocab(opts, transforms, n_sample=3): func = partial( build_sub_vocab, corpora, transforms, opts, n_sample, opts.num_threads) - for sub_counter_src, sub_counter_tgt in p.imap( + for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap( func, range(0, opts.num_threads)): counter_src.update(sub_counter_src) counter_tgt.update(sub_counter_tgt) + counter_src_feats.update(sub_counter_src_feats) if opts.dump_samples: write_process.join() - return counter_src, counter_tgt + return counter_src, counter_tgt, counter_src_feats def save_transformed_sample(opts, transforms, n_sample=3): @@ -387,7 +424,7 @@ def save_transformed_sample(opts, transforms, n_sample=3): maybe_example = DatasetAdapter._process(item, is_train=True) if maybe_example is None: continue - src_line, tgt_line = maybe_example['src'], maybe_example['tgt'] + src_line, tgt_line = maybe_example['src']['src'], maybe_example['tgt']['tgt'] f_src.write(src_line + '\n') f_tgt.write(tgt_line + '\n') if n_sample > 0 and i >= n_sample: diff --git a/onmt/inputters/dataset_base.py b/onmt/inputters/dataset_base.py index aeec428aaf..65322d9a4c 100644 --- a/onmt/inputters/dataset_base.py +++ b/onmt/inputters/dataset_base.py @@ -41,7 +41,7 @@ def _dynamic_dict(example, src_field, tgt_field): ``example``, changed as described. """ - src = src_field.tokenize(example["src"]) + src = src_field.tokenize(example["src"]["src"]) # make a small vocab containing just the tokens in the source sequence unk = src_field.unk_token pad = src_field.pad_token @@ -60,7 +60,7 @@ def _dynamic_dict(example, src_field, tgt_field): example["src_ex_vocab"] = src_ex_vocab if "tgt" in example: - tgt = tgt_field.tokenize(example["tgt"]) + tgt = tgt_field.tokenize(example["tgt"]["tgt"]) mask = torch.LongTensor( [unk_idx] + [src_ex_vocab.stoi[w] for w in tgt] + [unk_idx]) example["alignment"] = mask @@ -116,7 +116,7 @@ def __init__(self, fields, readers, data, sort_key, filter_pred=None): self.sort_key = sort_key can_copy = 'src_map' in fields and 'alignment' in fields - read_iters = [r.read(dat[1], dat[0]) for r, dat in zip(readers, data)] + read_iters = [r.read(dat, name, feats) for r, (name, dat, feats) in zip(readers, data)] # self.src_vocabs is used in collapse_copy_scores and Translator.py self.src_vocabs = [] @@ -162,5 +162,5 @@ def config(fields): for name, field in fields: if field["data"] is not None: readers.append(field["reader"]) - data.append((name, field["data"])) + data.append((name, field["data"], field["features"])) return readers, data diff --git a/onmt/inputters/fields.py b/onmt/inputters/fields.py index 50c4e6c17f..5f41a3a01f 100644 --- a/onmt/inputters/fields.py +++ b/onmt/inputters/fields.py @@ -8,11 +8,10 @@ def _get_dynamic_fields(opts): - # NOTE: not support nfeats > 0 yet - src_nfeats = 0 - tgt_nfeats = 0 + # NOTE: not support tgt feats yet + tgt_feats = None with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0 - fields = get_fields('text', src_nfeats, tgt_nfeats, + fields = get_fields('text', opts.src_feats_vocab, tgt_feats, dynamic_dict=opts.copy_attn, src_truncate=opts.src_seq_length_trunc, tgt_truncate=opts.tgt_seq_length_trunc, @@ -33,6 +32,12 @@ def build_dynamic_fields(opts, src_specials=None, tgt_specials=None): opts.src_vocab, 'src', counters, min_freq=opts.src_words_min_frequency) + if opts.src_feats_vocab: + for feat_name, filepath in opts.src_feats_vocab.items(): + _, _ = _load_vocab( + filepath, feat_name, counters, + min_freq=0) + if opts.tgt_vocab: _tgt_vocab, _tgt_vocab_size = _load_vocab( opts.tgt_vocab, 'tgt', counters, diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index f6b5c747d0..ffd8c77fb1 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -111,8 +111,8 @@ def get_task_spec_tokens(data_task, pad, bos, eos): def get_fields( src_data_type, - n_src_feats, - n_tgt_feats, + src_feats, + tgt_feats, pad=DefaultTokens.PAD, bos=DefaultTokens.BOS, eos=DefaultTokens.EOS, @@ -125,11 +125,11 @@ def get_fields( """ Args: src_data_type: type of the source input. Options are [text]. - n_src_feats (int): the number of source features (not counting tokens) + src_feats (Optional[Dict]): source features dict containing their names to create a :class:`torchtext.data.Field` for. (If ``src_data_type=="text"``, these fields are stored together as a ``TextMultiField``). - n_tgt_feats (int): See above. + tgt_feats (Optional[Dict]): See above. pad (str): Special pad symbol. Used on src and tgt side. bos (str): Special beginning of sequence symbol. Only relevant for tgt. @@ -158,7 +158,7 @@ def get_fields( task_spec_tokens = get_task_spec_tokens(data_task, pad, bos, eos) src_field_kwargs = { - "n_feats": n_src_feats, + "feats": src_feats, "include_lengths": True, "pad": task_spec_tokens["src"]["pad"], "bos": task_spec_tokens["src"]["bos"], @@ -169,7 +169,7 @@ def get_fields( fields["src"] = fields_getters[src_data_type](**src_field_kwargs) tgt_field_kwargs = { - "n_feats": n_tgt_feats, + "feats": tgt_feats, "include_lengths": False, "pad": task_spec_tokens["tgt"]["pad"], "bos": task_spec_tokens["tgt"]["bos"], diff --git a/onmt/inputters/text_dataset.py b/onmt/inputters/text_dataset.py index a0621f6407..a55d2593b2 100644 --- a/onmt/inputters/text_dataset.py +++ b/onmt/inputters/text_dataset.py @@ -9,7 +9,7 @@ class TextDataReader(DataReaderBase): - def read(self, sequences, side): + def read(self, sequences, side, features={}): """Read text data from disk. Args: @@ -17,6 +17,9 @@ def read(self, sequences, side): path to text file or iterable of the actual text data. side (str): Prefix used in return dict. Usually ``"src"`` or ``"tgt"``. + features: (Dict[str or Iterable[str]]): + dictionary mapping feature names with the path to feature + file or iterable of the actual feature data. Yields: dictionaries whose keys are the names of fields and whose @@ -25,10 +28,25 @@ def read(self, sequences, side): """ if isinstance(sequences, str): sequences = DataReaderBase._read_file(sequences) - for i, seq in enumerate(sequences): + + features_names = [] + features_values = [] + for feat_name, v in features.items(): + features_names.append(feat_name) + if isinstance(v, str): + features_values.append(DataReaderBase._read_file(features)) + else: + features_values.append(v) + for i, (seq, *feats) in enumerate(zip(sequences, *features_values)): + ex_dict = {} if isinstance(seq, bytes): seq = seq.decode("utf-8") - yield {side: seq, "indices": i} + ex_dict[side] = seq + for i, f in enumerate(feats): + if isinstance(f, bytes): + f = f.decode("utf-8") + ex_dict[features_names[i]] = f + yield {side: ex_dict, "indices": i} def text_sort_key(ex): @@ -38,6 +56,7 @@ def text_sort_key(ex): return len(ex.src[0]) +# Legacy function. Currently it only truncates input if truncate is set. # mix this with partial def _feature_tokenize( string, layer=0, tok_delim=None, feat_delim=None, truncate=None): @@ -140,8 +159,7 @@ def preprocess(self, x): lists of tokens/feature tags for the sentence. The output is ordered like ``self.fields``. """ - - return [f.preprocess(x) for _, f in self.fields] + return [f.preprocess(x[fn]) for fn, f in self.fields] def __getitem__(self, item): return self.fields[item] @@ -152,7 +170,7 @@ def text_fields(**kwargs): Args: base_name (str): Name associated with the field. - n_feats (int): Number of word level feats (not counting the tokens) + feats (Optional[Dict]): Word level feats include_lengths (bool): Optionally return the sequence lengths. pad (str, optional): Defaults to ``""``. bos (str or NoneType, optional): Defaults to ``""``. @@ -163,7 +181,7 @@ def text_fields(**kwargs): TextMultiField """ - n_feats = kwargs["n_feats"] + feats = kwargs["feats"] include_lengths = kwargs["include_lengths"] base_name = kwargs["base_name"] pad = kwargs.get("pad", DefaultTokens.PAD) @@ -171,20 +189,36 @@ def text_fields(**kwargs): eos = kwargs.get("eos", DefaultTokens.EOS) truncate = kwargs.get("truncate", None) fields_ = [] - feat_delim = u"│" if n_feats > 0 else None - for i in range(n_feats + 1): - name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name - tokenize = partial( - _feature_tokenize, - layer=i, - truncate=truncate, - feat_delim=feat_delim) - use_len = i == 0 and include_lengths - feat = Field( - init_token=bos, eos_token=eos, - pad_token=pad, tokenize=tokenize, - include_lengths=use_len) - fields_.append((name, feat)) + + feat_delim = None #u"│" if n_feats > 0 else None + + # Base field + tokenize = partial( + _feature_tokenize, + layer=None, + truncate=truncate, + feat_delim=feat_delim) + feat = Field( + init_token=bos, eos_token=eos, + pad_token=pad, tokenize=tokenize, + include_lengths=include_lengths) + fields_.append((base_name, feat)) + + # Feats fields + if feats: + for feat_name in feats.keys(): + # Legacy function, it is not really necessary + tokenize = partial( + _feature_tokenize, + layer=None, + truncate=truncate, + feat_delim=feat_delim) + feat = Field( + init_token=bos, eos_token=eos, + pad_token=pad, tokenize=tokenize, + include_lengths=False) + fields_.append((feat_name, feat)) + assert fields_[0][0] == base_name # sanity check field = TextMultiField(fields_[0][0], fields_[0][1], fields_[1:]) return field diff --git a/onmt/opts.py b/onmt/opts.py index ec66f14e95..4c37ab952d 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -132,6 +132,11 @@ def _add_dynamic_fields_opts(parser, build_vocab_only=False): group.add("-share_vocab", "--share_vocab", action="store_true", help="Share source and target vocabulary.") + group.add("-src_feats_vocab", "--src_feats_vocab", + help=("List of paths to save" if build_vocab_only else "List of paths to") + + " src features vocabulary files. " + "Files format: one or \t per line.") + if not build_vocab_only: group.add("-src_vocab_size", "--src_vocab_size", type=int, default=50000, @@ -755,6 +760,9 @@ def translate_opts(parser): group.add('--src', '-src', required=True, help="Source sequence to decode (one line per " "sequence)") + group.add("-src_feats", "--src_feats", required=False, + help="Source sequence features (dict format). " + "Ex: {'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}") group.add('--tgt', '-tgt', help='True target sequence (optional)') group.add('--tgt_prefix', '-tgt_prefix', action='store_true', diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh index b282cc7f1e..70cd76823a 100755 --- a/onmt/tests/pull_request_chk.sh +++ b/onmt/tests/pull_request_chk.sh @@ -67,10 +67,22 @@ PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH} ${PYTHON} onmt/bin/build_vocab.py \ -save_data $TMP_OUT_DIR/onmt \ -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ -tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \ - -n_sample 5000 >> ${LOG_FILE} 2>&1 + -n_sample 5000 -overwrite >> ${LOG_FILE} 2>&1 [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} -rm -r $TMP_OUT_DIR/sample +rm -f -r $TMP_OUT_DIR/sample + +echo -n "[+] Testing vocabulary building with features..." +PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH} ${PYTHON} onmt/bin/build_vocab.py \ + -config ${DATA_DIR}/features_data.yaml \ + -save_data $TMP_OUT_DIR/onmt_feat \ + -src_vocab $TMP_OUT_DIR/onmt_feat.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/onmt_feat.vocab.tgt \ + -src_feats_vocab '{"feat0": "${TMP_OUT_DIR}/onmt_feat.vocab.feat0"}' \ + -n_sample -1 -overwrite>> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} +rm -f -r $TMP_OUT_DIR/sample # # Training test @@ -254,8 +266,24 @@ ${PYTHON} onmt/bin/train.py \ [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} -rm $TMP_OUT_DIR/onmt.vocab* -rm $TMP_OUT_DIR/onmt.model* +echo -n " [+] Testing training with features..." +${PYTHON} onmt/bin/train.py \ + -config ${DATA_DIR}/features_data.yaml \ + -src_vocab $TMP_OUT_DIR/onmt_feat.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/onmt_feat.vocab.tgt \ + -src_feats_vocab '{"feat0": "${TMP_OUT_DIR}/onmt_feat.vocab.feat0"}' \ + -src_vocab_size 1000 -tgt_vocab_size 1000 \ + -rnn_size 2 -batch_size 10 \ + -word_vec_size 5 -rnn_size 10 \ + -report_every 5 -train_steps 10 \ + -save_model $TMP_OUT_DIR/onmt.features.model \ + -save_checkpoint_steps 10 >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} + +rm -f $TMP_OUT_DIR/onmt.vocab* +rm -f $TMP_OUT_DIR/onmt.model* +rm -f $TMP_OUT_DIR/onmt_feat.vocab.* # # Translation test @@ -269,6 +297,16 @@ ${PYTHON} translate.py -model ${TEST_DIR}/test_model.pt -src $TMP_OUT_DIR/src-te echo "Succeeded" | tee -a ${LOG_FILE} rm $TMP_OUT_DIR/src-test.txt +echo -n " [+] Testing NMT translation with features..." +${PYTHON} translate.py \ + -model ${TMP_OUT_DIR}/onmt.features.model_step_10.pt \ + -src ${DATA_DIR}/data_features/src-test.txt \ + -src_feats "{'feat0': '${DATA_DIR}/data_features/src-test.feat0'}" \ + -verbose >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} +rm -f $TMP_OUT_DIR/onmt.features.model* + echo -n " [+] Testing NMT ensemble translation..." head ${DATA_DIR}/src-test.txt > $TMP_OUT_DIR/src-test.txt ${PYTHON} translate.py -model ${TEST_DIR}/test_model.pt ${TEST_DIR}/test_model.pt \ diff --git a/onmt/tests/test_subword_marker.py b/onmt/tests/test_subword_marker.py index e827d52ffa..1b8337b56e 100644 --- a/onmt/tests/test_subword_marker.py +++ b/onmt/tests/test_subword_marker.py @@ -2,6 +2,7 @@ from onmt.transforms.bart import word_start_finder from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer +from onmt.constants import DefaultTokens, SubwordMarker class TestWordStartFinder(unittest.TestCase): @@ -37,7 +38,25 @@ class TestSubwordGroup(unittest.TestCase): def test_subword_group_joiner(self): data_in = ['however', '■,', 'according', 'to', 'the', 'logs', '■,', 'she', 'is', 'hard', '■-■', 'working', '■.'] # noqa: E501 true_out = [0, 0, 1, 2, 3, 4, 4, 5, 6, 7, 7, 7, 7] - out = subword_map_by_joiner(data_in) + out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP) + self.assertEqual(out, true_out) + + def test_subword_group_joiner_with_case_markup(self): + data_in = ['⦅mrk_case_modifier_C⦆', 'however', '■,', 'according', 'to', 'the', 'logs', '■,', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■-■', 'working', '■.', '⦅mrk_end_case_region_U⦆'] # noqa: E501 + true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7] + out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP) + self.assertEqual(out, true_out) + + def test_subword_group_joiner_with_new_joiner(self): + data_in = ['⦅mrk_case_modifier_C⦆', 'however', '■', ',', 'according', 'to', 'the', 'logs', '■', ',', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■', '-', '■', 'working', '■', '.', '⦅mrk_end_case_region_U⦆'] # noqa: E501 + true_out = [0, 0, 0, 0, 1, 2, 3, 4, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7] + out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP) + self.assertEqual(out, true_out) + + def test_subword_group_naive(self): + data_in = ['however', ',', 'according', 'to', 'the', 'logs', ',', 'she', 'is', 'hard', '-', 'working', '.'] # noqa: E501 + true_out = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP) self.assertEqual(out, true_out) def test_subword_group_spacer(self): @@ -50,6 +69,18 @@ def test_subword_group_spacer(self): no_dummy_out = subword_map_by_spacer(no_dummy) self.assertEqual(no_dummy_out, true_out) + def test_subword_group_spacer_with_case_markup(self): + data_in = ['⦅mrk_case_modifier_C⦆', '▁however', ',', '▁according', '▁to', '▁the', '▁logs', ',', '▁⦅mrk_begin_case_region_U⦆', '▁she', '▁is', '▁hard', '-', 'working', '.', '▁⦅mrk_end_case_region_U⦆'] # noqa: E501 + true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7] + out = subword_map_by_spacer(data_in) + self.assertEqual(out, true_out) + + def test_subword_group_spacer_with_spacer_new(self): + data_in = ['⦅mrk_case_modifier_C⦆', '▁', 'however', ',', '▁', 'according', '▁', 'to', '▁', 'the', '▁', 'logs', ',', '▁', '⦅mrk_begin_case_region_U⦆', '▁', 'she', '▁', 'is', '▁', 'hard', '-', 'working', '.', '▁', '⦅mrk_end_case_region_U⦆'] # noqa: E501 + true_out = [0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7] + out = subword_map_by_spacer(data_in) + self.assertEqual(out, true_out) + if __name__ == '__main__': unittest.main() diff --git a/onmt/tests/test_text_dataset.py b/onmt/tests/test_text_dataset.py index e4d22e9c0a..4477bca7fe 100644 --- a/onmt/tests/test_text_dataset.py +++ b/onmt/tests/test_text_dataset.py @@ -79,7 +79,8 @@ def test_preprocess_shape(self): self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) - sample_str = "dummy input here ." + + sample_str = {"base_field": "dummy input here .", "a": "A A B D", "r": "C C C C", "b": "D F E D", "zbase_field": "another dummy input ."} proc = mf.preprocess(sample_str) self.assertEqual(len(proc), len(init_case["feats_fields"]) + 1) @@ -147,7 +148,7 @@ def test_read(self): ] rdr = TextDataReader() for i, ex in enumerate(rdr.read(strings, "src")): - self.assertEqual(ex["src"], strings[i].decode("utf-8")) + self.assertEqual(ex["src"], {"src": strings[i].decode("utf-8")}) class TestTextDataReaderFromFS(unittest.TestCase): @@ -174,4 +175,23 @@ def tearDownClass(cls): def test_read(self): rdr = TextDataReader() for i, ex in enumerate(rdr.read(self.FILE_NAME, "src")): - self.assertEqual(ex["src"], self.STRINGS[i].decode("utf-8")) + self.assertEqual(ex["src"], {"src": self.STRINGS[i].decode("utf-8")}) + +class TestTextDataReaderWithFeatures(unittest.TestCase): + def test_read(self): + strings = [ + "hello world".encode("utf-8"), + "this's a string with punctuation .".encode("utf-8"), + "ThIs Is A sTrInG wItH oDD CapitALIZAtion".encode("utf-8") + ] + features = { + "feat_0": [ + "A A".encode("utf-8"), + "A A B B C".encode("utf-8"), + "A A D D E E".encode("utf-8") + ] + } + + rdr = TextDataReader() + for i, ex in enumerate(rdr.read(strings, "src", features)): + self.assertEqual(ex["src"], {"src": strings[i].decode("utf-8"), "feat_0": features["feat_0"][i].decode("utf-8")}) \ No newline at end of file diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index 4bfa8be3bc..d99bc607de 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -509,3 +509,25 @@ def test_span_infilling(self): # n_masked = math.ceil(n_words * bart_noise.mask_ratio) # print(f"Text Span Infilling: {infillied} / {tokens}") # print(n_words, n_masked) + +class TestFeaturesTransform(unittest.TestCase): + def test_inferfeats(self): + inferfeats_cls = get_transforms_cls(["inferfeats"])["inferfeats"] + opt = Namespace(reversible_tokenization="joiner") + inferfeats_transform = inferfeats_cls(opt) + + ex_in = { + "src": ['however', '■,', 'according', 'to', 'the', 'logs', '■,', 'she', 'is', 'hard', '■-■', 'working', '■.'], + "tgt": ['however', '■,', 'according', 'to', 'the', 'logs', '■,', 'she', 'is', 'hard', '■-■', 'working', '■.'] + } + ex_out = inferfeats_transform.apply(ex_in) + self.assertIs(ex_out, ex_in) + + ex_in["src_feats"] = {"feat_0": ["A", "A", "A", "A", "B", "A", "A", "C"]} + ex_out = inferfeats_transform.apply(ex_in) + self.assertEqual(ex_out["src_feats"]["feat_0"], ["A", "", "A", "A", "A", "B", "", "A", "A", "C", "", "C", ""]) + + ex_in["src"] = ['⦅mrk_case_modifier_C⦆', 'however', '■,', 'according', 'to', 'the', 'logs', '■,', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■-■', 'working', '■.', '⦅mrk_end_case_region_U⦆'] + ex_in["src_feats"] = {"feat_0": ["A", "A", "A", "A", "B", "A", "A", "C"]} + ex_out = inferfeats_transform.apply(ex_in) + self.assertEqual(ex_out["src_feats"]["feat_0"], ["", "A", "", "A", "A", "A", "B", "", "", "A", "A", "C", "", "C", "", ""]) diff --git a/onmt/transforms/features.py b/onmt/transforms/features.py new file mode 100644 index 0000000000..24f02e30fe --- /dev/null +++ b/onmt/transforms/features.py @@ -0,0 +1,90 @@ +from onmt.utils.logging import logger +from onmt.transforms import register_transform +from .transform import Transform, ObservableStats +from onmt.constants import DefaultTokens, SubwordMarker +from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer +import re +from collections import defaultdict + + +@register_transform(name='filterfeats') +class FilterFeatsTransform(Transform): + """Filter out examples with a mismatch between source and features.""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + pass + + def _parse_opts(self): + pass + + def apply(self, example, is_train=False, stats=None, **kwargs): + """Return None if mismatch""" + + if 'src_feats' not in example: + # Do nothing + return example + + for feat_name, feat_values in example['src_feats'].items(): + if len(example['src']) != len(feat_values): + logger.warning(f"Skipping example due to mismatch between source and feature {feat_name}") + return None + return example + + def _repr_args(self): + return '' + + +@register_transform(name='inferfeats') +class InferFeatsTransform(Transform): + """Infer features for subword tokenization.""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + """Avalilable options related to this Transform.""" + group = parser.add_argument_group("Transform/InferFeats") + group.add("--reversible_tokenization", "-reversible_tokenization", default="joiner", + choices=["joiner", "spacer"], help="Type of reversible tokenization applied on the tokenizer.") + + def _parse_opts(self): + super()._parse_opts() + self.reversible_tokenization = self.opts.reversible_tokenization + + def apply(self, example, is_train=False, stats=None, **kwargs): + + if "src_feats" not in example: + # Do nothing + return example + + if self.reversible_tokenization == "joiner": + word_to_subword_mapping = subword_map_by_joiner(example["src"]) + else: #Spacer + word_to_subword_mapping = subword_map_by_spacer(example["src"]) + + inferred_feats = defaultdict(list) + for subword, word_id in zip(example["src"], word_to_subword_mapping): + for feat_name, feat_values in example["src_feats"].items(): + # If case markup placeholder + if subword in SubwordMarker.CASE_MARKUP: + inferred_feat = "" + # Punctuation only (assumes joiner is also some punctuation token) + elif not re.sub(r'(\W)+', '', subword).strip(): + inferred_feat = "" + else: + inferred_feat = feat_values[word_id] + + inferred_feats[feat_name].append(inferred_feat) + + for feat_name, feat_values in inferred_feats.items(): + example["src_feats"][feat_name] = inferred_feats[feat_name] + + return example + + def _repr_args(self): + return '' \ No newline at end of file diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 9bdd4ee4c7..4d37e982fa 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -333,6 +333,7 @@ def _gold_score( def translate( self, src, + src_feats={}, tgt=None, batch_size=None, batch_type="sents", @@ -345,6 +346,7 @@ def translate( Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. + src_feats: See :func`self.src_reader.read()`. batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging @@ -363,8 +365,8 @@ def translate( if self.tgt_prefix and tgt is None: raise ValueError("Prefix should be feed to tgt if -tgt_prefix.") - src_data = {"reader": self.src_reader, "data": src} - tgt_data = {"reader": self.tgt_reader, "data": tgt} + src_data = {"reader": self.src_reader, "data": src, "features": src_feats} + tgt_data = {"reader": self.tgt_reader, "data": tgt, "features": {}} _readers, _data = inputters.Dataset.config( [("src", src_data), ("tgt", tgt_data)] ) @@ -925,6 +927,7 @@ def _align_forward(self, batch, predictions): def translate( self, src, + src_feats={}, tgt=None, batch_size=None, batch_type="sents", @@ -945,6 +948,7 @@ def translate( return super(GeneratorLM, self).translate( src, + src_feats, tgt, batch_size=1, batch_type=batch_type, diff --git a/onmt/utils/alignment.py b/onmt/utils/alignment.py index 0a70edb33e..d775cf920c 100644 --- a/onmt/utils/alignment.py +++ b/onmt/utils/alignment.py @@ -120,25 +120,43 @@ def to_word_align(src, tgt, subword_align, m_src='joiner', m_tgt='joiner'): return " ".join(word_align) -def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER): +def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP): """Return word id for each subword token (annotate by joiner).""" - flags = [0] * len(subwords) + flags = [1] * len(subwords) for i, tok in enumerate(subwords): - if tok.endswith(marker): - flags[i] = 1 - if tok.startswith(marker): - assert i >= 1 and flags[i-1] != 1, \ + if tok.endswith(marker) or (tok in case_markup and tok.find("end")<0): + flags[i] = 0 + if tok.startswith(marker) or (tok in case_markup and tok.find("end")>=0): + assert i >= 1 and flags[i-1] != 0, \ "Sentence `{}` not correct!".format(" ".join(subwords)) - flags[i-1] = 1 - marker_acc = list(accumulate([0] + flags[:-1])) - word_group = [(i - maker_sofar) for i, maker_sofar - in enumerate(marker_acc)] + flags[i-1] = 0 + word_group = list(accumulate([0] + flags[:-1])) return word_group -def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER): +def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER, case_markup=SubwordMarker.CASE_MARKUP): """Return word id for each subword token (annotate by spacer).""" - word_group = list(accumulate([int(marker in x) for x in subwords])) + flags = [0] * len(subwords) + for i, tok in enumerate(subwords): + if marker in tok: + if tok.replace(marker, "") in case_markup: + if i < len(subwords)-1: + flags[i] = 1 + else: + if i > 0: + previous = subwords[i-1].replace(marker, "") + if previous not in case_markup: + flags[i] = 1 + + # In case there is a final case_markup when new_spacer is on + for i in range(1,len(subwords)-1): + if subwords[-i] in case_markup: + flags[-i] = 0 + elif subwords[-i] == marker: + flags[-i] = 0 + break + + word_group = list(accumulate(flags)) if word_group[0] == 1: # when dummy prefix is set word_group = [item - 1 for item in word_group] return word_group diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 2f4f1e1c45..4a12a5fe4d 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -75,6 +75,19 @@ def _validate_data(cls, opt): logger.warning(f"Corpus {cname}'s weight should be given." " We default it to 1 for you.") corpus['weight'] = 1 + + # Check features + src_feats = corpus.get("src_feats", None) + if src_feats is not None: + for feature_name, feature_file in src_feats.items(): + cls._validate_file(feature_file, info=f'{cname}/path_{feature_name}') + if 'inferfeats' not in corpus["transforms"]: + raise ValueError(f"'inferfeats' transform is required when setting source features") + if 'filterfeats' not in corpus["transforms"]: + raise ValueError(f"'filterfeats' transform is required when setting source features") + else: + corpus["src_feats"] = None + logger.info(f"Parsed {len(corpora)} corpora from -data.") opt.data = corpora @@ -107,6 +120,18 @@ def _get_all_transform(cls, opt): @classmethod def _validate_fields_opts(cls, opt, build_vocab_only=False): """Check options relate to vocab and fields.""" + + for cname, corpus in opt.data.items(): + if cname != CorpusName.VALID and corpus["src_feats"] is not None: + assert opt.src_feats_vocab, \ + "-src_feats_vocab is required if using source features." + import yaml + opt.src_feats_vocab = yaml.safe_load(opt.src_feats_vocab) + + for feature in corpus["src_feats"].keys(): + assert feature in opt.src_feats_vocab, \ + f"No vocab file set for feature {feature}" + if build_vocab_only: if not opt.share_vocab: assert opt.tgt_vocab, \ @@ -295,4 +320,4 @@ def validate_train_opts(cls, opt): @classmethod def validate_translate_opts(cls, opt): - pass + opt.src_feats = eval(opt.src_feats) if opt.src_feats else {}