From 4a8566be1ef2fdb8ff19b2092b3074483fcafa7f Mon Sep 17 00:00:00 2001 From: Rahul Yedida Date: Wed, 7 May 2025 21:54:59 -0400 Subject: [PATCH 01/11] feat(phi4): add phi4_backbone --- keras_hub/api/models/__init__.py | 1 + keras_hub/src/models/phi4/__init__.py | 1 + keras_hub/src/models/phi4/phi4_attention.py | 268 ++++++++++++++++++ keras_hub/src/models/phi4/phi4_backbone.py | 214 ++++++++++++++ .../src/models/phi4/phi4_backbone_test.py | 92 ++++++ keras_hub/src/models/phi4/phi4_decoder.py | 246 ++++++++++++++++ keras_hub/src/models/phi4/phi4_layernorm.py | 35 +++ .../src/models/phi4/phi4_rotary_embedding.py | 124 ++++++++ 8 files changed, 981 insertions(+) create mode 100644 keras_hub/src/models/phi4/__init__.py create mode 100644 keras_hub/src/models/phi4/phi4_attention.py create mode 100644 keras_hub/src/models/phi4/phi4_backbone.py create mode 100644 keras_hub/src/models/phi4/phi4_backbone_test.py create mode 100644 keras_hub/src/models/phi4/phi4_decoder.py create mode 100644 keras_hub/src/models/phi4/phi4_layernorm.py create mode 100644 keras_hub/src/models/phi4/phi4_rotary_embedding.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 2a78362e9a..6334e2a4dd 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -397,6 +397,7 @@ from keras_hub.src.models.phi3.phi3_tokenizer import ( Phi3Tokenizer as Phi3Tokenizer, ) +from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone as Phi4Backbone from keras_hub.src.models.preprocessor import Preprocessor as Preprocessor from keras_hub.src.models.qwen.qwen_backbone import ( QwenBackbone as Qwen2Backbone, diff --git a/keras_hub/src/models/phi4/__init__.py b/keras_hub/src/models/phi4/__init__.py new file mode 100644 index 0000000000..c03faf3c17 --- /dev/null +++ b/keras_hub/src/models/phi4/__init__.py @@ -0,0 +1 @@ +# TODO: Add a register_presets call once phi4_presets.py is implemented. diff --git a/keras_hub/src/models/phi4/phi4_attention.py b/keras_hub/src/models/phi4/phi4_attention.py new file mode 100644 index 0000000000..f5563d6b92 --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_attention.py @@ -0,0 +1,268 @@ +import math + +import keras +from keras import ops + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.phi4.phi4_rotary_embedding import ( + Phi4SuScaledRotaryEmbedding, +) +from keras_hub.src.utils.keras_utils import clone_initializer +from keras_hub.src.utils.keras_utils import fused_attention_op_available + + +class Phi4Attention(keras.layers.Layer): + """A cached grounded query attention layer.""" + + def __init__( + self, + num_query_heads=40, + num_key_value_heads=10, + kernel_initializer="glorot_uniform", + dropout=0, + max_sequence_length=16_384, + pretraining_sequence_length=16_384, + rope_max_wavelength=250_000, + rope_scaling_type=None, + rope_scaling_short_factor=None, + rope_scaling_long_factor=None, + **kwargs, + ): + super().__init__(**kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = num_query_heads // num_key_value_heads + self.dropout = dropout + + self.max_sequence_length = max_sequence_length + self.pretraining_sequence_length = pretraining_sequence_length + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_type = rope_scaling_type + self.rope_scaling_short_factor = rope_scaling_short_factor + self.rope_scaling_long_factor = rope_scaling_long_factor + + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + + def build(self, inputs_shape): + # Einsum variables: + # b = batch size + # q = query length + # k = key/value length + # m = model dim + # u = num query heads + # v = num key/value heads + # h = head dim + hidden_dim = inputs_shape[-1] + head_dim = hidden_dim // self.num_query_heads + self._inv_norm_factor = 1.0 / math.sqrt(head_dim) + + self.query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=(None, self.num_query_heads, head_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(inputs_shape) + + self.key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(inputs_shape) + + self.value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(inputs_shape) + + self.softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + + self.dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + self.output_dense = keras.layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, hidden_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build((None, None, self.num_query_heads, head_dim)) + + if self.rope_scaling_type is None: + self.rotary_embedding_layer = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + dtype=self.dtype_policy, + ) + elif self.rope_scaling_type == "su": + if len(self.rope_scaling_short_factor) != head_dim // 2: + raise ValueError( + "`rope_scaling_short_factor` must be of length " + "`hidden_dim//num_query_heads//2`. " + "`len(rope_scaling_short_factor)` is " + f"{len(self.rope_scaling_short_factor)} " + f"while it should be {head_dim // 2}." + ) + if len(self.rope_scaling_long_factor) != head_dim // 2: + raise ValueError( + "`rope_scaling_long_factor` must be of length " + "`hidden_dim//num_query_heads//2`. " + "`len(rope_scaling_long_factor)` is " + f"{len(self.rope_scaling_long_factor)} " + f"while it should be {head_dim // 2}." + ) + self.rotary_embedding_layer = Phi4SuScaledRotaryEmbedding( + inverese_freq_short_factor=self.rope_scaling_short_factor, + inverese_freq_long_factor=self.rope_scaling_long_factor, + max_sequence_length=self.max_sequence_length, + pretraining_sequence_length=self.pretraining_sequence_length, + max_wavelength=self.rope_max_wavelength, + dtype=self.dtype_policy, + ) + else: + raise ValueError( + '`rope_scaling_type` must be `None` or `"su"`.' + "if `None` is choosed, `RotaryEmbedding` will be used." + 'if `"su"` is choosed, `Phi4SuScaledRotaryEmbedding` will be ' + "used." + ) + + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + start_index = ( + cache_update_index if cache_update_index is not None else 0 + ) + + query = self.query_dense(hidden_states) + key = self.key_dense(hidden_states) + value = self.value_dense(hidden_states) + + # Compute RoPE for queries + query = self.rotary_embedding_layer(query, start_index=start_index) + key = self.rotary_embedding_layer(key, start_index=start_index) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key) + value = ops.slice_update(value_cache, start, value) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + + # [batch_shape, seq_len, num_key_value_heads, head_dim] + # -> [batch_shape, seq_len, num_heads, head_dim] + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + + attention_output = self._compute_attention( + query, key, value, attention_mask + ) + + attention_output = self.dropout_layer( + attention_output, training=training + ) + + attention_output = self.output_dense(attention_output) + + if cache is not None: + return attention_output, cache + return attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + if attention_mask is not None: + return self.softmax(attention_scores, attention_mask[:, None, :, :]) + return self.softmax(attention_scores) + + def _compute_attention(self, query, key, value, attention_mask=None): + if fused_attention_op_available(): + # Use `dot_product_attention` with Flash Attention support if + # available. + if attention_mask is not None: + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.cast(attention_mask, dtype="bool") + attention_output = ops.dot_product_attention( + query, + key, + value, + mask=attention_mask, + scale=self._inv_norm_factor, + ) + return attention_output + + attention_scores = ops.einsum("bquh,bkuh->buqk", query, key) + attention_scores = ops.multiply( + attention_scores, + ops.cast(self._inv_norm_factor, self.compute_dtype), + ) + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum( + "buqk,bkuh->bquh", attention_scores, value + ) + + return attention_output + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "pretraining_sequence_length": self.pretraining_sequence_length, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_type": self.rope_scaling_type, + "rope_scaling_short_factor": self.rope_scaling_short_factor, + "rope_scaling_long_factor": self.rope_scaling_long_factor, + } + ) + return config diff --git a/keras_hub/src/models/phi4/phi4_backbone.py b/keras_hub/src/models/phi4/phi4_backbone.py new file mode 100644 index 0000000000..1074853a22 --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_backbone.py @@ -0,0 +1,214 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.phi4.phi4_decoder import Phi4Decoder +from keras_hub.src.models.phi4.phi4_layernorm import Phi4LayerNorm + + +def _phi4_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_hub_export("keras_hub.models.Phi4Backbone") +class Phi4Backbone(Backbone): + """Phi-4 core network with hyperparameters. + + This network implements a Transformer-based decoder network, + Phi-4, as described in ["Phi-4 Technical Report"](https://arxiv.org/pdf/2412.08905). + It includes the embedding lookups and transformer layers. + + The default constructor gives a fully customizable, randomly initialized + phi-4 model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. + + Args: + vocabulary_size (int): The size of the token vocabulary. + num_layers (int): The number of transformer layers. + hidden_dim (int): The size of the embeddings and the hidden states of + the transformer layers. + intermediate_dim (int): The output dimension of the first Dense layer in + a three-layer feedforward network for each transformer. + num_query_heads (int): The number of query attention heads for each + transformer layer. + num_key_value_heads (int): The number of key and value attention heads + for each transformer layer. + layer_norm_epsilon (float, optional): Epsilon for the RMS layernorm + layers in the transformer decoder. Defaults to `1e-6`. + dropout: (float, optional): Dropout probability for the Transformer + decoder. + max_sequence_length (int, optional): The maximum sequence length + that this model might ever be used with. Defaults to `4096`. + pretraining_sequence_length (int, optional): The maximum sequence length + that the model was pretrained with. Defaults to `4096`. + rope_max_wavelength (int, optional): The maximum angular wavelength of + the sine/cosine curves, for rotary embeddings. Defaults to `10000`. + rope_scaling_type (str, optional): The type of the rope scaling. Can be + either `None` or `"su"`. `None` is for no rope scaling, `"su"` is + for SuScaled rope, `"su"` is used when `max_sequence_length` is + larger than `original_max_sequence_length`. Defaults to `None`. + rope_scaling_short_factor List[float]: List of factors used to adjust + rope frequencies when the `rope_scaling_type` is `"su"`. List must + be of length `hidden_dim//num_query_heads//2`. It is used when + `sequence_length` is smaller than `original_max_sequence_length`. + Defaults to `None`. + rope_scaling_long_factor List[float]: List of factors used to adjust + rope frequencies when the `rope_scaling_type` is `"su"`. List must + be of length `hidden_dim//num_query_heads//2`. It is used when + `sequence_length` is larger than `original_max_sequence_length`. + Defaults to `None`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Phi4 decoder. + model = keras_hub.models.Phi4Backbone.from_preset( + "phi4_mini_4k_instruct_en" + ) + model(input_data) + + # Randomly initialized Phi4 decoder with custom config. + model = keras_hub.models.Phi4Backbone( + vocabulary_size=10, + num_layers=2, + hidden_dim=512, + intermediate_dim=1024, + num_query_heads=32, + num_key_value_heads=8, + layer_norm_epsilon=1e-6, + dtype="bfloat16" + ) + model(input_data) + ``` + + References: + - [Phi-4 Original Implementation](https://huggingface.co/microsoft/phi-4/blob/main/config.json) + """ + + def __init__( + self, + vocabulary_size=100_352, + num_layers=40, + hidden_dim=5120, + intermediate_dim=17_920, + num_query_heads=40, + num_key_value_heads=10, + layer_norm_epsilon=1e-5, + dropout=0.0, + max_sequence_length=16_384, + pretraining_sequence_length=16_384, + rope_max_wavelength=250_000, + rope_scaling_type=None, + rope_scaling_short_factor=None, + rope_scaling_long_factor=None, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=False, + embeddings_initializer=_phi4_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = Phi4Decoder( + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + rope_max_wavelength=rope_max_wavelength, + layer_norm_epsilon=layer_norm_epsilon, + activation="silu", + kernel_initializer=_phi4_kernel_initializer(stddev=0.02), + dropout=dropout, + max_sequence_length=max_sequence_length, + pretraining_sequence_length=pretraining_sequence_length, + rope_scaling_type=rope_scaling_type, + rope_scaling_short_factor=rope_scaling_short_factor, + rope_scaling_long_factor=rope_scaling_long_factor, + dtype=dtype, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = Phi4LayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="sequence_output_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.rope_scaling_type = rope_scaling_type + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.max_sequence_length = max_sequence_length + self.pretraining_sequence_length = pretraining_sequence_length + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_type = rope_scaling_type + self.rope_scaling_short_factor = rope_scaling_short_factor + self.rope_scaling_long_factor = rope_scaling_long_factor + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_key_value_heads": self.num_key_value_heads, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "pretraining_sequence_length": self.pretraining_sequence_length, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_type": self.rope_scaling_type, + "rope_scaling_short_factor": self.rope_scaling_short_factor, + "rope_scaling_long_factor": self.rope_scaling_long_factor, + } + ) + return config diff --git a/keras_hub/src/models/phi4/phi4_backbone_test.py b/keras_hub/src/models/phi4/phi4_backbone_test.py new file mode 100644 index 0000000000..adac55054e --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_backbone_test.py @@ -0,0 +1,92 @@ +import pytest +from keras import ops + +from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone +from keras_hub.src.tests.test_case import TestCase + + +class Phi4Test(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 2, + "hidden_dim": 8, + "intermediate_dim": 8, + } + self.su_rotary_init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_query_heads": 2, + "num_key_value_heads": 1, + "hidden_dim": 8, + "intermediate_dim": 12, + "max_sequence_length": 10, + "pretraining_sequence_length": 5, + "rope_scaling_type": "su", + "rope_scaling_short_factor": [1.2, 1.4], + "rope_scaling_long_factor": [0.8, 0.6], + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=Phi4Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 8), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Phi4Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_backbone_basics_with_su_rotary(self): + self.run_backbone_test( + cls=Phi4Backbone, + init_kwargs=self.su_rotary_init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 8), + ) + + @pytest.mark.large + def test_saved_model_with_su_rotary(self): + self.run_model_saving_test( + cls=Phi4Backbone, + init_kwargs=self.su_rotary_init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_smallest_preset(self): + self.run_preset_test( + cls=Phi4Backbone, + preset="phi4_mini_4k_instruct_en", + input_data={ + "token_ids": ops.array([[1, 450, 4996, 1701, 29916, 29889]]), + "padding_mask": ops.ones((1, 6), dtype="int32"), + }, + expected_output_shape=(1, 6, 3072), + # The forward pass from a preset should be stable! + # Reference values computed using PyTorch HF model. + expected_partial_output=ops.array( + [-0.21222, 0.04004, -0.02759, 0.02200] + ), + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Phi4Backbone.presets: + self.run_preset_test( + cls=Phi4Backbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/phi4/phi4_decoder.py b/keras_hub/src/models/phi4/phi4_decoder.py new file mode 100644 index 0000000000..a55e8f5e41 --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_decoder.py @@ -0,0 +1,246 @@ +import keras +from keras import ops + +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_hub.src.models.phi4.phi4_attention import Phi4Attention +from keras_hub.src.models.phi4.phi4_layernorm import Phi4LayerNorm +from keras_hub.src.utils.keras_utils import clone_initializer + + +class Phi4Decoder(keras.layers.Layer): + """A Transformer decoder layer for the Phi-4 backbone.""" + + def __init__( + self, + hidden_dim=5120, + intermediate_dim=17_920, + num_query_heads=40, + num_key_value_heads=10, + activation="silu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + dropout=0, + max_sequence_length=16_384, + pretraining_sequence_length=16_384, + rope_max_wavelength=250_000, + rope_scaling_type=None, + rope_scaling_short_factor=None, + rope_scaling_long_factor=None, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + + self.max_sequence_length = max_sequence_length + self.pretraining_sequence_length = pretraining_sequence_length + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_type = rope_scaling_type + self.rope_scaling_short_factor = rope_scaling_short_factor + self.rope_scaling_long_factor = rope_scaling_long_factor + + self.dropout = dropout + + self.layer_norm_epsilon = layer_norm_epsilon + self.activation = keras.activations.get(activation) + self.kernel_initializer = keras.initializers.get(kernel_initializer) + + def build(self, decoder_sequence_shape): + # Pre-attention layernorm. + self.pre_attention_layernorm = Phi4LayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_attention_layernorm", + ) + self.pre_attention_layernorm.build(decoder_sequence_shape) + + # Self attention layer. + self.attention = Phi4Attention( + num_query_heads=self.num_query_heads, + num_key_value_heads=self.num_key_value_heads, + kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, + max_sequence_length=self.max_sequence_length, + pretraining_sequence_length=self.pretraining_sequence_length, + rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_type=self.rope_scaling_type, + rope_scaling_short_factor=self.rope_scaling_short_factor, + rope_scaling_long_factor=self.rope_scaling_long_factor, + dtype=self.dtype_policy, + name="attention", + ) + self.attention.build(decoder_sequence_shape) + + # Post-attention layernorm. + self.post_attention_layernorm = Phi4LayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="post_attention_layernorm", + ) + self.post_attention_layernorm.build(decoder_sequence_shape) + + # feedforward layers. + self.feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_intermediate_dense", + ) + self.feedforward_intermediate_dense.build(decoder_sequence_shape) + + self.feedforward_gate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_gate_dense", + ) + self.feedforward_gate_dense.build(decoder_sequence_shape) + + self.feedforward_output_dense = keras.layers.Dense( + self.hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + + self.feedforward_output_dense.build( + self.feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) + + # Dropout + self.attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="attention_dropout", + ) + self.feedforward_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="feedforward_dropout", + ) + + self.built = True + + def call( + self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + attention_cache=None, + attention_cache_update_index=None, + ): + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + attention_cache=attention_cache, + attention_cache_update_index=attention_cache_update_index, + ) + residual = decoder_sequence + x = self.pre_attention_layernorm(decoder_sequence) + x = self.attention( + hidden_states=x, + attention_mask=self_attention_mask, + cache=attention_cache, + cache_update_index=attention_cache_update_index, + ) + if attention_cache is not None: + x, attention_cache = x + x = self.attention_dropout(x) + x = x + residual + + residual = x + x = self.post_attention_layernorm(x) + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = self.feedforward_gate_dense(x) + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + x = self.feedforward_intermediate_dense(x) + x = self.feedforward_output_dense(ops.multiply(x, gate_output)) + x = self.feedforward_dropout(x) + decoder_output = x + residual + + if attention_cache is not None: + return decoder_output, attention_cache + return decoder_output + + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + attention_cache, + attention_cache_update_index, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if attention_cache is not None: + input_length = ops.shape(attention_cache)[2] + + cache_update_index = ( + 0 + if attention_cache_update_index is None + else attention_cache_update_index + ) + + causal_mask = compute_causal_mask( + batch_size, input_length, output_length, cache_update_index + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "activation": keras.activations.serialize(self.activation), + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "pretraining_sequence_length": self.pretraining_sequence_length, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_type": self.rope_scaling_type, + "rope_scaling_short_factor": self.rope_scaling_short_factor, + "rope_scaling_long_factor": self.rope_scaling_long_factor, + } + ) + return config diff --git a/keras_hub/src/models/phi4/phi4_layernorm.py b/keras_hub/src/models/phi4/phi4_layernorm.py new file mode 100644 index 0000000000..e6238c4f0d --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_layernorm.py @@ -0,0 +1,35 @@ +import keras +from keras import ops + + +# TODO: Deprecate this in favor of +# `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is +# removed. +class Phi4LayerNorm(keras.layers.Layer): + """A normalization layer for Phi-4 that implements RMS normalization.""" + + def __init__(self, epsilon=1e-5, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + dim = input_shape[-1] + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(dim,), + initializer="ones", + dtype=self.variable_dtype, + ) + self.built = True + + def call(self, x): + x = ops.cast(x, "float32") + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + x = x * ops.rsqrt(var + self.epsilon) + return ops.cast(x * self.scale, self.compute_dtype) + + def get_config(self): + config = super().get_config() + config.update({"epsilon": self.epsilon}) + return config diff --git a/keras_hub/src/models/phi4/phi4_rotary_embedding.py b/keras_hub/src/models/phi4/phi4_rotary_embedding.py new file mode 100644 index 0000000000..77b2dadcc2 --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_rotary_embedding.py @@ -0,0 +1,124 @@ +import math + +from keras import ops + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding + + +class Phi4SuScaledRotaryEmbedding(RotaryEmbedding): + """SuRotary positional encoding layer. + + Args: + inverese_freq_short_factor List[float]: List of factors used to adjust + rope frequencies when the `rope_scaling_type` is `"su"`. List must + be of length `hidden_dim//num_query_heads//2`. It is used when + `sequence_length` is smaller than `original_max_sequence_length`. + inverese_freq_long_factor List[float]: List of factors used to adjust + rope frequencies when the `rope_scaling_type` is `"su"`. List must + be of length `hidden_dim//num_query_heads//2`. It is used when + `sequence_length` is larger than `original_max_sequence_length`. + max_sequence_length: int. The maximum sequence length that this + model might ever be used with. + pretraining_sequence_length: int. The maximum sequence length that + this model was pretrained with. + max_wavelength: int. The maximum angular wavelength of the sine/cosine + curves. + + Call arguments: + inputs: The tensor inputs to apply the embedding to. This can have + any shape, but must contain both a sequence and feature axis. The + rotary embedding will be applied to `inputs` and returned. + start_index: An integer or integer tensor. The starting position to + compute the rotary embedding from. This is useful during cached + decoding, where each position is predicted separately in a loop. + + References: + - [Phi-3-Medium-128k-Instruct Implementation (Since Phi-4 is based on Phi-3-Medium)](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct/blob/main/modeling_phi3.py) + """ + + def __init__( + self, + inverese_freq_short_factor, + inverese_freq_long_factor, + max_sequence_length=16_384, + pretraining_sequence_length=16_384, + max_wavelength=250_000, + **kwargs, + ): + super().__init__(max_wavelength=max_wavelength, **kwargs) + self.max_sequence_length = max_sequence_length + self.pretraining_sequence_length = pretraining_sequence_length + + scaling_factor = ( + self.max_sequence_length / self.pretraining_sequence_length + ) + if scaling_factor <= 1.0: + self.embedding_scaling_factor = 1.0 + else: + self.embedding_scaling_factor = math.sqrt( + 1 + + math.log(scaling_factor) + / math.log(self.pretraining_sequence_length) + ) + + self.inverese_freq_short_factor = inverese_freq_short_factor + self.inverese_freq_long_factor = inverese_freq_long_factor + + def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None): + feature_axis = len(inputs.shape) - 1 + sequence_axis = 1 + + rotary_dim = ops.shape(inputs)[feature_axis] + inverse_freq = self._get_inverse_freq(rotary_dim) + + # Multiply inverse_freq by a factor. + if ops.shape(inputs)[sequence_axis] > self.pretraining_sequence_length: + inverse_freq = ops.divide( + inverse_freq, + ops.convert_to_tensor(self.inverese_freq_long_factor), + ) + else: + inverse_freq = ops.divide( + inverse_freq, + ops.convert_to_tensor(self.inverese_freq_short_factor), + ) + + if positions is None: + positions = self._compute_positions(inputs, start_index) + else: + positions = ops.cast(positions, "float32") + + freq = ops.einsum("i,j->ij", positions, inverse_freq) + embedding = ops.stack((freq, freq), axis=-2) + embedding = ops.reshape( + embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2) + ) + + # Reshape the embedding to be broadcastable with input shape. + if feature_axis < sequence_axis: + embedding = ops.transpose(embedding) + for axis in range(len(inputs.shape)): + if axis != sequence_axis and axis != feature_axis: + embedding = ops.expand_dims(embedding, axis) + + cos_emb = ops.cast( + ops.cos(embedding) * self.embedding_scaling_factor, + self.compute_dtype, + ) + sin_emb = ops.cast( + ops.sin(embedding) * self.embedding_scaling_factor, + self.compute_dtype, + ) + return cos_emb, sin_emb + + def get_config(self): + config = super().get_config() + config.update( + { + "max_sequence_length": self.max_sequence_length, + "pretraining_sequence_length": self.pretraining_sequence_length, + "inverese_freq_short_factor": self.inverese_freq_short_factor, + "inverese_freq_long_factor": self.inverese_freq_long_factor, + } + ) + return config From 69f66fff0cc9e5850d9c2919c96fbc487da93517 Mon Sep 17 00:00:00 2001 From: Rahul Yedida Date: Wed, 7 May 2025 22:02:21 -0400 Subject: [PATCH 02/11] docs(phi4): update defaults in docstring --- keras_hub/src/models/phi4/phi4_backbone.py | 25 ++++++++++++---------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/keras_hub/src/models/phi4/phi4_backbone.py b/keras_hub/src/models/phi4/phi4_backbone.py index 1074853a22..481778ffad 100644 --- a/keras_hub/src/models/phi4/phi4_backbone.py +++ b/keras_hub/src/models/phi4/phi4_backbone.py @@ -27,26 +27,29 @@ class Phi4Backbone(Backbone): constructor. Args: - vocabulary_size (int): The size of the token vocabulary. - num_layers (int): The number of transformer layers. + vocabulary_size (int): The size of the token vocabulary. Defaults to + `100_352`. + num_layers (int): The number of transformer layers. Defaults to `40`. hidden_dim (int): The size of the embeddings and the hidden states of - the transformer layers. + the transformer layers. Defaults to `5120`. intermediate_dim (int): The output dimension of the first Dense layer in - a three-layer feedforward network for each transformer. + a three-layer feedforward network for each transformer. Defaults to + `17_920`. num_query_heads (int): The number of query attention heads for each - transformer layer. + transformer layer. Defaults to `40`. num_key_value_heads (int): The number of key and value attention heads - for each transformer layer. + for each transformer layer. Defaults to `10`. layer_norm_epsilon (float, optional): Epsilon for the RMS layernorm - layers in the transformer decoder. Defaults to `1e-6`. + layers in the transformer decoder. Defaults to `1e-5`. dropout: (float, optional): Dropout probability for the Transformer - decoder. + decoder. Defaults to `0.0`. max_sequence_length (int, optional): The maximum sequence length - that this model might ever be used with. Defaults to `4096`. + that this model might ever be used with. Defaults to `16_384`. pretraining_sequence_length (int, optional): The maximum sequence length - that the model was pretrained with. Defaults to `4096`. + that the model was pretrained with. Defaults to `16_384`. rope_max_wavelength (int, optional): The maximum angular wavelength of - the sine/cosine curves, for rotary embeddings. Defaults to `10000`. + the sine/cosine curves, for rotary embeddings. Defaults to + `250_000`. rope_scaling_type (str, optional): The type of the rope scaling. Can be either `None` or `"su"`. `None` is for no rope scaling, `"su"` is for SuScaled rope, `"su"` is used when `max_sequence_length` is From 8b7146e46933fefc02e3f5daf7b9afb54a1217a1 Mon Sep 17 00:00:00 2001 From: Rahul Yedida Date: Wed, 16 Jul 2025 00:35:58 -0400 Subject: [PATCH 03/11] feat(phi4): refactor Phi4Backbone to inherit from Phi-3 --- keras_hub/api/models/__init__.py | 9 + keras_hub/api/tokenizers/__init__.py | 3 + keras_hub/src/models/phi4/phi4_backbone.py | 214 +++------------------ 3 files changed, 42 insertions(+), 184 deletions(-) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 73211ac349..6425d4c0f0 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -422,6 +422,15 @@ Phi3Tokenizer as Phi3Tokenizer, ) from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone as Phi4Backbone +from keras_hub.src.models.phi4.phi4_causal_lm import ( + Phi4CausalLM as Phi4CausalLM, +) +from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import ( + Phi4CausalLMPreprocessor as Phi4CausalLMPreprocessor, +) +from keras_hub.src.models.phi4.phi4_tokenizer import ( + Phi4Tokenizer as Phi4Tokenizer, +) from keras_hub.src.models.preprocessor import Preprocessor as Preprocessor from keras_hub.src.models.qwen.qwen_backbone import ( QwenBackbone as Qwen2Backbone, diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 082078184f..822840ee61 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -68,6 +68,9 @@ from keras_hub.src.models.phi3.phi3_tokenizer import ( Phi3Tokenizer as Phi3Tokenizer, ) +from keras_hub.src.models.phi4.phi4_tokenizer import ( + Phi4Tokenizer as Phi4Tokenizer, +) from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) diff --git a/keras_hub/src/models/phi4/phi4_backbone.py b/keras_hub/src/models/phi4/phi4_backbone.py index 481778ffad..3970dcf2dd 100644 --- a/keras_hub/src/models/phi4/phi4_backbone.py +++ b/keras_hub/src/models/phi4/phi4_backbone.py @@ -1,20 +1,9 @@ -import keras - from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.layers.modeling.reversible_embedding import ( - ReversibleEmbedding, -) -from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.phi4.phi4_decoder import Phi4Decoder -from keras_hub.src.models.phi4.phi4_layernorm import Phi4LayerNorm - - -def _phi4_kernel_initializer(stddev=0.02): - return keras.initializers.RandomNormal(stddev=stddev) +from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone @keras_hub_export("keras_hub.models.Phi4Backbone") -class Phi4Backbone(Backbone): +class Phi4Backbone(Phi3Backbone): """Phi-4 core network with hyperparameters. This network implements a Transformer-based decoder network, @@ -26,40 +15,41 @@ class Phi4Backbone(Backbone): dimensions. To load preset architectures and weights, use the `from_preset` constructor. + Note that the defaults here are the Phi-3 defaults, because the Phi-4 model + follows the Phi-3-medium architecture but with different hyper-parameters. + Use `keras_hub.models.Backbone.from_preset` to get the Phi-4 defaults. + Args: - vocabulary_size (int): The size of the token vocabulary. Defaults to - `100_352`. - num_layers (int): The number of transformer layers. Defaults to `40`. - hidden_dim (int): The size of the embeddings and the hidden states of - the transformer layers. Defaults to `5120`. - intermediate_dim (int): The output dimension of the first Dense layer in - a three-layer feedforward network for each transformer. Defaults to - `17_920`. - num_query_heads (int): The number of query attention heads for each - transformer layer. Defaults to `40`. - num_key_value_heads (int): The number of key and value attention heads - for each transformer layer. Defaults to `10`. - layer_norm_epsilon (float, optional): Epsilon for the RMS layernorm - layers in the transformer decoder. Defaults to `1e-5`. - dropout: (float, optional): Dropout probability for the Transformer - decoder. Defaults to `0.0`. - max_sequence_length (int, optional): The maximum sequence length - that this model might ever be used with. Defaults to `16_384`. - pretraining_sequence_length (int, optional): The maximum sequence length - that the model was pretrained with. Defaults to `16_384`. - rope_max_wavelength (int, optional): The maximum angular wavelength of - the sine/cosine curves, for rotary embeddings. Defaults to - `250_000`. - rope_scaling_type (str, optional): The type of the rope scaling. Can be + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + hidden_dim: int. The size of the embeddings and the hidden states of + the transformer layers. + intermediate_dim: int. The output dimension of the first Dense layer in + a three-layer feedforward network for each transformer. + num_query_heads: int. The number of query attention heads for each + transformer layer. + num_key_value_heads: int. The number of key and value attention heads + for each transformer layer. + layer_norm_epsilon: float, optional. Epsilon for the RMS layernorm + layers in the transformer decoder. Defaults to `1e-6`. + dropout:: float, optional. Dropout probability for the Transformer + decoder. + max_sequence_length: int, optional. The maximum sequence length + that this model might ever be used with. Defaults to `4096`. + pretraining_sequence_length: int, optional. The maximum sequence length + that the model was pretrained with. Defaults to `4096`. + rope_max_wavelength: int, optional. The maximum angular wavelength of + the sine/cosine curves, for rotary embeddings. Defaults to `10000`. + rope_scaling_type: str, optional. The type of the rope scaling. Can be either `None` or `"su"`. `None` is for no rope scaling, `"su"` is for SuScaled rope, `"su"` is used when `max_sequence_length` is larger than `original_max_sequence_length`. Defaults to `None`. - rope_scaling_short_factor List[float]: List of factors used to adjust + rope_scaling_short_factor: list[float]. List of factors used to adjust rope frequencies when the `rope_scaling_type` is `"su"`. List must be of length `hidden_dim//num_query_heads//2`. It is used when `sequence_length` is smaller than `original_max_sequence_length`. Defaults to `None`. - rope_scaling_long_factor List[float]: List of factors used to adjust + rope_scaling_long_factor: list[float]. List of factors used to adjust rope frequencies when the `rope_scaling_type` is `"su"`. List must be of length `hidden_dim//num_query_heads//2`. It is used when `sequence_length` is larger than `original_max_sequence_length`. @@ -68,150 +58,6 @@ class Phi4Backbone(Backbone): for model computations and weights. Note that some computations, such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - - Examples: - - ```python - input_data = { - "token_ids": np.ones(shape=(1, 12), dtype="int32"), - "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), - } - - # Pretrained Phi4 decoder. - model = keras_hub.models.Phi4Backbone.from_preset( - "phi4_mini_4k_instruct_en" - ) - model(input_data) - - # Randomly initialized Phi4 decoder with custom config. - model = keras_hub.models.Phi4Backbone( - vocabulary_size=10, - num_layers=2, - hidden_dim=512, - intermediate_dim=1024, - num_query_heads=32, - num_key_value_heads=8, - layer_norm_epsilon=1e-6, - dtype="bfloat16" - ) - model(input_data) - ``` - - References: - - [Phi-4 Original Implementation](https://huggingface.co/microsoft/phi-4/blob/main/config.json) """ - def __init__( - self, - vocabulary_size=100_352, - num_layers=40, - hidden_dim=5120, - intermediate_dim=17_920, - num_query_heads=40, - num_key_value_heads=10, - layer_norm_epsilon=1e-5, - dropout=0.0, - max_sequence_length=16_384, - pretraining_sequence_length=16_384, - rope_max_wavelength=250_000, - rope_scaling_type=None, - rope_scaling_short_factor=None, - rope_scaling_long_factor=None, - dtype=None, - **kwargs, - ): - # === Layers === - self.token_embedding = ReversibleEmbedding( - input_dim=vocabulary_size, - output_dim=hidden_dim, - tie_weights=False, - embeddings_initializer=_phi4_kernel_initializer(stddev=0.01), - dtype=dtype, - name="token_embedding", - ) - self.transformer_layers = [] - for i in range(num_layers): - layer = Phi4Decoder( - hidden_dim=hidden_dim, - intermediate_dim=intermediate_dim, - num_query_heads=num_query_heads, - num_key_value_heads=num_key_value_heads, - rope_max_wavelength=rope_max_wavelength, - layer_norm_epsilon=layer_norm_epsilon, - activation="silu", - kernel_initializer=_phi4_kernel_initializer(stddev=0.02), - dropout=dropout, - max_sequence_length=max_sequence_length, - pretraining_sequence_length=pretraining_sequence_length, - rope_scaling_type=rope_scaling_type, - rope_scaling_short_factor=rope_scaling_short_factor, - rope_scaling_long_factor=rope_scaling_long_factor, - dtype=dtype, - name=f"transformer_layer_{i}", - ) - self.transformer_layers.append(layer) - self.layer_norm = Phi4LayerNorm( - epsilon=layer_norm_epsilon, - dtype=dtype, - name="sequence_output_layernorm", - ) - - # === Functional Model === - token_id_input = keras.Input( - shape=(None,), dtype="int32", name="token_ids" - ) - padding_mask_input = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - x = self.token_embedding(token_id_input) - for transformer_layer in self.transformer_layers: - x = transformer_layer(x, decoder_padding_mask=padding_mask_input) - sequence_output = self.layer_norm(x) - super().__init__( - inputs={ - "token_ids": token_id_input, - "padding_mask": padding_mask_input, - }, - outputs=sequence_output, - dtype=dtype, - **kwargs, - ) - - # === Config === - self.vocabulary_size = vocabulary_size - self.num_layers = num_layers - self.num_query_heads = num_query_heads - self.num_key_value_heads = num_key_value_heads - self.hidden_dim = hidden_dim - self.intermediate_dim = intermediate_dim - self.rope_scaling_type = rope_scaling_type - self.layer_norm_epsilon = layer_norm_epsilon - self.dropout = dropout - self.max_sequence_length = max_sequence_length - self.pretraining_sequence_length = pretraining_sequence_length - self.rope_max_wavelength = rope_max_wavelength - self.rope_scaling_type = rope_scaling_type - self.rope_scaling_short_factor = rope_scaling_short_factor - self.rope_scaling_long_factor = rope_scaling_long_factor - - def get_config(self): - config = super().get_config() - config.update( - { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "num_query_heads": self.num_query_heads, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "num_key_value_heads": self.num_key_value_heads, - "layer_norm_epsilon": self.layer_norm_epsilon, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "pretraining_sequence_length": self.pretraining_sequence_length, - "rope_max_wavelength": self.rope_max_wavelength, - "rope_scaling_type": self.rope_scaling_type, - "rope_scaling_short_factor": self.rope_scaling_short_factor, - "rope_scaling_long_factor": self.rope_scaling_long_factor, - } - ) - return config + pass From 3df73afe0d591a5ad33512813942efdba604d851 Mon Sep 17 00:00:00 2001 From: Rahul Yedida Date: Wed, 16 Jul 2025 22:26:11 -0400 Subject: [PATCH 04/11] feat(phi4): add phi-4 tokenizer --- keras_hub/src/models/phi4/phi4_tokenizer.py | 86 +++++++++++++++++++ .../src/models/phi4/phi4_tokenizer_test.py | 61 +++++++++++++ 2 files changed, 147 insertions(+) create mode 100644 keras_hub/src/models/phi4/phi4_tokenizer.py create mode 100644 keras_hub/src/models/phi4/phi4_tokenizer_test.py diff --git a/keras_hub/src/models/phi4/phi4_tokenizer.py b/keras_hub/src/models/phi4/phi4_tokenizer.py new file mode 100644 index 0000000000..de0ef3c3ef --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_tokenizer.py @@ -0,0 +1,86 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export( + [ + "keras_hub.tokenizers.Phi4Tokenizer", + "keras_hub.models.Phi4Tokenizer", + ] +) +class Phi4Tokenizer(BytePairTokenizer): + """Phi4 tokenizer using Byte-Pair Encoding subword segmentation. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_hub.tokenizers.BytePairTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + Phi4 models and provides a `from_preset()` method to automatically + download a matching vocabulary for a Phi4 preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + vocabulary: string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, + it should be the file path to merge rules. The merge rule file + should have one merge rule per line. Every merge rule contains + merge entities separated by a space. + sequence_length: int. If set, the output will be + padded or truncated to the `sequence_length`. Defaults to 100,352 + based on the [Phi-4 Technical Report](https://arxiv.org/pdf/2412.08905) + + Examples: + ```python + # Unbatched input. + tokenizer = keras_hub.models.Phi4Tokenizer.from_preset( + "phi4_mini_4k_instruct_en", + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + ``` + + # References + + - [Phi-4 tokenizer config](https://huggingface.co/microsoft/phi-4/raw/main/tokenizer.json) + """ + + backbone_cls = Phi4Backbone + + def __init__( + self, + vocabulary=None, + merges=None, + sequence_length=100_352, + **kwargs, + ): + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self._add_special_token("", "pad_token") + + # FIM = Fill-in-the-middle, which uses special tokens to identify + # the prefix/middle/suffix part of the input/output for coding tasks. + self._add_special_token("", "fim_prefix") + self._add_special_token("", "fim_middle") + self._add_special_token("", "fix_suffix") + + self._add_special_token("", "input_message_start") + self._add_special_token("", "input_message_separator") + self._add_special_token("", "input_message_end") + + super().__init__( + vocabulary=vocabulary, + merges=merges, + sequence_length=sequence_length, + **kwargs, + ) diff --git a/keras_hub/src/models/phi4/phi4_tokenizer_test.py b/keras_hub/src/models/phi4/phi4_tokenizer_test.py new file mode 100644 index 0000000000..2e91922d05 --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_tokenizer_test.py @@ -0,0 +1,61 @@ +import pytest + +from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Phi4TokenizerTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += [ + "", + "", + "", + "", + "", + "", + ] + self.vocab += ["", "", ""] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.init_kwargs = { + "vocabulary": self.vocab, + "merges": self.merges, + "sequence_length": None, + } + self.input_data = [ + " airplane at airport", + " airplane airport", + ] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=Phi4Tokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[6, 2, 3, 4, 2, 5, 7, 8], [2, 3, 2, 5]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + Phi4Tokenizer(vocabulary={"foo": 0, "bar": 1}, merges=["fo o"]) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=Phi4Tokenizer, + preset="phi4_8b_en", + input_data=["The quick brown fox."], + expected_output=[[791, 4062, 14198, 39935, 13]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Phi4Tokenizer.presets: + self.run_preset_test( + cls=Phi4Tokenizer, + preset=preset, + input_data=self.input_data, + ) From 4aceea3b29f59d0936ac92f1b72685c79ecc8b24 Mon Sep 17 00:00:00 2001 From: Rahul Yedida Date: Wed, 16 Jul 2025 22:26:52 -0400 Subject: [PATCH 05/11] feat(phi4): add phi-4 causal_lm files --- keras_hub/src/models/phi4/phi4_causal_lm.py | 33 +++++ .../phi4/phi4_causal_lm_preprocessor.py | 76 ++++++++++++ .../phi4/phi4_causal_lm_preprocessor_test.py | 81 ++++++++++++ .../src/models/phi4/phi4_causal_lm_test.py | 117 ++++++++++++++++++ 4 files changed, 307 insertions(+) create mode 100644 keras_hub/src/models/phi4/phi4_causal_lm.py create mode 100644 keras_hub/src/models/phi4/phi4_causal_lm_preprocessor.py create mode 100644 keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py create mode 100644 keras_hub/src/models/phi4/phi4_causal_lm_test.py diff --git a/keras_hub/src/models/phi4/phi4_causal_lm.py b/keras_hub/src/models/phi4/phi4_causal_lm.py new file mode 100644 index 0000000000..8f014d81e2 --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_causal_lm.py @@ -0,0 +1,33 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM +from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone +from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import ( + Phi4CausalLMPreprocessor, +) + + +@keras_hub_export("keras_hub.models.Phi4CausalLM") +class Phi4CausalLM(Phi3CausalLM): + """An end-to-end Phi4 model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a Phi-4 model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_hub.samplers` objects to control the generation. By + default, `"top_k"` sampling will be used. + + Args: + backbone: A `keras_hub.models.Phi4Backbone` instance. + preprocessor: A `keras_hub.models.Phi4CausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + """ + + backbone_cls = Phi4Backbone + preprocessor_cls = Phi4CausalLMPreprocessor diff --git a/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor.py b/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor.py new file mode 100644 index 0000000000..63d874daa3 --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor.py @@ -0,0 +1,76 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone +from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer + + +@keras_hub_export("keras_hub.models.Phi4CausalLMPreprocessor") +class Phi4CausalLMPreprocessor(CausalLMPreprocessor): + """Phi4 Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.Phi4CausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_hub.models.Phi4CausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_hub.models.Phi4Tokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_hub.models.Phi4CausalLMPreprocessor.from_preset( + "phi4_mini_4k_instruct_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + backbone_cls = Phi4Backbone + tokenizer_cls = Phi4Tokenizer diff --git a/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py b/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..ff39cbbaa1 --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py @@ -0,0 +1,81 @@ +import os + +import pytest + +from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import ( + Phi4CausalLMPreprocessor, +) +from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Phi4CausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = Phi4Tokenizer( + # Generated using create_phi4_test_proto.py + proto=os.path.join(self.get_test_data_dir(), "phi4_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 10, + } + # [3, 5, 6, 4, 3, 9, 7, 11] + self.input_data = (["the fox"],) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=Phi4CausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 5, 6, 4, 3, 9, 7, 11, 15]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], + }, + [[3, 5, 6, 4, 3, 9, 7, 11, 15, 0]], + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0]], + ), + ) + + def test_no_start_end_token(self): + input_data = ["the fox"] * 4 + + preprocessor = Phi4CausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [[3, 5, 6, 4, 3, 9, 7, 11, 0, 0]] * 4 + ) + self.assertAllEqual( + x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 4 + ) + self.assertAllEqual(y, [[5, 6, 4, 3, 9, 7, 11, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the fox" + preprocessor = Phi4CausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 5, 6, 4, 3, 9, 7, 11, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 5, 6, 4, 3, 9, 7, 11, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + } + preprocessor = Phi4CausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the fox") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Phi4CausalLMPreprocessor.presets: + self.run_preset_test( + cls=Phi4CausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/phi4/phi4_causal_lm_test.py b/keras_hub/src/models/phi4/phi4_causal_lm_test.py new file mode 100644 index 0000000000..f375aea782 --- /dev/null +++ b/keras_hub/src/models/phi4/phi4_causal_lm_test.py @@ -0,0 +1,117 @@ +import os +from unittest.mock import patch + +import pytest +from keras import ops + +from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone +from keras_hub.src.models.phi4.phi4_causal_lm import Phi4CausalLM +from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import ( + Phi4CausalLMPreprocessor, +) +from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Phi4CausalLMTest(TestCase): + def setUp(self): + self.preprocessor = Phi4CausalLMPreprocessor( + Phi4Tokenizer( + # Generated using create_phi4_test_proto.py + proto=os.path.join( + self.get_test_data_dir(), "phi4_test_vocab.spm" + ) + ), + sequence_length=12, + ) + self.vocab_size = self.preprocessor.tokenizer.vocabulary_size() + self.backbone = Phi4Backbone( + vocabulary_size=self.vocab_size, + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the earth is round"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=Phi4CausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 12, self.vocab_size), + ) + + def test_generate(self): + causal_lm = Phi4CausalLM(**self.init_kwargs) + # String input. + prompt = "the fox" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_early_stopping(self): + causal_lm = Phi4CausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the fox", "the earth"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = Phi4CausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + # @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Phi4CausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Phi4CausalLM.presets: + self.run_preset_test( + cls=Phi4CausalLM, + preset=preset, + input_data=self.input_data, + ) From 17a30cea6cff1876b90d414f3784b914ce18036f Mon Sep 17 00:00:00 2001 From: Rahul Yedida Date: Wed, 16 Jul 2025 22:34:28 -0400 Subject: [PATCH 06/11] fix(phi4): update docstring to use correct variable names --- keras_hub/src/models/phi4/phi4_backbone.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/phi4/phi4_backbone.py b/keras_hub/src/models/phi4/phi4_backbone.py index 3970dcf2dd..0222f50e3e 100644 --- a/keras_hub/src/models/phi4/phi4_backbone.py +++ b/keras_hub/src/models/phi4/phi4_backbone.py @@ -47,12 +47,12 @@ class Phi4Backbone(Phi3Backbone): rope_scaling_short_factor: list[float]. List of factors used to adjust rope frequencies when the `rope_scaling_type` is `"su"`. List must be of length `hidden_dim//num_query_heads//2`. It is used when - `sequence_length` is smaller than `original_max_sequence_length`. + `sequence_length` is smaller than `pretraining_sequence_length`. Defaults to `None`. rope_scaling_long_factor: list[float]. List of factors used to adjust rope frequencies when the `rope_scaling_type` is `"su"`. List must be of length `hidden_dim//num_query_heads//2`. It is used when - `sequence_length` is larger than `original_max_sequence_length`. + `sequence_length` is larger than `pretraining_sequence_length`. Defaults to `None`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for model computations and weights. Note that some computations, From 82b29122da7d6eb51e5725a26ffa6a82b1e9d67c Mon Sep 17 00:00:00 2001 From: Rahul Yedida Date: Wed, 16 Jul 2025 22:35:13 -0400 Subject: [PATCH 07/11] fix(phi4): remove dedicated attention and decoder modules --- keras_hub/src/models/phi4/phi4_attention.py | 268 -------------------- keras_hub/src/models/phi4/phi4_decoder.py | 246 ------------------ 2 files changed, 514 deletions(-) delete mode 100644 keras_hub/src/models/phi4/phi4_attention.py delete mode 100644 keras_hub/src/models/phi4/phi4_decoder.py diff --git a/keras_hub/src/models/phi4/phi4_attention.py b/keras_hub/src/models/phi4/phi4_attention.py deleted file mode 100644 index f5563d6b92..0000000000 --- a/keras_hub/src/models/phi4/phi4_attention.py +++ /dev/null @@ -1,268 +0,0 @@ -import math - -import keras -from keras import ops - -from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding -from keras_hub.src.models.phi4.phi4_rotary_embedding import ( - Phi4SuScaledRotaryEmbedding, -) -from keras_hub.src.utils.keras_utils import clone_initializer -from keras_hub.src.utils.keras_utils import fused_attention_op_available - - -class Phi4Attention(keras.layers.Layer): - """A cached grounded query attention layer.""" - - def __init__( - self, - num_query_heads=40, - num_key_value_heads=10, - kernel_initializer="glorot_uniform", - dropout=0, - max_sequence_length=16_384, - pretraining_sequence_length=16_384, - rope_max_wavelength=250_000, - rope_scaling_type=None, - rope_scaling_short_factor=None, - rope_scaling_long_factor=None, - **kwargs, - ): - super().__init__(**kwargs) - self.num_query_heads = num_query_heads - self.num_key_value_heads = num_key_value_heads - self.num_key_value_groups = num_query_heads // num_key_value_heads - self.dropout = dropout - - self.max_sequence_length = max_sequence_length - self.pretraining_sequence_length = pretraining_sequence_length - self.rope_max_wavelength = rope_max_wavelength - self.rope_scaling_type = rope_scaling_type - self.rope_scaling_short_factor = rope_scaling_short_factor - self.rope_scaling_long_factor = rope_scaling_long_factor - - self.kernel_initializer = keras.initializers.get( - clone_initializer(kernel_initializer) - ) - - def build(self, inputs_shape): - # Einsum variables: - # b = batch size - # q = query length - # k = key/value length - # m = model dim - # u = num query heads - # v = num key/value heads - # h = head dim - hidden_dim = inputs_shape[-1] - head_dim = hidden_dim // self.num_query_heads - self._inv_norm_factor = 1.0 / math.sqrt(head_dim) - - self.query_dense = keras.layers.EinsumDense( - equation="bqm,muh->bquh", - output_shape=(None, self.num_query_heads, head_dim), - kernel_initializer=self.kernel_initializer, - dtype=self.dtype_policy, - name="query", - ) - self.query_dense.build(inputs_shape) - - self.key_dense = keras.layers.EinsumDense( - equation="bkm,mvh->bkvh", - output_shape=( - None, - self.num_key_value_heads, - head_dim, - ), - kernel_initializer=self.kernel_initializer, - dtype=self.dtype_policy, - name="key", - ) - self.key_dense.build(inputs_shape) - - self.value_dense = keras.layers.EinsumDense( - equation="bkm,mvh->bkvh", - output_shape=( - None, - self.num_key_value_heads, - head_dim, - ), - kernel_initializer=self.kernel_initializer, - dtype=self.dtype_policy, - name="value", - ) - self.value_dense.build(inputs_shape) - - self.softmax = keras.layers.Softmax( - axis=-1, - dtype="float32", - name="attention_softmax", - ) - - self.dropout_layer = keras.layers.Dropout( - rate=self.dropout, - dtype=self.dtype_policy, - ) - - self.output_dense = keras.layers.EinsumDense( - equation="bquh,uhm->bqm", - output_shape=(None, hidden_dim), - kernel_initializer=self.kernel_initializer, - dtype=self.dtype_policy, - name="attention_output", - ) - self.output_dense.build((None, None, self.num_query_heads, head_dim)) - - if self.rope_scaling_type is None: - self.rotary_embedding_layer = RotaryEmbedding( - max_wavelength=self.rope_max_wavelength, - dtype=self.dtype_policy, - ) - elif self.rope_scaling_type == "su": - if len(self.rope_scaling_short_factor) != head_dim // 2: - raise ValueError( - "`rope_scaling_short_factor` must be of length " - "`hidden_dim//num_query_heads//2`. " - "`len(rope_scaling_short_factor)` is " - f"{len(self.rope_scaling_short_factor)} " - f"while it should be {head_dim // 2}." - ) - if len(self.rope_scaling_long_factor) != head_dim // 2: - raise ValueError( - "`rope_scaling_long_factor` must be of length " - "`hidden_dim//num_query_heads//2`. " - "`len(rope_scaling_long_factor)` is " - f"{len(self.rope_scaling_long_factor)} " - f"while it should be {head_dim // 2}." - ) - self.rotary_embedding_layer = Phi4SuScaledRotaryEmbedding( - inverese_freq_short_factor=self.rope_scaling_short_factor, - inverese_freq_long_factor=self.rope_scaling_long_factor, - max_sequence_length=self.max_sequence_length, - pretraining_sequence_length=self.pretraining_sequence_length, - max_wavelength=self.rope_max_wavelength, - dtype=self.dtype_policy, - ) - else: - raise ValueError( - '`rope_scaling_type` must be `None` or `"su"`.' - "if `None` is choosed, `RotaryEmbedding` will be used." - 'if `"su"` is choosed, `Phi4SuScaledRotaryEmbedding` will be ' - "used." - ) - - self.built = True - - def call( - self, - hidden_states, - attention_mask=None, - cache=None, - cache_update_index=None, - training=None, - ): - start_index = ( - cache_update_index if cache_update_index is not None else 0 - ) - - query = self.query_dense(hidden_states) - key = self.key_dense(hidden_states) - value = self.value_dense(hidden_states) - - # Compute RoPE for queries - query = self.rotary_embedding_layer(query, start_index=start_index) - key = self.rotary_embedding_layer(key, start_index=start_index) - - if cache is not None: - key_cache = cache[:, 0, ...] - value_cache = cache[:, 1, ...] - if cache_update_index is None: - key = key_cache - value = value_cache - else: - start = [0, cache_update_index, 0, 0] - key = ops.slice_update(key_cache, start, key) - value = ops.slice_update(value_cache, start, value) - cache = ops.stack((key, value), axis=1) - else: - if cache_update_index is not None: - raise ValueError( - "`cache_update_index` should not be set if `cache` is " - f"`None`. Received: cache={cache}, " - f"cache_update_index={cache_update_index}" - ) - - # [batch_shape, seq_len, num_key_value_heads, head_dim] - # -> [batch_shape, seq_len, num_heads, head_dim] - key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) - value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) - - attention_output = self._compute_attention( - query, key, value, attention_mask - ) - - attention_output = self.dropout_layer( - attention_output, training=training - ) - - attention_output = self.output_dense(attention_output) - - if cache is not None: - return attention_output, cache - return attention_output - - def _masked_softmax(self, attention_scores, attention_mask=None): - if attention_mask is not None: - return self.softmax(attention_scores, attention_mask[:, None, :, :]) - return self.softmax(attention_scores) - - def _compute_attention(self, query, key, value, attention_mask=None): - if fused_attention_op_available(): - # Use `dot_product_attention` with Flash Attention support if - # available. - if attention_mask is not None: - attention_mask = ops.expand_dims(attention_mask, axis=1) - attention_mask = ops.cast(attention_mask, dtype="bool") - attention_output = ops.dot_product_attention( - query, - key, - value, - mask=attention_mask, - scale=self._inv_norm_factor, - ) - return attention_output - - attention_scores = ops.einsum("bquh,bkuh->buqk", query, key) - attention_scores = ops.multiply( - attention_scores, - ops.cast(self._inv_norm_factor, self.compute_dtype), - ) - attention_scores = self._masked_softmax( - attention_scores, attention_mask - ) - attention_scores = ops.cast(attention_scores, self.compute_dtype) - attention_output = ops.einsum( - "buqk,bkuh->bquh", attention_scores, value - ) - - return attention_output - - def get_config(self): - config = super().get_config() - config.update( - { - "num_query_heads": self.num_query_heads, - "num_key_value_heads": self.num_key_value_heads, - "kernel_initializer": keras.initializers.serialize( - self.kernel_initializer - ), - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "pretraining_sequence_length": self.pretraining_sequence_length, - "rope_max_wavelength": self.rope_max_wavelength, - "rope_scaling_type": self.rope_scaling_type, - "rope_scaling_short_factor": self.rope_scaling_short_factor, - "rope_scaling_long_factor": self.rope_scaling_long_factor, - } - ) - return config diff --git a/keras_hub/src/models/phi4/phi4_decoder.py b/keras_hub/src/models/phi4/phi4_decoder.py deleted file mode 100644 index a55e8f5e41..0000000000 --- a/keras_hub/src/models/phi4/phi4_decoder.py +++ /dev/null @@ -1,246 +0,0 @@ -import keras -from keras import ops - -from keras_hub.src.layers.modeling.transformer_layer_utils import ( - compute_causal_mask, -) -from keras_hub.src.layers.modeling.transformer_layer_utils import ( - merge_padding_and_attention_mask, -) -from keras_hub.src.models.phi4.phi4_attention import Phi4Attention -from keras_hub.src.models.phi4.phi4_layernorm import Phi4LayerNorm -from keras_hub.src.utils.keras_utils import clone_initializer - - -class Phi4Decoder(keras.layers.Layer): - """A Transformer decoder layer for the Phi-4 backbone.""" - - def __init__( - self, - hidden_dim=5120, - intermediate_dim=17_920, - num_query_heads=40, - num_key_value_heads=10, - activation="silu", - layer_norm_epsilon=1e-5, - kernel_initializer="glorot_uniform", - dropout=0, - max_sequence_length=16_384, - pretraining_sequence_length=16_384, - rope_max_wavelength=250_000, - rope_scaling_type=None, - rope_scaling_short_factor=None, - rope_scaling_long_factor=None, - **kwargs, - ): - super().__init__(**kwargs) - self.hidden_dim = hidden_dim - self.intermediate_dim = intermediate_dim - self.num_query_heads = num_query_heads - self.num_key_value_heads = num_key_value_heads - - self.max_sequence_length = max_sequence_length - self.pretraining_sequence_length = pretraining_sequence_length - self.rope_max_wavelength = rope_max_wavelength - self.rope_scaling_type = rope_scaling_type - self.rope_scaling_short_factor = rope_scaling_short_factor - self.rope_scaling_long_factor = rope_scaling_long_factor - - self.dropout = dropout - - self.layer_norm_epsilon = layer_norm_epsilon - self.activation = keras.activations.get(activation) - self.kernel_initializer = keras.initializers.get(kernel_initializer) - - def build(self, decoder_sequence_shape): - # Pre-attention layernorm. - self.pre_attention_layernorm = Phi4LayerNorm( - epsilon=self.layer_norm_epsilon, - dtype=self.dtype_policy, - name="pre_attention_layernorm", - ) - self.pre_attention_layernorm.build(decoder_sequence_shape) - - # Self attention layer. - self.attention = Phi4Attention( - num_query_heads=self.num_query_heads, - num_key_value_heads=self.num_key_value_heads, - kernel_initializer=clone_initializer(self.kernel_initializer), - dropout=self.dropout, - max_sequence_length=self.max_sequence_length, - pretraining_sequence_length=self.pretraining_sequence_length, - rope_max_wavelength=self.rope_max_wavelength, - rope_scaling_type=self.rope_scaling_type, - rope_scaling_short_factor=self.rope_scaling_short_factor, - rope_scaling_long_factor=self.rope_scaling_long_factor, - dtype=self.dtype_policy, - name="attention", - ) - self.attention.build(decoder_sequence_shape) - - # Post-attention layernorm. - self.post_attention_layernorm = Phi4LayerNorm( - epsilon=self.layer_norm_epsilon, - dtype=self.dtype_policy, - name="post_attention_layernorm", - ) - self.post_attention_layernorm.build(decoder_sequence_shape) - - # feedforward layers. - self.feedforward_intermediate_dense = keras.layers.Dense( - self.intermediate_dim, - kernel_initializer=clone_initializer(self.kernel_initializer), - use_bias=False, - dtype=self.dtype_policy, - name="feedforward_intermediate_dense", - ) - self.feedforward_intermediate_dense.build(decoder_sequence_shape) - - self.feedforward_gate_dense = keras.layers.Dense( - self.intermediate_dim, - kernel_initializer=clone_initializer(self.kernel_initializer), - use_bias=False, - dtype=self.dtype_policy, - name="feedforward_gate_dense", - ) - self.feedforward_gate_dense.build(decoder_sequence_shape) - - self.feedforward_output_dense = keras.layers.Dense( - self.hidden_dim, - kernel_initializer=clone_initializer(self.kernel_initializer), - use_bias=False, - dtype=self.dtype_policy, - name="feedforward_output_dense", - ) - - self.feedforward_output_dense.build( - self.feedforward_gate_dense.compute_output_shape( - decoder_sequence_shape - ) - ) - - # Dropout - self.attention_dropout = keras.layers.Dropout( - rate=self.dropout, - dtype=self.dtype_policy, - name="attention_dropout", - ) - self.feedforward_dropout = keras.layers.Dropout( - rate=self.dropout, - dtype=self.dtype_policy, - name="feedforward_dropout", - ) - - self.built = True - - def call( - self, - decoder_sequence, - decoder_padding_mask=None, - decoder_attention_mask=None, - attention_cache=None, - attention_cache_update_index=None, - ): - self_attention_mask = self._compute_self_attention_mask( - decoder_sequence=decoder_sequence, - decoder_padding_mask=decoder_padding_mask, - decoder_attention_mask=decoder_attention_mask, - attention_cache=attention_cache, - attention_cache_update_index=attention_cache_update_index, - ) - residual = decoder_sequence - x = self.pre_attention_layernorm(decoder_sequence) - x = self.attention( - hidden_states=x, - attention_mask=self_attention_mask, - cache=attention_cache, - cache_update_index=attention_cache_update_index, - ) - if attention_cache is not None: - x, attention_cache = x - x = self.attention_dropout(x) - x = x + residual - - residual = x - x = self.post_attention_layernorm(x) - # Note that we run the activation function in full 32-bit - # precision since this is what `torch.nn.functional.silu` - # does. Internally, `torch.nn.functional.silu` converts the - # inputs to float32, computes SiLU, and converts the outputs - # back to compute dtype. - # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 - # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 - gate_output = self.feedforward_gate_dense(x) - gate_output = ops.cast(gate_output, "float32") - gate_output = self.activation(gate_output) - gate_output = ops.cast(gate_output, self.compute_dtype) - x = self.feedforward_intermediate_dense(x) - x = self.feedforward_output_dense(ops.multiply(x, gate_output)) - x = self.feedforward_dropout(x) - decoder_output = x + residual - - if attention_cache is not None: - return decoder_output, attention_cache - return decoder_output - - def _compute_self_attention_mask( - self, - decoder_sequence, - decoder_padding_mask, - decoder_attention_mask, - attention_cache, - attention_cache_update_index, - ): - decoder_mask = merge_padding_and_attention_mask( - decoder_sequence, decoder_padding_mask, decoder_attention_mask - ) - batch_size = ops.shape(decoder_sequence)[0] - input_length = output_length = ops.shape(decoder_sequence)[1] - # We need to handle a rectangular causal mask when doing cached - # decoding. For generative inference, `decoder_sequence` will - # generally be length 1, and `cache` will be the full generation length. - if attention_cache is not None: - input_length = ops.shape(attention_cache)[2] - - cache_update_index = ( - 0 - if attention_cache_update_index is None - else attention_cache_update_index - ) - - causal_mask = compute_causal_mask( - batch_size, input_length, output_length, cache_update_index - ) - - return ( - ops.minimum(decoder_mask, causal_mask) - if decoder_mask is not None - else causal_mask - ) - - def compute_output_shape(self, decoder_sequence_shape): - return decoder_sequence_shape - - def get_config(self): - config = super().get_config() - config.update( - { - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "num_query_heads": self.num_query_heads, - "num_key_value_heads": self.num_key_value_heads, - "activation": keras.activations.serialize(self.activation), - "layer_norm_epsilon": self.layer_norm_epsilon, - "kernel_initializer": keras.initializers.serialize( - self.kernel_initializer - ), - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "pretraining_sequence_length": self.pretraining_sequence_length, - "rope_max_wavelength": self.rope_max_wavelength, - "rope_scaling_type": self.rope_scaling_type, - "rope_scaling_short_factor": self.rope_scaling_short_factor, - "rope_scaling_long_factor": self.rope_scaling_long_factor, - } - ) - return config From cbdf6ced0a8e2fa9e4430d253eac862a04bbf917 Mon Sep 17 00:00:00 2001 From: Rahul Yedida Date: Wed, 16 Jul 2025 22:37:24 -0400 Subject: [PATCH 08/11] fix(phi4): remove unused layernorm and rotary embedding layers --- keras_hub/src/models/phi4/phi4_layernorm.py | 35 ----- .../src/models/phi4/phi4_rotary_embedding.py | 124 ------------------ 2 files changed, 159 deletions(-) delete mode 100644 keras_hub/src/models/phi4/phi4_layernorm.py delete mode 100644 keras_hub/src/models/phi4/phi4_rotary_embedding.py diff --git a/keras_hub/src/models/phi4/phi4_layernorm.py b/keras_hub/src/models/phi4/phi4_layernorm.py deleted file mode 100644 index e6238c4f0d..0000000000 --- a/keras_hub/src/models/phi4/phi4_layernorm.py +++ /dev/null @@ -1,35 +0,0 @@ -import keras -from keras import ops - - -# TODO: Deprecate this in favor of -# `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is -# removed. -class Phi4LayerNorm(keras.layers.Layer): - """A normalization layer for Phi-4 that implements RMS normalization.""" - - def __init__(self, epsilon=1e-5, **kwargs): - super().__init__(**kwargs) - self.epsilon = epsilon - - def build(self, input_shape): - dim = input_shape[-1] - self.scale = self.add_weight( - name="scale", - trainable=True, - shape=(dim,), - initializer="ones", - dtype=self.variable_dtype, - ) - self.built = True - - def call(self, x): - x = ops.cast(x, "float32") - var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) - x = x * ops.rsqrt(var + self.epsilon) - return ops.cast(x * self.scale, self.compute_dtype) - - def get_config(self): - config = super().get_config() - config.update({"epsilon": self.epsilon}) - return config diff --git a/keras_hub/src/models/phi4/phi4_rotary_embedding.py b/keras_hub/src/models/phi4/phi4_rotary_embedding.py deleted file mode 100644 index 77b2dadcc2..0000000000 --- a/keras_hub/src/models/phi4/phi4_rotary_embedding.py +++ /dev/null @@ -1,124 +0,0 @@ -import math - -from keras import ops - -from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding - - -class Phi4SuScaledRotaryEmbedding(RotaryEmbedding): - """SuRotary positional encoding layer. - - Args: - inverese_freq_short_factor List[float]: List of factors used to adjust - rope frequencies when the `rope_scaling_type` is `"su"`. List must - be of length `hidden_dim//num_query_heads//2`. It is used when - `sequence_length` is smaller than `original_max_sequence_length`. - inverese_freq_long_factor List[float]: List of factors used to adjust - rope frequencies when the `rope_scaling_type` is `"su"`. List must - be of length `hidden_dim//num_query_heads//2`. It is used when - `sequence_length` is larger than `original_max_sequence_length`. - max_sequence_length: int. The maximum sequence length that this - model might ever be used with. - pretraining_sequence_length: int. The maximum sequence length that - this model was pretrained with. - max_wavelength: int. The maximum angular wavelength of the sine/cosine - curves. - - Call arguments: - inputs: The tensor inputs to apply the embedding to. This can have - any shape, but must contain both a sequence and feature axis. The - rotary embedding will be applied to `inputs` and returned. - start_index: An integer or integer tensor. The starting position to - compute the rotary embedding from. This is useful during cached - decoding, where each position is predicted separately in a loop. - - References: - - [Phi-3-Medium-128k-Instruct Implementation (Since Phi-4 is based on Phi-3-Medium)](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct/blob/main/modeling_phi3.py) - """ - - def __init__( - self, - inverese_freq_short_factor, - inverese_freq_long_factor, - max_sequence_length=16_384, - pretraining_sequence_length=16_384, - max_wavelength=250_000, - **kwargs, - ): - super().__init__(max_wavelength=max_wavelength, **kwargs) - self.max_sequence_length = max_sequence_length - self.pretraining_sequence_length = pretraining_sequence_length - - scaling_factor = ( - self.max_sequence_length / self.pretraining_sequence_length - ) - if scaling_factor <= 1.0: - self.embedding_scaling_factor = 1.0 - else: - self.embedding_scaling_factor = math.sqrt( - 1 - + math.log(scaling_factor) - / math.log(self.pretraining_sequence_length) - ) - - self.inverese_freq_short_factor = inverese_freq_short_factor - self.inverese_freq_long_factor = inverese_freq_long_factor - - def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None): - feature_axis = len(inputs.shape) - 1 - sequence_axis = 1 - - rotary_dim = ops.shape(inputs)[feature_axis] - inverse_freq = self._get_inverse_freq(rotary_dim) - - # Multiply inverse_freq by a factor. - if ops.shape(inputs)[sequence_axis] > self.pretraining_sequence_length: - inverse_freq = ops.divide( - inverse_freq, - ops.convert_to_tensor(self.inverese_freq_long_factor), - ) - else: - inverse_freq = ops.divide( - inverse_freq, - ops.convert_to_tensor(self.inverese_freq_short_factor), - ) - - if positions is None: - positions = self._compute_positions(inputs, start_index) - else: - positions = ops.cast(positions, "float32") - - freq = ops.einsum("i,j->ij", positions, inverse_freq) - embedding = ops.stack((freq, freq), axis=-2) - embedding = ops.reshape( - embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2) - ) - - # Reshape the embedding to be broadcastable with input shape. - if feature_axis < sequence_axis: - embedding = ops.transpose(embedding) - for axis in range(len(inputs.shape)): - if axis != sequence_axis and axis != feature_axis: - embedding = ops.expand_dims(embedding, axis) - - cos_emb = ops.cast( - ops.cos(embedding) * self.embedding_scaling_factor, - self.compute_dtype, - ) - sin_emb = ops.cast( - ops.sin(embedding) * self.embedding_scaling_factor, - self.compute_dtype, - ) - return cos_emb, sin_emb - - def get_config(self): - config = super().get_config() - config.update( - { - "max_sequence_length": self.max_sequence_length, - "pretraining_sequence_length": self.pretraining_sequence_length, - "inverese_freq_short_factor": self.inverese_freq_short_factor, - "inverese_freq_long_factor": self.inverese_freq_long_factor, - } - ) - return config From ce07951cdae63793025d1c178c9f16af32584dec Mon Sep 17 00:00:00 2001 From: Rahul Yedida Date: Sun, 20 Jul 2025 19:19:39 -0400 Subject: [PATCH 09/11] fix(phi4): fix unit tests --- .../src/models/phi4/phi4_causal_lm_test.py | 39 ++++++++++++------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/keras_hub/src/models/phi4/phi4_causal_lm_test.py b/keras_hub/src/models/phi4/phi4_causal_lm_test.py index f375aea782..31651a56e6 100644 --- a/keras_hub/src/models/phi4/phi4_causal_lm_test.py +++ b/keras_hub/src/models/phi4/phi4_causal_lm_test.py @@ -1,4 +1,3 @@ -import os from unittest.mock import patch import pytest @@ -15,14 +14,26 @@ class Phi4CausalLMTest(TestCase): def setUp(self): + # Move to index 0 since the tokenizer sets pad_token_id to 0. + self.vocab = ["", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += [ + "", + "", + "!", + "", + "", + "", + # Necessary since `Phi3CausalLM` requires this in `generate()` + "<|end|>", + ] + self.vocab += ["", "", ""] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] self.preprocessor = Phi4CausalLMPreprocessor( - Phi4Tokenizer( - # Generated using create_phi4_test_proto.py - proto=os.path.join( - self.get_test_data_dir(), "phi4_test_vocab.spm" - ) - ), - sequence_length=12, + Phi4Tokenizer(vocabulary=self.vocab, merges=self.merges), + sequence_length=15, ) self.vocab_size = self.preprocessor.tokenizer.vocabulary_size() self.backbone = Phi4Backbone( @@ -37,7 +48,7 @@ def setUp(self): "preprocessor": self.preprocessor, "backbone": self.backbone, } - self.train_data = (["the quick brown fox", "the earth is round"],) + self.train_data = ([" airplane at airport", " airplane at airport"],) self.input_data = self.preprocessor(*self.train_data)[0] def test_causal_lm_basics(self): @@ -45,14 +56,14 @@ def test_causal_lm_basics(self): cls=Phi4CausalLM, init_kwargs=self.init_kwargs, train_data=self.train_data, - expected_output_shape=(2, 12, self.vocab_size), + expected_output_shape=(2, 15, self.vocab_size), ) def test_generate(self): causal_lm = Phi4CausalLM(**self.init_kwargs) # String input. - prompt = "the fox" - output = causal_lm.generate(prompt) + prompt = " airplane at airport" + output = causal_lm.generate(" airplane at airport") self.assertTrue(prompt in output) # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess([prompt]) @@ -82,8 +93,8 @@ def wrapper(*args, **kwargs): return logits, hidden_states, cache with patch.object(causal_lm, "call_with_cache", wraps=wrapper): - prompt = ["the fox", "the earth"] - output = causal_lm.generate(prompt) + prompt = [" airplane at airport", " airplane"] + output = causal_lm.generate(prompt, max_length=7) # We should immediately abort and output the prompt. self.assertEqual(prompt, output) From 0d130498228c2003e8fd1c787f91e3af7d2fe53a Mon Sep 17 00:00:00 2001 From: Rahul Yedida Date: Tue, 22 Jul 2025 21:04:02 -0400 Subject: [PATCH 10/11] fix(phi4): fix unit tests --- .../phi4/phi4_causal_lm_preprocessor_test.py | 51 +++++++++++-------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py b/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py index ff39cbbaa1..96f207e817 100644 --- a/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py +++ b/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py @@ -1,5 +1,3 @@ -import os - import pytest from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import ( @@ -11,16 +9,29 @@ class Phi4CausalLMPreprocessorTest(TestCase): def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += [ + "", + "", + "", + "", + "", + "", + ] + self.vocab += ["", "", ""] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] self.tokenizer = Phi4Tokenizer( - # Generated using create_phi4_test_proto.py - proto=os.path.join(self.get_test_data_dir(), "phi4_test_vocab.spm") + vocabulary=self.vocab, merges=self.merges ) self.init_kwargs = { "tokenizer": self.tokenizer, "sequence_length": 10, } - # [3, 5, 6, 4, 3, 9, 7, 11] - self.input_data = (["the fox"],) + # [1, 3, 4, 2, 5] + self.input_data = (["airplane at airport"],) def test_preprocessor_basics(self): self.run_preprocessor_test( @@ -29,16 +40,16 @@ def test_preprocessor_basics(self): input_data=self.input_data, expected_output=( { - "token_ids": [[1, 3, 5, 6, 4, 3, 9, 7, 11, 15]], + "token_ids": [[6, 1, 3, 4, 2, 5, 0, 0, 0, 0]], "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], }, - [[3, 5, 6, 4, 3, 9, 7, 11, 15, 0]], - [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0]], + [[1, 3, 4, 2, 5, 0, 0, 0, 0, 7]], + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], ), ) def test_no_start_end_token(self): - input_data = ["the fox"] * 4 + input_data = ["airplane at airport"] * 4 preprocessor = Phi4CausalLMPreprocessor( **self.init_kwargs, @@ -47,29 +58,29 @@ def test_no_start_end_token(self): ) x, y, sw = preprocessor(input_data) self.assertAllEqual( - x["token_ids"], [[3, 5, 6, 4, 3, 9, 7, 11, 0, 0]] * 4 + x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0, 0, 0]] * 4 ) self.assertAllEqual( - x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 4 + x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] * 4 ) - self.assertAllEqual(y, [[5, 6, 4, 3, 9, 7, 11, 0, 0, 0]] * 4) - self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] * 4) def test_generate_preprocess(self): - input_data = "the fox" + input_data = "airplane at airport" preprocessor = Phi4CausalLMPreprocessor(**self.init_kwargs) x = preprocessor.generate_preprocess(input_data) - self.assertAllEqual(x["token_ids"], [1, 3, 5, 6, 4, 3, 9, 7, 11, 0]) - self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]) + self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) def test_generate_postprocess(self): input_data = { - "token_ids": [1, 3, 5, 6, 4, 3, 9, 7, 11, 0], - "padding_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + "token_ids": [1, 3, 4, 2, 5, 3, 9, 7, 11, 0], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], } preprocessor = Phi4CausalLMPreprocessor(**self.init_kwargs) x = preprocessor.generate_postprocess(input_data) - self.assertAllEqual(x, "the fox") + self.assertAllEqual(x, "airplane at airport") @pytest.mark.extra_large def test_all_presets(self): From e3fafbfd65fbb12cae57d3d56c430b70e1142968 Mon Sep 17 00:00:00 2001 From: Rahul Yedida Date: Tue, 23 Sep 2025 23:33:44 -0400 Subject: [PATCH 11/11] chore(phi4): change test preset model, uncomment test marker --- keras_hub/src/models/phi4/phi4_backbone.py | 2 -- keras_hub/src/models/phi4/phi4_causal_lm_test.py | 2 +- keras_hub/src/models/phi4/phi4_tokenizer_test.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/models/phi4/phi4_backbone.py b/keras_hub/src/models/phi4/phi4_backbone.py index 0222f50e3e..23328c3f78 100644 --- a/keras_hub/src/models/phi4/phi4_backbone.py +++ b/keras_hub/src/models/phi4/phi4_backbone.py @@ -59,5 +59,3 @@ class Phi4Backbone(Phi3Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. """ - - pass diff --git a/keras_hub/src/models/phi4/phi4_causal_lm_test.py b/keras_hub/src/models/phi4/phi4_causal_lm_test.py index 31651a56e6..28a7c2e797 100644 --- a/keras_hub/src/models/phi4/phi4_causal_lm_test.py +++ b/keras_hub/src/models/phi4/phi4_causal_lm_test.py @@ -110,7 +110,7 @@ def test_generate_compilation(self): causal_lm.compile(sampler="greedy") self.assertIsNone(causal_lm.generate_function) - # @pytest.mark.large + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( cls=Phi4CausalLM, diff --git a/keras_hub/src/models/phi4/phi4_tokenizer_test.py b/keras_hub/src/models/phi4/phi4_tokenizer_test.py index 2e91922d05..e5bafa416c 100644 --- a/keras_hub/src/models/phi4/phi4_tokenizer_test.py +++ b/keras_hub/src/models/phi4/phi4_tokenizer_test.py @@ -46,7 +46,7 @@ def test_errors_missing_special_tokens(self): def test_smallest_preset(self): self.run_preset_test( cls=Phi4Tokenizer, - preset="phi4_8b_en", + preset="phi4_mini_4k_instruct_en", input_data=["The quick brown fox."], expected_output=[[791, 4062, 14198, 39935, 13]], )