Skip to content

Conversation

@liangel-02
Copy link
Contributor

@liangel-02 liangel-02 commented Nov 7, 2025

Summary
This PR adds variable length attention (varlen) support to the Llama 3 8b model in torchtitan. We add a flag use_varlen_attn to the model config, and if this is set to True, the attention module calls a compiled varlen_attn defined here.

Testing
Ran loss and performance tests against flex attention. Loss is on par.

Screenshot 2025-11-12 at 5 53 15 PM

Varlen is slightly slower than Flex due to the cuda kernel speeds (varlen calls into flash_attention_forward/flash_attention_backward today).

Varlen Flex
Forward 774us 357ns 722us 317ns
Backward 1ms 955us 916ns 1ms 558us 747ns

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 7, 2025
@liangel-02 liangel-02 force-pushed the test_varlen branch 2 times, most recently from a96be88 to eeecb63 Compare November 12, 2025 22:42
@liangel-02 liangel-02 changed the title Test varlen adding variable length attention to llama 3 8b Nov 12, 2025
@liangel-02 liangel-02 changed the title adding variable length attention to llama 3 8b adding variable length attention to llama3 8b Nov 12, 2025
@liangel-02 liangel-02 requested a review from drisspg November 12, 2025 23:18
self._sample_idx = 0
self._token_buffer: list[int] = []

self._boundary_buffer: list[int] = [0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment on on boundary_buffer and why its needed

yield {"input": input}, label

if self.use_varlen_attn:
boundaries_in_window = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also maybe make this a func? that gets called so you can better document it

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation won't work with PP and too model intrusive. The pack logic should be hide inside the inner attention.

return path, config.loader, config.sample_processor


def varlen_collate_fn(batch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be done in the dataloader. If you always pack the input batch to batch size 1, then pipeline parallelism won't work. You should perform the pack using the mask (namedtuple) data (see below) inside the inner attention to pack to what you need.

Comment on lines +119 to +125
return {
"input": packed_input,
"cu_seq_q": packed_cu_seqlens,
"cu_seq_k": packed_cu_seqlens,
"max_q": max_seqlen,
"max_k": max_seqlen,
}, packed_label
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should follow how we create BlockMask by letting the model to provide the attention mask. You can extend AttentionMasksType and https://github.com/pytorch/torchtitan/blob/main/torchtitan/protocols/model.py#L64. You can use namedtuple for this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants