Skip to content

Commit dfd2ec6

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
Batch matmul fast path in MHAWithCache (#449)
Summary: Pull Request resolved: #449 When doing self attention, an optimization is to combine the Q, K, V input projection matrices and do a single matmul, instead of 3. Adding this optimization in MHAWithCache. Differential Revision: D48418780 fbshipit-source-id: e8001eb870e827b05146221bb66f82939deae0c6
1 parent 0dc3c21 commit dfd2ec6

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

torchmultimodal/modules/layers/multi_head_attention.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import NamedTuple, Optional, Tuple, Union
88

99
import torch
10-
1110
import torch.nn.functional as F
1211
from torch import nn, Tensor
1312

@@ -17,6 +16,14 @@ class MHAWithCacheOutput(NamedTuple):
1716
past_key_value: Tuple[Tensor, Tensor]
1817

1918

19+
def _batched_input_proj(
20+
query: Tensor, input_proj: Tensor
21+
) -> Tuple[Tensor, Tensor, Tensor]:
22+
projected_query = input_proj(query)
23+
query, key, value = projected_query.chunk(3, dim=-1)
24+
return query, key, value
25+
26+
2027
class MultiHeadSelfAttention(nn.Module):
2128
"""
2229
Multihead self attention.
@@ -59,8 +66,7 @@ def forward(
5966

6067
bsz = query.size(0)
6168
embed_dim = query.size(-1)
62-
projected_query = self.input_proj(query)
63-
query, key, value = projected_query.chunk(3, dim=-1)
69+
query, key, value = _batched_input_proj(query=query, input_proj=self.input_proj)
6470

6571
head_dim = embed_dim // self.num_heads
6672
# bsz x seq len x embed_dim => bsz x num_heads x seq len x head_dim
@@ -105,9 +111,15 @@ def __init__(
105111
) -> None:
106112
super().__init__()
107113
self.num_heads = num_heads
108-
self.q_proj = nn.Linear(dim_q, dim_q, bias=add_bias)
109-
self.k_proj = nn.Linear(dim_kv, dim_q, bias=add_bias)
110-
self.v_proj = nn.Linear(dim_kv, dim_q, bias=add_bias)
114+
if dim_kv == dim_q:
115+
# Module is being used for self-attention, so batch the matmuls
116+
self.input_proj = nn.Linear(dim_q, 3 * dim_q, bias=add_bias)
117+
self.is_self_attn = True
118+
else:
119+
self.q_proj = nn.Linear(dim_q, dim_q, bias=add_bias)
120+
self.k_proj = nn.Linear(dim_kv, dim_q, bias=add_bias)
121+
self.v_proj = nn.Linear(dim_kv, dim_q, bias=add_bias)
122+
self.is_self_attn = False
111123
self.output_proj = nn.Linear(dim_q, dim_q)
112124
self.dropout = dropout
113125

@@ -144,9 +156,14 @@ def forward(
144156
bsz = query.size(0)
145157
embed_dim = query.size(-1)
146158
head_dim = embed_dim // self.num_heads
147-
query = self.q_proj(query)
148-
key = self.k_proj(key)
149-
value = self.v_proj(value)
159+
if self.is_self_attn:
160+
query, key, value = _batched_input_proj(
161+
query=query, input_proj=self.input_proj
162+
)
163+
else:
164+
query = self.q_proj(query)
165+
key = self.k_proj(key)
166+
value = self.v_proj(value)
150167

151168
# bsz x seq_len x embed_dim => bsz x num_heads x seq_len x head_dim
152169
query = query.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2)

0 commit comments

Comments
 (0)