Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix overwrite bug when adding symbol to dictionary #5329

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
55 changes: 36 additions & 19 deletions fairseq/data/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def unk_string(self, escape=False):
else:
return self.unk_word

def add_symbol(self, word, n=1, overwrite=False):
def add_symbol(self, word, n=1, overwrite=True):
"""Adds a word to the dictionary"""
if word in self.indices and not overwrite:
if word in self.indices and overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
Expand Down Expand Up @@ -218,11 +218,17 @@ def unk(self):
def load(cls, f, add_special_symbols=True):
"""Loads the dictionary from a text file with the format:

```
<symbol0> <count0>
<symbol1> <count1>
...
```
Example::
```
<symbol0> <count0> [<flag0>]
<symbol1> <count1> [<flag1>]
...
```

Note:
Possible flags are `#fairseq:overwrite` to overwrite duplicates
and `#fairseq:duplicate` to keep them (for backward compatibility
after bug fix)
"""
d = cls(add_special_symbols=add_special_symbols)
d.add_from_file(f)
Expand Down Expand Up @@ -255,31 +261,36 @@ def add_from_file(self, f):
if field == "#fairseq:overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
else:
elif field == "#fairseq:duplicate":
overwrite = False
line, field = line.rsplit(" ", 1)
else:
if line in self:
raise RuntimeError(
"Duplicate word found when loading Dictionary: '{}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. Use the #fairseq:duplicate flag "
"to keep duplicates in the dictionary (backward compatibility "
"after bug fix). If using the Camembert model, please "
"download an updated copy of the model file.".format(word)
)
overwrite = True # default behaviour
count = int(field)
word = line
if word in self and not overwrite:
raise RuntimeError(
"Duplicate word found when loading Dictionary: '{}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
"download an updated copy of the model file.".format(word)
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError(
f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\""
)

def _save(self, f, kv_iterator):
def _save(self, f, kvf_iterator):
if isinstance(f, str):
PathManager.mkdirs(os.path.dirname(f))
with PathManager.open(f, "w", encoding="utf-8") as fd:
return self.save(fd)
for k, v in kv_iterator:
print("{} {}".format(k, v), file=f)
for k, v, flag in kvf_iterator:
print("{} {} {}".format(k, v, flag), file=f)

def _get_meta(self):
return [], []
Expand All @@ -295,6 +306,12 @@ def save(self, f):
zip(
ex_keys + self.symbols[self.nspecial :],
ex_vals + self.count[self.nspecial :],
[
'#fairseq:duplicate'
if s in self.symbols[:self.nspecial+i]
else ''
for i, s in enumerate(self.symbols[self.nspecial:])
]
),
)

Expand Down
32 changes: 29 additions & 3 deletions tests/test_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def assertMatch(ids, ref_ids):
assertMatch(finalized_ids, reload_ids)

def test_overwrite(self):
# for example, Camembert overwrites <unk>, <s> and </s>
dict_file = io.StringIO(
"<unk> 999 #fairseq:overwrite\n"
"<s> 999 #fairseq:overwrite\n"
Expand All @@ -88,6 +87,33 @@ def test_overwrite(self):
)
d = Dictionary()
d.add_from_file(dict_file)
self.assertEqual(d.bos(), 0)
self.assertEqual(d.pad(), 1)
self.assertEqual(d.eos(), 2)
self.assertEqual(d.unk(), 3)
self.assertEqual(d.index("<pad>"), 1)
self.assertEqual(d.index("foo"), 3)
self.assertEqual(d.index("<unk>"), 3)
self.assertEqual(d.index("<s>"), 0)
self.assertEqual(d.index("</s>"), 2)
self.assertEqual(d.index(","), 4)
self.assertEqual(d.index("▁de"), 5)

def test_duplicate(self):
# for example, Camembert duplicates <unk>, <s> and </s>
dict_file = io.StringIO(
"<unk> 999 #fairseq:duplicate\n"
"<s> 999 #fairseq:duplicate\n"
"</s> 999 #fairseq:duplicate\n"
", 999\n"
"▁de 999\n"
)
d = Dictionary()
d.add_from_file(dict_file)
self.assertEqual(d.bos(), 0)
self.assertEqual(d.pad(), 1)
self.assertEqual(d.eos(), 2)
self.assertEqual(d.unk(), 3)
self.assertEqual(d.index("<pad>"), 1)
self.assertEqual(d.index("foo"), 3)
self.assertEqual(d.index("<unk>"), 4)
Expand All @@ -96,8 +122,8 @@ def test_overwrite(self):
self.assertEqual(d.index(","), 7)
self.assertEqual(d.index("▁de"), 8)

def test_no_overwrite(self):
# for example, Camembert overwrites <unk>, <s> and </s>
def test_no_overwrite_nor_duplicate(self):
# for example, Camembert duplicates <unk>, <s> and </s>
dict_file = io.StringIO(
"<unk> 999\n" "<s> 999\n" "</s> 999\n" ", 999\n" "▁de 999\n"
)
Expand Down