From 3ed453aca9a909fe0a6c62af7ac51f906efcc883 Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Fri, 15 Sep 2023 19:45:08 +0200 Subject: [PATCH 01/10] fix overwrite bug when adding symbol to dictionary This bug ignored the tokens that were meant to be overwritten and appends them to the end of the dictionary symbols. For example, a dictionary with 50K tokens that already has ``, ``, `` and `` with the #fairseq:overwrite tag will end up having 50004 tokens when loaded. --- fairseq/data/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 7ad590a19b..3b8b741c4d 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -126,7 +126,7 @@ def unk_string(self, escape=False): def add_symbol(self, word, n=1, overwrite=False): """Adds a word to the dictionary""" - if word in self.indices and not overwrite: + if word in self.indices and overwrite: idx = self.indices[word] self.count[idx] = self.count[idx] + n return idx From 576602d18691aec1db3029a72ab2f9e7f8805bcb Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Thu, 21 Sep 2023 18:19:51 +0200 Subject: [PATCH 02/10] Fix test_overwrite in test_dictionary.py Assert that overwrite works as expected (i.e. ignoring the duplicates) --- tests/test_dictionary.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py index dc9d71b3c7..994aa7b13c 100644 --- a/tests/test_dictionary.py +++ b/tests/test_dictionary.py @@ -90,11 +90,11 @@ def test_overwrite(self): d.add_from_file(dict_file) self.assertEqual(d.index(""), 1) self.assertEqual(d.index("foo"), 3) - self.assertEqual(d.index(""), 4) - self.assertEqual(d.index(""), 5) - self.assertEqual(d.index(""), 6) - self.assertEqual(d.index(","), 7) - self.assertEqual(d.index("▁de"), 8) + self.assertEqual(d.index(""), 3) + self.assertEqual(d.index(""), 0) + self.assertEqual(d.index(""), 2) + self.assertEqual(d.index(","), 4) + self.assertEqual(d.index("▁de"), 5) def test_no_overwrite(self): # for example, Camembert overwrites , and From c7535b06209c39dfe180c1797efcdf785013ee84 Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Thu, 21 Sep 2023 22:01:29 +0200 Subject: [PATCH 03/10] Add support for fairseq:duplicate flag in dictionary For backward compatibility with the existing models/pipelines that uses a flawed dictionary loaded from file (before the bug fix) --- fairseq/data/dictionary.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 3b8b741c4d..01c0f2f960 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -219,10 +219,13 @@ def load(cls, f, add_special_symbols=True): """Loads the dictionary from a text file with the format: ``` - - + [] + [] ... ``` + Possible flags are `#fairseq:overwrite` to overwrite duplicates + and `#fairseq:duplicate` to keep them (for backwards compatibility + after bug fix) """ d = cls(add_special_symbols=add_special_symbols) d.add_from_file(f) @@ -253,18 +256,23 @@ def add_from_file(self, f): try: line, field = line.rstrip().rsplit(" ", 1) if field == "#fairseq:overwrite": - overwrite = True + overwrite, duplicate = True, False + line, field = line.rsplit(" ", 1) + elif field == "#fairseq:duplicate": + overwrite, duplicate = False, True line, field = line.rsplit(" ", 1) else: - overwrite = False + overwrite, duplicate = False, False count = int(field) word = line - if word in self and not overwrite: + if word in self and not overwrite and not duplicate: raise RuntimeError( "Duplicate word found when loading Dictionary: '{}'. " "Duplicate words can overwrite earlier ones by adding the " "#fairseq:overwrite flag at the end of the corresponding row " - "in the dictionary file. If using the Camembert model, please " + "in the dictionary file. Use the #fairseq:duplicate flag " + "to keep duplicates in the dictionary (backward compatibility " + "after bug fix). If using the Camembert model, please " "download an updated copy of the model file.".format(word) ) self.add_symbol(word, n=count, overwrite=overwrite) @@ -273,13 +281,13 @@ def add_from_file(self, f): f"Incorrect dictionary format, expected ' [flags]': \"{line}\"" ) - def _save(self, f, kv_iterator): + def _save(self, f, kvf_iterator): if isinstance(f, str): PathManager.mkdirs(os.path.dirname(f)) with PathManager.open(f, "w", encoding="utf-8") as fd: return self.save(fd) - for k, v in kv_iterator: - print("{} {}".format(k, v), file=f) + for k, v, flag in kvf_iterator: + print("{} {} {}".format(k, v, flag), file=f) def _get_meta(self): return [], [] @@ -295,6 +303,12 @@ def save(self, f): zip( ex_keys + self.symbols[self.nspecial :], ex_vals + self.count[self.nspecial :], + [ + '#fairseq:duplicate' + if s in self.symbols[:self.nspecial+i] + else '' + for i, s in enumerate(self.symbols[self.nspecial:]) + ] ), ) From 89878963c3c03851d3241de32baa708950d5ba55 Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Thu, 21 Sep 2023 22:28:00 +0200 Subject: [PATCH 04/10] Write unit tests for overwrite and duplicate in dictionary --- tests/test_dictionary.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py index 994aa7b13c..a7b71b705d 100644 --- a/tests/test_dictionary.py +++ b/tests/test_dictionary.py @@ -78,7 +78,6 @@ def assertMatch(ids, ref_ids): assertMatch(finalized_ids, reload_ids) def test_overwrite(self): - # for example, Camembert overwrites , and dict_file = io.StringIO( " 999 #fairseq:overwrite\n" " 999 #fairseq:overwrite\n" @@ -88,6 +87,10 @@ def test_overwrite(self): ) d = Dictionary() d.add_from_file(dict_file) + self.assertEqual(d.bos(), 0) + self.assertEqual(d.pad(), 1) + self.assertEqual(d.eos(), 2) + self.assertEqual(d.unk(), 3) self.assertEqual(d.index(""), 1) self.assertEqual(d.index("foo"), 3) self.assertEqual(d.index(""), 3) @@ -95,9 +98,32 @@ def test_overwrite(self): self.assertEqual(d.index(""), 2) self.assertEqual(d.index(","), 4) self.assertEqual(d.index("▁de"), 5) + + def test_duplicate(self): + # for example, Camembert duplicates , and + dict_file = io.StringIO( + " 999 #fairseq:duplicate\n" + " 999 #fairseq:duplicate\n" + " 999 #fairseq:duplicate\n" + ", 999\n" + "▁de 999\n" + ) + d = Dictionary() + d.add_from_file(dict_file) + self.assertEqual(d.bos(), 0) + self.assertEqual(d.pad(), 1) + self.assertEqual(d.eos(), 2) + self.assertEqual(d.unk(), 3) + self.assertEqual(d.index(""), 1) + self.assertEqual(d.index("foo"), 3) + self.assertEqual(d.index(""), 4) + self.assertEqual(d.index(""), 5) + self.assertEqual(d.index(""), 6) + self.assertEqual(d.index(","), 7) + self.assertEqual(d.index("▁de"), 8) def test_no_overwrite(self): - # for example, Camembert overwrites , and + # for example, Camembert duplicates , and dict_file = io.StringIO( " 999\n" " 999\n" " 999\n" ", 999\n" "▁de 999\n" ) From 0968083e09c1153e1be535e2c9ce71767fa0ac01 Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Thu, 21 Sep 2023 23:10:30 +0200 Subject: [PATCH 05/10] Update dictionary.py load function documentation --- fairseq/data/dictionary.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 01c0f2f960..3695458363 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -218,14 +218,17 @@ def unk(self): def load(cls, f, add_special_symbols=True): """Loads the dictionary from a text file with the format: - ``` - [] - [] - ... - ``` - Possible flags are `#fairseq:overwrite` to overwrite duplicates - and `#fairseq:duplicate` to keep them (for backwards compatibility - after bug fix) + Example:: + ``` + [] + [] + ... + ``` + + Note: + Possible flags are `#fairseq:overwrite` to overwrite duplicates + and `#fairseq:duplicate` to keep them (for backward compatibility + after bug fix) """ d = cls(add_special_symbols=add_special_symbols) d.add_from_file(f) From eed21c03ba0a8dc77933e06ede262382c483e7cb Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Thu, 21 Sep 2023 23:47:02 +0200 Subject: [PATCH 06/10] Adding symbols with overwrite=True in encode_line and add_file_to_dictionary After fixing the behaviour of add_symbol, two of the unit tests were failing because they called the function with the default value of overwrite (False). --- fairseq/data/dictionary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 3695458363..bd29be5b94 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -337,7 +337,7 @@ def encode_line( for i, word in enumerate(words): if add_if_not_exist: - idx = self.add_symbol(word) + idx = self.add_symbol(word, overwrite=True) else: idx = self.index(word) if consumer is not None: @@ -367,7 +367,7 @@ def _add_file_to_dictionary_single_worker( def add_file_to_dictionary(filename, dict, tokenize, num_workers): def merge_result(counter): for w, c in sorted(counter.items()): - dict.add_symbol(w, c) + dict.add_symbol(w, c, overwrite=True) local_file = PathManager.get_local_path(filename) offsets = find_offsets(local_file, num_workers) From b291c8d09e5c09ec1aa5b31454f7105050088d9c Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Fri, 22 Sep 2023 09:59:53 +0200 Subject: [PATCH 07/10] rename test_no_overwrite to test_no_overwrite_nor_duplicate --- tests/test_dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py index a7b71b705d..8f367e8ab3 100644 --- a/tests/test_dictionary.py +++ b/tests/test_dictionary.py @@ -122,7 +122,7 @@ def test_duplicate(self): self.assertEqual(d.index(","), 7) self.assertEqual(d.index("▁de"), 8) - def test_no_overwrite(self): + def test_no_overwrite_nor_duplicate(self): # for example, Camembert duplicates , and dict_file = io.StringIO( " 999\n" " 999\n" " 999\n" ", 999\n" "▁de 999\n" From 174331444852a6912d9ea7638d83b2ae4c43dffd Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Fri, 8 Mar 2024 13:42:17 +0100 Subject: [PATCH 08/10] set overwrite default value to True in add_symbol This ensures compatibility with all the calls to add_symbol across the repo (which overwrite by default, as in the original implementation). The only place where the value is explicitly changed is when loading the dictionary from file (which was the source of the bug). In a file you have to explicitly say whether the tokens should be overwritten or duplicated --- fairseq/data/dictionary.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index bd29be5b94..ff4b6f8980 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -124,7 +124,7 @@ def unk_string(self, escape=False): else: return self.unk_word - def add_symbol(self, word, n=1, overwrite=False): + def add_symbol(self, word, n=1, overwrite=True): """Adds a word to the dictionary""" if word in self.indices and overwrite: idx = self.indices[word] @@ -337,7 +337,7 @@ def encode_line( for i, word in enumerate(words): if add_if_not_exist: - idx = self.add_symbol(word, overwrite=True) + idx = self.add_symbol(word) else: idx = self.index(word) if consumer is not None: @@ -367,7 +367,7 @@ def _add_file_to_dictionary_single_worker( def add_file_to_dictionary(filename, dict, tokenize, num_workers): def merge_result(counter): for w, c in sorted(counter.items()): - dict.add_symbol(w, c, overwrite=True) + dict.add_symbol(w, c) local_file = PathManager.get_local_path(filename) offsets = find_offsets(local_file, num_workers) From 5c40fd37f96bc011090741729d6fe019e69939ad Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Fri, 8 Mar 2024 13:57:04 +0100 Subject: [PATCH 09/10] remove redundant duplicate variable when loading dictionary from file --- fairseq/data/dictionary.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index ff4b6f8980..03df789f7e 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -259,25 +259,25 @@ def add_from_file(self, f): try: line, field = line.rstrip().rsplit(" ", 1) if field == "#fairseq:overwrite": - overwrite, duplicate = True, False + overwrite = True line, field = line.rsplit(" ", 1) elif field == "#fairseq:duplicate": - overwrite, duplicate = False, True + overwrite = False line, field = line.rsplit(" ", 1) else: - overwrite, duplicate = False, False + if line in self: + raise RuntimeError( + "Duplicate word found when loading Dictionary: '{}'. " + "Duplicate words can overwrite earlier ones by adding the " + "#fairseq:overwrite flag at the end of the corresponding row " + "in the dictionary file. Use the #fairseq:duplicate flag " + "to keep duplicates in the dictionary (backward compatibility " + "after bug fix). If using the Camembert model, please " + "download an updated copy of the model file.".format(word) + ) + overwrite = True # default behaviour count = int(field) word = line - if word in self and not overwrite and not duplicate: - raise RuntimeError( - "Duplicate word found when loading Dictionary: '{}'. " - "Duplicate words can overwrite earlier ones by adding the " - "#fairseq:overwrite flag at the end of the corresponding row " - "in the dictionary file. Use the #fairseq:duplicate flag " - "to keep duplicates in the dictionary (backward compatibility " - "after bug fix). If using the Camembert model, please " - "download an updated copy of the model file.".format(word) - ) self.add_symbol(word, n=count, overwrite=overwrite) except ValueError: raise ValueError( From 552fb216d6dd586fd933c0a3850976bfd7f5e3cb Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Wed, 24 Jul 2024 20:58:09 +0200 Subject: [PATCH 10/10] Update dictionary.py with fairseq-overwrite bug fix