Skip to content

Commit e8cf860

Browse files
authored
Consider escaped characters as single characters in BPE (#322)
1 parent 128b900 commit e8cf860

File tree

4 files changed

+42
-11
lines changed

4 files changed

+42
-11
lines changed

bindings/python/test/test.py

+12
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,18 @@ def test_bpe_learner_no_pairs(tmpdir):
341341
tokenizer = learner.learn(model_path)
342342

343343

344+
def test_bpe_learner_escaped_character(tmpdir):
345+
text = "คุณอาจจะทำอย ่ างนั ้ นไปซักพัก จนคุณเริ ่ มจะรู ้ สึกถึงมันจริงๆ"
346+
347+
tokenizer = pyonmttok.Tokenizer("aggressive", joiner_annotate=True)
348+
learner = pyonmttok.BPELearner(tokenizer=tokenizer, symbols=5, min_frequency=1)
349+
learner.ingest(text)
350+
tokenizer = learner.learn(str(tmpdir.join("bpe.model")))
351+
352+
tokens = tokenizer(text)
353+
assert "■%0020่" in tokens
354+
355+
344356
@pytest.mark.parametrize("keep_vocab", [False, True])
345357
def test_sp_learner(tmpdir, keep_vocab):
346358
learner = pyonmttok.SentencePieceLearner(

include/onmt/Tokenizer.h

+2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ namespace onmt
7070
static const std::string spacer_marker;
7171
static const std::string ph_marker_open;
7272
static const std::string ph_marker_close;
73+
static const std::string escaped_character_prefix;
74+
static const size_t escaped_character_width;
7375

7476
Tokenizer(Options options,
7577
const std::shared_ptr<const SubwordEncoder>& subword_encoder = nullptr);

src/BPE.cc

+15-1
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,23 @@ namespace onmt
138138
std::vector<std::string> pieces;
139139
pieces.reserve(chars.size());
140140

141+
static const auto escaped_character_prefix = (
142+
unicode::utf8_to_cp(Tokenizer::escaped_character_prefix.c_str()));
143+
size_t escaped_character_length = 0;
144+
141145
for (const auto& c : chars)
142146
{
143-
if (c.char_type == unicode::CharType::Mark)
147+
if (escaped_character_length > 0)
148+
{
149+
pieces.back().append(c.data, c.length);
150+
escaped_character_length--;
151+
}
152+
else if (c.value == escaped_character_prefix)
153+
{
154+
pieces.emplace_back(c.data, c.length);
155+
escaped_character_length = Tokenizer::escaped_character_width;
156+
}
157+
else if (c.char_type == unicode::CharType::Mark)
144158
{
145159
if (pieces.empty())
146160
pieces.emplace_back(c.data, c.length);

src/Tokenizer.cc

+13-10
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ namespace onmt
1818
const std::string Tokenizer::spacer_marker = "";
1919
const std::string Tokenizer::ph_marker_open = "";
2020
const std::string Tokenizer::ph_marker_close = "";
21+
const std::string Tokenizer::escaped_character_prefix = "";
22+
const size_t Tokenizer::escaped_character_width = 4;
2123
static const unicode::code_point_t ph_marker_open_cp = 0xFF5F;
2224
static const unicode::code_point_t ph_marker_close_cp = 0xFF60;
23-
static const std::string protected_character = "";
2425
static const std::vector<std::pair<unicode::code_point_t, std::string>> substitutes = {
2526
{0x2581 /**/, "_"},
2627
{0xFFED /**/, ""},
@@ -32,7 +33,6 @@ namespace onmt
3233

3334
static const int placeholder_alphabet = -2;
3435
static const int number_alphabet = -3;
35-
static const int hex_value_width = 4;
3636

3737
Tokenizer::Mode Tokenizer::str_to_mode(const std::string& mode)
3838
{
@@ -290,21 +290,23 @@ namespace onmt
290290

291291
static inline void unescape_characters(std::string& str)
292292
{
293+
const auto& prefix = Tokenizer::escaped_character_prefix;
294+
const auto& width = Tokenizer::escaped_character_width;
295+
293296
for (size_t offset = 0;;)
294297
{
295-
const size_t index = str.find(protected_character, offset);
296-
if (index == std::string::npos
297-
|| index + protected_character.size() + hex_value_width > str.size())
298+
const size_t index = str.find(prefix, offset);
299+
if (index == std::string::npos || index + prefix.size() + width > str.size())
298300
break;
299301

300-
const std::string code = str.substr(index + protected_character.size(), hex_value_width);
302+
const std::string code = str.substr(index + prefix.size(), width);
301303
const int v = hex_to_int(code);
302304
const std::string c = unicode::cp_to_utf8(v);
303-
if (c.empty() || !c[0] || int_to_hex(v, hex_value_width) != code)
304-
offset = index + protected_character.size();
305+
if (c.empty() || !c[0] || int_to_hex(v, width) != code)
306+
offset = index + prefix.size();
305307
else
306308
{
307-
str.replace(index, protected_character.size() + hex_value_width, c);
309+
str.replace(index, prefix.size() + width, c);
308310
offset = index + 1;
309311
}
310312
}
@@ -635,7 +637,8 @@ namespace onmt
635637
if (_no_substitution)
636638
append(character);
637639
else
638-
append(protected_character + int_to_hex(character.value, hex_value_width));
640+
append(Tokenizer::escaped_character_prefix
641+
+ int_to_hex(character.value, Tokenizer::escaped_character_width));
639642
}
640643

641644
void flush_feature()

0 commit comments

Comments
 (0)