@@ -77,6 +77,92 @@ def forward(
7777 return attn_out
7878
7979
80+ class MultiHeadSelfAttentionWithCache (nn .Module ):
81+ """
82+ MultiHeadAttention module for self-attention(SA). Similar to MultiHeadAttentionWithCache,
83+ but only supports self attention and uses a fast path where the query, key, and value projections
84+ are batched into a single matrix multiplication as opposed to three separate matmuls.
85+ This class supports a cache mechanism for decoders to store previous states through
86+ "past_key_value".
87+
88+ Args:
89+ embed_dim (int): query, key, value embedding dimension
90+ num_heads (int): number of attention heads
91+ dropout (float): dropout rate
92+ add_bias (bool): if ``True``, adds a learnable bias to query, key, value input projection matrix.
93+ Defaults to True.
94+ """
95+
96+ def __init__ (
97+ self ,
98+ embed_dim : int ,
99+ num_heads : int ,
100+ dropout : float = 0.0 ,
101+ add_bias : bool = True ,
102+ ) -> None :
103+ super ().__init__ ()
104+ self .num_heads = num_heads
105+ self .input_proj = nn .Linear (embed_dim , 3 * embed_dim , bias = add_bias )
106+ self .output_proj = nn .Linear (embed_dim , embed_dim )
107+ self .dropout = dropout
108+
109+ def forward (
110+ self ,
111+ query : Tensor ,
112+ attn_mask : Optional [Tensor ] = None ,
113+ past_key_value : Optional [Tuple [Tensor , Tensor ]] = None ,
114+ is_causal : bool = False ,
115+ use_cache : bool = False ,
116+ ) -> Union [Tensor , MHAWithCacheOutput ]:
117+ """
118+ Args:
119+ query (Tensor): input query of shape bsz x target_seq_len x embed_dim
120+ attn_mask (optional Tensor): Attention mask of shape bsz x num_heads x target_seq_len x source_seq_len.
121+ Note that the num_heads dimension can equal 1 and the mask will be broadcasted to all heads.
122+ Two types of masks are supported. A boolean mask where a value of True
123+ indicates that the element *should* take part in attention.
124+ A float mask of the same type as query, key, value that is added to the attention score.
125+ past_key_value (optional tuple of tensors): cached key and value with the same shape of key, value inputs.
126+ The size of tuple should be 2, where the first entry is for cached key and second entry is for cached value.
127+ is_causal (bool): If true, does causal attention masking, attn_mask should be set to None if this is set to True
128+ is_causal is a hint that the mask is a causal mask, providing incorrect hints can result in incorrect execution.
129+ use_cache (bool): whether to use cache for key and value tensors
130+
131+ Returns:
132+ if use_cache is off, return attn_output tensor of shape bsz x seq_len x embed_dim;
133+ otherwise return namedtuple with attn_output, cached key and value.
134+ """
135+ bsz = query .size (0 )
136+ embed_dim = query .size (- 1 )
137+ head_dim = embed_dim // self .num_heads
138+ projected_query = self .input_proj (query )
139+ query , key , value = projected_query .chunk (3 , dim = - 1 )
140+
141+ # bsz x seq_len x embed_dim => bsz x num_heads x seq_len x head_dim
142+ query = query .view (bsz , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
143+ if key .size (0 ) != bsz or value .size (0 ) != bsz :
144+ raise ValueError ("key and value should have the same bsz as query." )
145+ key = key .view (bsz , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
146+ value = value .view (bsz , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
147+
148+ # concat key value with cached values
149+ if past_key_value is not None :
150+ key = torch .cat ([past_key_value [0 ], key ], dim = 2 )
151+ value = torch .cat ([past_key_value [1 ], value ], dim = 2 )
152+
153+ # turn off causal attention inside scaled_dot_product_attention, we handle it separately with attn_mask.
154+ attn = F .scaled_dot_product_attention (
155+ query , key , value , attn_mask , self .dropout , is_causal
156+ )
157+ attn = attn .transpose (1 , 2 ).reshape (bsz , - 1 , embed_dim )
158+
159+ # add dense layer after attention
160+ attn_output = self .output_proj (attn )
161+ if use_cache :
162+ return MHAWithCacheOutput (attn_output , (key , value ))
163+ return attn_output
164+
165+
80166class MultiHeadAttentionWithCache (nn .Module ):
81167 """
82168 MultiHeadAttention module for both self-attention(SA) and cross-attention(CA).
0 commit comments