|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from tests.common import ASSETS |
| 8 | + |
| 9 | +from torchtune.data import Message |
| 10 | +from torchtune.models.qwen3 import qwen3_tokenizer |
| 11 | + |
| 12 | + |
| 13 | +class Testqwen3Tokenizer: # noqa: N801 |
| 14 | + def tokenizer(self): |
| 15 | + return qwen3_tokenizer( |
| 16 | + path=str(ASSETS / "tiny_bpe_vocab.json"), |
| 17 | + merges_file=str(ASSETS / "tiny_bpe_merges.txt"), |
| 18 | + ) |
| 19 | + |
| 20 | + def test_tokenize_messages(self): |
| 21 | + tokenizer = self.tokenizer() |
| 22 | + messages = [ |
| 23 | + Message(role="system", content="You are a helpful assistant."), |
| 24 | + Message(role="user", content="Give me a short introduction to LLMs."), |
| 25 | + Message(role="assistant", content=""), |
| 26 | + ] |
| 27 | + |
| 28 | + # fmt: off |
| 29 | + expected_tokens = [ |
| 30 | + 151644, 82, 88, 479, 94, 56, 119, 230, 98, 374, 494, 1318, 249, 13, 151645, 94, 151644, 273, 105, 94, |
| 31 | + 38, 229, 362, 98, 1695, 310, 1305, 165, 128, 432, 43, 44, 82, 13, 151645, 94, 151644, 397, 251, 249, 94, |
| 32 | + 151645, |
| 33 | + ] # noqa |
| 34 | + # fmt: on |
| 35 | + |
| 36 | + expected_formatted_messages = ( |
| 37 | + "<|im_start|>system\n" |
| 38 | + "You are a helpful assistant.<|im_end|>\n" |
| 39 | + "<|im_start|>user\n" |
| 40 | + "Give me a short introduction to LLMs.<|im_end|>\n" |
| 41 | + "<|im_start|>assistant\n" |
| 42 | + "<|im_end|>" |
| 43 | + ) |
| 44 | + _test_tokenize_messages( |
| 45 | + tokenizer, |
| 46 | + messages, |
| 47 | + expected_tokens, |
| 48 | + expected_formatted_messages, |
| 49 | + ) |
| 50 | + |
| 51 | + def test_tool_call(self): |
| 52 | + tokenizer = self.tokenizer() |
| 53 | + messages = [ |
| 54 | + Message(role="system", content="a"), |
| 55 | + Message(role="user", content="b"), |
| 56 | + Message(role="assistant", content="test call", ipython=True), |
| 57 | + Message(role="ipython", content="test response"), |
| 58 | + Message(role="assistant", content=""), |
| 59 | + ] |
| 60 | + # fmt: off |
| 61 | + expected_tokens = [ |
| 62 | + 151644, 82, 88, 479, 94, 64, 151645, 94, 151644, 273, 105, 94, 65, 151645, 94, 151644, 397, 251, 249, |
| 63 | + 94, 151657, 94, 83, 269, 107, 330, 94, 151658, 151645, 94, 151644, 273, 105, 94, 151665, |
| 64 | + 94, 83, 269, 706, 102, 182, 94, 151666, 151645, 94, 151644, 397, 251, 249, 94, 151645, |
| 65 | + ] # noqa |
| 66 | + # fmt: on |
| 67 | + |
| 68 | + expected_formatted_messages = ( |
| 69 | + "<|im_start|>system\n" |
| 70 | + "a<|im_end|>\n" |
| 71 | + "<|im_start|>user\n" |
| 72 | + "b<|im_end|>\n" |
| 73 | + "<|im_start|>assistant\n" |
| 74 | + "<tool_call>\n" |
| 75 | + "test call\n" |
| 76 | + "</tool_call><|im_end|>\n" |
| 77 | + "<|im_start|>user\n" |
| 78 | + "<tool_response>\n" |
| 79 | + "test response\n" |
| 80 | + "</tool_response><|im_end|>\n" |
| 81 | + "<|im_start|>assistant\n" |
| 82 | + "<|im_end|>" |
| 83 | + ) |
| 84 | + _test_tokenize_messages( |
| 85 | + tokenizer, |
| 86 | + messages, |
| 87 | + expected_tokens, |
| 88 | + expected_formatted_messages, |
| 89 | + ) |
| 90 | + |
| 91 | + def test_reasoning(self): |
| 92 | + tokenizer = self.tokenizer() |
| 93 | + messages = [ |
| 94 | + Message(role="system", content="You are a math assistant."), |
| 95 | + Message(role="user", content="What is 2 + 2?"), |
| 96 | + Message( |
| 97 | + role="assistant", |
| 98 | + content="<think>Using basic arithmetic</think>\n2 + 2 = 4", |
| 99 | + ), |
| 100 | + ] |
| 101 | + # fmt: off |
| 102 | + expected_tokens = [ |
| 103 | + 151644, 82, 88, 479, 94, 56, 119, 230, 98, 1077, 71, 1318, 249, 13, |
| 104 | + 151645, 94, 151644, 273, 105, 94, 1221, 156, 95, 17, 714, 95, 17, |
| 105 | + 30, 151645, 94, 151644, 397, 251, 249, 94, 151667, 52, 82, 114, |
| 106 | + 1618, 140, 410, 185, 76, 157, 140, 151668, 94, 17, 714, 95, 17, 407, |
| 107 | + 95, 19, 151645 |
| 108 | + ] |
| 109 | + # fmt: on |
| 110 | + |
| 111 | + expected_formatted_messages = ( |
| 112 | + "<|im_start|>system\n" |
| 113 | + "You are a math assistant.<|im_end|>\n" |
| 114 | + "<|im_start|>user\n" |
| 115 | + "What is 2 + 2?<|im_end|>\n" |
| 116 | + "<|im_start|>assistant\n" |
| 117 | + "<think>Using basic arithmetic</think>\n" |
| 118 | + "2 + 2 = 4<|im_end|>" |
| 119 | + ) |
| 120 | + _test_tokenize_messages( |
| 121 | + tokenizer, |
| 122 | + messages, |
| 123 | + expected_tokens, |
| 124 | + expected_formatted_messages, |
| 125 | + ) |
| 126 | + |
| 127 | + def test_reasoning_with_tools(self): |
| 128 | + tokenizer = self.tokenizer() |
| 129 | + messages = [ |
| 130 | + Message(role="system", content="You are a math assistant."), |
| 131 | + Message(role="user", content="What is 2 + 2?"), |
| 132 | + Message( |
| 133 | + role="assistant", |
| 134 | + content="<think>I should use the calculator tool.</think>", |
| 135 | + ), |
| 136 | + Message(role="assistant", content="calc('2 + 2')", ipython=True), |
| 137 | + Message(role="ipython", content="4"), |
| 138 | + Message(role="assistant", content="2 + 2 = 4", eot=True), |
| 139 | + ] |
| 140 | + |
| 141 | + # fmt: off |
| 142 | + expected_tokens = [ |
| 143 | + 151644, 82, 88, 479, 94, 56, 119, 230, 98, 1077, 71, |
| 144 | + 1318, 249, 13, 151645, 94, 151644, 273, 105, 94, 1221, 156, 95, 17, 714, |
| 145 | + 95, 17, 30, 151645, 94, 151644, 397, 251, 249, 94, 151667, 40, 756, 462, |
| 146 | + 103, 1569, 106, 118, 1189, 13, 151668, 151645, 94, 151644, 397, 251, |
| 147 | + 249, 94, 151657, 94, 66, 126, 66, 7, 6, 17, 714, 95, 17, 6, 8, 94, |
| 148 | + 151658, 151645, 94, 151644, 273, 105, 94, 151665, 94, 19, 94, 151666, |
| 149 | + 151645, 94, 151644, 397, 251, 249, 94, 17, 714, 95, 17, 407, 95, 19, |
| 150 | + 151645 |
| 151 | + ] |
| 152 | + # fmt: on |
| 153 | + |
| 154 | + expected_formatted_messages = ( |
| 155 | + "<|im_start|>system\n" |
| 156 | + "You are a math assistant.<|im_end|>\n" |
| 157 | + "<|im_start|>user\n" |
| 158 | + "What is 2 + 2?<|im_end|>\n" |
| 159 | + "<|im_start|>assistant\n" |
| 160 | + "<think>I should use the calculator tool.</think><|im_end|>\n" |
| 161 | + "<|im_start|>assistant\n" |
| 162 | + "<tool_call>\n" |
| 163 | + "calc('2 + 2')\n" |
| 164 | + "</tool_call><|im_end|>\n" |
| 165 | + "<|im_start|>user\n" |
| 166 | + "<tool_response>\n" |
| 167 | + "4\n" |
| 168 | + "</tool_response><|im_end|>\n" |
| 169 | + "<|im_start|>assistant\n" |
| 170 | + "2 + 2 = 4<|im_end|>" |
| 171 | + ) |
| 172 | + |
| 173 | + _test_tokenize_messages( |
| 174 | + tokenizer, |
| 175 | + messages, |
| 176 | + expected_tokens, |
| 177 | + expected_formatted_messages, |
| 178 | + ) |
| 179 | + |
| 180 | + def test_all_tokens_work(self): |
| 181 | + # Check if all tokens can be detokenized, separately and together |
| 182 | + tokenizer = self.tokenizer() |
| 183 | + |
| 184 | + num_tokens_small = 2000 |
| 185 | + num_normal_tokens = 151643 # Based on the first special token added in models/qwen2/_tokenizer.py |
| 186 | + num_all_tokens = 151936 # Based on the maximum vocab size in Qwen 3 definitions |
| 187 | + |
| 188 | + normal_tokens = list(range(num_tokens_small)) |
| 189 | + special_tokens = list(range(num_normal_tokens, num_all_tokens)) |
| 190 | + |
| 191 | + all_tokens = normal_tokens + special_tokens |
| 192 | + |
| 193 | + for token in all_tokens: |
| 194 | + decoded = tokenizer.decode([token], skip_special_tokens=False) |
| 195 | + assert isinstance(decoded, str) |
| 196 | + decoded = tokenizer.decode(all_tokens, skip_special_tokens=False) |
| 197 | + |
| 198 | + assert isinstance(decoded, str) |
| 199 | + |
| 200 | + |
| 201 | +def _test_tokenize_messages( |
| 202 | + tokenizer, messages, expected_tokens, expected_formatted_messages |
| 203 | +): |
| 204 | + tokens, mask = tokenizer.tokenize_messages(messages) |
| 205 | + assert len(tokens) == len(mask) |
| 206 | + assert expected_tokens == tokens |
| 207 | + formatted_messages = tokenizer.decode(tokens) |
| 208 | + assert expected_formatted_messages == formatted_messages |
0 commit comments