@@ -10,15 +10,15 @@ import Foundation
10
10
class CoreBPE {
11
11
private let encoder : [ [ UInt8 ] : Int ]
12
12
private let specialTokensEncoder : [ String : Int ]
13
- private let decoder : [ Int : Data ]
13
+ private let decoder : [ Int : [ UInt8 ] ]
14
14
private let specialTokensDecoder : [ Int : Data ]
15
15
private let regexTls : [ NSRegularExpression ]
16
16
private let specialRegexTls : [ NSRegularExpression ]
17
17
private let sortedTokenBytes : [ Data ]
18
18
19
19
init ( encoder: [ [ UInt8 ] : Int ] = . init( ) ,
20
20
specialTokensEncoder: [ String : Int ] = . init( ) ,
21
- decoder: [ Int : Data ] = . init( ) ,
21
+ decoder: [ Int : [ UInt8 ] ] = . init( ) ,
22
22
specialTokensDecoder: [ Int : Data ] = . init( ) ,
23
23
regexTls: [ NSRegularExpression ] = . init( ) ,
24
24
specialRegexTls: [ NSRegularExpression ] = . init( ) ,
@@ -35,13 +35,8 @@ class CoreBPE {
35
35
func encodeOrdinaryNative( text: String ) -> [ Int ] {
36
36
let regex = regexTls. first!
37
37
var ret = [ Int] ( )
38
- // var newEncoder = [[UInt8]: Int]()
39
- // encoder.forEach({
40
- // newEncoder[[UInt8]($0.key)] = $0.value
41
- // })
42
38
for mat in regex. matches ( in: text, range: NSRange ( text. startIndex... , in: text) ) {
43
39
if let range = Range ( mat. range, in: text) {
44
- // if let piece = Range(mat.range, in: text).map({ String(text[$0]) })?.data(using: .utf8) {
45
40
let piece = Array ( text [ range] . utf8)
46
41
if let token = encoder [ piece] {
47
42
ret. append ( token)
@@ -53,6 +48,15 @@ class CoreBPE {
53
48
}
54
49
return ret
55
50
}
51
+
52
+ func decodeNative( tokens: [ Int ] ) -> String {
53
+ let data = tokens. reduce ( into: Data ( ) , {
54
+ if let tokenBytes = decoder [ $1] {
55
+ $0. append ( contentsOf: tokenBytes)
56
+ }
57
+ } )
58
+ return String ( data: data, encoding: . utf8) ?? " "
59
+ }
56
60
}
57
61
58
62
private extension CoreBPE {
@@ -63,37 +67,6 @@ private extension CoreBPE {
63
67
// func _get_tl_special_regex() -> NSRegularExpression {
64
68
// specialRegexTls[hash_current_thread() % MAX_NUM_THREADS]
65
69
// }
66
-
67
- func decodeNative( tokens: [ Int ] ) -> Data {
68
- var data = Data ( )
69
- data. reserveCapacity ( tokens. count * 2 )
70
-
71
- for token in tokens {
72
- guard let tokenBytes = decoder [ token] ?? specialTokensDecoder [ token] else { break }
73
- data. append ( tokenBytes)
74
- }
75
- return data
76
- }
77
-
78
- // func encodeOrdinaryNative(text: String) -> [Int] {
79
- // let regex = regexTls.first!
80
- // var ret = [Int]()
81
- // var newEncoder = [[UInt8]: Int]()
82
- // encoder.forEach({
83
- // newEncoder[[UInt8]($0.key)] = $0.value
84
- // })
85
- // for mat in regex.matches(in: text, range: NSRange(text.startIndex..., in: text)) {
86
- // let piece = Range(mat.range, in: text).map({ String(text[$0]) })?.data(using: .utf8) ?? Data() // WARNING
87
- // if let token = encoder[piece] {
88
- // ret.append(token)
89
- // continue
90
- // }
91
- //
92
- // ret.append(contentsOf: bytePairEncode([UInt8](piece), newEncoder))
93
- // }
94
- // return ret
95
- // }
96
-
97
70
// func encodeNative(text: String, allowedSpecial: Set<String>) -> ([Int], Int) {
98
71
// let specialRegex = specialRegexTls.first!
99
72
// let regex = regexTls.first!
@@ -331,22 +304,8 @@ private extension CoreBPE {
331
304
}
332
305
}
333
306
334
- var out = [ T] ( )
335
- out. reserveCapacity ( parts. count - 1 )
336
- // for i in 0..<(parts.count - 1) {
337
- // out.append(completion(parts[i].0..<parts[i + 1].0))
338
- // }
339
-
340
307
// TODO: Use ranks
341
- parts. prevCurrent ( {
342
- // if let result = completion($0.0..<$1.0) {
343
- // out.append(result)
344
- // }
345
-
346
- let result = completion ( $0. 0 ..< $1. 0 )
347
- out. append ( result)
348
- } )
349
- return out
308
+ return parts. prevCurrent ( { completion ( $0. 0 ..< $1. 0 ) } )
350
309
}
351
310
352
311
func bytePairEncode( _ piece: [ UInt8 ] , _ ranks: [ [ UInt8 ] : Int ] ) -> [ Int ] {
@@ -355,7 +314,6 @@ private extension CoreBPE {
355
314
}
356
315
return bytePairMerge ( piece, ranks, completion: { p in
357
316
let chunk = Array ( piece [ p] )
358
- let characters = chunk. map ( { Array ( Character ( Int ( $0) ) . utf8) } ) . flatMap ( { $0 } )
359
317
return ranks [ chunk] ?? 0
360
318
} )
361
319
}
@@ -369,11 +327,11 @@ private extension CoreBPE {
369
327
}
370
328
371
329
extension Array {
372
- func prevCurrent( _ body: ( Element , Element ) -> Void ) {
373
- enumerated ( ) . forEach ( { index, element in
374
- guard index > 0 else { return }
330
+ func prevCurrent< T > ( _ body: ( Element , Element ) throws -> T ) rethrows -> [ T ] {
331
+ enumerated ( ) . compactMap ( { index, element in
332
+ guard index > 0 else { return nil }
375
333
let prev = self [ index- 1 ]
376
- body ( prev, element)
334
+ return try ? body ( prev, element)
377
335
} )
378
336
}
379
337
}
0 commit comments