Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions keras_hub/src/layers/modeling/transformer_layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,18 @@ def merge_padding_and_attention_mask(
else:
return ops.minimum(mask, attention_mask)
return mask


def compute_positions_from_mask(mask):
"""Computes positions from provided padding mask.

Args:
mask: Tensor of shape `(batch_size, sequence_length)`. Padding mask,
1 for non-padding tokens, 0 for padding tokens.

Returns:
positions: Tensor of the same shape as `mask`, which contains indices
corresponding to positions of tokens in the sequence.
"""
positions = ops.cumsum(mask, axis=-1)
return ops.subtract(positions, ops.greater_equal(positions, 1))
12 changes: 12 additions & 0 deletions keras_hub/src/layers/modeling/transformer_layer_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,15 @@ def test_bad_mask_shapes(self):
padding_mask,
attention_mask,
)

def test_compute_positions_from_mask(self):
mask = ops.array(
[
[False, False, True, True, False],
[True, False, True, False, True],
]
)
output = utils.compute_positions_from_mask(mask)

expected_output = ops.array([[0, 0, 0, 1, 1], [0, 0, 1, 1, 2]])
self.assertAllEqual(output, expected_output)
9 changes: 5 additions & 4 deletions keras_hub/src/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def build(self, inputs_shape):

self.built = True

def _apply_rope(self, x, start_index):
def _apply_rope(self, x, start_index, positions=None):
"""Rope rotate q or k."""
x = self.rope_layer(x, start_index=start_index)
x = self.rope_layer(x, start_index=start_index, positions=positions)
# Gemma uses a different layout for positional embeddings.
# The transformation below ensures the embeddings are numerically
# equivalent to the original gemma implementation.
Expand Down Expand Up @@ -230,12 +230,13 @@ def call(
self,
x,
attention_mask=None,
positions=None,
cache=None,
cache_update_index=0,
training=False,
):
query = self.query_dense(x)
query = self._apply_rope(query, cache_update_index)
query = self._apply_rope(query, cache_update_index, positions=positions)

if cache is not None:
key_cache = cache[:, 0, ...]
Expand All @@ -249,7 +250,7 @@ def call(
cache = ops.stack((key, value), axis=1)
else:
key = self.key_dense(x)
key = self._apply_rope(key, cache_update_index)
key = self._apply_rope(key, cache_update_index, positions=positions)
value = self.value_dense(x)

attention_vec = self._compute_attention(
Expand Down
14 changes: 14 additions & 0 deletions keras_hub/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def test_backbone_basics(self):
expected_output_shape=(2, 5, 16),
)

def test_flexible_positions(self):
self.run_positions_test(
cls=GemmaBackbone,
init_kwargs=self.init_kwargs,
vocabulary_size=self.init_kwargs["vocabulary_size"],
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
Expand Down Expand Up @@ -188,6 +195,13 @@ def test_backbone_basics(self):
expected_output_shape=(2, 10, 16),
)

def test_flexible_positions(self):
self.run_positions_test(
cls=GemmaBackbone,
init_kwargs=self.init_kwargs,
vocabulary_size=self.init_kwargs["vocabulary_size"],
)

def test_sliding_window(self):
# Test sliding window correctness by hand.
backbone = GemmaBackbone(**self.init_kwargs)
Expand Down
8 changes: 8 additions & 0 deletions keras_hub/src/models/gemma/gemma_decoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from keras_hub.src.layers.modeling.transformer_layer_utils import (
compute_causal_mask,
)
from keras_hub.src.layers.modeling.transformer_layer_utils import (
compute_positions_from_mask,
)
from keras_hub.src.layers.modeling.transformer_layer_utils import (
merge_padding_and_attention_mask,
)
Expand Down Expand Up @@ -178,9 +181,14 @@ def call(
cache_update_index=cache_update_index,
)
else:
positions = None
if padding_mask is not None:
positions = compute_positions_from_mask(padding_mask)

attention = self.attention(
normalized_x,
attention_mask=attention_mask,
positions=positions,
)

if self.use_post_attention_norm:
Expand Down
41 changes: 41 additions & 0 deletions keras_hub/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,47 @@ def compare(actual, expected):
output = ops.argmax(output, axis=-1)
self.assertAllEqual(output, expected_labels)

def run_positions_test(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we think this is generic and will extend beyond gemma? if so ok to leave here. if not I might park this directly in the gemma tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we should do this for all CausalLMs. I've done it only for Gemma for now, will extend to other models later.

self,
cls,
init_kwargs,
vocabulary_size,
):
"""Tests that conventional and flexible positions give same output."""
model = cls(**init_kwargs)

rng = np.random.default_rng(seed=42)
x1 = {
"token_ids": rng.integers(low=1, high=vocabulary_size, size=(2, 5)),
"padding_mask": np.array(
[
[True] * 3 + [False] * 2,
[True] * 2 + [False] * 3,
]
),
}
# Convert token_ids to list for easier manipulation.
token_ids_lst = x1["token_ids"].tolist()
x2 = {
"token_ids": np.array(
[
[0] + token_ids_lst[0][:3] + [0],
[0] * 2 + token_ids_lst[1][:2] + [0],
]
),
"padding_mask": np.array(
[
[False] + [True] * 3 + [False],
[False] * 2 + [True] * 2 + [False],
]
),
}

output_1 = model.predict(x1)
output_2 = model.predict(x2)
self.assertAllClose(output_1[0][:3], output_2[0][1:4])
self.assertAllClose(output_1[1][:2], output_2[1][2:4])

def get_test_data_dir(self):
return str(pathlib.Path(__file__).parent / "test_data")

Expand Down
Loading