-
Notifications
You must be signed in to change notification settings - Fork 336
[mxfp8 moe training] add compile support #2990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2990
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 01bfca9 with merge base 66384a9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
7660825
to
b294802
Compare
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
b294802
to
bceaa66
Compare
Possible wrap_triton bug:
Command: Update: disabling dynamic shapes resolved the error Error:
cc @zou3519 any idea what could be going on here? this is with torch built from source on 9/10 |
Disable dynamic shapes worked as a workaround |
Only forward is getting compiled for some reason, not backward: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp7B4CM7/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 This is with torch built from source with cuda 12.9 at 53b8bdb97774114ca02948fed47f2fd49996c564 (sept 12th) |
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
bceaa66
to
1f692de
Compare
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
1f692de
to
bc4aef6
Compare
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
bc4aef6
to
6327f2e
Compare
Do you have a function in C++ that calls Tensor.sizes()? The main fix for that error message is to change Tensor.sizes() in C++ to Tensor.sym_sizes() |
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
6327f2e
to
d6e908b
Compare
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
d6e908b
to
01bfca9
Compare
Stacked PRs:
[mxfp8 moe training] add compile support
padded_rows = output_scale_group_offsets[-1]
. By looking at the trace I found this was doing a .item() under the hood and causing a d2h sync, so now instead I'm just using the upper bound of the padding needed viapadded_rows = rows + num_groups * 128
and avoiding the d2h sync.Test plan
pytest test/prototype/moe_training/test_training.py
Microbenchmarks with compile
Perf analysis for M=16640, G=4, N=8192, K=5120
Looking at the trace to see why it's slower, here are my initial takeways:
Mxfp8 grouped GEMMs are 1.57x faster than bf16 grouped gemms, but should be 2.1x faster: (mxpf8 avg = ~700us, bf16 avg = ~1.1ms). This contradicts the grouped GEMM microbenchmarking we did, which indicates for shape M=16640, G=4, K=5120, N=8192 we should be getting a ~2.1x speedup, but we only are getting a ~1.57x speedup. So we need to determine what is going on here. cc @cthi @slayton58 who may be able to help with this
Blocked scale swizzling kernels look okay: 2 of the 3 of the handwritten triton kernels for converting scales to blocked swizzled format are very fast. One may need more optimization (TBD):
triton_scale_swizzle_M_groups
= 18us avgtriton_scale_swizzle_K_groups
= 14us avgtriton_scale_swizzle_per_group_3d
= ~250us avg (not surprising it's longer since 3d tensor is much more data than 2d activations, will get some mem bw benchmarks on these kernels though)Mxfp8 dim1 cast CUDA kernel looks good: ~100us average, used on 2d RHS operands. This kernel achieves ~5300gbps mem bw, around 66% of the peak 8TB/s bandwidth, so it can potentially still be improved.
Torch inductor codegen kernels are pretty slow (worst offender is ~1.2ms, longer than the grouped GEMM itself). This is probably largely due to stray *.contiguous() call I was forced to do by to_mx API limitations but I'm guessing it's also that inductor codegen is slow, as it has been historically for various cases in fp8 rowwise, fp8 blockwise, and mxfp8. So we should try to get the mxfp8 dim1 cuda kernel working for 3d tensors so we can get the double win of avoiding the .contiguous() calls and using the faster kernel / achieves higher mem bw utilization than torch.compile / triton. I can do this next, or perhaps @slayton58 may be interested in this. (Perhaps we can just reshape 3d->2d, use the existing kernel, then reshape back to 3d? Need to think about this.)