Skip to content

Commit 173699e

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
Add MultiHeadAttentionWithCache for self attention with cache fast path (facebookresearch#449)
Summary: Pull Request resolved: facebookresearch#449 When doing self attention, an optimization is to combine the Q, K, V input projection matrices and do a single matmul, instead of 3. Adding this optimization for MHA with cache in a new module `MultiHeadSelfAttentionWithCache`. Note: we are primarily using a new module to avoid breaking checkpoint BC with respect to `MultiHeadAttentionWithCache`. In the future, we should consolidate these MHA implementations. Differential Revision: D48418780 fbshipit-source-id: 0b20fb807527109a9a3ad419805e47e0f9ba2c74
1 parent 951a452 commit 173699e

File tree

2 files changed

+149
-10
lines changed

2 files changed

+149
-10
lines changed

tests/modules/layers/test_multi_head_attention.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torchmultimodal.modules.layers.multi_head_attention import (
1212
MultiHeadAttentionWithCache,
1313
MultiHeadSelfAttention,
14+
MultiHeadSelfAttentionWithCache,
1415
)
1516

1617

@@ -103,6 +104,13 @@ def multi_head_self_attn_use_cache(self, dim_q):
103104
mha.eval()
104105
return mha
105106

107+
@pytest.fixture
108+
def multi_head_self_attn_module_with_cache(self, dim_q):
109+
mha = MultiHeadSelfAttentionWithCache(dim_q, num_heads=2)
110+
init_weights_with_constant(mha)
111+
mha.eval()
112+
return mha
113+
106114
@pytest.fixture
107115
def multi_head_cross_attn(self, dim_q, dim_kv):
108116
mha = MultiHeadAttentionWithCache(dim_q, dim_kv, num_heads=2)
@@ -117,16 +125,7 @@ def multi_head_cross_attn_without_bias(self, dim_q, dim_kv):
117125
mha.eval()
118126
return mha
119127

120-
def test_multi_head_self_attention_use_cache(
121-
self,
122-
multi_head_self_attn_use_cache,
123-
current_key_value,
124-
past_key_value,
125-
q,
126-
):
127-
actual = multi_head_self_attn_use_cache(
128-
q, q, q, past_key_value=(past_key_value, past_key_value), use_cache=True
129-
)
128+
def _assert_mha_self_attn_equal(self, actual, past_key_value, current_key_value):
130129
expected = torch.tensor(
131130
[
132131
[
@@ -138,6 +137,7 @@ def test_multi_head_self_attention_use_cache(
138137
)
139138
assert_expected(actual.attn_output, expected, rtol=0, atol=1e-4)
140139
# Check that the cache is properly updated
140+
torch.cat([past_key_value, current_key_value], dim=2)
141141
assert_expected(
142142
actual.past_key_value[0],
143143
torch.cat([past_key_value, current_key_value], dim=2),
@@ -147,6 +147,59 @@ def test_multi_head_self_attention_use_cache(
147147
torch.cat([past_key_value, current_key_value], dim=2),
148148
)
149149

150+
def test_multi_head_self_attention_use_cache(
151+
self,
152+
multi_head_self_attn_use_cache,
153+
current_key_value,
154+
past_key_value,
155+
q,
156+
):
157+
actual = multi_head_self_attn_use_cache(
158+
q, q, q, past_key_value=(past_key_value, past_key_value), use_cache=True
159+
)
160+
self._assert_mha_self_attn_equal(actual, past_key_value, current_key_value)
161+
162+
def test_multi_head_self_attn_module_with_cache(
163+
self,
164+
multi_head_self_attn_module_with_cache,
165+
current_key_value,
166+
past_key_value,
167+
q,
168+
):
169+
actual = multi_head_self_attn_module_with_cache(
170+
q, past_key_value=(past_key_value, past_key_value), use_cache=True
171+
)
172+
self._assert_mha_self_attn_equal(actual, past_key_value, current_key_value)
173+
174+
def test_multi_head_attention_with_cache_modules_equal(
175+
self,
176+
multi_head_self_attn_use_cache,
177+
multi_head_self_attn_module_with_cache,
178+
current_key_value,
179+
past_key_value,
180+
q,
181+
):
182+
mha_with_cache_cls_output = multi_head_self_attn_use_cache(
183+
q, q, q, past_key_value=(past_key_value, past_key_value), use_cache=True
184+
)
185+
sa_with_cache_cls_output = multi_head_self_attn_module_with_cache(
186+
q, past_key_value=(past_key_value, past_key_value), use_cache=True
187+
)
188+
assert_expected(
189+
mha_with_cache_cls_output.attn_output,
190+
sa_with_cache_cls_output.attn_output,
191+
rtol=0,
192+
atol=1e-4,
193+
)
194+
assert_expected(
195+
mha_with_cache_cls_output.past_key_value[0],
196+
sa_with_cache_cls_output.past_key_value[0],
197+
)
198+
assert_expected(
199+
mha_with_cache_cls_output.past_key_value[1],
200+
sa_with_cache_cls_output.past_key_value[1],
201+
)
202+
150203
def test_multi_head_cross_attention(self, multi_head_cross_attn, q, kv):
151204
actual = multi_head_cross_attn(q, kv, kv)
152205
expected = torch.tensor(

torchmultimodal/modules/layers/multi_head_attention.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,92 @@ def forward(
7777
return attn_out
7878

7979

80+
class MultiHeadSelfAttentionWithCache(nn.Module):
81+
"""
82+
MultiHeadAttention module for self-attention(SA). Similar to MultiHeadAttentionWithCache,
83+
but only supports self attention and uses a fast path where the query, key, and value projections
84+
are batched into a single matrix multiplication as opposed to three separate matmuls.
85+
This class supports a cache mechanism for decoders to store previous states through
86+
"past_key_value".
87+
88+
Args:
89+
embed_dim (int): query, key, value embedding dimension
90+
num_heads (int): number of attention heads
91+
dropout (float): dropout rate
92+
add_bias (bool): if ``True``, adds a learnable bias to query, key, value input projection matrix.
93+
Defaults to True.
94+
"""
95+
96+
def __init__(
97+
self,
98+
embed_dim: int,
99+
num_heads: int,
100+
dropout: float = 0.0,
101+
add_bias: bool = True,
102+
) -> None:
103+
super().__init__()
104+
self.num_heads = num_heads
105+
self.input_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=add_bias)
106+
self.output_proj = nn.Linear(embed_dim, embed_dim)
107+
self.dropout = dropout
108+
109+
def forward(
110+
self,
111+
query: Tensor,
112+
attn_mask: Optional[Tensor] = None,
113+
past_key_value: Optional[Tuple[Tensor, Tensor]] = None,
114+
is_causal: bool = False,
115+
use_cache: bool = False,
116+
) -> Union[Tensor, MHAWithCacheOutput]:
117+
"""
118+
Args:
119+
query (Tensor): input query of shape bsz x target_seq_len x embed_dim
120+
attn_mask (optional Tensor): Attention mask of shape bsz x num_heads x target_seq_len x source_seq_len.
121+
Note that the num_heads dimension can equal 1 and the mask will be broadcasted to all heads.
122+
Two types of masks are supported. A boolean mask where a value of True
123+
indicates that the element *should* take part in attention.
124+
A float mask of the same type as query, key, value that is added to the attention score.
125+
past_key_value (optional tuple of tensors): cached key and value with the same shape of key, value inputs.
126+
The size of tuple should be 2, where the first entry is for cached key and second entry is for cached value.
127+
is_causal (bool): If true, does causal attention masking, attn_mask should be set to None if this is set to True
128+
is_causal is a hint that the mask is a causal mask, providing incorrect hints can result in incorrect execution.
129+
use_cache (bool): whether to use cache for key and value tensors
130+
131+
Returns:
132+
if use_cache is off, return attn_output tensor of shape bsz x seq_len x embed_dim;
133+
otherwise return namedtuple with attn_output, cached key and value.
134+
"""
135+
bsz = query.size(0)
136+
embed_dim = query.size(-1)
137+
head_dim = embed_dim // self.num_heads
138+
projected_query = self.input_proj(query)
139+
query, key, value = projected_query.chunk(3, dim=-1)
140+
141+
# bsz x seq_len x embed_dim => bsz x num_heads x seq_len x head_dim
142+
query = query.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2)
143+
if key.size(0) != bsz or value.size(0) != bsz:
144+
raise ValueError("key and value should have the same bsz as query.")
145+
key = key.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2)
146+
value = value.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2)
147+
148+
# concat key value with cached values
149+
if past_key_value is not None:
150+
key = torch.cat([past_key_value[0], key], dim=2)
151+
value = torch.cat([past_key_value[1], value], dim=2)
152+
153+
# turn off causal attention inside scaled_dot_product_attention, we handle it separately with attn_mask.
154+
attn = F.scaled_dot_product_attention(
155+
query, key, value, attn_mask, self.dropout, is_causal
156+
)
157+
attn = attn.transpose(1, 2).reshape(bsz, -1, embed_dim)
158+
159+
# add dense layer after attention
160+
attn_output = self.output_proj(attn)
161+
if use_cache:
162+
return MHAWithCacheOutput(attn_output, (key, value))
163+
return attn_output
164+
165+
80166
class MultiHeadAttentionWithCache(nn.Module):
81167
"""
82168
MultiHeadAttention module for both self-attention(SA) and cross-attention(CA).

0 commit comments

Comments
 (0)