Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Sep 12, 2025

Stacked PRs:


[mxfp8 moe training] add compile support

  • wrap triton kernels in custom ops
  • update tensor subclass torch_dispatch to avoid double wrapping
  • fix d2h sync caused by doing padded_rows = output_scale_group_offsets[-1]. By looking at the trace I found this was doing a .item() under the hood and causing a d2h sync, so now instead I'm just using the upper bound of the padding needed via padded_rows = rows + num_groups * 128 and avoiding the d2h sync.

Test plan

  • pytest test/prototype/moe_training/test_training.py

Microbenchmarks with compile

A_shape        B_shape           recipe                  bf16_e2e_us    scaled_e2e_us  scaled_e2e_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-------------  ----------------  --------------------  -------------  ---------------  --------------------  -------------  ---------------  --------------------
(16640, 5120)  (1, 8192, 5120)   MoEScalingType.MXFP8        4268.5           3402.75  1.254x                      1513.76          1675.81  0.903x
(16640, 5120)  (4, 8192, 5120)   MoEScalingType.MXFP8        3968.88          4282.53  0.927x                      1126.21          2222.37  0.507x
(16640, 5120)  (16, 8192, 5120)  MoEScalingType.MXFP8        4900.77          8091.55  0.606x                      1262.66          9047.7   0.14x
(16640, 5120)  (64, 8192, 5120)  MoEScalingType.MXFP8        8432.61         21453.3   0.393x                      1788.94         14476.4   0.124x

Perf analysis for M=16640, G=4, N=8192, K=5120

Looking at the trace to see why it's slower, here are my initial takeways:

  1. Mxfp8 grouped GEMMs are 1.57x faster than bf16 grouped gemms, but should be 2.1x faster: (mxpf8 avg = ~700us, bf16 avg = ~1.1ms). This contradicts the grouped GEMM microbenchmarking we did, which indicates for shape M=16640, G=4, K=5120, N=8192 we should be getting a ~2.1x speedup, but we only are getting a ~1.57x speedup. So we need to determine what is going on here. cc @cthi @slayton58 who may be able to help with this

  2. Blocked scale swizzling kernels look okay: 2 of the 3 of the handwritten triton kernels for converting scales to blocked swizzled format are very fast. One may need more optimization (TBD):

    • triton_scale_swizzle_M_groups = 18us avg
    • triton_scale_swizzle_K_groups = 14us avg
    • triton_scale_swizzle_per_group_3d = ~250us avg (not surprising it's longer since 3d tensor is much more data than 2d activations, will get some mem bw benchmarks on these kernels though)
  3. Mxfp8 dim1 cast CUDA kernel looks good: ~100us average, used on 2d RHS operands. This kernel achieves ~5300gbps mem bw, around 66% of the peak 8TB/s bandwidth, so it can potentially still be improved.

  4. Torch inductor codegen kernels are pretty slow (worst offender is ~1.2ms, longer than the grouped GEMM itself). This is probably largely due to stray *.contiguous() call I was forced to do by to_mx API limitations but I'm guessing it's also that inductor codegen is slow, as it has been historically for various cases in fp8 rowwise, fp8 blockwise, and mxfp8. So we should try to get the mxfp8 dim1 cuda kernel working for 3d tensors so we can get the double win of avoiding the .contiguous() calls and using the faster kernel / achieves higher mem bw utilization than torch.compile / triton. I can do this next, or perhaps @slayton58 may be interested in this. (Perhaps we can just reshape 3d->2d, use the existing kernel, then reshape back to 3d? Need to think about this.)

Copy link

pytorch-bot bot commented Sep 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2990

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 01bfca9 with merge base 66384a9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Sep 12, 2025
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2025
danielvegamyhre added a commit that referenced this pull request Sep 12, 2025
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Sep 12, 2025

Possible wrap_triton bug:

  • Unit tests using compile pass: pytest test/prototype/moe_training/test_training.py -s
  • Benchmarks using compile hit this error, inside the custom op (??):

Command: python benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py --compile

Update: disabling dynamic shapes resolved the error

Error:


  File "/home/danvm/ao/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py", line 228, in triton_mx_block_rearrange_2d_M_groups
    output = scales_tensor.new_empty((padded_rows, padded_cols))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function torchao.triton_mx_block_rearrange_2d_M_groups.default(*(FakeTensor(..., device='cuda:0', size=(16640, 160), dtype=torch.float8_e8m0fnu), FakeTensor(..., device='cuda:0', size=(s21,), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(s21 + 1,), dtype=torch.int64)), **{}): got RuntimeError('Cannot call numel() on tensor with symbolic sizes/strides\nException raised from throw_cannot_call_with_symbolic at /home/danvm/pytorch/c10/core/TensorImpl.cpp:291 (most recent call first):\nframe #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x88 (0x7f35c53946c8 in /home/danvm/pytorch/torch/lib/libc10.so)\nframe #1: c10::TensorImpl::throw_cannot_call_with_symbolic(char const*) const + 0x78 (0x7f35c5325476 in /home/danvm/pytorch/torch/lib/libc10.so)\nframe #2: <unknown function> + 0x7938f (0x7f35c537038f in /home/danvm/pytorch/torch/lib/libc10.so)\nframe #3: <unknown function> + 0x49f856 (0x7f35c589f856 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #4: <unknown function> + 0x484fde (0x7f35c5884fde in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #5: python() [0x545869]\n<omitting python frames>\nframe #11: python() [0x42c4c2]\nframe #13: python() [0x625fb6]\nframe #23: python() [0x457d8d]\nframe #24: python() [0x5d6bb6]\nframe #26: torch::handle_torch_function_no_python_arg_parser(c10::ArrayRef<_object*>, _object*, _object*, char const*, _object*, char const*, torch::TorchFunctionName) + 0x405 (0x7f35c603d6d5 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #27: <unknown function> + 0x6a1865 (0x7f35c5aa1865 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #28: <unknown function> + 0x17bd9d3 (0x7f35b93bd9d3 in /home/danvm/pytorch/torch/lib/libtorch_cpu.so)\nframe #29: <unknown function> + 0x6a7d90 (0x7f35c5aa7d90 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #30: <unknown function> + 0x6a4c8d (0x7f35c5aa4c8d in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #31: <unknown function> + 0xc56c70 (0x7f35c6056c70 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #32: <unknown function> + 0xc5708d (0x7f35c605708d in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #33: <unknown function> + 0x3b7dc6 (0x7f35c57b7dc6 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #34: python() [0x5429c4]\nframe #36: python() [0x56b000]\nframe #38: python() [0x457d8d]\nframe #41: <unknown function> + 0xc66254 (0x7f35c6066254 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #42: <unknown function> + 0xc58d6b (0x7f35c6058d6b in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #43: <unknown function> + 0xc66a96 (0x7f35c6066a96 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #44: <unknown function> + 0x6a7d90 (0x7f35c5aa7d90 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #45: <unknown function> + 0x6a4c8d (0x7f35c5aa4c8d in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #46: <unknown function> + 0x17bd7cb (0x7f35b93bd7cb in /home/danvm/pytorch/torch/lib/libtorch_cpu.so)\nframe #47: <unknown function> + 0x6a7d90 (0x7f35c5aa7d90 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #48: <unknown function> + 0x6a4c8d (0x7f35c5aa4c8d in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #49: <unknown function> + 0x596736b (0x7f35bd56736b in /home/danvm/pytorch/torch/lib/libtorch_cpu.so)\nframe #50: <unknown function> + 0x9e685f (0x7f35c5de685f in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #51: <unknown function> + 0x9e6b5c (0x7f35c5de6b5c in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #52: <unknown function> + 0x8de7e9 (0x7f35c5cde7e9 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #53: <unknown function> + 0x3b7dc6 (0x7f35c57b7dc6 in /home/danvm/pytorch/torch/lib/libtorch_python.so)\nframe #54: python() [0x5429c4]\nframe #57: python() [0x42c4c2]\nframe #59: python() [0x625fb6]\n')

from user code:
   File "/home/danvm/ao/benchmarks/utils.py", line 12, in fwd_bwd
    out = fn(*args, **kwargs)
  File "/home/danvm/ao/torchao/prototype/moe_training/scaled_grouped_mm.py", line 71, in _scaled_grouped_mm
    return _MXFP8GroupedMM.apply(
  File "/home/danvm/ao/torchao/prototype/moe_training/scaled_grouped_mm.py", line 323, in forward
    A_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
  File "/home/danvm/pytorch/torch/_library/custom_ops.py", line 676, in __call__
    return self._opoverload(*args, **kwargs)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

cc @zou3519 any idea what could be going on here? this is with torch built from source on 9/10

@danielvegamyhre
Copy link
Contributor Author

Disable dynamic shapes worked as a workaround

@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Sep 12, 2025

Only forward is getting compiled for some reason, not backward: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp7B4CM7/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

This is with torch built from source with cuda 12.9 at 53b8bdb97774114ca02948fed47f2fd49996c564 (sept 12th)
Screenshot 2025-09-12 at 1 21 38 PM

danielvegamyhre added a commit that referenced this pull request Sep 12, 2025
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
danielvegamyhre added a commit that referenced this pull request Sep 12, 2025
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Sep 12, 2025
danielvegamyhre added a commit that referenced this pull request Sep 12, 2025
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
@zou3519
Copy link
Contributor

zou3519 commented Sep 12, 2025

Do you have a function in C++ that calls Tensor.sizes()? The main fix for that error message is to change Tensor.sizes() in C++ to Tensor.sym_sizes()

@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] add compile support [mxfp8 moe training] add compile support and fix d2h sync Sep 12, 2025
danielvegamyhre added a commit that referenced this pull request Sep 12, 2025
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] add compile support and fix d2h sync [mxfp8 moe training] add compile support Sep 12, 2025
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants