diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ef928a1655..81848e3b6f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,14 +3,7 @@ repos: hooks: - id: api-gen name: api_gen - entry: | - bash shell/api_gen.sh - git status - clean=$(git status | grep "nothing to commit") - if [ -z "$clean" ]; then - echo "Please run shell/api_gen.sh to generate API." - exit 1 - fi + entry: bash -c "shell/api_gen.sh && if [ -n \"$(git status --porcelain)\" ]; then echo 'Please run shell/api_gen.sh to generate API.' && exit 1; fi" language: system stages: [pre-commit, manual] require_serial: true diff --git a/keras_hub/src/models/layoutlmv3/__init__.py b/keras_hub/src/models/layoutlmv3/__init__.py new file mode 100644 index 0000000000..4ba7dfbfb7 --- /dev/null +++ b/keras_hub/src/models/layoutlmv3/__init__.py @@ -0,0 +1,35 @@ +# Import LayoutLMv3 components with error handling for backend compatibility +try: + from keras_hub.src.models.layoutlmv3.layoutlmv3_backbone import ( + LayoutLMv3Backbone, + ) +except ImportError as e: + # Graceful degradation for missing dependencies + LayoutLMv3Backbone = None + import warnings + + warnings.warn(f"LayoutLMv3Backbone import failed: {e}") + +try: + from keras_hub.src.models.layoutlmv3.layoutlmv3_tokenizer import ( + LayoutLMv3Tokenizer, + ) +except ImportError as e: + # Graceful degradation for missing dependencies + LayoutLMv3Tokenizer = None + import warnings + + warnings.warn(f"LayoutLMv3Tokenizer import failed: {e}") + +from keras_hub.src.utils.preset_utils import register_presets + +# Only register presets if classes loaded successfully +if LayoutLMv3Backbone is not None: + try: + # Register presets if they exist + backbone_presets = {} # Empty for now - will be populated when presets are added + register_presets(backbone_presets, LayoutLMv3Backbone) + except Exception as e: + import warnings + + warnings.warn(f"Failed to register LayoutLMv3 presets: {e}") diff --git a/keras_hub/src/models/layoutlmv3/layoutlmv3_backbone.py b/keras_hub/src/models/layoutlmv3/layoutlmv3_backbone.py new file mode 100644 index 0000000000..8e8aab4619 --- /dev/null +++ b/keras_hub/src/models/layoutlmv3/layoutlmv3_backbone.py @@ -0,0 +1,377 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.layoutlmv3.layoutlmv3_transformer import ( + LayoutLMv3TransformerLayer, +) + + +@keras_hub_export("keras_hub.models.LayoutLMv3Backbone") +class LayoutLMv3Backbone(Backbone): + """LayoutLMv3 backbone model for document understanding tasks. + + This class implements the LayoutLMv3 model architecture for joint text and + layout understanding in document AI tasks. It processes both text and image + inputs while maintaining spatial relationships in documents. + + The default constructor gives a fully customizable, randomly initialized + LayoutLMv3 model with any number of layers, heads, and embedding dimensions. + To load preset architectures and weights, use the `from_preset` constructor. + + Args: + vocabulary_size: int. The size of the token vocabulary. Defaults to + 30522. + hidden_dim: int. The size of the transformer hidden state at the end of + each transformer layer. Defaults to 768. + num_layers: int. The number of transformer layers. Defaults to 12. + num_heads: int. The number of attention heads for each transformer. + Defaults to 12. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. Defaults to + 3072. + dropout: float. Dropout probability for the transformer encoder. + Defaults to 0.1. + max_sequence_length: int. The maximum sequence length that this encoder + can consume. Defaults to 512. + type_vocab_size: int. The vocabulary size for token types. Defaults to + 2. + initializer_range: float. The standard deviation of the truncated_normal + initializer for initializing all weight matrices. Defaults to 0.02. + layer_norm_epsilon: float. The epsilon used by the layer normalization + layers. Defaults to 1e-12. + spatial_embedding_dim: int. The dimension of spatial position + embeddings for bounding box coordinates. Defaults to 64. + patch_size: int. The size of the patches for image processing. Defaults + to 16. + num_channels: int. The number of channels in the input images. Defaults + to 3. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. + + Examples: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + "bbox": np.ones(shape=(1, 12, 4), dtype="int32"), + } + + # Pretrained LayoutLMv3 encoder. + model = keras_hub.models.LayoutLMv3Backbone.from_preset( + "layoutlmv3_base", + ) + model(input_data) + + # Randomly initialized LayoutLMv3 encoder with custom config. + model = keras_hub.models.LayoutLMv3Backbone( + vocabulary_size=30522, + hidden_dim=768, + num_layers=12, + num_heads=12, + intermediate_dim=3072, + max_sequence_length=512, + spatial_embedding_dim=64, + ) + model(input_data) + ``` + + References: + - [LayoutLMv3 Paper](https://arxiv.org/abs/2204.08387) + - [LayoutLMv3 GitHub](https://github.com/microsoft/unilm/tree/master/layoutlmv3) + """ + + def __init__( + self, + vocabulary_size=30522, + hidden_dim=768, + num_layers=12, + num_heads=12, + intermediate_dim=3072, + dropout=0.1, + max_sequence_length=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_epsilon=1e-12, + spatial_embedding_dim=64, + patch_size=16, + num_channels=3, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="token_embedding", + ) + + self.position_embedding = keras.layers.Embedding( + input_dim=max_sequence_length, + output_dim=hidden_dim, + embeddings_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="position_embedding", + ) + + # Spatial position embeddings for bounding box coordinates + self.x_position_embedding = keras.layers.Embedding( + input_dim=1024, + output_dim=spatial_embedding_dim, + embeddings_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="x_position_embedding", + ) + + self.y_position_embedding = keras.layers.Embedding( + input_dim=1024, + output_dim=spatial_embedding_dim, + embeddings_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="y_position_embedding", + ) + + self.h_position_embedding = keras.layers.Embedding( + input_dim=1024, + output_dim=spatial_embedding_dim, + embeddings_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="h_position_embedding", + ) + + self.w_position_embedding = keras.layers.Embedding( + input_dim=1024, + output_dim=spatial_embedding_dim, + embeddings_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="w_position_embedding", + ) + + # Spatial projection layers + self.x_projection = keras.layers.Dense( + hidden_dim, + kernel_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="x_projection", + ) + + self.y_projection = keras.layers.Dense( + hidden_dim, + kernel_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="y_projection", + ) + + self.h_projection = keras.layers.Dense( + hidden_dim, + kernel_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="h_projection", + ) + + self.w_projection = keras.layers.Dense( + hidden_dim, + kernel_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="w_projection", + ) + + self.token_type_embedding = keras.layers.Embedding( + input_dim=type_vocab_size, + output_dim=hidden_dim, + embeddings_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="token_type_embedding", + ) + + self.embeddings_layer_norm = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="embeddings_layer_norm", + ) + + self.embeddings_dropout = keras.layers.Dropout( + dropout, + dtype=dtype, + name="embeddings_dropout", + ) + + # Transformer layers + self.transformer_layers = [] + for i in range(num_layers): + layer = LayoutLMv3TransformerLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + intermediate_dim=intermediate_dim, + dropout=dropout, + activation="gelu", + layer_norm_epsilon=layer_norm_epsilon, + kernel_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + + # Image processing layers + self.patch_embedding = keras.layers.Conv2D( + filters=hidden_dim, + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + padding="valid", + kernel_initializer=keras.initializers.TruncatedNormal( + stddev=initializer_range + ), + dtype=dtype, + name="patch_embedding", + ) + + self.patch_layer_norm = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="patch_layer_norm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + bbox_input = keras.Input( + shape=(None, 4), dtype="int32", name="bbox" + ) + + # Compute sequence length for position embeddings + seq_length = ops.shape(token_id_input)[1] + position_ids = ops.arange(seq_length, dtype="int32") + position_ids = ops.expand_dims(position_ids, axis=0) + position_ids = ops.broadcast_to( + position_ids, ops.shape(token_id_input) + ) + + # Token embeddings + token_embeddings = self.token_embedding(token_id_input) + + # Position embeddings + position_embeddings = self.position_embedding(position_ids) + + # Spatial embeddings + x_embeddings = self.x_position_embedding(bbox_input[..., 0]) + y_embeddings = self.y_position_embedding(bbox_input[..., 1]) + h_embeddings = self.h_position_embedding(bbox_input[..., 2]) + w_embeddings = self.w_position_embedding(bbox_input[..., 3]) + + # Project spatial embeddings + x_embeddings = self.x_projection(x_embeddings) + y_embeddings = self.y_projection(y_embeddings) + h_embeddings = self.h_projection(h_embeddings) + w_embeddings = self.w_projection(w_embeddings) + + # Token type embeddings (default to 0) + token_type_ids = ops.zeros_like(token_id_input) + token_type_embeddings = self.token_type_embedding(token_type_ids) + + # Combine all embeddings + embeddings = ( + token_embeddings + + position_embeddings + + x_embeddings + + y_embeddings + + h_embeddings + + w_embeddings + + token_type_embeddings + ) + + # Apply layer normalization and dropout + embeddings = self.embeddings_layer_norm(embeddings) + embeddings = self.embeddings_dropout(embeddings) + + # Apply transformer layers + hidden_states = embeddings + for transformer_layer in self.transformer_layers: + hidden_states = transformer_layer( + hidden_states, padding_mask=padding_mask_input + ) + + # Build the model + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + "bbox": bbox_input, + }, + outputs=hidden_states, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.intermediate_dim = intermediate_dim + self.dropout = dropout + self.max_sequence_length = max_sequence_length + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_epsilon = layer_norm_epsilon + self.spatial_embedding_dim = spatial_embedding_dim + self.patch_size = patch_size + self.num_channels = num_channels + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "hidden_dim": self.hidden_dim, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "type_vocab_size": self.type_vocab_size, + "initializer_range": self.initializer_range, + "layer_norm_epsilon": self.layer_norm_epsilon, + "spatial_embedding_dim": self.spatial_embedding_dim, + "patch_size": self.patch_size, + "num_channels": self.num_channels, + } + ) + return config + + @property + def token_embedding_matrix(self): + return self.token_embedding.embeddings diff --git a/keras_hub/src/models/layoutlmv3/layoutlmv3_backbone_test.py b/keras_hub/src/models/layoutlmv3/layoutlmv3_backbone_test.py new file mode 100644 index 0000000000..76b2eac159 --- /dev/null +++ b/keras_hub/src/models/layoutlmv3/layoutlmv3_backbone_test.py @@ -0,0 +1,180 @@ +import keras +import numpy as np + +from keras_hub.src.models.layoutlmv3.layoutlmv3_backbone import ( + LayoutLMv3Backbone, +) +from keras_hub.src.tests.test_case import TestCase + + +class LayoutLMv3BackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 1000, + "hidden_dim": 64, + "num_layers": 2, + "num_heads": 2, + "intermediate_dim": 128, + "max_sequence_length": 128, + "spatial_embedding_dim": 32, + } + self.input_data = { + "token_ids": keras.random.uniform( + shape=(2, 10), minval=0, maxval=1000, dtype="int32" + ), + "padding_mask": keras.ops.ones((2, 10), dtype="int32"), + "bbox": keras.random.uniform( + shape=(2, 10, 4), minval=0, maxval=1000, dtype="int32" + ), + } + + def test_backbone_basics(self): + model = LayoutLMv3Backbone(**self.init_kwargs) + self.assertEqual(model.vocabulary_size, 1000) + self.assertEqual(model.hidden_dim, 64) + self.assertEqual(model.num_layers, 2) + self.assertEqual(model.num_heads, 2) + self.assertEqual(model.intermediate_dim, 128) + self.assertEqual(model.max_sequence_length, 128) + self.assertEqual(model.spatial_embedding_dim, 32) + + def test_backbone_output_shape(self): + model = LayoutLMv3Backbone(**self.init_kwargs) + output = model(self.input_data) + # Output should be (batch_size, sequence_length, hidden_dim) + expected_shape = [2, 10, 64] + self.assertEqual(list(output.shape), expected_shape) + + def test_backbone_predict(self): + model = LayoutLMv3Backbone(**self.init_kwargs) + output = model.predict(self.input_data) + # Output should be (batch_size, sequence_length, hidden_dim) + expected_shape = [2, 10, 64] + self.assertEqual(list(output.shape), expected_shape) + + def test_saved_model(self): + model = LayoutLMv3Backbone(**self.init_kwargs) + model_output = model(self.input_data) + path = self.get_temp_dir() + model.save(path) + restored_model = keras.models.load_model(path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, LayoutLMv3Backbone) + + # Check that output matches. + restored_output = restored_model(self.input_data) + self.assertAllClose(model_output, restored_output) + + def test_get_config_and_from_config(self): + model = LayoutLMv3Backbone(**self.init_kwargs) + config = model.get_config() + restored_model = LayoutLMv3Backbone.from_config(config) + + # Check config was preserved + self.assertEqual(restored_model.vocabulary_size, 1000) + self.assertEqual(restored_model.hidden_dim, 64) + self.assertEqual(restored_model.num_layers, 2) + + def test_compute_output_shape(self): + model = LayoutLMv3Backbone(**self.init_kwargs) + batch_size = 3 + sequence_length = 5 + + input_shapes = { + "token_ids": (batch_size, sequence_length), + "padding_mask": (batch_size, sequence_length), + "bbox": (batch_size, sequence_length, 4), + } + + output_shape = model.compute_output_shape(input_shapes) + expected_shape = (batch_size, sequence_length, 64) + self.assertEqual(output_shape, expected_shape) + + def test_different_sequence_lengths(self): + model = LayoutLMv3Backbone(**self.init_kwargs) + + # Test with different sequence length + input_data = { + "token_ids": keras.random.uniform( + shape=(1, 5), minval=0, maxval=1000, dtype="int32" + ), + "padding_mask": keras.ops.ones((1, 5), dtype="int32"), + "bbox": keras.random.uniform( + shape=(1, 5, 4), minval=0, maxval=1000, dtype="int32" + ), + } + + output = model(input_data) + expected_shape = [1, 5, 64] + self.assertEqual(list(output.shape), expected_shape) + + def test_all_kwargs_in_config(self): + model = LayoutLMv3Backbone(**self.init_kwargs) + config = model.get_config() + + # Ensure all init arguments are in the config + for key, value in self.init_kwargs.items(): + self.assertEqual(config[key], value) + + def test_mixed_precision(self): + # Test with mixed precision + init_kwargs = {**self.init_kwargs, "dtype": "mixed_float16"} + model = LayoutLMv3Backbone(**init_kwargs) + output = model(self.input_data) + self.assertEqual(output.dtype, "float16") + + def test_token_embedding_matrix_property(self): + model = LayoutLMv3Backbone(**self.init_kwargs) + embeddings = model.token_embedding_matrix + expected_shape = [1000, 64] # vocabulary_size, hidden_dim + self.assertEqual(list(embeddings.shape), expected_shape) + + def test_spatial_embeddings_initialization(self): + model = LayoutLMv3Backbone(**self.init_kwargs) + + # Check that spatial embeddings have correct shapes + x_embeddings = model.x_position_embedding.embeddings + y_embeddings = model.y_position_embedding.embeddings + h_embeddings = model.h_position_embedding.embeddings + w_embeddings = model.w_position_embedding.embeddings + + expected_shape = [1024, 32] # max_bbox_value, spatial_embedding_dim + self.assertEqual(list(x_embeddings.shape), expected_shape) + self.assertEqual(list(y_embeddings.shape), expected_shape) + self.assertEqual(list(h_embeddings.shape), expected_shape) + self.assertEqual(list(w_embeddings.shape), expected_shape) + + def test_bbox_processing(self): + model = LayoutLMv3Backbone(**self.init_kwargs) + + # Test with bbox values at the boundary + bbox_data = keras.ops.array([[[0, 0, 100, 50], [100, 100, 200, 150]]], dtype="int32") + input_data = { + "token_ids": keras.ops.array([[1, 2]], dtype="int32"), + "padding_mask": keras.ops.ones((1, 2), dtype="int32"), + "bbox": bbox_data, + } + + output = model(input_data) + expected_shape = [1, 2, 64] + self.assertEqual(list(output.shape), expected_shape) + + def test_large_sequence_length(self): + # Test with sequence length at the maximum + model = LayoutLMv3Backbone(**self.init_kwargs) + + seq_len = 128 # max_sequence_length + input_data = { + "token_ids": keras.random.uniform( + shape=(1, seq_len), minval=0, maxval=1000, dtype="int32" + ), + "padding_mask": keras.ops.ones((1, seq_len), dtype="int32"), + "bbox": keras.random.uniform( + shape=(1, seq_len, 4), minval=0, maxval=1000, dtype="int32" + ), + } + + output = model(input_data) + expected_shape = [1, seq_len, 64] + self.assertEqual(list(output.shape), expected_shape) diff --git a/keras_hub/src/models/layoutlmv3/layoutlmv3_document_classifier_preprocessor.py b/keras_hub/src/models/layoutlmv3/layoutlmv3_document_classifier_preprocessor.py new file mode 100644 index 0000000000..7b7caec0d9 --- /dev/null +++ b/keras_hub/src/models/layoutlmv3/layoutlmv3_document_classifier_preprocessor.py @@ -0,0 +1,100 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.layoutlmv3.layoutlmv3_backbone import ( + LayoutLMv3Backbone, +) +from keras_hub.src.models.layoutlmv3.layoutlmv3_tokenizer import ( + LayoutLMv3Tokenizer, +) +from keras_hub.src.models.preprocessor import Preprocessor + + +@keras_hub_export("keras_hub.models.LayoutLMv3DocumentClassifierPreprocessor") +class LayoutLMv3DocumentClassifierPreprocessor(Preprocessor): + """LayoutLMv3 preprocessor for document classification tasks. + + This preprocessing layer is meant for use with + `keras_hub.models.LayoutLMv3Backbone`, and can be used to chain a + `keras_hub.models.LayoutLMv3Tokenizer` with the model preprocessing logic. + It can optionally be configured with a `sequence_length` which will pad or + truncate sequences to a fixed length. + + Arguments: + tokenizer: A `keras_hub.models.LayoutLMv3Tokenizer` instance. + sequence_length: int. If set, the output will be packed or padded to + exactly this sequence length. + + Call arguments: + x: A dictionary with "text" and optionally "bbox" keys. The "text" + should be a string or tensor of strings. The "bbox" should be a + list or tensor of bounding box coordinates with shape + `(..., num_words, 4)`. + y: Label data. Should always be `None` as the layer is unsupervised. + sample_weight: Label weights. Should always be `None` as the layer is + unsupervised. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = ( + keras_hub.models.LayoutLMv3DocumentClassifierPreprocessor.from_preset( + "layoutlmv3_base" + ) + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Tokenize with bounding boxes. + preprocessor({ + "text": "Hello world", + "bbox": [[0, 0, 100, 50], [100, 0, 200, 50]] + }) + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = ( + keras_hub.models.LayoutLMv3DocumentClassifierPreprocessor.from_preset( + "layoutlmv3_base" + ) + ) + + text_ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox."]) + text_ds = text_ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + backbone_cls = LayoutLMv3Backbone + tokenizer_cls = LayoutLMv3Tokenizer + + def call(self, x, y=None, sample_weight=None): + if isinstance(x, dict): + text = x["text"] + bbox = x.get("bbox", None) + else: + text = x + bbox = None + + token_output = self.tokenizer( + text, bbox=bbox, sequence_length=self.sequence_length + ) + + # The tokenizer already provides token_ids, padding_mask, and bbox + # Rename token_ids to match backbone expectations + output = { + "token_ids": token_output["token_ids"], + "padding_mask": token_output["padding_mask"], + "bbox": token_output["bbox"], + } + + return keras.utils.pack_x_y_sample_weight(output, y, sample_weight) + + def get_config(self): + config = super().get_config() + return config diff --git a/keras_hub/src/models/layoutlmv3/layoutlmv3_presets.py b/keras_hub/src/models/layoutlmv3/layoutlmv3_presets.py new file mode 100644 index 0000000000..506a1963d7 --- /dev/null +++ b/keras_hub/src/models/layoutlmv3/layoutlmv3_presets.py @@ -0,0 +1,28 @@ +"""LayoutLMv3 model preset configurations.""" + +backbone_presets = { + "layoutlmv3_base": { + "metadata": { + "description": ( + "12-layer LayoutLMv3 model with visual backbone. " + "Trained on IIT-CDIP dataset for document understanding." + ), + "params": 113000000, + "path": "layoutlmv3", + }, + "kaggle_handle": "kaggle://keras/layoutlmv3/keras/layoutlmv3_base/1", + }, + "layoutlmv3_large": { + "metadata": { + "description": ( + "24-layer LayoutLMv3 model with multimodal " + "(text + layout + image) understanding capabilities. " + "Trained on IIT-CDIP, RVL-CDIP, FUNSD, CORD, SROIE, " + "and DocVQA datasets." + ), + "params": 340787200, + "path": "layoutlmv3", + }, + "kaggle_handle": "kaggle://keras/layoutlmv3/keras/layoutlmv3_large/3", + }, +} diff --git a/keras_hub/src/models/layoutlmv3/layoutlmv3_tokenizer.py b/keras_hub/src/models/layoutlmv3/layoutlmv3_tokenizer.py new file mode 100644 index 0000000000..993084a72e --- /dev/null +++ b/keras_hub/src/models/layoutlmv3/layoutlmv3_tokenizer.py @@ -0,0 +1,214 @@ +""" +LayoutLMv3 tokenizer for document understanding tasks. + +References: +- [LayoutLMv3 Paper](https://arxiv.org/abs/2204.08387) +- [LayoutLMv3 GitHub](https://github.com/microsoft/unilm/tree/master/layoutlmv3) +""" + +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer + + +@keras_hub_export("keras_hub.models.LayoutLMv3Tokenizer") +class LayoutLMv3Tokenizer(WordPieceTokenizer): + """LayoutLMv3 tokenizer for document understanding tasks. + + This tokenizer is specifically designed for LayoutLMv3 models that process + both text and layout information. It tokenizes text and processes bounding + box coordinates for document understanding tasks. + + Args: + vocabulary: Optional list of strings containing the vocabulary. If None, + vocabulary will be loaded from preset. + lowercase: bool, defaults to True. Whether to lowercase the input text. + strip_accents: bool, defaults to True. Whether to strip accents from + the input text. + split: bool, defaults to True. Whether to split the input on whitespace. + split_on_cjk: bool, defaults to True. Whether to split CJK characters. + suffix_indicator: str, defaults to "##". The prefix to add to + continuation tokens. + oov_token: str, defaults to "[UNK]". The out-of-vocabulary token. + cls_token: str, defaults to "[CLS]". The classification token. + sep_token: str, defaults to "[SEP]". The separator token. + pad_token: str, defaults to "[PAD]". The padding token. + mask_token: str, defaults to "[MASK]". The mask token. + unk_token: str, defaults to "[UNK]". The unknown token. + **kwargs: Additional keyword arguments passed to the parent class. + + Examples: + ```python + # Initialize tokenizer from preset + tokenizer = LayoutLMv3Tokenizer.from_preset("layoutlmv3_base") + + # Tokenize text and bounding boxes + inputs = tokenizer( + text=["Hello world", "How are you"], + bbox=[[[0, 0, 100, 100], [100, 0, 200, 100]], + [[0, 0, 100, 100], [100, 0, 200, 100]]] + ) + ``` + """ + + def __init__( + self, + vocabulary=None, + lowercase=True, + strip_accents=True, + split=True, + split_on_cjk=True, + suffix_indicator="##", + oov_token="[UNK]", + cls_token="[CLS]", + sep_token="[SEP]", + pad_token="[PAD]", + mask_token="[MASK]", + unk_token="[UNK]", + **kwargs, + ): + super().__init__( + vocabulary=vocabulary, + lowercase=lowercase, + strip_accents=strip_accents, + split=split, + split_on_cjk=split_on_cjk, + suffix_indicator=suffix_indicator, + oov_token=oov_token, + **kwargs, + ) + self.cls_token = cls_token + self.sep_token = sep_token + self.pad_token = pad_token + self.mask_token = mask_token + self.unk_token = unk_token + + def _process_bbox_for_tokens(self, text_list, bbox_list): + """This method expands bounding boxes for subword tokens and adds + dummy boxes for special tokens. + + Args: + text_list: List of text strings. + bbox_list: List of bounding box lists corresponding to words. + + Returns: + List of bounding box lists aligned with tokens, or None if bbox_list is None. + """ + if bbox_list is None: + return None + + processed_bbox = [] + + try: + for text, bbox in zip(text_list, bbox_list): + # Handle empty or None inputs defensively + if not text or not bbox: + words = [] + word_bbox = [] + else: + words = text.split() + # Ensure bbox has correct length or use dummy boxes + if len(bbox) != len(words): + word_bbox = [[0, 0, 0, 0] for _ in words] + else: + word_bbox = bbox + + token_bbox = [] + # Add dummy box for [CLS] token + token_bbox.append([0, 0, 0, 0]) + + # Process each word and its corresponding box + for word, word_box in zip(words, word_bbox): + # Tokenize the word to handle subwords + try: + word_tokens = self.tokenize(word) + # Expand the bounding box for all subword tokens + for _ in word_tokens: + token_bbox.append(word_box) + except Exception: + # Fallback: just add one token with the box + token_bbox.append(word_box) + + # Add dummy box for [SEP] token + token_bbox.append([0, 0, 0, 0]) + processed_bbox.append(token_bbox) + + except Exception: + # Fallback: return None to use dummy boxes + return None + + return processed_bbox + + def call(self, inputs, bbox=None, sequence_length=None): + """Tokenize input text and process bounding boxes. + + Args: + inputs: A string, list of strings, or tensor of strings to tokenize. + bbox: Optional bounding box coordinates corresponding to the words + in the input text. Should be a list of lists of [x0, y0, x1, y1] + coordinates for each word. + sequence_length: int. If set, the output will be packed or padded to + exactly this sequence length. + + Returns: + A dictionary with the tokenized inputs and optionally bounding boxes. + If input is a string or list of strings, the dictionary will contain: + - "token_ids": Tokenized representation of the inputs. + - "padding_mask": A mask indicating which tokens are real vs padding. + - "bbox": Bounding box coordinates aligned with tokens (if provided). + """ + # Handle string inputs by converting to list + if isinstance(inputs, str): + inputs = [inputs] + if bbox is not None: + bbox = [bbox] + + # Process bounding boxes before tokenization + processed_bbox = self._process_bbox_for_tokens(inputs, bbox) + + # Tokenize the text + token_output = super().call(inputs, sequence_length=sequence_length) + + # Process bbox if provided + if processed_bbox is not None: + # Convert to tensors and pad to match token sequence length + batch_size = ops.shape(token_output["token_ids"])[0] + seq_len = ops.shape(token_output["token_ids"])[1] + + # Create bbox tensor + bbox_tensor = [] + for i, bbox_seq in enumerate(processed_bbox): + # Pad or truncate bbox sequence to match token sequence + if len(bbox_seq) > seq_len: + bbox_seq = bbox_seq[:seq_len] + else: + # Pad with dummy boxes + bbox_seq = bbox_seq + [[0, 0, 0, 0]] * (seq_len - len(bbox_seq)) + bbox_tensor.append(bbox_seq) + + # Convert to tensor + bbox_tensor = ops.convert_to_tensor(bbox_tensor, dtype="int32") + token_output["bbox"] = bbox_tensor + else: + # Create dummy bbox tensor if no bbox provided + batch_size = ops.shape(token_output["token_ids"])[0] + seq_len = ops.shape(token_output["token_ids"])[1] + dummy_bbox = ops.zeros((batch_size, seq_len, 4), dtype="int32") + token_output["bbox"] = dummy_bbox + + return token_output + + def get_config(self): + config = super().get_config() + config.update( + { + "cls_token": self.cls_token, + "sep_token": self.sep_token, + "pad_token": self.pad_token, + "mask_token": self.mask_token, + "unk_token": self.unk_token, + } + ) + return config diff --git a/keras_hub/src/models/layoutlmv3/layoutlmv3_tokenizer_test.py b/keras_hub/src/models/layoutlmv3/layoutlmv3_tokenizer_test.py new file mode 100644 index 0000000000..578c3c6f70 --- /dev/null +++ b/keras_hub/src/models/layoutlmv3/layoutlmv3_tokenizer_test.py @@ -0,0 +1,252 @@ +import numpy as np + +from keras_hub.src.models.layoutlmv3.layoutlmv3_tokenizer import ( + LayoutLMv3Tokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class LayoutLMv3TokenizerTest(TestCase): + def setUp(self): + # Create a simple vocabulary for testing + self.vocabulary = { + "[PAD]": 0, + "[UNK]": 1, + "[CLS]": 2, + "[SEP]": 3, + "[MASK]": 4, + "hello": 5, + "world": 6, + "how": 7, + "are": 8, + "you": 9, + "good": 10, + "morning": 11, + } + + self.tokenizer = LayoutLMv3Tokenizer( + vocabulary=self.vocabulary, + sequence_length=16, + ) + + def test_tokenizer_basics(self): + # Test basic properties + self.assertEqual(self.tokenizer.cls_token, "[CLS]") + self.assertEqual(self.tokenizer.sep_token, "[SEP]") + self.assertEqual(self.tokenizer.pad_token, "[PAD]") + self.assertEqual(self.tokenizer.mask_token, "[MASK]") + self.assertEqual(self.tokenizer.unk_token, "[UNK]") + + def test_simple_tokenization(self): + # Test simple string tokenization + output = self.tokenizer("hello world") + + # Check that output contains the expected keys + self.assertIn("token_ids", output) + self.assertIn("padding_mask", output) + self.assertIn("bbox", output) + + # Check shapes + self.assertEqual(output["token_ids"].shape, (1, 16)) + self.assertEqual(output["padding_mask"].shape, (1, 16)) + self.assertEqual(output["bbox"].shape, (1, 16, 4)) + + def test_list_tokenization(self): + # Test list of strings tokenization + texts = ["hello world", "how are you"] + output = self.tokenizer(texts) + + # Check shapes for batch processing + self.assertEqual(output["token_ids"].shape, (2, 16)) + self.assertEqual(output["padding_mask"].shape, (2, 16)) + self.assertEqual(output["bbox"].shape, (2, 16, 4)) + + def test_bbox_processing(self): + # Test with bounding boxes provided + texts = ["hello world"] + bbox = [[[0, 0, 100, 50], [100, 0, 200, 50]]] + + output = self.tokenizer(texts, bbox=bbox) + + # Check that bbox was processed correctly + self.assertEqual(output["bbox"].shape, (1, 16, 4)) + + # Check that dummy bbox was added for special tokens + bbox_values = output["bbox"][0] + # First position should be dummy for [CLS] + self.assertTrue(np.array_equal(bbox_values[0], [0, 0, 0, 0])) + + def test_bbox_expansion_for_subwords(self): + # Test that bounding boxes are properly expanded for subword tokens + texts = ["hello"] + bbox = [[[0, 0, 100, 50]]] # One bbox for one word + + output = self.tokenizer(texts, bbox=bbox) + + # The bbox should be expanded to cover all tokens including specials + self.assertEqual(output["bbox"].shape, (1, 16, 4)) + + def test_mismatched_bbox_count(self): + # Test handling when bbox count doesn't match word count + texts = ["hello world how"] # 3 words + bbox = [[[0, 0, 100, 50], [100, 0, 200, 50]]] # 2 bboxes + + # Should handle gracefully by using dummy boxes + output = self.tokenizer(texts, bbox=bbox) + + self.assertEqual(output["bbox"].shape, (1, 16, 4)) + + def test_no_bbox_provided(self): + # Test tokenization without bounding boxes + texts = ["hello world"] + output = self.tokenizer(texts) + + # Should create dummy bbox tensor + self.assertEqual(output["bbox"].shape, (1, 16, 4)) + + # All bbox values should be zeros (dummy) + bbox_values = output["bbox"][0] + for i in range(bbox_values.shape[0]): + self.assertTrue(np.array_equal(bbox_values[i], [0, 0, 0, 0])) + + def test_get_config(self): + config = self.tokenizer.get_config() + + # Check that all expected keys are in config + expected_keys = [ + "vocabulary", + "lowercase", + "strip_accents", + "split", + "split_on_cjk", + "suffix_indicator", + "oov_token", + "cls_token", + "sep_token", + "pad_token", + "mask_token", + "unk_token", + ] + + for key in expected_keys: + self.assertIn(key, config) + + def test_from_config(self): + config = self.tokenizer.get_config() + restored_tokenizer = LayoutLMv3Tokenizer.from_config(config) + + # Test that restored tokenizer works the same + output1 = self.tokenizer("hello world") + output2 = restored_tokenizer("hello world") + + self.assertAllClose(output1["token_ids"], output2["token_ids"]) + self.assertAllClose(output1["padding_mask"], output2["padding_mask"]) + + def test_special_token_handling(self): + # Test that special tokens are handled correctly + texts = ["hello"] + output = self.tokenizer(texts) + + token_ids = output["token_ids"][0] + + # Should start with [CLS] and end with [SEP] + self.assertEqual(token_ids[0], self.vocabulary["[CLS]"]) + + # Find the last non-padding token - should be [SEP] + padding_mask = output["padding_mask"][0] + last_token_idx = np.sum(padding_mask) - 1 + self.assertEqual(token_ids[last_token_idx], self.vocabulary["[SEP]"]) + + def test_sequence_length_parameter(self): + # Test with custom sequence length + custom_tokenizer = LayoutLMv3Tokenizer( + vocabulary=self.vocabulary, + sequence_length=8, + ) + + output = custom_tokenizer("hello world") + + # Check that output respects custom sequence length + self.assertEqual(output["token_ids"].shape, (1, 8)) + self.assertEqual(output["padding_mask"].shape, (1, 8)) + self.assertEqual(output["bbox"].shape, (1, 8, 4)) + + def test_padding_and_truncation(self): + # Test with a very long input + long_text = " ".join(["hello"] * 20) + output = self.tokenizer(long_text) + + # Should be truncated to sequence_length + self.assertEqual(output["token_ids"].shape, (1, 16)) + + # Test with short input + short_text = "hello" + output = self.tokenizer(short_text) + + # Should be padded to sequence_length + self.assertEqual(output["token_ids"].shape, (1, 16)) + + # Check that padding tokens are used + token_ids = output["token_ids"][0] + padding_mask = output["padding_mask"][0] + + # Find first padding position + padding_positions = np.where(padding_mask == 0)[0] + if len(padding_positions) > 0: + first_pad_pos = padding_positions[0] + self.assertEqual(token_ids[first_pad_pos], self.vocabulary["[PAD]"]) + + def test_batch_processing_consistency(self): + # Test that batch processing gives same results as individual processing + texts = ["hello world", "how are you"] + + # Process as batch + batch_output = self.tokenizer(texts) + + # Process individually + individual_outputs = [] + for text in texts: + individual_outputs.append(self.tokenizer(text)) + + # Compare results + for i in range(len(texts)): + self.assertAllClose( + batch_output["token_ids"][i : i + 1], + individual_outputs[i]["token_ids"], + ) + self.assertAllClose( + batch_output["padding_mask"][i : i + 1], + individual_outputs[i]["padding_mask"], + ) + + def test_empty_input(self): + # Test handling of empty input + output = self.tokenizer("") + + # Should still produce valid output with special tokens + self.assertEqual(output["token_ids"].shape, (1, 16)) + self.assertEqual(output["padding_mask"].shape, (1, 16)) + self.assertEqual(output["bbox"].shape, (1, 16, 4)) + + # Should contain [CLS] and [SEP] tokens + token_ids = output["token_ids"][0] + self.assertEqual(token_ids[0], self.vocabulary["[CLS]"]) + self.assertEqual(token_ids[1], self.vocabulary["[SEP]"]) + + def test_oov_token_handling(self): + # Test handling of out-of-vocabulary tokens + output = self.tokenizer("unknown_token") + + # Should use [UNK] token for unknown words + token_ids = output["token_ids"][0] + + # Check that [UNK] token appears (excluding [CLS] and [SEP]) + self.assertIn(self.vocabulary["[UNK]"], token_ids[1:-1]) + + def test_case_sensitivity(self): + # Test case handling based on lowercase parameter + output1 = self.tokenizer("Hello") + output2 = self.tokenizer("hello") + + # Should be the same if lowercase=True (default) + self.assertAllClose(output1["token_ids"], output2["token_ids"]) diff --git a/keras_hub/src/models/layoutlmv3/layoutlmv3_transformer.py b/keras_hub/src/models/layoutlmv3/layoutlmv3_transformer.py new file mode 100644 index 0000000000..00a81e5de1 --- /dev/null +++ b/keras_hub/src/models/layoutlmv3/layoutlmv3_transformer.py @@ -0,0 +1,84 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.transformer_encoder import ( + TransformerEncoder, +) + + +@keras_hub_export("keras_hub.models.LayoutLMv3TransformerLayer") +class LayoutLMv3TransformerLayer(TransformerEncoder): + """LayoutLMv3 transformer encoder layer. + + This layer implements a transformer encoder block for LayoutLMv3, which + includes multi-head self-attention and a feed-forward network. + + Args: + hidden_dim: int. The size of the transformer hidden state. + num_heads: int. The number of attention heads. + intermediate_dim: int. The output dimension of the first Dense layer + in the feedforward network. + dropout: float. Dropout probability. + activation: string or callable. The activation function to use. + layer_norm_epsilon: float. The epsilon value in layer normalization + components. + kernel_initializer: string or `keras.initializers` initializer. + The kernel initializer for the dense and multiheaded attention + layers. + bias_initializer: string or `keras.initializers` initializer. + The bias initializer for the dense and multiheaded attention + layers. + **kwargs: additional keyword arguments to pass to TransformerEncoder. + """ + + def __init__( + self, + hidden_dim, + num_heads, + intermediate_dim, + dropout=0.1, + activation="gelu", + layer_norm_epsilon=1e-12, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + **kwargs, + ): + super().__init__( + intermediate_dim=intermediate_dim, + num_heads=num_heads, + dropout=dropout, + activation=activation, + layer_norm_epsilon=layer_norm_epsilon, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + **kwargs, + ) + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.intermediate_dim = intermediate_dim + self.dropout_rate = dropout + self.activation = activation + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout_rate, + "activation": self.activation, + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + keras.initializers.get(self.kernel_initializer) + ), + "bias_initializer": keras.initializers.serialize( + keras.initializers.get(self.bias_initializer) + ), + } + ) + return config \ No newline at end of file diff --git a/shell/api_gen.sh b/shell/api_gen.sh index 253e8fd394..1f5feabdcd 100755 --- a/shell/api_gen.sh +++ b/shell/api_gen.sh @@ -4,8 +4,15 @@ set -Eeuo pipefail base_dir=$(dirname $(dirname $0)) echo "Generating api directory with public APIs..." -# Generate API Files -python3 "${base_dir}"/api_gen.py +# Generate API Files - try python3 first, fall back to python +if command -v python3 > /dev/null 2>&1; then + python3 "${base_dir}"/api_gen.py +elif command -v python > /dev/null 2>&1; then + python "${base_dir}"/api_gen.py +else + echo "Error: Neither python3 nor python found" + exit 1 +fi # Format code because `api_gen.py` might order # imports differently. diff --git a/tools/checkpoint_conversion/convert_layoutlmv3_checkpoints.py b/tools/checkpoint_conversion/convert_layoutlmv3_checkpoints.py new file mode 100644 index 0000000000..e3a5b82433 --- /dev/null +++ b/tools/checkpoint_conversion/convert_layoutlmv3_checkpoints.py @@ -0,0 +1,199 @@ +""" +Script to convert LayoutLMv3 checkpoints from Hugging Face to Keras format. +""" + +import argparse +import os + +import keras +from transformers import LayoutLMv3Config +from transformers import LayoutLMv3Model + +from keras_hub.src.models.layoutlmv3.layoutlmv3_backbone import ( + LayoutLMv3Backbone, +) +from keras_hub.src.models.layoutlmv3.layoutlmv3_tokenizer import ( + LayoutLMv3Tokenizer, +) + + +def convert_checkpoint(model_name): + print(f"✨ Converting {model_name}...") + + # Load HuggingFace model and config + hf_model = LayoutLMv3Model.from_pretrained(model_name) + hf_config = LayoutLMv3Config.from_pretrained(model_name) + hf_weights = hf_model.state_dict() + + # Create KerasHub model + keras_model = LayoutLMv3Backbone( + vocabulary_size=hf_config.vocab_size, + hidden_dim=hf_config.hidden_size, + num_layers=hf_config.num_hidden_layers, + num_heads=hf_config.num_attention_heads, + intermediate_dim=hf_config.intermediate_size, + max_sequence_length=hf_config.max_position_embeddings, + dtype="float32", + ) + + # Build model with dummy inputs + dummy_inputs = { + "token_ids": keras.ops.ones((1, 8), dtype="int32"), + "padding_mask": keras.ops.ones((1, 8), dtype="int32"), + "bbox": keras.ops.ones((1, 8, 4), dtype="int32"), + } + keras_model(dummy_inputs) + + # Token embeddings + token_embedding_weight = hf_weights[ + "embeddings.word_embeddings.weight" + ].numpy() + keras_model.token_embedding.embeddings.assign(token_embedding_weight) + print(f"✅ Token embedding: {token_embedding_weight.shape}") + + # Position embeddings + position_weight = hf_weights[ + "embeddings.position_embeddings.weight" + ].numpy() + keras_model.position_embedding.position_embeddings.assign(position_weight) + print(f"✅ Position embedding: {position_weight.shape}") + + # Token type embeddings + token_type_weight = hf_weights[ + "embeddings.token_type_embeddings.weight" + ].numpy() + keras_model.token_type_embedding.embeddings.assign(token_type_weight) + print(f"✅ Token type embedding: {token_type_weight.shape}") + + # Spatial embeddings and projections + spatial_coords = ["x", "y", "h", "w"] + + for coord in spatial_coords: + # Spatial embedding + spatial_key = f"embeddings.{coord}_position_embeddings.weight" + if spatial_key in hf_weights: + spatial_weight = hf_weights[spatial_key].numpy() + spatial_emb = getattr(keras_model, f"{coord}_position_embedding") + spatial_emb.embeddings.assign(spatial_weight) + print(f"✅ {coord} spatial embedding: {spatial_weight.shape}") + + # Spatial projection + proj_key = f"embeddings.{coord}_position_projection" + if f"{proj_key}.weight" in hf_weights: + proj_weight = hf_weights[f"{proj_key}.weight"].numpy().T + proj_bias = hf_weights[f"{proj_key}.bias"].numpy() + projection_layer = getattr(keras_model, f"{coord}_projection") + projection_layer.kernel.assign(proj_weight) + projection_layer.bias.assign(proj_bias) + print(f"✅ {coord} projection: {proj_weight.shape}") + + # Layer norm and dropout + ln_weight = hf_weights["embeddings.LayerNorm.weight"].numpy() + ln_bias = hf_weights["embeddings.LayerNorm.bias"].numpy() + keras_model.embeddings_layer_norm.gamma.assign(ln_weight) + keras_model.embeddings_layer_norm.beta.assign(ln_bias) + print(f"✅ Embeddings LayerNorm: {ln_weight.shape}") + + # Transformer layers + for i in range(hf_config.num_hidden_layers): + hf_prefix = f"encoder.layer.{i}" + keras_layer = keras_model.transformer_layers[i] + + # Self attention + q_weight = ( + hf_weights[f"{hf_prefix}.attention.self.query.weight"].numpy().T + ) + k_weight = ( + hf_weights[f"{hf_prefix}.attention.self.key.weight"].numpy().T + ) + v_weight = ( + hf_weights[f"{hf_prefix}.attention.self.value.weight"].numpy().T + ) + q_bias = hf_weights[f"{hf_prefix}.attention.self.query.bias"].numpy() + k_bias = hf_weights[f"{hf_prefix}.attention.self.key.bias"].numpy() + v_bias = hf_weights[f"{hf_prefix}.attention.self.value.bias"].numpy() + + keras_layer._self_attention_layer._query_dense.kernel.assign(q_weight) + keras_layer._self_attention_layer._key_dense.kernel.assign(k_weight) + keras_layer._self_attention_layer._value_dense.kernel.assign(v_weight) + keras_layer._self_attention_layer._query_dense.bias.assign(q_bias) + keras_layer._self_attention_layer._key_dense.bias.assign(k_bias) + keras_layer._self_attention_layer._value_dense.bias.assign(v_bias) + + # Attention output + attn_out_weight = ( + hf_weights[f"{hf_prefix}.attention.output.dense.weight"].numpy().T + ) + attn_out_bias = hf_weights[ + f"{hf_prefix}.attention.output.dense.bias" + ].numpy() + keras_layer._self_attention_layer._output_dense.kernel.assign( + attn_out_weight + ) + keras_layer._self_attention_layer._output_dense.bias.assign( + attn_out_bias + ) + + # Attention layer norm + attn_ln_weight = hf_weights[ + f"{hf_prefix}.attention.output.LayerNorm.weight" + ].numpy() + attn_ln_bias = hf_weights[ + f"{hf_prefix}.attention.output.LayerNorm.bias" + ].numpy() + keras_layer._self_attention_layernorm.gamma.assign(attn_ln_weight) + keras_layer._self_attention_layernorm.beta.assign(attn_ln_bias) + + # Feed forward + ff1_weight = ( + hf_weights[f"{hf_prefix}.intermediate.dense.weight"].numpy().T + ) + ff1_bias = hf_weights[f"{hf_prefix}.intermediate.dense.bias"].numpy() + keras_layer._feedforward_intermediate_dense.kernel.assign(ff1_weight) + keras_layer._feedforward_intermediate_dense.bias.assign(ff1_bias) + + ff2_weight = hf_weights[f"{hf_prefix}.output.dense.weight"].numpy().T + ff2_bias = hf_weights[f"{hf_prefix}.output.dense.bias"].numpy() + keras_layer._feedforward_output_dense.kernel.assign(ff2_weight) + keras_layer._feedforward_output_dense.bias.assign(ff2_bias) + + # Output layer norm + out_ln_weight = hf_weights[ + f"{hf_prefix}.output.LayerNorm.weight" + ].numpy() + out_ln_bias = hf_weights[f"{hf_prefix}.output.LayerNorm.bias"].numpy() + keras_layer._feedforward_layernorm.gamma.assign(out_ln_weight) + keras_layer._feedforward_layernorm.beta.assign(out_ln_bias) + + print(f"✅ Transformer layer {i}") + + # Save the model + preset_dir = f"layoutlmv3_{model_name.split('/')[-1]}_keras" + os.makedirs(preset_dir, exist_ok=True) + + keras_model.save_preset(preset_dir) + + # Create tokenizer and save + tokenizer = LayoutLMv3Tokenizer( + vocabulary=os.path.join(preset_dir, "vocabulary.json"), + merges=os.path.join(preset_dir, "merges.txt"), + ) + tokenizer.save_preset(preset_dir) + + print(f"✅ Saved preset to {preset_dir}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + default="microsoft/layoutlmv3-base", + help="HuggingFace model name", + ) + + args = parser.parse_args() + convert_checkpoint(args.model_name) + + +if __name__ == "__main__": + main()