-
Notifications
You must be signed in to change notification settings - Fork 601
adding variable length attention to llama3 8b #2000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
a96be88 to
eeecb63
Compare
eeecb63 to
cad97e5
Compare
| self._sample_idx = 0 | ||
| self._token_buffer: list[int] = [] | ||
|
|
||
| self._boundary_buffer: list[int] = [0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment on on boundary_buffer and why its needed
| yield {"input": input}, label | ||
|
|
||
| if self.use_varlen_attn: | ||
| boundaries_in_window = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also maybe make this a func? that gets called so you can better document it
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation won't work with PP and too model intrusive. The pack logic should be hide inside the inner attention.
| return path, config.loader, config.sample_processor | ||
|
|
||
|
|
||
| def varlen_collate_fn(batch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be done in the dataloader. If you always pack the input batch to batch size 1, then pipeline parallelism won't work. You should perform the pack using the mask (namedtuple) data (see below) inside the inner attention to pack to what you need.
| return { | ||
| "input": packed_input, | ||
| "cu_seq_q": packed_cu_seqlens, | ||
| "cu_seq_k": packed_cu_seqlens, | ||
| "max_q": max_seqlen, | ||
| "max_k": max_seqlen, | ||
| }, packed_label |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should follow how we create BlockMask by letting the model to provide the attention mask. You can extend AttentionMasksType and https://github.com/pytorch/torchtitan/blob/main/torchtitan/protocols/model.py#L64. You can use namedtuple for this.
Summary
This PR adds variable length attention (varlen) support to the Llama 3 8b model in torchtitan. We add a flag
use_varlen_attnto the model config, and if this is set to True, the attention module calls a compiledvarlen_attndefined here.Testing
Ran loss and performance tests against flex attention. Loss is on par.
Varlen is slightly slower than Flex due to the cuda kernel speeds (varlen calls into
flash_attention_forward/flash_attention_backwardtoday).