77from typing import NamedTuple , Optional , Tuple , Union
88
99import torch
10-
1110import torch .nn .functional as F
1211from torch import nn , Tensor
1312
@@ -17,6 +16,14 @@ class MHAWithCacheOutput(NamedTuple):
1716 past_key_value : Tuple [Tensor , Tensor ]
1817
1918
19+ def _batched_input_proj (
20+ query : Tensor , input_proj : Tensor
21+ ) -> Tuple [Tensor , Tensor , Tensor ]:
22+ projected_query = input_proj (query )
23+ query , key , value = projected_query .chunk (3 , dim = - 1 )
24+ return query , key , value
25+
26+
2027class MultiHeadSelfAttention (nn .Module ):
2128 """
2229 Multihead self attention.
@@ -59,8 +66,7 @@ def forward(
5966
6067 bsz = query .size (0 )
6168 embed_dim = query .size (- 1 )
62- projected_query = self .input_proj (query )
63- query , key , value = projected_query .chunk (3 , dim = - 1 )
69+ query , key , value = _batched_input_proj (query = query , input_proj = self .input_proj )
6470
6571 head_dim = embed_dim // self .num_heads
6672 # bsz x seq len x embed_dim => bsz x num_heads x seq len x head_dim
@@ -105,9 +111,15 @@ def __init__(
105111 ) -> None :
106112 super ().__init__ ()
107113 self .num_heads = num_heads
108- self .q_proj = nn .Linear (dim_q , dim_q , bias = add_bias )
109- self .k_proj = nn .Linear (dim_kv , dim_q , bias = add_bias )
110- self .v_proj = nn .Linear (dim_kv , dim_q , bias = add_bias )
114+ if dim_kv == dim_q :
115+ # Module is being used for self-attention, so batch the matmuls
116+ self .input_proj = nn .Linear (dim_q , 3 * dim_q , bias = add_bias )
117+ self .is_self_attn = True
118+ else :
119+ self .q_proj = nn .Linear (dim_q , dim_q , bias = add_bias )
120+ self .k_proj = nn .Linear (dim_kv , dim_q , bias = add_bias )
121+ self .v_proj = nn .Linear (dim_kv , dim_q , bias = add_bias )
122+ self .is_self_attn = False
111123 self .output_proj = nn .Linear (dim_q , dim_q )
112124 self .dropout = dropout
113125
@@ -144,9 +156,14 @@ def forward(
144156 bsz = query .size (0 )
145157 embed_dim = query .size (- 1 )
146158 head_dim = embed_dim // self .num_heads
147- query = self .q_proj (query )
148- key = self .k_proj (key )
149- value = self .v_proj (value )
159+ if self .is_self_attn :
160+ query , key , value = _batched_input_proj (
161+ query = query , input_proj = self .input_proj
162+ )
163+ else :
164+ query = self .q_proj (query )
165+ key = self .k_proj (key )
166+ value = self .v_proj (value )
150167
151168 # bsz x seq_len x embed_dim => bsz x num_heads x seq_len x head_dim
152169 query = query .view (bsz , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
0 commit comments