-
Notifications
You must be signed in to change notification settings - Fork 615
Qwen3 #2669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Qwen3 #2669
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2669
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 3 PendingAs of commit 2c081ba with merge base e11b313 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @prvnsmpth! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Looks like a good start! Thanks for the PR. Make sure that the values in builders are correct (they align with the
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I left a bunch of inline comments, but overall all the components look pretty reasonable. Two main requests around testing:
- Can you verify parity with the version on Hugging Face? This should just involve loading checkpoints into the HF model and the one you've implemented here, then running a forward pass. Here is an example script using Llama3. (Note that there can sometimes be some small numerical differences due to differing RoPE implementations, if this is the only gap you see then we should be good.)
- Let's make sure all the configs run and give reasonable loss curves. They mostly passed the eyeball test, but if you're able to run them and attach e.g. WandB runs in the test plan that'd be great. (Let me know if you need any help from us testing any of these.)
A couple other general comments (these aren't blocking for your PR, but rather things we should clean up more generally):
- The fact that all our Qwen models are now
QWEN2
ModelType is a bit confusing. I wonder whether we should just rename toQWEN
to keep things general (it's a bit different than what we do for Llama models, but Qwen model architectures seem to be quite consistent across generations). If we add MoE models, we could use e.g.QWEN_MOE
orQWEN3_MOE
as its own model type. - Regarding instruct vs base model builders.. I think with Qwen2.5 we set the standard that e.g.
qwen2_5_7b
is the base model builder andqwen2_5_7b_instruct
is the instruct model builder. I wonder whether we should reverse this? I.e. make the instruct model the default and delineate the base model by e.g.qwen2_5_7b_base
. This would be more in line with what's on HF; also the instruct models are much more frequently used so imo providing them as the default option makes sense.
num_kv_heads=8, | ||
embed_dim=2048, | ||
intermediate_dim=6144, | ||
max_seq_len=32678, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_seq_len=32678, | |
max_seq_len=32768, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fat fingers 😞
I think I should write a script to check all the values here against huggingface configs.
@ebsmothers Thanks for the detailed review! I'll work on testing and update the test plan. About this:
In Qwen2.5, we have explicit From
I like the explicit naming because then I don't have to look into the model builder code and figure out which variant I'm dealing with. So I decided to keep the explicit naming convention with Qwen3 as well. Let me know if you feel strongly that we should change this, we can use suffix-less names for the instruct models and use |
@prvnsmpth Great work! On the first glance I don't see any problems, except the small nit about using max_files in case if we have 5 model files or more. Would you mind to attempt to do some sanity checks as we discussed previously? |
@krammnic Thanks! Yes, I'll update the checkpointer configs with I'm working on doing parity checks against HF and doing a fine-tuning run to check loss curves - will update the PR once done. |
@ebsmothers @krammnic I ran into an issue while loading the model weights due to
This doesn't work for Qwen3. These are the values defined in the config (0.6B model):
So in our case, I suggest adding a new optional
This mirrors how HF transformers does it (modeling_qwen3.py)
Does this sound good, or is there a better way to handle this? |
@prvnsmpth Yes, this sounds acceptable for me. |
@ebsmothers Do we know how much of a variance in the output logits is expected due to the differences in RoPE implementation? Here's what I observe:
The output logits match exactly only at position 0, and differ in other positions by a small amount. But the max difference is quite high in some cases. I'm looking into the rope implementation in both places, but just wanted to check to see if this is expected. |
@prvnsmpth yes this is not unreasonable. To test the hypothesis around RoPE, one suggestion is simply to patch into the RoPE components to rule that out as the cause of the difference (e.g. replace them with |
OK @prvnsmpth getting back to this.. monkey-patching both RoPE implementations as in this script yields identical outputs, so I think we should be good here in terms of forward parity (I only checked 0.6B, we can run for the other sizes as well just to be safe). The main remaining thing here should be verifying that the configs run end-to-end -- as @krammnic mentioned you will want to add each new config to _recipe_registry.py. Lmk if you need any assistance with any of this |
@ebsmothers Yes, I verified this as well, the outputs are identical if we stub out the RoPE impl. Will check the other model sizes too. I'll post an update after I test the fine-tuning configs end-to-end. Thanks! |
@ebsmothers @krammnic I've updated the test plan with loss curves for the 0.6b and 1.7b models - they look ok to me. I had to make a few more fixes, especially in the LoRa component builders. I couldn't test the larger models, don't have enough VRAM on my machine to run them. Do we have any GPUs available where we can test the configs for the larger models? |
@prvnsmpth Sure! Ping me, when all required fixes will be done, so I will run tests for the larger models. |
@krammnic Thanks! I've already pushed the fixes, could you please test the recipes for 4b and larger models? |
Thanks @prvnsmpth! I will run some tests today as well to make sure everything looks good. Will share the results when I have them. |
Hello! I found a weird edge case for checkpointing during training—the pipeline breaks. What has happened is that originally, there were two shards of safetensors. However, it turns out that the second shard only contains the lm_head, and since the embeddings and lm_head are tied, the save process only saves one shard. There are assumptions in the current code that assume the number of shards will be the same. I am not sure what is the best way to fix this. It might be that overrides the Thanks again for all your hard work! |
@neelsjain I haven't dug into the checkpointer code to fully understand the issue you pointed out, but is this specific to Qwen or is this an issue that you'd run into with any architecture that uses weight tying? Seems to me that it's the latter? |
It's specific to Qwen3-1.7B (https://huggingface.co/Qwen/Qwen3-1.7B), not even all the Qwen3s. However, this problem will likely arise with any model that is saved with only the lm_head saved in the separate shard. I'll see if I can come up with a good solution for this soon, but I think the torchtune people know better than I do on how this repo works. |
So it seems to me that we got some blocker here. But probably, it is not a blocker for most of the recipes and models from the family. I assume that we can run some tests. @ebsmothers |
Hmmm it's strange that they have a key for The quick and dirty solution is to modify the |
@prvnsmpth @neelsjain @krammnic apologies for the delay here as I was tied up with a few other things. Regarding the 1.7B model, following up on @joecummings's suggestion, can we modify the |
Hi @prvnsmpth I took the liberty of pushing some changes to your PR. It seems that there are some inconsistencies in which Qwen models have
I think 4B is a bug. But as a workaround, I added a separate Qwen3 weight conversion (basically just a proxy for whether I have also finished up (most) testing. Runs of the various configs can be found here. I have not yet run the KD or eval configs successfully, once we confirm those are working I think we can merge. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @prvnsmpth for your patience here! Really excited to be able to land this one. After testing everything, I think we're good to go here. Will let CI run and then we can merge.
@ebsmothers Thanks a lot for helping with this PR! |
Context
What is the purpose of this PR? Is it to
Issue #2645
Changelog
What are the changes made in this PR?
TODO
Test plan
Comparing Forwards
Stubbed out the RoPE implementations in both torchtune and the HF implementation, and ran this script to compare output logits from a single forward pass:
(Ran this for 0.6B, 4B and 8B models)
Any differences between the two implementations is only due to differences in the RoPE implementations.
Test finetuning recipes
(All finetuning runs: https://www.comet.com/prvnsmpth/qwen3-torchtune-finetuning)
Loss curves for the following configs:
Checklist
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example