Skip to content

Commit 83d4947

Browse files
author
Praveen Sampath
committed
Qwen3
1 parent d39fd9b commit 83d4947

File tree

6 files changed

+1415
-3
lines changed

6 files changed

+1415
-3
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from tests.common import ASSETS
8+
9+
from torchtune.data import Message
10+
from torchtune.models.qwen3 import qwen3_tokenizer
11+
12+
13+
class Testqwen3Tokenizer: # noqa: N801
14+
def tokenizer(self):
15+
return qwen3_tokenizer(
16+
path=str(ASSETS / "tiny_bpe_vocab.json"),
17+
merges_file=str(ASSETS / "tiny_bpe_merges.txt"),
18+
)
19+
20+
def test_tokenize_messages(self):
21+
tokenizer = self.tokenizer()
22+
messages = [
23+
Message(role="system", content="You are a helpful assistant."),
24+
Message(role="user", content="Give me a short introduction to LLMs."),
25+
Message(role="assistant", content=""),
26+
]
27+
28+
# fmt: off
29+
expected_tokens = [
30+
151644, 82, 88, 479, 94, 56, 119, 230, 98, 374, 494, 1318, 249, 13, 151645, 94, 151644, 273, 105, 94,
31+
38, 229, 362, 98, 1695, 310, 1305, 165, 128, 432, 43, 44, 82, 13, 151645, 94, 151644, 397, 251, 249, 94,
32+
151645,
33+
] # noqa
34+
# fmt: on
35+
36+
expected_formatted_messages = (
37+
"<|im_start|>system\n"
38+
"You are a helpful assistant.<|im_end|>\n"
39+
"<|im_start|>user\n"
40+
"Give me a short introduction to LLMs.<|im_end|>\n"
41+
"<|im_start|>assistant\n"
42+
"<|im_end|>"
43+
)
44+
_test_tokenize_messages(
45+
tokenizer,
46+
messages,
47+
expected_tokens,
48+
expected_formatted_messages,
49+
)
50+
51+
def test_tool_call(self):
52+
tokenizer = self.tokenizer()
53+
messages = [
54+
Message(role="system", content="a"),
55+
Message(role="user", content="b"),
56+
Message(role="assistant", content="test call", ipython=True),
57+
Message(role="ipython", content="test response"),
58+
Message(role="assistant", content=""),
59+
]
60+
# fmt: off
61+
expected_tokens = [
62+
151644, 82, 88, 479, 94, 64, 151645, 94, 151644, 273, 105, 94, 65, 151645, 94, 151644, 397, 251, 249,
63+
94, 151657, 94, 83, 269, 107, 330, 94, 151658, 151645, 94, 151644, 273, 105, 94, 151665,
64+
94, 83, 269, 706, 102, 182, 94, 151666, 151645, 94, 151644, 397, 251, 249, 94, 151645,
65+
] # noqa
66+
# fmt: on
67+
68+
expected_formatted_messages = (
69+
"<|im_start|>system\n"
70+
"a<|im_end|>\n"
71+
"<|im_start|>user\n"
72+
"b<|im_end|>\n"
73+
"<|im_start|>assistant\n"
74+
"<tool_call>\n"
75+
"test call\n"
76+
"</tool_call><|im_end|>\n"
77+
"<|im_start|>user\n"
78+
"<tool_response>\n"
79+
"test response\n"
80+
"</tool_response><|im_end|>\n"
81+
"<|im_start|>assistant\n"
82+
"<|im_end|>"
83+
)
84+
_test_tokenize_messages(
85+
tokenizer,
86+
messages,
87+
expected_tokens,
88+
expected_formatted_messages,
89+
)
90+
91+
def test_reasoning(self):
92+
tokenizer = self.tokenizer()
93+
messages = [
94+
Message(role="system", content="You are a math assistant."),
95+
Message(role="user", content="What is 2 + 2?"),
96+
Message(
97+
role="assistant",
98+
content="<think>Using basic arithmetic</think>\n2 + 2 = 4",
99+
),
100+
]
101+
# fmt: off
102+
expected_tokens = [
103+
151644, 82, 88, 479, 94, 56, 119, 230, 98, 1077, 71, 1318, 249, 13,
104+
151645, 94, 151644, 273, 105, 94, 1221, 156, 95, 17, 714, 95, 17,
105+
30, 151645, 94, 151644, 397, 251, 249, 94, 151667, 52, 82, 114,
106+
1618, 140, 410, 185, 76, 157, 140, 151668, 94, 17, 714, 95, 17, 407,
107+
95, 19, 151645
108+
]
109+
# fmt: on
110+
111+
expected_formatted_messages = (
112+
"<|im_start|>system\n"
113+
"You are a math assistant.<|im_end|>\n"
114+
"<|im_start|>user\n"
115+
"What is 2 + 2?<|im_end|>\n"
116+
"<|im_start|>assistant\n"
117+
"<think>Using basic arithmetic</think>\n"
118+
"2 + 2 = 4<|im_end|>"
119+
)
120+
_test_tokenize_messages(
121+
tokenizer,
122+
messages,
123+
expected_tokens,
124+
expected_formatted_messages,
125+
)
126+
127+
def test_reasoning_with_tools(self):
128+
tokenizer = self.tokenizer()
129+
messages = [
130+
Message(role="system", content="You are a math assistant."),
131+
Message(role="user", content="What is 2 + 2?"),
132+
Message(
133+
role="assistant",
134+
content="<think>I should use the calculator tool.</think>",
135+
),
136+
Message(role="assistant", content="calc('2 + 2')", ipython=True),
137+
Message(role="ipython", content="4"),
138+
Message(role="assistant", content="2 + 2 = 4", eot=True),
139+
]
140+
141+
# fmt: off
142+
expected_tokens = [
143+
151644, 82, 88, 479, 94, 56, 119, 230, 98, 1077, 71,
144+
1318, 249, 13, 151645, 94, 151644, 273, 105, 94, 1221, 156, 95, 17, 714,
145+
95, 17, 30, 151645, 94, 151644, 397, 251, 249, 94, 151667, 40, 756, 462,
146+
103, 1569, 106, 118, 1189, 13, 151668, 151645, 94, 151644, 397, 251,
147+
249, 94, 151657, 94, 66, 126, 66, 7, 6, 17, 714, 95, 17, 6, 8, 94,
148+
151658, 151645, 94, 151644, 273, 105, 94, 151665, 94, 19, 94, 151666,
149+
151645, 94, 151644, 397, 251, 249, 94, 17, 714, 95, 17, 407, 95, 19,
150+
151645
151+
]
152+
# fmt: on
153+
154+
expected_formatted_messages = (
155+
"<|im_start|>system\n"
156+
"You are a math assistant.<|im_end|>\n"
157+
"<|im_start|>user\n"
158+
"What is 2 + 2?<|im_end|>\n"
159+
"<|im_start|>assistant\n"
160+
"<think>I should use the calculator tool.</think><|im_end|>\n"
161+
"<|im_start|>assistant\n"
162+
"<tool_call>\n"
163+
"calc('2 + 2')\n"
164+
"</tool_call><|im_end|>\n"
165+
"<|im_start|>user\n"
166+
"<tool_response>\n"
167+
"4\n"
168+
"</tool_response><|im_end|>\n"
169+
"<|im_start|>assistant\n"
170+
"2 + 2 = 4<|im_end|>"
171+
)
172+
173+
_test_tokenize_messages(
174+
tokenizer,
175+
messages,
176+
expected_tokens,
177+
expected_formatted_messages,
178+
)
179+
180+
def test_all_tokens_work(self):
181+
# Check if all tokens can be detokenized, separately and together
182+
tokenizer = self.tokenizer()
183+
184+
num_tokens_small = 2000
185+
num_normal_tokens = 151643 # Based on the first special token added in models/qwen2/_tokenizer.py
186+
num_all_tokens = 151936 # Based on the maximum vocab size in Qwen 3 definitions
187+
188+
normal_tokens = list(range(num_tokens_small))
189+
special_tokens = list(range(num_normal_tokens, num_all_tokens))
190+
191+
all_tokens = normal_tokens + special_tokens
192+
193+
for token in all_tokens:
194+
decoded = tokenizer.decode([token], skip_special_tokens=False)
195+
assert isinstance(decoded, str)
196+
decoded = tokenizer.decode(all_tokens, skip_special_tokens=False)
197+
198+
assert isinstance(decoded, str)
199+
200+
201+
def _test_tokenize_messages(
202+
tokenizer, messages, expected_tokens, expected_formatted_messages
203+
):
204+
tokens, mask = tokenizer.tokenize_messages(messages)
205+
assert len(tokens) == len(mask)
206+
assert expected_tokens == tokens
207+
formatted_messages = tokenizer.decode(tokens)
208+
assert expected_formatted_messages == formatted_messages

torchtune/models/qwen2/_component_builders.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def qwen2(
4848
norm_eps: float = 1e-5,
4949
rope_base: float = 1_000_000.0,
5050
tie_word_embeddings: bool = False,
51+
q_proj_bias: bool = True,
52+
k_proj_bias: bool = True,
53+
v_proj_bias: bool = True,
54+
q_norm: bool = True,
55+
k_norm: bool = True,
5156
) -> TransformerDecoder:
5257
"""
5358
Build the decoder associated with the Qwen2 model. This includes:
@@ -74,6 +79,11 @@ def qwen2(
7479
norm_eps (float): epsilon in RMS norms.
7580
rope_base (float): the base period of the RoPE embeddings.
7681
tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
82+
q_proj_bias (bool): whether to use bias in the query projection.
83+
k_proj_bias (bool): whether to use bias in the key projection.
84+
v_proj_bias (bool): whether to use bias in the value projection.
85+
q_norm (bool): whether to use normalization in the query projection.
86+
k_norm (bool): whether to use normalization in the key projection.
7787
7888
Returns:
7989
TransformerDecoder: Instantiation of Qwen2 model.
@@ -90,11 +100,13 @@ def qwen2(
90100
num_heads=num_heads,
91101
num_kv_heads=num_kv_heads,
92102
head_dim=head_dim,
93-
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=True),
94-
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True),
95-
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True),
103+
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=q_proj_bias),
104+
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=k_proj_bias),
105+
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=v_proj_bias),
96106
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
97107
pos_embeddings=rope,
108+
q_norm=nn.RMSNorm(eps=norm_eps) if q_norm else None,
109+
k_norm=nn.RMSNorm(eps=norm_eps) if k_norm else None,
98110
kv_cache=None,
99111
max_seq_len=max_seq_len,
100112
attn_dropout=attn_dropout,

torchtune/models/qwen3/__init__.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from ._model_builders import (
8+
lora_qwen3_0_6b_base,
9+
lora_qwen3_0_6b_instruct,
10+
lora_qwen3_14b_base,
11+
lora_qwen3_14b_instruct,
12+
lora_qwen3_1_7b_base,
13+
lora_qwen3_1_7b_instruct,
14+
lora_qwen3_32b,
15+
lora_qwen3_4b_base,
16+
lora_qwen3_4b_instruct,
17+
lora_qwen3_8b_base,
18+
lora_qwen3_8b_instruct,
19+
qwen3_0_6b_base,
20+
qwen3_0_6b_instruct,
21+
qwen3_14b_base,
22+
qwen3_14b_instruct,
23+
qwen3_1_7b_base,
24+
qwen3_1_7b_instruct,
25+
qwen3_32b,
26+
qwen3_4b_base,
27+
qwen3_4b_instruct,
28+
qwen3_8b_base,
29+
qwen3_8b_instruct,
30+
qwen3_tokenizer,
31+
)
32+
33+
__all__ = [
34+
"lora_qwen3_0_6b_base",
35+
"lora_qwen3_0_6b_instruct",
36+
"lora_qwen3_1_7b_base",
37+
"lora_qwen3_1_7b_instruct",
38+
"lora_qwen3_4b_base",
39+
"lora_qwen3_4b_instruct",
40+
"lora_qwen3_8b_base",
41+
"lora_qwen3_8b_instruct",
42+
"lora_qwen3_14b_base",
43+
"lora_qwen3_14b_instruct",
44+
"lora_qwen3_32b",
45+
"qwen3_0_6b_base",
46+
"qwen3_0_6b_instruct",
47+
"qwen3_1_7b_base",
48+
"qwen3_1_7b_instruct",
49+
"qwen3_4b_base",
50+
"qwen3_4b_instruct",
51+
"qwen3_8b_base",
52+
"qwen3_8b_instruct",
53+
"qwen3_14b_base",
54+
"qwen3_14b_instruct",
55+
"qwen3_32b",
56+
"qwen3_tokenizer",
57+
]

0 commit comments

Comments
 (0)