-
Notifications
You must be signed in to change notification settings - Fork 374
[mxfp8 moe training] add compile support #2990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2990
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 10 PendingAs of commit d4a26bd with merge base 66384a9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
7660825 to
b294802
Compare
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
b294802 to
bceaa66
Compare
|
Possible wrap_triton bug:
Command: Update: disabling dynamic shapes resolved the error Error: cc @zou3519 any idea what could be going on here? this is with torch built from source on 9/10 |
|
Disable dynamic shapes worked as a workaround |
|
Only forward is getting compiled for some reason, not backward: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp7B4CM7/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 This is with torch built from source with cuda 12.9 at 53b8bdb97774114ca02948fed47f2fd49996c564 (sept 12th) |
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
bceaa66 to
1f692de
Compare
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
1f692de to
bc4aef6
Compare
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
bc4aef6 to
6327f2e
Compare
|
Do you have a function in C++ that calls Tensor.sizes()? The main fix for that error message is to change Tensor.sizes() in C++ to Tensor.sym_sizes() |
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
6327f2e to
d6e908b
Compare
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
d6e908b to
01bfca9
Compare
| lambda x: ScaledGroupedMMTensor(x, scaling_type), | ||
| out, | ||
| ) | ||
| # Only wrap tensor outputs, prevent double wrapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what would cause double wrapping to have happened previously? I would expect that as long as args_unwrapped above only consists of plain tensors, then out here should be only plain tensors as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried to repro the original issue to answer this, but now using the latest torch nightly build it doesn't repro anymore.... good news I guess
| torch.ops.aten.view.default, | ||
| torch.ops.aten.as_strided.default, | ||
| torch.ops.aten._to_copy.default, | ||
| torch.ops.aten._to_copy.default, # for *.to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems a bit strange to me that _to_copy should always preserve the subclass. For example - if I do aten._to_copy(mxfp8_tensor, dtype=torch.bfloat16), I would expect the output to not be a subclass, right?
Put another way - I'd imagine that you need to carefully implement _to_copy to only preserve the subclass if we are not doing a cross-dtype copy to a higher precision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example - if I do aten._to_copy(mxfp8_tensor, dtype=torch.bfloat16), I would expect the output to not be a subclass, right?
It would seem that way, but we actually do want to preserve the subclass in this case, because of how torchtitan integrates with grouped_mm here
Basically, kernels that torch._grouped_mm and torch._scaled_grouped_mm dispatch to only accept bf16 as the input/output dtype, yet torchtitan wants to support doing training in full fp32 precision, so they do this cast here.
e7247f1 to
68ed781
Compare
68ed781 to
e381f71
Compare
| torch._dynamo.config.cache_size_limit = 1000 | ||
|
|
||
| # Workaround for https://github.com/pytorch/ao/pull/2990#issuecomment-3285762681 | ||
| torch._dynamo.config.automatic_dynamic_shapes = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed for fast e2e? if yes, add to README.md?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, we can remove this workaround now - seems like using the latest nightly, the issue is resolved. Updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So actually yes, while the bug/crash no longer occurs with dynamic shapes, perf is worse with dynamic shapes: https://www.internalfb.com/phabricator/paste/view/P1949814669
So I think we should recommend disabling. Will make a note of this to put in the readme
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No we dont need to recommend in the readme. Benchmarks artificially hit automatic dynamic shapes because we are sweeping certain dims. In practice the dims that are being swept won't be dynamic or if they are then its up to the users to decide how they want to handle this -> dynamic=false w/ more recompiles or the default behavior
e381f71 to
e3f64f0
Compare
|
|
||
| # Final offset is the total number of rows in the tensor | ||
| padded_rows = output_group_start_offsets[-1] | ||
| padded_rows = rows + num_groups * 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you write a note: padding to max_possible: ...
and then link to that note throughout the codebase this is a pretty common pattern
| # output_group_start_offsets always starts with 0 and ends with the total number of cols | ||
| padded_cols = output_group_start_offsets[-1] | ||
| output = scales_tensor.new_empty((padded_rows, padded_cols)) | ||
| padded_cols = cols + num_groups * 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| assert x_scales.ndim == 2, "x_scales must be 2D" | ||
| assert block_size == 32, "Only block_size=32 is supported for now" | ||
| blocked_scales_list = [] | ||
| M, total_K = x_scales.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for reviewers: the changes to the torch reference implementation torch_to_blocked_2d_K_groups (to account for upper-bound based padding, i.e. total_padded_K = cols + num_groups * 4) is not working properly yet. It does not match the triton kernel in unit tests, but I am sure the triton kernel is correct because the e2e training tests validate correct gradient numerics. As further evidence, replacing the triton kernel with this torch impl in the training code, and the gradient numerics are garbage.
Need to figure out the correct way to represent "row of blocks"-major PER GROUP in torch native code, before landing this.
f5228f7 to
3f644c5
Compare
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
3f644c5 to
d4a26bd
Compare

Stacked PRs:
[mxfp8 moe training] add compile support
padded_rows = output_scale_group_offsets[-1]. By looking at the trace I found this was doing a .item() under the hood and causing a d2h sync, so now instead I'm just using the upper bound of the padding needed viapadded_rows = rows + num_groups * 128and avoiding the d2h sync.Test plan
pytest test/prototype/moe_training/test_training.pyMicrobenchmarks with compile
Perf analysis for M=16640, G=4, N=8192, K=5120
Looking at the trace to see why it's slower, here are my initial takeways:
Mxfp8 grouped GEMMs are 1.57x faster than bf16 grouped gemms, but should be 2.1x faster: (mxpf8 avg = ~700us, bf16 avg = ~1.1ms). This contradicts the grouped GEMM microbenchmarking we did, which indicates for shape M=16640, G=4, K=5120, N=8192 we should be getting a ~2.1x speedup, but we only are getting a ~1.57x speedup. So we need to determine what is going on here. cc @cthi @slayton58 who may be able to help with this
Blocked scale swizzling kernels look okay: 2 of the 3 of the handwritten triton kernels for converting scales to blocked swizzled format are very fast. One may need more optimization (TBD):
triton_scale_swizzle_M_groups= 18us avgtriton_scale_swizzle_K_groups= 14us avgtriton_scale_swizzle_per_group_3d= ~250us avg (not surprising it's longer since 3d tensor is much more data than 2d activations, will get some mem bw benchmarks on these kernels though)Mxfp8 dim1 cast CUDA kernel looks good: ~100us average, used on 2d RHS operands. This kernel achieves ~5300gbps mem bw, around 66% of the peak 8TB/s bandwidth, so it can potentially still be improved.
Torch inductor codegen kernels are pretty slow (worst offender is ~1.2ms, longer than the grouped GEMM itself). This is probably largely due to stray *.contiguous() call I was forced to do by to_mx API limitations but I'm guessing it's also that inductor codegen is slow, as it has been historically for various cases in fp8 rowwise, fp8 blockwise, and mxfp8. So we should try to get the mxfp8 dim1 cuda kernel working for 3d tensors so we can get the double win of avoiding the .contiguous() calls and using the faster kernel / achieves higher mem bw utilization than torch.compile / triton. I can do this next, or perhaps @slayton58 may be interested in this. (Perhaps we can just reshape 3d->2d, use the existing kernel, then reshape back to 3d? Need to think about this.)