Skip to content

Commit 0d3217d

Browse files
[mxfp8 moe training] increase num_warps in mxfp8 a2a comms kernel (pytorch#3087)
[mxfp8 moe training] increase num_warps in mxfp8 a2a kernel
1 parent 0b96757 commit 0d3217d

File tree

1 file changed

+1
-1
lines changed
  • torchao/prototype/moe_training/kernels/mxfp8

1 file changed

+1
-1
lines changed

torchao/prototype/moe_training/kernels/mxfp8/comms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _mxfp8_on_device_all_to_all_v(
275275
world_size=input_hdl.world_size,
276276
BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
277277
BLOCK_SIZE=BLOCK_SIZE,
278-
num_warps=1,
278+
num_warps=16,
279279
)
280280

281281
return output

0 commit comments

Comments
 (0)