Skip to content

Commit 5b13fbf

Browse files
authored
Add bidirectional masking for EmbeddingGemma (#2382)
1 parent e7dd9b0 commit 5b13fbf

File tree

3 files changed

+64
-0
lines changed

3 files changed

+64
-0
lines changed

keras_hub/src/models/gemma3/gemma3_attention.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
layer_norm_epsilon=1e-6,
4747
rope_wavelength=10_000.0,
4848
rope_scaling_factor=1.0,
49+
use_bidirectional_attention=False,
4950
dropout=0,
5051
**kwargs,
5152
):
@@ -61,6 +62,7 @@ def __init__(
6162
self.layer_norm_epsilon = layer_norm_epsilon
6263
self.rope_wavelength = rope_wavelength
6364
self.rope_scaling_factor = rope_scaling_factor
65+
self.use_bidirectional_attention = use_bidirectional_attention
6466
self.dropout = dropout
6567

6668
self._kernel_initializer = keras.initializers.get(
@@ -240,12 +242,58 @@ def _compute_attention(
240242
results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
241243
return ops.reshape(results, (b, q_len, self.num_query_heads, h))
242244

245+
def _compute_bidirectional_sliding_mask(self, batch_size, sequence_length):
246+
"""Computes a bidirectional sliding window attention mask.
247+
248+
A token can attend to any other token if their absolute distance is
249+
within half the sliding window size. This mask is used in embedding
250+
models like `EmbeddingGemma`.
251+
252+
Args:
253+
batch_size: The batch size for the mask.
254+
sequence_length: The length of the sequence.
255+
256+
Returns:
257+
A boolean attention mask with shape
258+
`(batch_size, sequence_length, sequence_length)`.
259+
"""
260+
i = keras.ops.expand_dims(
261+
keras.ops.arange(sequence_length, dtype="int32"), axis=1
262+
)
263+
j = keras.ops.arange(sequence_length, dtype="int32")
264+
265+
# If sliding window size is 4, the token in question attends to 1
266+
# token before and 2 tokens after.
267+
w_right = self.sliding_window_size // 2
268+
w_left = self.sliding_window_size - w_right - 1
269+
270+
# Calculate the relative distance.
271+
distance = i - j
272+
273+
mask = keras.ops.logical_and(distance <= w_left, distance >= -w_right)
274+
275+
mask = keras.ops.expand_dims(mask, axis=0)
276+
return keras.ops.broadcast_to(
277+
mask, (batch_size, sequence_length, sequence_length)
278+
)
279+
243280
def _mask_sliding_window(
244281
self,
245282
attention_mask,
246283
cache_update_index=0,
247284
):
248285
batch_size, query_len, key_len = ops.shape(attention_mask)
286+
287+
if self.use_bidirectional_attention:
288+
bidirectional_sliding_mask = (
289+
self._compute_bidirectional_sliding_mask(
290+
batch_size=batch_size,
291+
# `query_len = key_len` for embedding models
292+
sequence_length=query_len,
293+
)
294+
)
295+
return ops.logical_and(attention_mask, bidirectional_sliding_mask)
296+
249297
# Compute the sliding window for square attention.
250298
all_ones = ops.ones((key_len, key_len), "bool")
251299
if keras.config.backend() == "tensorflow":

keras_hub/src/models/gemma3/gemma3_backbone.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def __init__(
196196
global_rope_scaling_factor=1.0,
197197
vision_encoder=None,
198198
layer_norm_epsilon=1e-6,
199+
use_bidirectional_attention=False,
199200
dropout=0,
200201
dtype=None,
201202
**kwargs,
@@ -251,6 +252,7 @@ def __init__(
251252
sliding_window_size=sliding_window_size,
252253
rope_wavelength=rope_wavelength,
253254
rope_scaling_factor=rope_scaling_factor,
255+
use_bidirectional_attention=use_bidirectional_attention,
254256
dropout=dropout,
255257
dtype=dtype,
256258
name=f"decoder_block_{i}",
@@ -357,6 +359,7 @@ def __init__(
357359
self.sliding_window_size = sliding_window_size
358360
self.local_rope_scaling_factor = local_rope_scaling_factor
359361
self.global_rope_scaling_factor = global_rope_scaling_factor
362+
self.use_bidirectional_attention = use_bidirectional_attention
360363
self.layer_norm_epsilon = layer_norm_epsilon
361364
self.dropout = dropout
362365

@@ -396,6 +399,7 @@ def get_config(self):
396399
"vision_encoder": None
397400
if self.vision_encoder is None
398401
else keras.layers.serialize(self.vision_encoder),
402+
"use_bidirectional_attention": self.use_bidirectional_attention,
399403
"layer_norm_epsilon": self.layer_norm_epsilon,
400404
"dropout": self.dropout,
401405
}

keras_hub/src/models/gemma3/gemma3_decoder_block.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
layer_norm_epsilon=1e-6,
4646
rope_wavelength=10_000.0,
4747
rope_scaling_factor=1.0,
48+
use_bidirectional_attention=False,
4849
dropout=0,
4950
**kwargs,
5051
):
@@ -66,6 +67,7 @@ def __init__(
6667
self.layer_norm_epsilon = layer_norm_epsilon
6768
self.rope_wavelength = rope_wavelength
6869
self.rope_scaling_factor = rope_scaling_factor
70+
self.use_bidirectional_attention = use_bidirectional_attention
6971
self.dropout = dropout
7072

7173
self.pre_attention_norm = RMSNormalization(
@@ -93,6 +95,7 @@ def __init__(
9395
rope_wavelength=rope_wavelength,
9496
rope_scaling_factor=rope_scaling_factor,
9597
dropout=dropout,
98+
use_bidirectional_attention=use_bidirectional_attention,
9699
dtype=self.dtype_policy,
97100
name="attention",
98101
)
@@ -209,6 +212,14 @@ def _compute_attention_mask(
209212
if cache is not None:
210213
input_length = ops.shape(cache)[2]
211214

215+
if self.use_bidirectional_attention:
216+
# `output_length` and `input_length` will be the same in this case
217+
# because we use bidirectional attention for models like
218+
# `EmbeddingGemma` which aren't used for text generation.
219+
mask_1 = decoder_mask
220+
mask_2 = ops.transpose(mask_1, (0, 2, 1))
221+
return mask_1 * mask_2
222+
212223
causal_mask = compute_causal_mask(
213224
batch_size=batch_size,
214225
input_length=input_length,
@@ -304,6 +315,7 @@ def get_config(self):
304315
"dropout": self.dropout,
305316
"rope_wavelength": self.rope_wavelength,
306317
"rope_scaling_factor": self.rope_scaling_factor,
318+
"use_bidirectional_attention": self.use_bidirectional_attention,
307319
}
308320
)
309321
return config

0 commit comments

Comments
 (0)