Skip to content

Conversation

manman-ren
Copy link
Contributor

@manman-ren manman-ren commented Sep 2, 2025

python run.py --op blackwell_attentions --only triton_tutorial_flash_dp_blackwell, --seq-len 1024 --batch 1152 --n-heads 4 --d-head 128

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@meta-cla meta-cla bot added the cla signed label Sep 2, 2025
@manman-ren manman-ren marked this pull request as draft September 2, 2025 20:21
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@manman-ren manman-ren marked this pull request as ready for review September 10, 2025 21:38
@manman-ren manman-ren requested a review from njriasan September 10, 2025 21:38
Copy link
Contributor

@njriasan njriasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! If you could clarify my understanding of the kernel that would be great.

desc_q = TensorDescriptor(
q,
shape=[y_dim, HEAD_DIM_K],
strides=[HEAD_DIM_K, 1],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that q, k, and v are all contiguous and not transposed. Can we add an assert to enforce this requirement?

strides=[HEAD_DIM_K, 1],
block_shape=dummy_block,
)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just error on the else case? The kernel requires TMA

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to support on-device TMA. I will need to call _maybe_make_tensor_desc.

OUTER_LOOP: tl.constexpr,
dtype: tl.constexpr,
):
n_tile_num = tl.cdiv(N_CTX, BLOCK_M)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? This stands out to me as a possible typo, possibly just in variable name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah should rename n_tile_num to num_pid_m.


tile_idx = prog_id
# inner loop warpspec vs. outer loop warpspec
for _ in tl.range(0, tiles_per_sm, warp_specialize=warp_specialize and OUTER_LOOP):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding why don't we need FLATTEN=True here? Is it not viable with FA because its too complicated and we actually need more complex AutoWS in the compiler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FLATTEN only works on very simple case such as GEMM. For FA, FLATTEN doesn't work and we will need to handle nested control flow. In OSS, NVidia is driving the work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying/confirming my understanding.

offs_m1 = start_m * BLOCK_M + tl.arange(BLOCK_M//2, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)

m_i0 = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify this is the explicit data partitioning to enable subtiling + ping pong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to set up the outputs for data partitioning, we will call _attn_fwd_inner_oss_dp twice, each one working on one half of the full block size.
We will have one data partition working on q0, the other working on q1, where q0 + q1 is the original q.

BN: tl.constexpr = acc.shape[1]

acc0, acc1 = acc.reshape([BM, 2, BN//2]).permute(0, 2, 1).split()
acc0 = acc0 * alpha[:, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another form of partitioning/subtiling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is subtiling for correction.
FA tutorial has this too: https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py#L88

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants