- 
                Notifications
    
You must be signed in to change notification settings  - Fork 307
 
Add Phi-4 Backbone #2272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Open
      
      
            yrahul3910
  wants to merge
  12
  commits into
  keras-team:master
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
yrahul3910:master
  
      
      
   
  
    
  
  
  
 
  
      
    base: master
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
  
     Open
                    Add Phi-4 Backbone #2272
Changes from 11 commits
      Commits
    
    
            Show all changes
          
          
            12 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      4a8566b
              
                feat(phi4): add phi4_backbone
              
              
                yrahul3910 69f66ff
              
                docs(phi4): update defaults in docstring
              
              
                yrahul3910 1bfe756
              
                Merge branch 'keras-team:master' into master
              
              
                yrahul3910 8b7146e
              
                feat(phi4): refactor Phi4Backbone to inherit from Phi-3
              
              
                yrahul3910 3df73af
              
                feat(phi4): add phi-4 tokenizer
              
              
                yrahul3910 4aceea3
              
                feat(phi4): add phi-4 causal_lm files
              
              
                yrahul3910 17a30ce
              
                fix(phi4): update docstring to use correct variable names
              
              
                yrahul3910 82b2912
              
                fix(phi4): remove dedicated attention and decoder modules
              
              
                yrahul3910 cbdf6ce
              
                fix(phi4): remove unused layernorm and rotary embedding layers
              
              
                yrahul3910 ce07951
              
                fix(phi4): fix unit tests
              
              
                yrahul3910 0d13049
              
                fix(phi4): fix unit tests
              
              
                yrahul3910 e3fafbf
              
                chore(phi4): change test preset model, uncomment test marker
              
              
                yrahul3910 File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1 @@ | ||
| # TODO: Add a register_presets call once phi4_presets.py is implemented. | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone | ||
| 
     | 
||
| 
     | 
||
| @keras_hub_export("keras_hub.models.Phi4Backbone") | ||
| class Phi4Backbone(Phi3Backbone): | ||
| """Phi-4 core network with hyperparameters. | ||
| 
     | 
||
| This network implements a Transformer-based decoder network, | ||
| Phi-4, as described in ["Phi-4 Technical Report"](https://arxiv.org/pdf/2412.08905). | ||
| It includes the embedding lookups and transformer layers. | ||
| 
     | 
||
| The default constructor gives a fully customizable, randomly initialized | ||
| phi-4 model with any number of layers, heads, and embedding | ||
| dimensions. To load preset architectures and weights, use the `from_preset` | ||
| constructor. | ||
| 
     | 
||
| Note that the defaults here are the Phi-3 defaults, because the Phi-4 model | ||
| follows the Phi-3-medium architecture but with different hyper-parameters. | ||
| Use `keras_hub.models.Backbone.from_preset` to get the Phi-4 defaults. | ||
| 
     | 
||
| Args: | ||
| vocabulary_size: int. The size of the token vocabulary. | ||
| num_layers: int. The number of transformer layers. | ||
| hidden_dim: int. The size of the embeddings and the hidden states of | ||
| the transformer layers. | ||
| intermediate_dim: int. The output dimension of the first Dense layer in | ||
| a three-layer feedforward network for each transformer. | ||
| num_query_heads: int. The number of query attention heads for each | ||
| transformer layer. | ||
| num_key_value_heads: int. The number of key and value attention heads | ||
| for each transformer layer. | ||
| layer_norm_epsilon: float, optional. Epsilon for the RMS layernorm | ||
| layers in the transformer decoder. Defaults to `1e-6`. | ||
| dropout:: float, optional. Dropout probability for the Transformer | ||
| decoder. | ||
| max_sequence_length: int, optional. The maximum sequence length | ||
| that this model might ever be used with. Defaults to `4096`. | ||
| pretraining_sequence_length: int, optional. The maximum sequence length | ||
| that the model was pretrained with. Defaults to `4096`. | ||
| rope_max_wavelength: int, optional. The maximum angular wavelength of | ||
| the sine/cosine curves, for rotary embeddings. Defaults to `10000`. | ||
| rope_scaling_type: str, optional. The type of the rope scaling. Can be | ||
| either `None` or `"su"`. `None` is for no rope scaling, `"su"` is | ||
| for SuScaled rope, `"su"` is used when `max_sequence_length` is | ||
| larger than `original_max_sequence_length`. Defaults to `None`. | ||
| rope_scaling_short_factor: list[float]. List of factors used to adjust | ||
| rope frequencies when the `rope_scaling_type` is `"su"`. List must | ||
| be of length `hidden_dim//num_query_heads//2`. It is used when | ||
| `sequence_length` is smaller than `pretraining_sequence_length`. | ||
| Defaults to `None`. | ||
| rope_scaling_long_factor: list[float]. List of factors used to adjust | ||
| rope frequencies when the `rope_scaling_type` is `"su"`. List must | ||
| be of length `hidden_dim//num_query_heads//2`. It is used when | ||
| `sequence_length` is larger than `pretraining_sequence_length`. | ||
| Defaults to `None`. | ||
| dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use | ||
| for model computations and weights. Note that some computations, | ||
| such as softmax and layer normalization, will always be done at | ||
| float32 precision regardless of dtype. | ||
| """ | ||
| 
     | 
||
| pass | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| import pytest | ||
| from keras import ops | ||
| 
     | 
||
| from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone | ||
| from keras_hub.src.tests.test_case import TestCase | ||
| 
     | 
||
| 
     | 
||
| class Phi4Test(TestCase): | ||
| def setUp(self): | ||
| self.init_kwargs = { | ||
| "vocabulary_size": 10, | ||
| "num_layers": 2, | ||
| "num_query_heads": 4, | ||
| "num_key_value_heads": 2, | ||
| "hidden_dim": 8, | ||
| "intermediate_dim": 8, | ||
| } | ||
| self.su_rotary_init_kwargs = { | ||
| "vocabulary_size": 10, | ||
| "num_layers": 2, | ||
| "num_query_heads": 2, | ||
| "num_key_value_heads": 1, | ||
| "hidden_dim": 8, | ||
| "intermediate_dim": 12, | ||
| "max_sequence_length": 10, | ||
| "pretraining_sequence_length": 5, | ||
| "rope_scaling_type": "su", | ||
| "rope_scaling_short_factor": [1.2, 1.4], | ||
| "rope_scaling_long_factor": [0.8, 0.6], | ||
| } | ||
| self.input_data = { | ||
| "token_ids": ops.ones((2, 5), dtype="int32"), | ||
| "padding_mask": ops.ones((2, 5), dtype="int32"), | ||
| } | ||
| 
     | 
||
| def test_backbone_basics(self): | ||
| self.run_backbone_test( | ||
| cls=Phi4Backbone, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| expected_output_shape=(2, 5, 8), | ||
| ) | ||
| 
     | 
||
| @pytest.mark.large | ||
| def test_saved_model(self): | ||
| self.run_model_saving_test( | ||
| cls=Phi4Backbone, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| ) | ||
| 
     | 
||
| def test_backbone_basics_with_su_rotary(self): | ||
| self.run_backbone_test( | ||
| cls=Phi4Backbone, | ||
| init_kwargs=self.su_rotary_init_kwargs, | ||
| input_data=self.input_data, | ||
| expected_output_shape=(2, 5, 8), | ||
| ) | ||
| 
     | 
||
| @pytest.mark.large | ||
| def test_saved_model_with_su_rotary(self): | ||
| self.run_model_saving_test( | ||
| cls=Phi4Backbone, | ||
| init_kwargs=self.su_rotary_init_kwargs, | ||
| input_data=self.input_data, | ||
| ) | ||
| 
     | 
||
| @pytest.mark.extra_large | ||
| def test_smallest_preset(self): | ||
| self.run_preset_test( | ||
| cls=Phi4Backbone, | ||
| preset="phi4_mini_4k_instruct_en", | ||
| input_data={ | ||
| "token_ids": ops.array([[1, 450, 4996, 1701, 29916, 29889]]), | ||
| "padding_mask": ops.ones((1, 6), dtype="int32"), | ||
| }, | ||
| expected_output_shape=(1, 6, 3072), | ||
| # The forward pass from a preset should be stable! | ||
| # Reference values computed using PyTorch HF model. | ||
| expected_partial_output=ops.array( | ||
| [-0.21222, 0.04004, -0.02759, 0.02200] | ||
| ), | ||
| ) | ||
| 
     | 
||
| @pytest.mark.extra_large | ||
| def test_all_presets(self): | ||
| for preset in Phi4Backbone.presets: | ||
| self.run_preset_test( | ||
| cls=Phi4Backbone, | ||
| preset=preset, | ||
| input_data=self.input_data, | ||
| ) | ||
| 
         
      Comment on lines
    
      +85
     to 
      +92
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usually how big these models will be and how many presets are we testing here?  | 
||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM | ||
| from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone | ||
| from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import ( | ||
| Phi4CausalLMPreprocessor, | ||
| ) | ||
| 
     | 
||
| 
     | 
||
| @keras_hub_export("keras_hub.models.Phi4CausalLM") | ||
| class Phi4CausalLM(Phi3CausalLM): | ||
| """An end-to-end Phi4 model for causal language modeling. | ||
| 
     | 
||
| A causal language model (LM) predicts the next token based on previous | ||
| tokens. This task setup can be used to train the model unsupervised on | ||
| plain text input, or to autoregressively generate plain text similar to | ||
| the data used for training. This task can be used for pre-training or | ||
| fine-tuning a Phi-4 model, simply by calling `fit()`. | ||
| 
     | 
||
| This model has a `generate()` method, which generates text based on a | ||
| prompt. The generation strategy used is controlled by an additional | ||
| `sampler` argument on `compile()`. You can recompile the model with | ||
| different `keras_hub.samplers` objects to control the generation. By | ||
| default, `"top_k"` sampling will be used. | ||
| 
     | 
||
| Args: | ||
| backbone: A `keras_hub.models.Phi4Backbone` instance. | ||
| preprocessor: A `keras_hub.models.Phi4CausalLMPreprocessor` or `None`. | ||
| If `None`, this model will not apply preprocessing, and inputs | ||
| should be preprocessed before calling the model. | ||
| """ | ||
| 
     | 
||
| backbone_cls = Phi4Backbone | ||
| preprocessor_cls = Phi4CausalLMPreprocessor | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor | ||
| from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone | ||
| from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer | ||
| 
     | 
||
| 
     | 
||
| @keras_hub_export("keras_hub.models.Phi4CausalLMPreprocessor") | ||
| class Phi4CausalLMPreprocessor(CausalLMPreprocessor): | ||
| """Phi4 Causal LM preprocessor. | ||
| 
     | 
||
| This preprocessing layer is meant for use with | ||
| `keras_hub.models.Phi4CausalLM`. By default, it will take in batches of | ||
| strings, and return outputs in a `(x, y, sample_weight)` format, where the | ||
| `y` label is the next token id in the `x` sequence. | ||
| 
     | 
||
| For use with generation, the layer also exposes two methods | ||
| `generate_preprocess()` and `generate_postprocess()`. When this preprocessor | ||
| is attached to a `keras_hub.models.Phi4CausalLM` instance, these methods | ||
| will be called implicitly in `generate()`. They can also be called | ||
| standalone (e.g. to precompute preprocessing inputs for generation in a | ||
| separate process). | ||
| 
     | 
||
| Args: | ||
| tokenizer: A `keras_hub.models.Phi4Tokenizer` instance. | ||
| sequence_length: The length of the packed inputs. | ||
| add_start_token: If `True`, the preprocessor will prepend the tokenizer | ||
| start token to each input sequence. Default is `True`. | ||
| add_end_token: If `True`, the preprocessor will append the tokenizer | ||
| end token to each input sequence. Default is `False`. | ||
| 
     | 
||
| Call arguments: | ||
| x: A string, `tf.Tensor` or list of python strings. | ||
| y: Label data. Should always be `None` as the layer generates labels. | ||
| sample_weight: Label weights. Should always be `None` as the layer | ||
| generates label weights. | ||
| sequence_length: Pass to override the configured `sequence_length` of | ||
| the layer. | ||
| 
     | 
||
| Examples: | ||
| ```python | ||
| # Load the preprocessor from a preset. | ||
| preprocessor = keras_hub.models.Phi4CausalLMPreprocessor.from_preset( | ||
| "phi4_mini_4k_instruct_en" | ||
| ) | ||
| 
     | 
||
| # Tokenize and pack a single sentence. | ||
| sentence = tf.constant("League of legends") | ||
| preprocessor(sentence) | ||
| # Same output. | ||
| preprocessor("League of legends") | ||
| 
     | 
||
| # Tokenize a batch of sentences. | ||
| sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) | ||
| preprocessor(sentences) | ||
| # Same output. | ||
| preprocessor(["Taco tuesday", "Fish taco please!"]) | ||
| 
     | 
||
| # Map a dataset to preprocess a single sentence. | ||
| features = tf.constant( | ||
| [ | ||
| "Avatar 2 is amazing!", | ||
| "Well, I am not sure.", | ||
| ] | ||
| ) | ||
| labels = tf.constant([1, 0]) | ||
| ds = tf.data.Dataset.from_tensor_slices((features, labels)) | ||
| ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
| 
     | 
||
| # Map a dataset to preprocess unlabled sentences. | ||
| ds = tf.data.Dataset.from_tensor_slices(features) | ||
| ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
| ``` | ||
| """ | ||
| 
     | 
||
| backbone_cls = Phi4Backbone | ||
| tokenizer_cls = Phi4Tokenizer | 
        
          
          
            92 changes: 92 additions & 0 deletions
          
          92 
        
  keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py
  
  
      
      
   
        
      
      
    
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| import pytest | ||
| 
     | 
||
| from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import ( | ||
| Phi4CausalLMPreprocessor, | ||
| ) | ||
| from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer | ||
| from keras_hub.src.tests.test_case import TestCase | ||
| 
     | 
||
| 
     | 
||
| class Phi4CausalLMPreprocessorTest(TestCase): | ||
| def setUp(self): | ||
| self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] | ||
| self.vocab += [ | ||
| "<s>", | ||
| "</s>", | ||
| "<pad>", | ||
| "<im_start>", | ||
| "<im_sep>", | ||
| "<im_end>", | ||
| ] | ||
| self.vocab += ["<fim_prefix>", "<fim_middle>", "<fim_suffix>"] | ||
| self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) | ||
| self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] | ||
| self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] | ||
| self.merges += ["Ġai r", "Ġa i", "pla ne"] | ||
| self.tokenizer = Phi4Tokenizer( | ||
| vocabulary=self.vocab, merges=self.merges | ||
| ) | ||
| self.init_kwargs = { | ||
| "tokenizer": self.tokenizer, | ||
| "sequence_length": 10, | ||
| } | ||
| # [1, 3, 4, 2, 5] | ||
| self.input_data = (["airplane at airport"],) | ||
| 
     | 
||
| def test_preprocessor_basics(self): | ||
| self.run_preprocessor_test( | ||
| cls=Phi4CausalLMPreprocessor, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| expected_output=( | ||
| { | ||
| "token_ids": [[6, 1, 3, 4, 2, 5, 0, 0, 0, 0]], | ||
| "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], | ||
| }, | ||
| [[1, 3, 4, 2, 5, 0, 0, 0, 0, 7]], | ||
| [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], | ||
| ), | ||
| ) | ||
| 
     | 
||
| def test_no_start_end_token(self): | ||
| input_data = ["airplane at airport"] * 4 | ||
| 
     | 
||
| preprocessor = Phi4CausalLMPreprocessor( | ||
| **self.init_kwargs, | ||
| add_start_token=False, | ||
| add_end_token=False, | ||
| ) | ||
| x, y, sw = preprocessor(input_data) | ||
| self.assertAllEqual( | ||
| x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0, 0, 0]] * 4 | ||
| ) | ||
| self.assertAllEqual( | ||
| x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] * 4 | ||
| ) | ||
| self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0, 0, 0]] * 4) | ||
| self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] * 4) | ||
| 
     | 
||
| def test_generate_preprocess(self): | ||
| input_data = "airplane at airport" | ||
| preprocessor = Phi4CausalLMPreprocessor(**self.init_kwargs) | ||
| x = preprocessor.generate_preprocess(input_data) | ||
| self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0, 0, 0]) | ||
| self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) | ||
| 
     | 
||
| def test_generate_postprocess(self): | ||
| input_data = { | ||
| "token_ids": [1, 3, 4, 2, 5, 3, 9, 7, 11, 0], | ||
| "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], | ||
| } | ||
| preprocessor = Phi4CausalLMPreprocessor(**self.init_kwargs) | ||
| x = preprocessor.generate_postprocess(input_data) | ||
| self.assertAllEqual(x, "airplane at airport") | ||
| 
     | 
||
| @pytest.mark.extra_large | ||
| def test_all_presets(self): | ||
| for preset in Phi4CausalLMPreprocessor.presets: | ||
| self.run_preset_test( | ||
| cls=Phi4CausalLMPreprocessor, | ||
| preset=preset, | ||
| input_data=self.input_data, | ||
| ) | 
      
      Oops, something went wrong.
        
    
  
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not required