Skip to content

Commit c1949da

Browse files
committed
Add model
1 parent 65a7883 commit c1949da

File tree

5 files changed

+145
-2
lines changed

5 files changed

+145
-2
lines changed

.gitignore

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@
44
/*.xcodeproj
55
xcuserdata/
66
DerivedData/
7-
.swiftpm/config/registries.json
8-
.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
7+
.swiftpm
98
.netrc

Sources/Tiktoken/Model.swift

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
//
2+
// Model.swift
3+
//
4+
//
5+
// Created by Alberto Espinilla Garrido on 20/3/23.
6+
//
7+
8+
import Foundation
9+
10+
enum Model {
11+
static func getEncoding(_ name: String) -> String? {
12+
if let encodingName = MODEL_TO_ENCODING[name] {
13+
return encodingName
14+
}
15+
return findPrefix(with: name)
16+
}
17+
}
18+
19+
private extension Model {
20+
static let MODEL_PREFIX_TO_ENCODING: [String: String] = [
21+
// chat
22+
"gpt-4-": "cl100k_base", // e.g., gpt-4-0314, etc., plus gpt-4-32k
23+
"gpt-3.5-turbo-": "cl100k_base", // e.g, gpt-3.5-turbo-0301, -0401, etc.
24+
]
25+
26+
static let MODEL_TO_ENCODING: [String: String] = [
27+
// chat
28+
"gpt-4": "cl100k_base",
29+
"gpt-3.5-turbo": "cl100k_base",
30+
// text
31+
"text-davinci-003": "p50k_base",
32+
"text-davinci-002": "p50k_base",
33+
"text-davinci-001": "r50k_base",
34+
"text-curie-001": "r50k_base",
35+
"text-babbage-001": "r50k_base",
36+
"text-ada-001": "r50k_base",
37+
"davinci": "r50k_base",
38+
"curie": "r50k_base",
39+
"babbage": "r50k_base",
40+
"ada": "r50k_base",
41+
// code
42+
"code-davinci-002": "p50k_base",
43+
"code-davinci-001": "p50k_base",
44+
"code-cushman-002": "p50k_base",
45+
"code-cushman-001": "p50k_base",
46+
"davinci-codex": "p50k_base",
47+
"cushman-codex": "p50k_base",
48+
// edit
49+
"text-davinci-edit-001": "p50k_edit",
50+
"code-davinci-edit-001": "p50k_edit",
51+
// embeddings
52+
"text-embedding-ada-002": "cl100k_base",
53+
// old embeddings
54+
"text-similarity-davinci-001": "r50k_base",
55+
"text-similarity-curie-001": "r50k_base",
56+
"text-similarity-babbage-001": "r50k_base",
57+
"text-similarity-ada-001": "r50k_base",
58+
"text-search-davinci-doc-001": "r50k_base",
59+
"text-search-curie-doc-001": "r50k_base",
60+
"text-search-babbage-doc-001": "r50k_base",
61+
"text-search-ada-doc-001": "r50k_base",
62+
"code-search-babbage-code-001": "r50k_base",
63+
"code-search-ada-code-001": "r50k_base",
64+
// open source
65+
"gpt2": "gpt2",
66+
]
67+
68+
static func findPrefix(with name: String) -> String? {
69+
guard let key = Model.MODEL_PREFIX_TO_ENCODING.keys.first(where: { name.starts(with: $0) }) else {
70+
return nil
71+
}
72+
return Model.MODEL_PREFIX_TO_ENCODING[key]
73+
}
74+
}

Sources/Tiktoken/Tiktoken.swift

+14
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,18 @@ public struct Tiktoken {
33

44
public init() {
55
}
6+
7+
public func getEncoding(_ name: String) -> Encoding? {
8+
nil
9+
}
10+
11+
public func getEncoding(for model: String) -> Encoding? {
12+
nil
13+
}
14+
}
15+
16+
17+
public protocol Encoding {
18+
func encode(value: String) -> [Int]
19+
func decode(value: [Int]) -> String
620
}

Tests/TiktokenTests/ModelTests.swift

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//
2+
// ModelTests.swift
3+
//
4+
//
5+
// Created by Alberto Espinilla Garrido on 20/3/23.
6+
//
7+
8+
import XCTest
9+
@testable import Tiktoken
10+
11+
final class ModelTests: XCTestCase {
12+
13+
func testGivenModelNamesWhenGetEncodingThenMatch() throws {
14+
try [
15+
Test(input: "gpt-4", output: "cl100k_base"),
16+
Test(input: "gpt-3.5-turbo", output: "cl100k_base"),
17+
Test(input: "davinci", output: "r50k_base"),
18+
Test(input: "text-davinci-edit-001", output: "p50k_edit"),
19+
].forEach({
20+
let output = Model.getEncoding($0.input)
21+
XCTAssertEqual(try XCTUnwrap(output), $0.output)
22+
})
23+
}
24+
25+
func testGivenModelNamesWithPrefisWhenGetEncodingThenMatch() throws {
26+
try [
27+
Test(input: "gpt-4-0314", output: "cl100k_base"),
28+
Test(input: "gpt-4-32k", output: "cl100k_base"),
29+
Test(input: "gpt-3.5-turbo-0301", output: "cl100k_base"),
30+
Test(input: "gpt-3.5-turbo-0401", output: "cl100k_base"),
31+
].forEach({
32+
let output = Model.getEncoding($0.input)
33+
XCTAssertEqual(try XCTUnwrap(output), $0.output)
34+
})
35+
}
36+
37+
func testGivenUnknowModelNamesWhenGetEncodingThenMatchNil() throws {
38+
["sample", "chatgpt", "invalid", "test"].forEach({
39+
let output = Model.getEncoding($0)
40+
XCTAssertNil(output)
41+
})
42+
}
43+
}

Tests/TiktokenTests/Test.swift

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//
2+
// Test.swift
3+
//
4+
//
5+
// Created by Alberto Espinilla Garrido on 20/3/23.
6+
//
7+
8+
import Foundation
9+
10+
struct Test<Input, Output> {
11+
let input: Input
12+
let output: Output
13+
}

0 commit comments

Comments
 (0)