@@ -77,6 +77,92 @@ def forward(
7777        return  attn_out 
7878
7979
80+ class  MultiHeadSelfAttentionWithCache (nn .Module ):
81+     """ 
82+     MultiHeadAttention module for self-attention(SA). Similar to MultiHeadAttentionWithCache, 
83+     but only supports self attention and uses a fast path where the query, key, and value projections 
84+     are batched into a single matrix multiplication as opposed to three separate matmuls. 
85+     This class supports a cache mechanism for decoders to store previous states through 
86+     "past_key_value". 
87+ 
88+     Args: 
89+         embed_dim (int): query, key, value embedding dimension 
90+         num_heads (int): number of attention heads 
91+         dropout (float): dropout rate 
92+         add_bias (bool): if ``True``, adds a learnable bias to query, key, value input projection matrix. 
93+             Defaults to True. 
94+     """ 
95+ 
96+     def  __init__ (
97+         self ,
98+         embed_dim : int ,
99+         num_heads : int ,
100+         dropout : float  =  0.0 ,
101+         add_bias : bool  =  True ,
102+     ) ->  None :
103+         super ().__init__ ()
104+         self .num_heads  =  num_heads 
105+         self .input_proj  =  nn .Linear (embed_dim , 3  *  embed_dim , bias = add_bias )
106+         self .output_proj  =  nn .Linear (embed_dim , embed_dim )
107+         self .dropout  =  dropout 
108+ 
109+     def  forward (
110+         self ,
111+         query : Tensor ,
112+         attn_mask : Optional [Tensor ] =  None ,
113+         past_key_value : Optional [Tuple [Tensor , Tensor ]] =  None ,
114+         is_causal : bool  =  False ,
115+         use_cache : bool  =  False ,
116+     ) ->  Union [Tensor , MHAWithCacheOutput ]:
117+         """ 
118+         Args: 
119+             query (Tensor): input query of shape bsz x target_seq_len x embed_dim 
120+             attn_mask (optional Tensor): Attention mask of shape bsz x num_heads x target_seq_len x source_seq_len. 
121+                 Note that the num_heads dimension can equal 1 and the mask will be broadcasted to all heads. 
122+                 Two types of masks are supported. A boolean mask where a value of True 
123+                 indicates that the element *should* take part in attention. 
124+                 A float mask of the same type as query, key, value that is added to the attention score. 
125+             past_key_value (optional tuple of tensors): cached key and value with the same shape of key, value inputs. 
126+                 The size of tuple should be 2, where the first entry is for cached key and second entry is for cached value. 
127+             is_causal (bool): If true, does causal attention masking, attn_mask should be set to None if this is set to True 
128+                  is_causal is a hint that the mask is a causal mask, providing incorrect hints can result in incorrect execution. 
129+             use_cache (bool): whether to use cache for key and value tensors 
130+ 
131+         Returns: 
132+             if use_cache is off, return attn_output tensor of shape bsz x seq_len x embed_dim; 
133+             otherwise return namedtuple with attn_output, cached key and value. 
134+         """ 
135+         bsz  =  query .size (0 )
136+         embed_dim  =  query .size (- 1 )
137+         head_dim  =  embed_dim  //  self .num_heads 
138+         projected_query  =  self .input_proj (query )
139+         query , key , value  =  projected_query .chunk (3 , dim = - 1 )
140+ 
141+         # bsz x seq_len x embed_dim => bsz x num_heads x seq_len x head_dim 
142+         query  =  query .view (bsz , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
143+         if  key .size (0 ) !=  bsz  or  value .size (0 ) !=  bsz :
144+             raise  ValueError ("key and value should have the same bsz as query." )
145+         key  =  key .view (bsz , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
146+         value  =  value .view (bsz , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
147+ 
148+         # concat key value with cached values 
149+         if  past_key_value  is  not   None :
150+             key  =  torch .cat ([past_key_value [0 ], key ], dim = 2 )
151+             value  =  torch .cat ([past_key_value [1 ], value ], dim = 2 )
152+ 
153+         # turn off causal attention inside scaled_dot_product_attention, we handle it separately with attn_mask. 
154+         attn  =  F .scaled_dot_product_attention (
155+             query , key , value , attn_mask , self .dropout , is_causal 
156+         )
157+         attn  =  attn .transpose (1 , 2 ).reshape (bsz , - 1 , embed_dim )
158+ 
159+         # add dense layer after attention 
160+         attn_output  =  self .output_proj (attn )
161+         if  use_cache :
162+             return  MHAWithCacheOutput (attn_output , (key , value ))
163+         return  attn_output 
164+ 
165+ 
80166class  MultiHeadAttentionWithCache (nn .Module ):
81167    """ 
82168    MultiHeadAttention module for both self-attention(SA) and cross-attention(CA). 
0 commit comments