Skip to content

Qwen3 #2669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 21, 2025
Merged

Qwen3 #2669

merged 20 commits into from
May 21, 2025

Conversation

prvnsmpth
Copy link
Contributor

@prvnsmpth prvnsmpth commented May 3, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Issue #2645

Changelog

What are the changes made in this PR?

  • Added model builders for Qwen3
  • Added special tokens
  • Added unit tests for the tokenizer (with reasoning and tool calls)

TODO

  • Create finetuning recipes

Test plan

  • Compare forwards of torchtune impl against the reference HF impl
  • Verify tokenizer and chat template against official Qwen3 implementation from huggingface
  • Verify that all models can be loaded and the finetune runs without issues
Comparing Forwards

Stubbed out the RoPE implementations in both torchtune and the HF implementation, and ran this script to compare output logits from a single forward pass:

torchtune on  qwen3 [!?] via  v3.12.3 (venv)
❯ python compare_forwards.py
tensor(0.)

(Ran this for 0.6B, 4B and 8B models)

Any differences between the two implementations is only due to differences in the RoPE implementations.

Test finetuning recipes

(All finetuning runs: https://www.comet.com/prvnsmpth/qwen3-torchtune-finetuning)

Loss curves for the following configs:

  • qwen3/0.6B_full_single_device
  • qwen3/0.6B_lora_single_device
  • qwen3/1.7B_full_single_device
  • qwen3/1.7B_lora_single_device

image

Checklist

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented May 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2669

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 3 Pending

As of commit 2c081ba with merge base e11b313 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @prvnsmpth!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@prvnsmpth prvnsmpth mentioned this pull request May 3, 2025
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 3, 2025
@krammnic
Copy link
Contributor

krammnic commented May 3, 2025

Looks like a good start! Thanks for the PR. Make sure that the values in builders are correct (they align with the config.json values from HF). But at first glance, they look fine. A few things that I feel are important to highlight:

  1. We don't need the QWEN3 model type for the checkpointer; let's stay with QWEN2.
  2. Don't forget to register new recipes in _recipe_registry.py
  3. You need to provide at least 2 sanity checks: Manual runs with the wandb metric logger (just some evidence that loss is going down and the model can be loaded)

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I left a bunch of inline comments, but overall all the components look pretty reasonable. Two main requests around testing:

  1. Can you verify parity with the version on Hugging Face? This should just involve loading checkpoints into the HF model and the one you've implemented here, then running a forward pass. Here is an example script using Llama3. (Note that there can sometimes be some small numerical differences due to differing RoPE implementations, if this is the only gap you see then we should be good.)
  2. Let's make sure all the configs run and give reasonable loss curves. They mostly passed the eyeball test, but if you're able to run them and attach e.g. WandB runs in the test plan that'd be great. (Let me know if you need any help from us testing any of these.)

A couple other general comments (these aren't blocking for your PR, but rather things we should clean up more generally):

  1. The fact that all our Qwen models are now QWEN2 ModelType is a bit confusing. I wonder whether we should just rename to QWEN to keep things general (it's a bit different than what we do for Llama models, but Qwen model architectures seem to be quite consistent across generations). If we add MoE models, we could use e.g. QWEN_MOE or QWEN3_MOE as its own model type.
  2. Regarding instruct vs base model builders.. I think with Qwen2.5 we set the standard that e.g. qwen2_5_7b is the base model builder and qwen2_5_7b_instruct is the instruct model builder. I wonder whether we should reverse this? I.e. make the instruct model the default and delineate the base model by e.g. qwen2_5_7b_base. This would be more in line with what's on HF; also the instruct models are much more frequently used so imo providing them as the default option makes sense.

num_kv_heads=8,
embed_dim=2048,
intermediate_dim=6144,
max_seq_len=32678,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_seq_len=32678,
max_seq_len=32768,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fat fingers 😞

I think I should write a script to check all the values here against huggingface configs.

@prvnsmpth
Copy link
Contributor Author

@ebsmothers Thanks for the detailed review! I'll work on testing and update the test plan.

About this:

  1. Regarding instruct vs base model builders.. I think with Qwen2.5 we set the standard that e.g. qwen2_5_7b is the base model builder and qwen2_5_7b_instruct is the instruct model builder. I wonder whether we should reverse this? I.e. make the instruct model the default and delineate the base model by e.g. qwen2_5_7b_base. This would be more in line with what's on HF; also the instruct models are much more frequently used so imo providing them as the default option makes sense.

In Qwen2.5, we have explicit _base and _instruct suffixes:

From qwen2_5/__init__.py:

    ...
    qwen2_5_14b_base,
    qwen2_5_14b_instruct,
    qwen2_5_1_5b_base,
    qwen2_5_1_5b_instruct,
    ...

I like the explicit naming because then I don't have to look into the model builder code and figure out which variant I'm dealing with. So I decided to keep the explicit naming convention with Qwen3 as well.

Let me know if you feel strongly that we should change this, we can use suffix-less names for the instruct models and use _base for the base variants.

@prvnsmpth prvnsmpth marked this pull request as ready for review May 6, 2025 06:43
@krammnic
Copy link
Contributor

krammnic commented May 6, 2025

@prvnsmpth Great work! On the first glance I don't see any problems, except the small nit about using max_files in case if we have 5 model files or more. Would you mind to attempt to do some sanity checks as we discussed previously?

@prvnsmpth
Copy link
Contributor Author

@krammnic Thanks! Yes, I'll update the checkpointer configs with max_files where required.

I'm working on doing parity checks against HF and doing a fine-tuning run to check loss curves - will update the PR once done.

@prvnsmpth
Copy link
Contributor Author

@ebsmothers @krammnic I ran into an issue while loading the model weights due to head_dim mismatch. We compute head_dim like so:

    head_dim = embed_dim // num_heads
    num_kv_heads = num_kv_heads if num_kv_heads else num_heads

This doesn't work for Qwen3. These are the values defined in the config (0.6B model):

  "head_dim": 128,
  "hidden_size": 1024,
  "model_type": "qwen3",
  "num_attention_heads": 16,
  "num_key_value_heads": 8,

So in our case, head_dim evaluates to (1024 // 16 =) 64 instead of 128, and the checkpoint loading fails.

I suggest adding a new optional head_dim param in the qwen2 component builder, so we can do this:

head_dim = head_dim or embed_dim // num_heads

This mirrors how HF transformers does it (modeling_qwen3.py)

self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)

Does this sound good, or is there a better way to handle this?

@krammnic
Copy link
Contributor

krammnic commented May 6, 2025

@prvnsmpth Yes, this sounds acceptable for me.

@prvnsmpth
Copy link
Contributor Author

@ebsmothers Do we know how much of a variance in the output logits is expected due to the differences in RoPE implementation?

Here's what I observe:

>>> pos = 0
>>> torch.max(torch.abs(hf_out.logits[0][pos] - tt_out[0][pos]))
tensor(0.)

>>> pos = 10
>>> torch.max(torch.abs(hf_out.logits[0][pos] - tt_out[0][pos]))
tensor(6.4603)

>>> torch.median(torch.abs(hf_out.logits[0] - tt_out[0]), axis=1).values
tensor([0.0000, 0.4157, 0.4001, 0.8757, 2.4013, 1.3230, 1.4447, 1.3762, 0.9497,
        0.8879, 0.8413, 0.9732, 1.5417, 1.0064, 1.8529, 1.9433, 1.2108, 1.4595,
        1.2758, 1.5974, 1.3994, 1.2803, 3.1988, 2.5264, 1.6579, 1.4070, 1.3072, ...

>>> torch.max(torch.abs(hf_out.logits[0] - tt_out[0]), axis=1).values
tensor([ 0.0000,  2.8457,  2.9103,  6.2893, 16.3794, 10.7218, 10.5545,  9.7847,
         6.7109,  7.1883,  6.8754,  8.0629, 11.5977,  8.3435, 14.6050, 10.5954,
         8.3610, 12.5571,  9.7235, 11.8763, 12.1868, 10.1667, 22.2543, 17.2351, ...

The output logits match exactly only at position 0, and differ in other positions by a small amount. But the max difference is quite high in some cases. I'm looking into the rope implementation in both places, but just wanted to check to see if this is expected.

@ebsmothers
Copy link
Contributor

@prvnsmpth yes this is not unreasonable. To test the hypothesis around RoPE, one suggestion is simply to patch into the RoPE components to rule that out as the cause of the difference (e.g. replace them with nn.Identity() in both implementations). Then we can separately verify that RoPE is correct (though I think it should be as iiuc it's the same as the Llama implementation). I can also take a proper look in a few hours here

@ebsmothers
Copy link
Contributor

OK @prvnsmpth getting back to this.. monkey-patching both RoPE implementations as in this script yields identical outputs, so I think we should be good here in terms of forward parity (I only checked 0.6B, we can run for the other sizes as well just to be safe). The main remaining thing here should be verifying that the configs run end-to-end -- as @krammnic mentioned you will want to add each new config to _recipe_registry.py. Lmk if you need any assistance with any of this

@prvnsmpth
Copy link
Contributor Author

@ebsmothers Yes, I verified this as well, the outputs are identical if we stub out the RoPE impl. Will check the other model sizes too.

I'll post an update after I test the fine-tuning configs end-to-end. Thanks!

@prvnsmpth
Copy link
Contributor Author

@ebsmothers @krammnic I've updated the test plan with loss curves for the 0.6b and 1.7b models - they look ok to me. I had to make a few more fixes, especially in the LoRa component builders.

I couldn't test the larger models, don't have enough VRAM on my machine to run them. Do we have any GPUs available where we can test the configs for the larger models?

@krammnic
Copy link
Contributor

@prvnsmpth Sure! Ping me, when all required fixes will be done, so I will run tests for the larger models.

@prvnsmpth
Copy link
Contributor Author

@krammnic Thanks! I've already pushed the fixes, could you please test the recipes for 4b and larger models?

@ebsmothers
Copy link
Contributor

Thanks @prvnsmpth! I will run some tests today as well to make sure everything looks good. Will share the results when I have them.

@neelsjain
Copy link

neelsjain commented May 14, 2025

Hello! I found a weird edge case for checkpointing during training—the pipeline breaks. What has happened is that originally, there were two shards of safetensors. However, it turns out that the second shard only contains the lm_head, and since the embeddings and lm_head are tied, the save process only saves one shard. There are assumptions in the current code that assume the number of shards will be the same.

I am not sure what is the best way to fix this. It might be that overrides the self._weight_map in FullModelHFCheckpointer.save_checkpoint. However, I am not sure if that will break other things.

Thanks again for all your hard work!

@prvnsmpth
Copy link
Contributor Author

@neelsjain I haven't dug into the checkpointer code to fully understand the issue you pointed out, but is this specific to Qwen or is this an issue that you'd run into with any architecture that uses weight tying? Seems to me that it's the latter?

@neelsjain
Copy link

It's specific to Qwen3-1.7B (https://huggingface.co/Qwen/Qwen3-1.7B), not even all the Qwen3s. However, this problem will likely arise with any model that is saved with only the lm_head saved in the separate shard. I'll see if I can come up with a good solution for this soon, but I think the torchtune people know better than I do on how this repo works.

@krammnic
Copy link
Contributor

So it seems to me that we got some blocker here. But probably, it is not a blocker for most of the recipes and models from the family. I assume that we can run some tests. @ebsmothers

@joecummings
Copy link
Contributor

Hmmm it's strange that they have a key for lm_head.weight when they tie word embeddings. This is different from their implementation for Qwen2.

The quick and dirty solution is to modify the tune_to_hf() call for Qwen3 to create the lm_head.weight key to be the same as the embeddings key on save.

@ebsmothers
Copy link
Contributor

@prvnsmpth @neelsjain @krammnic apologies for the delay here as I was tied up with a few other things.

Regarding the 1.7B model, following up on @joecummings's suggestion, can we modify the tune_to_hf function to manually add back the lm_head weight when we save the checkpoint? E.g. I made this change and can confirm the checkpoint saves. As he pointed out, I guess they do this differently for Qwen2 vs Qwen3. In that case, maybe we should just create a separate QWEN3 ModelType to handle the discrepancy (not ideal, but this seems like the cleanest way). It can keep the same behavior of Qwen2 except for the handling of the tied weight save. Lmk if this sounds reasonable, and I will run the remaining tests today.

@joecummings joecummings mentioned this pull request Mar 30, 2025
4 tasks
@ebsmothers
Copy link
Contributor

Hi @prvnsmpth I took the liberty of pushing some changes to your PR. It seems that there are some inconsistencies in which Qwen models have lm_head.weight on the Hub. To summarize:

  • Qwen2: no tied-embedding models have lm_head.weight
  • Qwen3:
    • All non-4B tied-embedding models have lm_head.weight
    • 4B model (which is also tied-embedding) does not have lm_head.weight

I think 4B is a bug. But as a workaround, I added a separate Qwen3 weight conversion (basically just a proxy for whether lm_head.weight is present on the Hub or not), then for Qwen3-4B we can fallback to QWEN2 model type. A bit hacky, but at least this way we can support all the models.

I have also finished up (most) testing. Runs of the various configs can be found here. I have not yet run the KD or eval configs successfully, once we confirm those are working I think we can merge.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @prvnsmpth for your patience here! Really excited to be able to land this one. After testing everything, I think we're good to go here. Will let CI run and then we can merge.

@prvnsmpth
Copy link
Contributor Author

@ebsmothers Thanks a lot for helping with this PR!

@ebsmothers ebsmothers merged commit eb01f07 into pytorch:main May 21, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants