-
Notifications
You must be signed in to change notification settings - Fork 336
[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise #3002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3002
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit bb8b07f with merge base f75b251 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
2b1b340
to
146b42a
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
146b42a
to
9921d5e
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
9921d5e
to
b3b709c
Compare
@slayton58 @ngimel i would be curious to get your thoughts on ways to improve this kernel for quantizing 3d expert weights (E,N,K) along the N dim, where weights are contiguous. It uses nearly identical logic to the 2d dim1 cast kernel (which achieves ~85% mem bw utilization), yet the perf is much worse (~8% to 62% peak mem bw, depending on input size - see benchmarks in PR description). I think the culprit might be how i'm allocating all the TMA descriptors and passing them in, the overhead might be too much for small E? NCU has not flagged anything particularly helpful so far. Strangely, for E=2 it shows the kernel is compute bound with 78% compute throughput % and 38% memory bandwidth %. Additional context: torch.compile and handwritten triton kernels were both slow for mxfp8 quant for RHS operands where we scale colwise (32x1 granularity) e.g., (triton hit 56% peak mem bw). So I added a CUDA kernel here which I derived from a TE kernel which achieves ~85% peak mem bw (#2513). Basically we stripped out internal TE types, added support for different scale calculation modes (floor, rceil) to align with torchao numerics, then resolved some perf issues resulting from those changes to get reasonable perf. Now, I'm finding quantizing 3d expert weights along dim1 is scaling extremely poorly as number of experts increases (see this PR's description for details, and see #2999 for benchmarks). So I added a similar CUDA kernel to our mxfp8_cuda extension specifically for quantizing 3d expert weights colwise and writing directly to col major format we need it in. The first approach I tried was just updating the 2d kernel to handle 3d tensors by treating it as a 2d tensor of shape (E*N, K) but the coordinate mapping / pointer arithemetic became a complicated mess that wasn't working. So I made a new kernel, that is similar to the 2d kernel but passes in separate input/output TMA descriptors for each expert, then the kernel operates on each 2d expert with logical separation, in parallel. |
c197afa
to
7120c0d
Compare
7120c0d
to
859d741
Compare
859d741
to
030f4f3
Compare
030f4f3
to
4ebcfec
Compare
4ebcfec
to
6403a25
Compare
213a554
to
fefb1e0
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
6403a25
to
a60ee11
Compare
a60ee11
to
6593572
Compare
6dd01fc
to
644d635
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
367b67c
to
bb8b07f
Compare
Stacked PRs:
[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise
Summary
.contiguous()
calls:to_mx
only scales along the last dim and requires contiguos inputs. So this requires transposing contiguous tensor (E,N,K) -> (E,K,N) then calling .contiguous() to scale along the N dim (needed for backwards)Test plan
Kernel microbenchmarks
Perf is decent for large E and abysmal for small E. Need to investigate this.
Update (9/15): NCU shows 3d kernel operating on (2,8192,5120) tensor is actually compute bound (??)