-
Notifications
You must be signed in to change notification settings - Fork 37
[WIP][FA][Blackwell] Implementation with explicit data partitioning #384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! If you could clarify my understanding of the kernel that would be great.
desc_q = TensorDescriptor( | ||
q, | ||
shape=[y_dim, HEAD_DIM_K], | ||
strides=[HEAD_DIM_K, 1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes that q, k, and v are all contiguous and not transposed. Can we add an assert to enforce this requirement?
strides=[HEAD_DIM_K, 1], | ||
block_shape=dummy_block, | ||
) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just error on the else case? The kernel requires TMA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to support on-device TMA. I will need to call _maybe_make_tensor_desc.
OUTER_LOOP: tl.constexpr, | ||
dtype: tl.constexpr, | ||
): | ||
n_tile_num = tl.cdiv(N_CTX, BLOCK_M) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? This stands out to me as a possible typo, possibly just in variable name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah should rename n_tile_num to num_pid_m.
|
||
tile_idx = prog_id | ||
# inner loop warpspec vs. outer loop warpspec | ||
for _ in tl.range(0, tiles_per_sm, warp_specialize=warp_specialize and OUTER_LOOP): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding why don't we need FLATTEN=True
here? Is it not viable with FA because its too complicated and we actually need more complex AutoWS in the compiler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FLATTEN only works on very simple case such as GEMM. For FA, FLATTEN doesn't work and we will need to handle nested control flow. In OSS, NVidia is driving the work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying/confirming my understanding.
offs_m1 = start_m * BLOCK_M + tl.arange(BLOCK_M//2, BLOCK_M) | ||
offs_n = tl.arange(0, BLOCK_N) | ||
|
||
m_i0 = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify this is the explicit data partitioning to enable subtiling + ping pong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to set up the outputs for data partitioning, we will call _attn_fwd_inner_oss_dp twice, each one working on one half of the full block size.
We will have one data partition working on q0, the other working on q1, where q0 + q1 is the original q.
BN: tl.constexpr = acc.shape[1] | ||
|
||
acc0, acc1 = acc.reshape([BM, 2, BN//2]).permute(0, 2, 1).split() | ||
acc0 = acc0 * alpha[:, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is another form of partitioning/subtiling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is subtiling for correction.
FA tutorial has this too: https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py#L88
python run.py --op blackwell_attentions --only triton_tutorial_flash_dp_blackwell, --seq-len 1024 --batch 1152 --n-heads 4 --d-head 128