Skip to content

Conversation

wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Sep 5, 2025

As Qwen3 dense model and MoE model share a lot of common parts (eg, Attention), I added MoE module on top of Qwen3 dense model.

Initial verification with FSDP=8, EP=2
Screenshot 2025-09-05 at 2 43 52 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 5, 2025
@wwwjn wwwjn changed the title [Qwen3] Qwen3 MoE support [Qwen3] Qwen3 MoE initial support Sep 5, 2025
@jthomy
Copy link

jthomy commented Sep 5, 2025

Great to see!
I was about to open a PR with a Qwen3 MoE implementation I have, but happy to see it here as well (it is similar, except that it uses the RoPE implementation with complex numbers and has a state dict adapter from/to hf).
Looking over the code, I see that the MFU implementation still needs to take into account the sparse activations.
Are you interested in adding sequence packing support as well?
If not, I also have an implementation for flash attention with sequence packing using the flash_attn_varlen_func and could open a PR on it, but I am not sure if the torchtitan repo wants to stick with flex attention only.

@tianyu-l
Copy link
Contributor

tianyu-l commented Sep 5, 2025

@jthomy
IIUC these are two separate questions

Looking over the code, I see that the MFU implementation still needs to take into account the sparse activations.
Are you interested in adding sequence packing support as well?

We do have sequence packing / document masking support in torchtitan using FlexAttention, which can (and should) be added to Qwen MoE. cc @wwwjn
The MFU computation is more of following convention, lol.

If not, I also have an implementation for flash attention with sequence packing using the flash_attn_varlen_func and could open a PR on it, but I am not sure if the torchtitan repo wants to stick with flex attention only.

There used to be discussions on this. IIRC the worry was around supporting flash_attn_varlen_func with CP? Even if we support flash_attn_varlen_func, IMO that should go into pytorch SDPA instead of into torchtitan directly.

cc @fegin @drisspg

@vwxyzjn
Copy link

vwxyzjn commented Sep 6, 2025

@jthomy quite curious about your implementation as well. Especially the hf state dict adapter. Would you mind sharing your branch?

@jthomy
Copy link

jthomy commented Sep 8, 2025

@vwxyzjn sure, I made a draft pull request here, feel free to have a look: #1688
Let me know if there's anything I can help with.

@fegin
Copy link
Contributor

fegin commented Sep 9, 2025

@jthomy FlexAttention has document masking implementation. The main blocker now is the composability with SAC, which we are working a workaround. As for SDPA version, @drisspg has prototyped a version, pytorch/pytorch#162326. There may be also change to how we provide these API calls due to the composability issues when enabling CP with FlexAttention. I can provide more detail later this week once I try out the proposal.

- Supports FSDP/HSDP, TP, DDP, EP.
- Supports AC, torch.compile.
- MoE models use Token Choice routing, which is using auxiluary-loss-free load balancing algorithm.
- [WIP] CP is not supported currently becase we used different RoPE embeding.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's our status on this? why it works on dense by not MoE? I thought RoPE only matters in the Attention layer.

When you say we used different RoPE, doesn't it mean we could've switched to (alternative but also correct) complex number based RoPE (e.g. #1688) and CP would automatically work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CP is not supported is correct but it is due to Flex?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin oh is Flex only enabled for Qwen MoE but not Qwen dense? Either way we should update both to be consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this because of minor issue not being adressed: In Qwen, we used cos/sin RoPE embeddings, so there's no freqs_cis field and CP is explicitly adding freqs_cis here:

cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
. This is a known issue but I plan to address it later.

For Flex support - Qwen3 MoE only have one test model here, and I'm testing bigger size model. I will update bigger size MoE model to use Flex Attention. And update README to state the CP issue with FlexAttention

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's your plan on

This is a known issue but I plan to address it later.

I think I don't mind switching to freqs_cis based RoPE, if that makes more sense to you.

Do we know Qwen is using Flash attention of Mem Efficient attention under SDPA? @fegin

Since (1) Flex + CP is WIP and (2) sdpa may not be the bottleneck, personally for Qwen I think it's OK to use SDPA until Flex+CP is ready. But I don't have strong opinion on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin oh is Flex only enabled for Qwen MoE but not Qwen dense? Either way we should update both to be consistent.

The whole model can only use one type of attention, that is the design, unless Qwen implementation explicitly allows users to configure. But I would not suggest this approach.

Since Qwen is not using Flex, I think CP would work once the freq_cis issue @wwwjn mentioned is fixed.

@wwwjn wwwjn requested a review from tianyu-l September 11, 2025 02:59
- MoE alternatives
## To be added
- MoE model
- `StateDictAdapter` support for MoE model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is next step right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is on my plate

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@jthomy
Copy link

jthomy commented Sep 11, 2025

Did you verify that this model is the same as HF Qwen3?
E.g. I see that no scaling factor is passed to the attention, instead of self.scaling = self.head_dim**-0.5 in HF if I am not mistaken.

@wwwjn
Copy link
Contributor Author

wwwjn commented Sep 11, 2025

Did you verify that this model is the same as HF Qwen3? E.g. I see that no scaling factor is passed to the attention, instead of self.scaling = self.head_dim**-0.5 in HF if I am not mistaken.

Thanks for pointing out! For dense model, yes we did some check on end-to-end forward results , eg: the description #1590. But we haven't done finer-granularity checks on intermediate results. I saw you have a great test scripts in you PR and I'm thinking leveraging that! For MoE, the numerical verification is in progress

I agree the Attention part is not the same as Qwen3 - Nice catch! And I will create a fix for that

@jthomy
Copy link

jthomy commented Sep 11, 2025

Nice, thank you! Yes feel free to us any of my code.

@wwwjn wwwjn merged commit bd3850b into pytorch:main Sep 11, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants