1
1
import string
2
2
from heapq import heappush , heappop
3
+ from typing import List , Tuple
3
4
4
5
from labml import lab , monit
5
6
@@ -25,33 +26,20 @@ def calc_bpe_itos(self):
25
26
return itos
26
27
27
28
28
- class BPELearner :
29
- def __init__ (self , data : str ):
30
- self .data = data
31
- self .words = {}
32
- self .heap = []
33
- self .heap_modified = set ()
34
- self .char_itos = []
35
- self .char_stoi = {}
36
- self .bpe = []
37
- self .word_codes = []
38
- self .word_code_prev = {}
39
- self .word_code_next = {}
29
+ class Tokenizer :
30
+ def collect_words (self , data : str ):
31
+ raise NotImplementedError
40
32
41
- self . counts = {}
42
- self . locations = {}
33
+ def get_words ( self ) -> Tuple [ List [ str ], List [ int ]]:
34
+ raise NotImplementedError
43
35
44
- self .collect_words ()
45
- self .build_vocab ()
46
- self .build_word_arrays ()
47
- self .collect_pairs ()
36
+ def tokenize (self , data : str ) -> List [str ]:
37
+ raise NotImplementedError
48
38
49
- def learn (self , merges : int ):
50
- for i in monit .iterate ('BPE' , merges ):
51
- while True :
52
- res = self .merge_pair ()
53
- if res is not None :
54
- break
39
+
40
+ class SourceCodeTokenizer (Tokenizer ):
41
+ def __init__ (self ):
42
+ self .words = {}
55
43
56
44
def add_word (self , word ):
57
45
if not word :
@@ -62,28 +50,96 @@ def add_word(self, word):
62
50
else :
63
51
self .words [word ] += 1
64
52
65
- def collect_words (self ) :
53
+ def tokenize (self , data : str ) -> List [ str ] :
66
54
last_idx = 0
67
55
is_id = False
56
+ res = []
68
57
69
- for i , c in monit .enum ('Collect words' , self . data ):
58
+ for i , c in monit .enum ('Collect words' , data ):
70
59
if c in ID_CHARS :
71
60
if not is_id :
72
- self .add_word (self .data [last_idx :i ])
61
+ if last_idx < i :
62
+ res .append (data [last_idx :i ])
73
63
last_idx = i
74
64
is_id = True
75
65
else :
76
66
if is_id :
77
- self .add_word (self .data [last_idx :i ])
67
+ if last_idx < i :
68
+ res .append (data [last_idx :i ])
78
69
last_idx = i
79
70
is_id = False
80
71
81
- self .add_word (self .data [last_idx :])
72
+ if last_idx < len (data ):
73
+ res .append (data [last_idx :])
74
+
75
+ return res
76
+
77
+ def collect_words (self , data : str ):
78
+ last_idx = 0
79
+ is_id = False
80
+
81
+ for i , c in monit .enum ('Collect words' , data ):
82
+ if c in ID_CHARS :
83
+ if not is_id :
84
+ self .add_word (data [last_idx :i ])
85
+ last_idx = i
86
+ is_id = True
87
+ else :
88
+ if is_id :
89
+ self .add_word (data [last_idx :i ])
90
+ last_idx = i
91
+ is_id = False
92
+
93
+ self .add_word (data [last_idx :])
94
+
95
+ def get_words (self ):
82
96
words_list = [(f , w ) for w , f in self .words .items ()]
83
97
words_list .sort (key = lambda x : - x [0 ])
84
98
85
- self .words_list = [w for _ , w in words_list ]
86
- self .word_freq = [f for f , _ in words_list ]
99
+ return [w for _ , w in words_list ], [f for f , _ in words_list ]
100
+
101
+
102
+ class NoTokenizer (Tokenizer ):
103
+ def __init__ (self ):
104
+ self .data = ''
105
+
106
+ def collect_words (self , data ):
107
+ self .data += data
108
+
109
+ def get_words (self ):
110
+ return [self .data ], [1 ]
111
+
112
+ def tokenize (self , data : str ) -> List [str ]:
113
+ return [data ]
114
+
115
+
116
+ class BPELearner :
117
+ def __init__ (self , words_list : List [str ], word_freq : List [int ]):
118
+ self .words_list = words_list
119
+ self .word_freq = word_freq
120
+
121
+ self .heap = []
122
+ self .heap_modified = set ()
123
+ self .char_itos = []
124
+ self .char_stoi = {}
125
+ self .bpe = []
126
+ self .word_codes = []
127
+ self .word_code_prev = []
128
+ self .word_code_next = []
129
+
130
+ self .counts = {}
131
+ self .locations = {}
132
+
133
+ self .build_vocab ()
134
+ self .build_word_arrays ()
135
+ self .collect_pairs ()
136
+
137
+ def learn (self , merges : int ):
138
+ for i in monit .iterate ('BPE' , merges ):
139
+ while True :
140
+ res = self .merge_pair ()
141
+ if res is not None :
142
+ break
87
143
88
144
def build_vocab (self ):
89
145
vocab = set ()
@@ -230,11 +286,14 @@ def main():
230
286
with open (str (path ), 'r' ) as f :
231
287
data = f .read ()[:100_000 ]
232
288
233
- bpe = BPELearner (data )
289
+ tokenizer = SourceCodeTokenizer ()
290
+ tokenizer .collect_words (data )
291
+
292
+ bpe = BPELearner (* tokenizer .get_words ())
234
293
bpe .learn (1000 )
235
294
print (len (bpe .bpe ))
236
295
print (bpe .bpe_itos ()[len (bpe .char_itos ):])
237
- print (len (bpe . data ), bpe .get_length ())
296
+ print (len (data ), bpe .get_length ())
238
297
239
298
240
299
if __name__ == '__main__' :
0 commit comments