-
Notifications
You must be signed in to change notification settings - Fork 522
add qwen3 moe implementation basics #1688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hi @jthomy! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Hi @jthomy thanks for contributing!
Can you elaborate more on this? Are you comparing the following 2 forward pass run: (1) your titan Qwen3 and run a single forward pass (2) HF Qwen3 and run a single forward pass? |
Hi @wwwjn, sorry for my late answer, yes exactly. For 235b, I only checked the first layer, because the forward pass in torchtitan doesn't work on CPU, and otherwise I'd have to shard the model, for 30b I checked the full model forward pass. |
Thank you so much for sharing the details! I was asking because I was wondering how different RoPE embedding gives the same result - We realize the different RoPE embedding (cos/sin version vs. complex version) corresponds to different weight metrics for wq, wk, q_norm, k_norm. And HF weights corresponds to cos/sin RoPE embeddings. After running your test script, I realize you already handled the difference: You permute the wq, wk, q_norm, k_norm when loading in |
The implementation is mainly tested for MoE, dense is implemented as well (needs more testing).
FSDP and EP are supported, PP is not tested yet, I believe it needs a fix in the auxiliary loss which uses an all_reduce.
The state dict adapted converts from and to huggingface.
I verified that the pretrained 30b and 235b model forward passes do match down to machine precision (I observed testing with random weights can lead to false positives with ok-ish looking errors).
I had to make small changes to the loss (accepting dictionaries) and mixture of experts (normalizing the softmax and returning the logits).
parallelize.py has code duplication compared to other model implementations.
The implementation uses torch.complex for RoPE embeddings. Because casting the model (model.to(dtype)) results in the freqs_cis delete the complex part, silently corrupting the RoPE embeddings after a model cast, I avoid this footgun by storing via view_as_real, different from the existing llama implementations.