Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Sep 13, 2025

Stacked PRs:


[mxfp8 moe training] use dim1 cast cuda kernel for 3d weights by reshaping to 2d

  • Doing to_mx(B_t.contiguous()) is unspeakably slow (see perf analysis in previous PR in stack)
  • As a workaround, we can use the faster dim1 cast cuda kernel by reshaping the 3d weights to 2d, casting, then reshaping back 2d. I wasn't able to find a way to reshape the 2d, column major quantized tensor -> 3d column major tensor, so I was forced to use .t().contiguous().t() pattern, which is not ideal for perf, yet still faster than doing to_mx(B_t.contiguous()).

Next steps

  • We should update the dim1 cast CUDA kernel to handle 3 inputs, writing directly to col major format, so we can avoid this expensive transformation to column major.
    • We could also update the CUTLASS grouped gemm to handle NT/TN/TT/NN layouts but I think passing in args with different memory layouts could affect kernel perf, need to think about this more.

Test plan

  • pytest test/prototype/moe_training/test_training.py

Benchmarks

Before:

A_shape        B_shape           recipe                  bf16_e2e_us    scaled_e2e_us  scaled_e2e_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-------------  ----------------  --------------------  -------------  ---------------  --------------------  -------------  ---------------  --------------------
(16640, 5120)  (1, 8192, 5120)   MoEScalingType.MXFP8        4268.5           3402.75  1.254x                      1513.76          1675.81  0.903x
(16640, 5120)  (4, 8192, 5120)   MoEScalingType.MXFP8        3968.88          4282.53  0.927x                      1126.21          2222.37  0.507x
(16640, 5120)  (16, 8192, 5120)  MoEScalingType.MXFP8        4900.77          8091.55  0.606x                      1262.66          9047.7   0.14x
(16640, 5120)  (64, 8192, 5120)  MoEScalingType.MXFP8        8432.61         21453.3   0.393x                      1788.94         14476.4   0.124x

After:

A_shape        B_shape           recipe                  bf16_e2e_us    scaled_e2e_us  scaled_e2e_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-------------  ----------------  --------------------  -------------  ---------------  --------------------  -------------  ---------------  --------------------
(16640, 5120)  (1, 8192, 5120)   MoEScalingType.MXFP8        4920.32          3057.79  1.609x                      1299.92          1091.58  1.191x
(16640, 5120)  (4, 8192, 5120)   MoEScalingType.MXFP8        3886.21          3402.82  1.142x                      1087.39           931.36  1.168x
(16640, 5120)  (16, 8192, 5120)  MoEScalingType.MXFP8        5769.02          5384.91  1.071x                      1411.1           1222.7   1.154x
(16640, 5120)  (64, 8192, 5120)  MoEScalingType.MXFP8        8455.23         12846.2   0.658x                      1796.1           2968.21  0.605x

Copy link

pytorch-bot bot commented Sep 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2998

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5d874ed with merge base 66384a9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Sep 13, 2025
…aping to 2d

stack-info: PR: #2998, branch: danielvegamyhre/stack/67
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 13, 2025
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 13, 2025 05:24
danielvegamyhre added a commit that referenced this pull request Sep 13, 2025
…aping to 2d

stack-info: PR: #2998, branch: danielvegamyhre/stack/67
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 13, 2025 05:24
@danielvegamyhre danielvegamyhre added mx moe topic: not user facing Use this tag if you don't want this PR to show up in release notes labels Sep 13, 2025
…aping to 2d

stack-info: PR: #2998, branch: danielvegamyhre/stack/67
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 13, 2025 17:09
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 13, 2025 17:10
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 13, 2025 19:20
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 13, 2025 19:20
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 13, 2025 19:46
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 13, 2025 19:46
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 13, 2025 21:06
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 13, 2025 21:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant