@@ -46,6 +46,7 @@ def __init__(
46
46
layer_norm_epsilon = 1e-6 ,
47
47
rope_wavelength = 10_000.0 ,
48
48
rope_scaling_factor = 1.0 ,
49
+ use_bidirectional_attention = False ,
49
50
dropout = 0 ,
50
51
** kwargs ,
51
52
):
@@ -61,6 +62,7 @@ def __init__(
61
62
self .layer_norm_epsilon = layer_norm_epsilon
62
63
self .rope_wavelength = rope_wavelength
63
64
self .rope_scaling_factor = rope_scaling_factor
65
+ self .use_bidirectional_attention = use_bidirectional_attention
64
66
self .dropout = dropout
65
67
66
68
self ._kernel_initializer = keras .initializers .get (
@@ -240,12 +242,58 @@ def _compute_attention(
240
242
results = ops .einsum ("bkgts,bskh->btkgh" , attention_softmax , v )
241
243
return ops .reshape (results , (b , q_len , self .num_query_heads , h ))
242
244
245
+ def _compute_bidirectional_sliding_mask (self , batch_size , sequence_length ):
246
+ """Computes a bidirectional sliding window attention mask.
247
+
248
+ A token can attend to any other token if their absolute distance is
249
+ within half the sliding window size. This mask is used in embedding
250
+ models like `EmbeddingGemma`.
251
+
252
+ Args:
253
+ batch_size: The batch size for the mask.
254
+ sequence_length: The length of the sequence.
255
+
256
+ Returns:
257
+ A boolean attention mask with shape
258
+ `(batch_size, sequence_length, sequence_length)`.
259
+ """
260
+ i = keras .ops .expand_dims (
261
+ keras .ops .arange (sequence_length , dtype = "int32" ), axis = 1
262
+ )
263
+ j = keras .ops .arange (sequence_length , dtype = "int32" )
264
+
265
+ # If sliding window size is 4, the token in question attends to 1
266
+ # token before and 2 tokens after.
267
+ w_right = self .sliding_window_size // 2
268
+ w_left = self .sliding_window_size - w_right - 1
269
+
270
+ # Calculate the relative distance.
271
+ distance = i - j
272
+
273
+ mask = keras .ops .logical_and (distance <= w_left , distance >= - w_right )
274
+
275
+ mask = keras .ops .expand_dims (mask , axis = 0 )
276
+ return keras .ops .broadcast_to (
277
+ mask , (batch_size , sequence_length , sequence_length )
278
+ )
279
+
243
280
def _mask_sliding_window (
244
281
self ,
245
282
attention_mask ,
246
283
cache_update_index = 0 ,
247
284
):
248
285
batch_size , query_len , key_len = ops .shape (attention_mask )
286
+
287
+ if self .use_bidirectional_attention :
288
+ bidirectional_sliding_mask = (
289
+ self ._compute_bidirectional_sliding_mask (
290
+ batch_size = batch_size ,
291
+ # `query_len = key_len` for embedding models
292
+ sequence_length = query_len ,
293
+ )
294
+ )
295
+ return ops .logical_and (attention_mask , bidirectional_sliding_mask )
296
+
249
297
# Compute the sliding window for square attention.
250
298
all_ones = ops .ones ((key_len , key_len ), "bool" )
251
299
if keras .config .backend () == "tensorflow" :
0 commit comments