Skip to content

Conversation

jthomy
Copy link

@jthomy jthomy commented Sep 8, 2025

The implementation is mainly tested for MoE, dense is implemented as well (needs more testing).
FSDP and EP are supported, PP is not tested yet, I believe it needs a fix in the auxiliary loss which uses an all_reduce.
The state dict adapted converts from and to huggingface.
I verified that the pretrained 30b and 235b model forward passes do match down to machine precision (I observed testing with random weights can lead to false positives with ok-ish looking errors).
I had to make small changes to the loss (accepting dictionaries) and mixture of experts (normalizing the softmax and returning the logits).
parallelize.py has code duplication compared to other model implementations.

The implementation uses torch.complex for RoPE embeddings. Because casting the model (model.to(dtype)) results in the freqs_cis delete the complex part, silently corrupting the RoPE embeddings after a model cast, I avoid this footgun by storing via view_as_real, different from the existing llama implementations.

Copy link

meta-cla bot commented Sep 8, 2025

Hi @jthomy!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@wwwjn
Copy link
Contributor

wwwjn commented Sep 8, 2025

Hi @jthomy thanks for contributing!

I verified that the pretrained 30b and 235b model forward passes do match down to machine precision (I observed testing with random weights can lead to false positives with ok-ish looking errors).

Can you elaborate more on this? Are you comparing the following 2 forward pass run: (1) your titan Qwen3 and run a single forward pass (2) HF Qwen3 and run a single forward pass?

@jthomy
Copy link
Author

jthomy commented Sep 8, 2025

Hi @wwwjn, sorry for my late answer, yes exactly. For 235b, I only checked the first layer, because the forward pass in torchtitan doesn't work on CPU, and otherwise I'd have to shard the model, for 30b I checked the full model forward pass.
I just added a very R&D-like script that checks this in the tests folder, only to your reference (we probably should not push this script on the main repo).
The script loads the HF model with pretrained weights, uses the state dict adapter to load them into the TT equivalent, and performs a forward pass on both, to check the difference in the logits/hidden states.

@wwwjn
Copy link
Contributor

wwwjn commented Sep 9, 2025

Hi @wwwjn, sorry for my late answer, yes exactly. For 235b, I only checked the first layer, because the forward pass in torchtitan doesn't work on CPU, and otherwise I'd have to shard the model, for 30b I checked the full model forward pass. I just added a very R&D-like script that checks this in the tests folder, only to your reference (we probably should not push this script on the main repo). The script loads the HF model with pretrained weights, uses the state dict adapter to load them into the TT equivalent, and performs a forward pass on both, to check the difference in the logits/hidden states.

Thank you so much for sharing the details! I was asking because I was wondering how different RoPE embedding gives the same result - We realize the different RoPE embedding (cos/sin version vs. complex version) corresponds to different weight metrics for wq, wk, q_norm, k_norm. And HF weights corresponds to cos/sin RoPE embeddings.

After running your test script, I realize you already handled the difference: You permute the wq, wk, q_norm, k_norm when loading in state_dict_adapter.py, and apply reverse_permute() when comparing results. Thanks again for sharing this!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants