Skip to content

Commit a241fe0

Browse files
committed
bpe tokenizers
1 parent b09c776 commit a241fe0

File tree

1 file changed

+92
-33
lines changed

1 file changed

+92
-33
lines changed

python_autocomplete/bpe.py

Lines changed: 92 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import string
22
from heapq import heappush, heappop
3+
from typing import List, Tuple
34

45
from labml import lab, monit
56

@@ -25,33 +26,20 @@ def calc_bpe_itos(self):
2526
return itos
2627

2728

28-
class BPELearner:
29-
def __init__(self, data: str):
30-
self.data = data
31-
self.words = {}
32-
self.heap = []
33-
self.heap_modified = set()
34-
self.char_itos = []
35-
self.char_stoi = {}
36-
self.bpe = []
37-
self.word_codes = []
38-
self.word_code_prev = {}
39-
self.word_code_next = {}
29+
class Tokenizer:
30+
def collect_words(self, data: str):
31+
raise NotImplementedError
4032

41-
self.counts = {}
42-
self.locations = {}
33+
def get_words(self) -> Tuple[List[str], List[int]]:
34+
raise NotImplementedError
4335

44-
self.collect_words()
45-
self.build_vocab()
46-
self.build_word_arrays()
47-
self.collect_pairs()
36+
def tokenize(self, data: str) -> List[str]:
37+
raise NotImplementedError
4838

49-
def learn(self, merges: int):
50-
for i in monit.iterate('BPE', merges):
51-
while True:
52-
res = self.merge_pair()
53-
if res is not None:
54-
break
39+
40+
class SourceCodeTokenizer(Tokenizer):
41+
def __init__(self):
42+
self.words = {}
5543

5644
def add_word(self, word):
5745
if not word:
@@ -62,28 +50,96 @@ def add_word(self, word):
6250
else:
6351
self.words[word] += 1
6452

65-
def collect_words(self):
53+
def tokenize(self, data: str) -> List[str]:
6654
last_idx = 0
6755
is_id = False
56+
res = []
6857

69-
for i, c in monit.enum('Collect words', self.data):
58+
for i, c in monit.enum('Collect words', data):
7059
if c in ID_CHARS:
7160
if not is_id:
72-
self.add_word(self.data[last_idx:i])
61+
if last_idx < i:
62+
res.append(data[last_idx:i])
7363
last_idx = i
7464
is_id = True
7565
else:
7666
if is_id:
77-
self.add_word(self.data[last_idx:i])
67+
if last_idx < i:
68+
res.append(data[last_idx:i])
7869
last_idx = i
7970
is_id = False
8071

81-
self.add_word(self.data[last_idx:])
72+
if last_idx < len(data):
73+
res.append(data[last_idx:])
74+
75+
return res
76+
77+
def collect_words(self, data: str):
78+
last_idx = 0
79+
is_id = False
80+
81+
for i, c in monit.enum('Collect words', data):
82+
if c in ID_CHARS:
83+
if not is_id:
84+
self.add_word(data[last_idx:i])
85+
last_idx = i
86+
is_id = True
87+
else:
88+
if is_id:
89+
self.add_word(data[last_idx:i])
90+
last_idx = i
91+
is_id = False
92+
93+
self.add_word(data[last_idx:])
94+
95+
def get_words(self):
8296
words_list = [(f, w) for w, f in self.words.items()]
8397
words_list.sort(key=lambda x: -x[0])
8498

85-
self.words_list = [w for _, w in words_list]
86-
self.word_freq = [f for f, _ in words_list]
99+
return [w for _, w in words_list], [f for f, _ in words_list]
100+
101+
102+
class NoTokenizer(Tokenizer):
103+
def __init__(self):
104+
self.data = ''
105+
106+
def collect_words(self, data):
107+
self.data += data
108+
109+
def get_words(self):
110+
return [self.data], [1]
111+
112+
def tokenize(self, data: str) -> List[str]:
113+
return [data]
114+
115+
116+
class BPELearner:
117+
def __init__(self, words_list: List[str], word_freq: List[int]):
118+
self.words_list = words_list
119+
self.word_freq = word_freq
120+
121+
self.heap = []
122+
self.heap_modified = set()
123+
self.char_itos = []
124+
self.char_stoi = {}
125+
self.bpe = []
126+
self.word_codes = []
127+
self.word_code_prev = []
128+
self.word_code_next = []
129+
130+
self.counts = {}
131+
self.locations = {}
132+
133+
self.build_vocab()
134+
self.build_word_arrays()
135+
self.collect_pairs()
136+
137+
def learn(self, merges: int):
138+
for i in monit.iterate('BPE', merges):
139+
while True:
140+
res = self.merge_pair()
141+
if res is not None:
142+
break
87143

88144
def build_vocab(self):
89145
vocab = set()
@@ -230,11 +286,14 @@ def main():
230286
with open(str(path), 'r') as f:
231287
data = f.read()[:100_000]
232288

233-
bpe = BPELearner(data)
289+
tokenizer = SourceCodeTokenizer()
290+
tokenizer.collect_words(data)
291+
292+
bpe = BPELearner(*tokenizer.get_words())
234293
bpe.learn(1000)
235294
print(len(bpe.bpe))
236295
print(bpe.bpe_itos()[len(bpe.char_itos):])
237-
print(len(bpe.data), bpe.get_length())
296+
print(len(data), bpe.get_length())
238297

239298

240299
if __name__ == '__main__':

0 commit comments

Comments
 (0)