-
Notifications
You must be signed in to change notification settings - Fork 309
Added LayoutLMv3 #2178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Added LayoutLMv3 #2178
Changes from 4 commits
ae79d15
737f03a
455a140
d92c8c4
0948f95
3c02f78
4a79d9b
c2fed4c
e828047
063054d
476c0fd
4439fad
ad3c758
885f2fe
5019abb
e1fc266
a32555c
a885afa
ad004f7
6fb0fdc
5aaadab
8c7e989
5a371a5
bcad8d7
ca96183
9c90753
cf4b20b
193496a
4d8604e
e07224c
6187459
00fc976
0d3099d
82b9b93
e40a6a0
7796cbf
ae239c7
6671da2
f1ac61a
c83c124
2ff3157
87359e5
e610073
47167e8
704dad2
4856b47
4159cd6
76fdc13
2e506eb
8ff847e
d5e28f1
8378631
e424b07
e9cbdfb
eae3e40
3d7d4c1
a5be852
ddf2618
08c7090
7e10fab
b3280f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| """LayoutLMv3 document classifier.""" | ||
|
||
|
|
||
| from keras_hub.src.models.layoutlmv3.document_classifier.layoutlmv3_document_classifier import LayoutLMv3DocumentClassifier | ||
| from keras_hub.src.models.layoutlmv3.document_classifier.layoutlmv3_document_classifier_preprocessor import LayoutLMv3DocumentClassifierPreprocessor | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from keras_hub.src.models.layoutlmv3.layoutlmv3_backbone import LayoutLMv3Backbone | ||
|
||
| from keras_hub.src.models.layoutlmv3.layoutlmv3_document_classifier import LayoutLMv3DocumentClassifier | ||
| from keras_hub.src.models.layoutlmv3.layoutlmv3_document_classifier_preprocessor import LayoutLMv3DocumentClassifierPreprocessor | ||
| from keras_hub.src.models.layoutlmv3.layoutlmv3_tokenizer import LayoutLMv3Tokenizer | ||
| from keras_hub.src.models.layoutlmv3.layoutlmv3_transformer import LayoutLMv3Transformer | ||
| from keras_hub.src.models.layoutlmv3.layoutlmv3_presets import layoutlmv3_presets | ||
|
|
||
| __all__ = [ | ||
| "LayoutLMv3Backbone", | ||
| "LayoutLMv3DocumentClassifier", | ||
| "LayoutLMv3DocumentClassifierPreprocessor", | ||
| "LayoutLMv3Tokenizer", | ||
| "LayoutLMv3Transformer", | ||
| "layoutlmv3_presets", | ||
| ] | ||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| # Copyright 2024 The Keras Hub Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
|
||
| import os | ||
| import numpy as np | ||
| from keras import testing_utils | ||
| from keras import ops | ||
| from keras import backend | ||
| from keras.testing import test_case | ||
| from ..layoutlmv3.layoutlmv3_backbone import LayoutLMv3Backbone | ||
|
||
|
|
||
| class LayoutLMv3BackboneTest(test_case.TestCase): | ||
| def setUp(self): | ||
| super().setUp() | ||
| self.backbone = LayoutLMv3Backbone( | ||
| vocab_size=100, | ||
| hidden_size=64, | ||
| num_hidden_layers=2, | ||
| num_attention_heads=2, | ||
| intermediate_size=128, | ||
| image_size=(112, 112), | ||
| patch_size=16, | ||
| ) | ||
|
|
||
| # Create dummy inputs | ||
| self.batch_size = 2 | ||
| self.seq_length = 16 | ||
| self.input_ids = ops.random.uniform( | ||
| (self.batch_size, self.seq_length), minval=0, maxval=100, dtype="int32" | ||
| ) | ||
| self.bbox = ops.random.uniform( | ||
| (self.batch_size, self.seq_length, 4), minval=0, maxval=100, dtype="int32" | ||
| ) | ||
| self.attention_mask = ops.ones((self.batch_size, self.seq_length), dtype="int32") | ||
| self.image = ops.random.uniform( | ||
| (self.batch_size, 112, 112, 3), minval=0, maxval=1, dtype="float32" | ||
| ) | ||
|
|
||
| self.inputs = { | ||
| "input_ids": self.input_ids, | ||
| "bbox": self.bbox, | ||
| "attention_mask": self.attention_mask, | ||
| "image": self.image, | ||
| } | ||
|
|
||
| def test_valid_call(self): | ||
| """Test the backbone with valid inputs.""" | ||
| outputs = self.backbone(self.inputs) | ||
| self.assertIn("sequence_output", outputs) | ||
| self.assertIn("pooled_output", outputs) | ||
| self.assertEqual(outputs["sequence_output"].shape, (self.batch_size, self.seq_length + 49 + 1, 64)) # text + image patches + cls | ||
| self.assertEqual(outputs["pooled_output"].shape, (self.batch_size, 64)) | ||
|
|
||
| def test_save_and_load(self): | ||
| """Test saving and loading the backbone.""" | ||
| outputs = self.backbone(self.inputs) | ||
| path = self.get_temp_dir() | ||
| self.backbone.save(path) | ||
| restored_backbone = backend.saving.load_model(path) | ||
| restored_outputs = restored_backbone(self.inputs) | ||
| self.assertAllClose(outputs["sequence_output"], restored_outputs["sequence_output"]) | ||
| self.assertAllClose(outputs["pooled_output"], restored_outputs["pooled_output"]) | ||
|
|
||
| def test_from_preset(self): | ||
| """Test creating a backbone from a preset.""" | ||
| backbone = LayoutLMv3Backbone.from_preset("layoutlmv3_base") | ||
| inputs = { | ||
| "input_ids": ops.random.uniform((2, 16), 0, 100, dtype="int32"), | ||
| "bbox": ops.random.uniform((2, 16, 4), 0, 100, dtype="int32"), | ||
| "attention_mask": ops.ones((2, 16), dtype="int32"), | ||
| "image": ops.random.uniform((2, 112, 112, 3), dtype="float32"), | ||
| } | ||
| outputs = backbone(inputs) | ||
| self.assertIn("sequence_output", outputs) | ||
| self.assertIn("pooled_output", outputs) | ||
|
|
||
| def test_backbone_with_different_input_shapes(self): | ||
| """Test the backbone with different input shapes.""" | ||
| # Test with different sequence lengths | ||
| seq_lengths = [32, 128] | ||
| for seq_len in seq_lengths: | ||
| inputs = { | ||
| "input_ids": ops.random.uniform( | ||
| (self.batch_size, seq_len), minval=0, maxval=100, dtype="int32" | ||
| ), | ||
| "bbox": ops.random.uniform( | ||
| (self.batch_size, seq_len, 4), minval=0, maxval=100, dtype="int32" | ||
| ), | ||
| "attention_mask": ops.ones((self.batch_size, seq_len), dtype="int32"), | ||
| "image": self.image, | ||
| } | ||
| outputs = self.backbone(inputs) | ||
| expected_seq_length = seq_len + 49 + 1 | ||
| self.assertEqual(outputs["sequence_output"].shape, (self.batch_size, expected_seq_length, 64)) | ||
|
|
||
| # Test with different batch sizes | ||
| batch_sizes = [1, 4] | ||
| for batch_size in batch_sizes: | ||
| inputs = { | ||
| "input_ids": ops.random.uniform( | ||
| (batch_size, self.seq_length), minval=0, maxval=100, dtype="int32" | ||
| ), | ||
| "bbox": ops.random.uniform( | ||
| (batch_size, self.seq_length, 4), minval=0, maxval=100, dtype="int32" | ||
| ), | ||
| "attention_mask": ops.ones((batch_size, self.seq_length), dtype="int32"), | ||
| "image": ops.random.uniform( | ||
| (batch_size, 112, 112, 3), minval=0, maxval=1, dtype="float32" | ||
| ), | ||
| } | ||
| outputs = self.backbone(inputs) | ||
| expected_seq_length = self.seq_length + 49 + 1 | ||
| self.assertEqual(outputs["sequence_output"].shape, (batch_size, expected_seq_length, 64)) | ||
|
|
||
| def test_backbone_with_attention_mask(self): | ||
| """Test the backbone with different attention masks.""" | ||
| # Create a mask with some padding | ||
| attention_mask = ops.ones((self.batch_size, self.seq_length), dtype="int32") | ||
| indices = ops.array([[0, 32], [1, 48]], dtype="int32") | ||
| updates = ops.array([0, 0], dtype="int32") | ||
| attention_mask = ops.scatter_nd(indices, updates, attention_mask.shape) | ||
|
|
||
| inputs = { | ||
| "input_ids": self.input_ids, | ||
| "bbox": self.bbox, | ||
| "attention_mask": attention_mask, | ||
| "image": self.image, | ||
| } | ||
|
|
||
| outputs = self.backbone(inputs) | ||
| self.assertIsInstance(outputs, dict) | ||
| self.assertIn("sequence_output", outputs) | ||
| self.assertIn("pooled_output", outputs) | ||
|
|
||
| def test_backbone_gradient(self): | ||
| """Test that the backbone produces gradients.""" | ||
| with backend.GradientTape() as tape: | ||
| outputs = self.backbone(self.inputs) | ||
| loss = ops.mean(outputs["pooled_output"]) | ||
|
|
||
| # Check if gradients exist for all trainable variables | ||
| gradients = tape.gradient(loss, self.backbone.trainable_variables) | ||
| for grad in gradients: | ||
| self.assertIsNotNone(grad) | ||
| self.assertFalse(ops.all(ops.isnan(grad))) | ||
| self.assertFalse(ops.all(ops.isinf(grad))) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| """LayoutLMv3 document classifier implementation. | ||
| This module implements a document classification model using the LayoutLMv3 backbone. | ||
| """ | ||
|
|
||
| from typing import Dict, List, Optional, Union | ||
|
||
|
|
||
| from keras import backend, layers, ops | ||
| from keras.saving import register_keras_serializable | ||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.models.backbone import Backbone | ||
|
|
||
| from .layoutlmv3_backbone import LayoutLMv3Backbone | ||
| from .layoutlmv3_document_classifier_preprocessor import LayoutLMv3DocumentClassifierPreprocessor | ||
|
|
||
| @keras_hub_export("keras_hub.models.LayoutLMv3DocumentClassifier") | ||
| class LayoutLMv3DocumentClassifier(layers.Layer): | ||
| """Document classifier using LayoutLMv3 backbone. | ||
| This model uses the LayoutLMv3 backbone for document classification tasks, | ||
| adding a classification head on top of the backbone's pooled output. | ||
| Args: | ||
| backbone: LayoutLMv3Backbone instance or string preset name. | ||
| num_classes: int, defaults to 2. Number of output classes. | ||
| dropout: float, defaults to 0.1. Dropout rate for the classification head. | ||
| **kwargs: Additional keyword arguments passed to the parent class. | ||
| Example: | ||
| ```python | ||
| # Initialize classifier from preset | ||
| classifier = LayoutLMv3DocumentClassifier.from_preset("layoutlmv3_base") | ||
| # Process document | ||
| outputs = classifier({ | ||
| "input_ids": input_ids, | ||
| "bbox": bbox, | ||
| "attention_mask": attention_mask, | ||
| "image": image | ||
| }) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| backbone, | ||
| num_classes=2, | ||
| dropout=0.1, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.backbone = backbone | ||
| self.num_classes = num_classes | ||
| self.dropout = dropout | ||
|
|
||
| def call(self, inputs): | ||
| # Get backbone outputs | ||
| backbone_outputs = self.backbone(inputs) | ||
| sequence_output = backbone_outputs["sequence_output"] | ||
| pooled_output = backbone_outputs["pooled_output"] | ||
|
|
||
| # Classification head | ||
| x = layers.Dropout(self.dropout)(pooled_output) | ||
| outputs = layers.Dense( | ||
| self.num_classes, | ||
| activation="softmax", | ||
| name="classifier", | ||
| )(x) | ||
|
|
||
| return outputs | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update({ | ||
| "backbone": self.backbone, | ||
| "num_classes": self.num_classes, | ||
| "dropout": self.dropout, | ||
| }) | ||
| return config | ||
|
|
||
| @classmethod | ||
| def from_preset( | ||
| cls, | ||
| preset, | ||
| num_classes=2, | ||
| dropout=0.1, | ||
| **kwargs, | ||
| ): | ||
| """Create a LayoutLMv3 document classifier from a preset. | ||
| Args: | ||
| preset: string. Must be one of "layoutlmv3_base", "layoutlmv3_large". | ||
| num_classes: int. Number of classes to classify documents into. | ||
| dropout: float. Dropout probability for the classification head. | ||
| **kwargs: Additional keyword arguments. | ||
| Returns: | ||
| A LayoutLMv3DocumentClassifier instance. | ||
| """ | ||
| backbone = LayoutLMv3Backbone.from_preset(preset) | ||
| return cls( | ||
| backbone=backbone, | ||
| num_classes=num_classes, | ||
| dropout=dropout, | ||
| **kwargs, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this directory and file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still needs to be removed