diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 7ad590a19b..03df789f7e 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -124,9 +124,9 @@ def unk_string(self, escape=False): else: return self.unk_word - def add_symbol(self, word, n=1, overwrite=False): + def add_symbol(self, word, n=1, overwrite=True): """Adds a word to the dictionary""" - if word in self.indices and not overwrite: + if word in self.indices and overwrite: idx = self.indices[word] self.count[idx] = self.count[idx] + n return idx @@ -218,11 +218,17 @@ def unk(self): def load(cls, f, add_special_symbols=True): """Loads the dictionary from a text file with the format: - ``` - - - ... - ``` + Example:: + ``` + [] + [] + ... + ``` + + Note: + Possible flags are `#fairseq:overwrite` to overwrite duplicates + and `#fairseq:duplicate` to keep them (for backward compatibility + after bug fix) """ d = cls(add_special_symbols=add_special_symbols) d.add_from_file(f) @@ -255,31 +261,36 @@ def add_from_file(self, f): if field == "#fairseq:overwrite": overwrite = True line, field = line.rsplit(" ", 1) - else: + elif field == "#fairseq:duplicate": overwrite = False + line, field = line.rsplit(" ", 1) + else: + if line in self: + raise RuntimeError( + "Duplicate word found when loading Dictionary: '{}'. " + "Duplicate words can overwrite earlier ones by adding the " + "#fairseq:overwrite flag at the end of the corresponding row " + "in the dictionary file. Use the #fairseq:duplicate flag " + "to keep duplicates in the dictionary (backward compatibility " + "after bug fix). If using the Camembert model, please " + "download an updated copy of the model file.".format(word) + ) + overwrite = True # default behaviour count = int(field) word = line - if word in self and not overwrite: - raise RuntimeError( - "Duplicate word found when loading Dictionary: '{}'. " - "Duplicate words can overwrite earlier ones by adding the " - "#fairseq:overwrite flag at the end of the corresponding row " - "in the dictionary file. If using the Camembert model, please " - "download an updated copy of the model file.".format(word) - ) self.add_symbol(word, n=count, overwrite=overwrite) except ValueError: raise ValueError( f"Incorrect dictionary format, expected ' [flags]': \"{line}\"" ) - def _save(self, f, kv_iterator): + def _save(self, f, kvf_iterator): if isinstance(f, str): PathManager.mkdirs(os.path.dirname(f)) with PathManager.open(f, "w", encoding="utf-8") as fd: return self.save(fd) - for k, v in kv_iterator: - print("{} {}".format(k, v), file=f) + for k, v, flag in kvf_iterator: + print("{} {} {}".format(k, v, flag), file=f) def _get_meta(self): return [], [] @@ -295,6 +306,12 @@ def save(self, f): zip( ex_keys + self.symbols[self.nspecial :], ex_vals + self.count[self.nspecial :], + [ + '#fairseq:duplicate' + if s in self.symbols[:self.nspecial+i] + else '' + for i, s in enumerate(self.symbols[self.nspecial:]) + ] ), ) diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py index dc9d71b3c7..8f367e8ab3 100644 --- a/tests/test_dictionary.py +++ b/tests/test_dictionary.py @@ -78,7 +78,6 @@ def assertMatch(ids, ref_ids): assertMatch(finalized_ids, reload_ids) def test_overwrite(self): - # for example, Camembert overwrites , and dict_file = io.StringIO( " 999 #fairseq:overwrite\n" " 999 #fairseq:overwrite\n" @@ -88,6 +87,33 @@ def test_overwrite(self): ) d = Dictionary() d.add_from_file(dict_file) + self.assertEqual(d.bos(), 0) + self.assertEqual(d.pad(), 1) + self.assertEqual(d.eos(), 2) + self.assertEqual(d.unk(), 3) + self.assertEqual(d.index(""), 1) + self.assertEqual(d.index("foo"), 3) + self.assertEqual(d.index(""), 3) + self.assertEqual(d.index(""), 0) + self.assertEqual(d.index(""), 2) + self.assertEqual(d.index(","), 4) + self.assertEqual(d.index("▁de"), 5) + + def test_duplicate(self): + # for example, Camembert duplicates , and + dict_file = io.StringIO( + " 999 #fairseq:duplicate\n" + " 999 #fairseq:duplicate\n" + " 999 #fairseq:duplicate\n" + ", 999\n" + "▁de 999\n" + ) + d = Dictionary() + d.add_from_file(dict_file) + self.assertEqual(d.bos(), 0) + self.assertEqual(d.pad(), 1) + self.assertEqual(d.eos(), 2) + self.assertEqual(d.unk(), 3) self.assertEqual(d.index(""), 1) self.assertEqual(d.index("foo"), 3) self.assertEqual(d.index(""), 4) @@ -96,8 +122,8 @@ def test_overwrite(self): self.assertEqual(d.index(","), 7) self.assertEqual(d.index("▁de"), 8) - def test_no_overwrite(self): - # for example, Camembert overwrites , and + def test_no_overwrite_nor_duplicate(self): + # for example, Camembert duplicates , and dict_file = io.StringIO( " 999\n" " 999\n" " 999\n" ", 999\n" "▁de 999\n" )