From e1cff30fce81ce3309f80f26ca39b5f0d4a7da49 Mon Sep 17 00:00:00 2001 From: Przemek Date: Tue, 9 Nov 2021 15:27:08 +0900 Subject: [PATCH 1/3] Upgraded PATTERN_PREFIX regexp: 1. Bert: `bert.` -> optional `bert.` 2. Bart: `model.` -> optional `model.` --- nn_pruning/model_structure.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nn_pruning/model_structure.py b/nn_pruning/model_structure.py index 21c3d968..aaf7e988 100644 --- a/nn_pruning/model_structure.py +++ b/nn_pruning/model_structure.py @@ -56,7 +56,7 @@ def is_layernorm(module_name): return "layernorm" in module_name.lower().replace("_", "") class BertStructure(ModelStructure): - PATTERN_PREFIX = "bert.encoder.layer.[0-9]+." + PATTERN_PREFIX = "(:?bert.)?encoder.layer.[0-9]+." LAYER_PATTERNS = dict( query="attention.self.query", key="attention.self.key", @@ -77,7 +77,7 @@ class BertStructure(ModelStructure): ) class BartStructure(ModelStructure): - PATTERN_PREFIX = "model.(en|de)coder.layers.[0-9]+." + PATTERN_PREFIX = "(:?model.)?(en|de)coder.layers.[0-9]+." LAYER_PATTERNS = dict( query="self_attn.q_proj", key="self_attn.k_proj", From f4034775fb7d420ffe8738b8a4de5167e2d6b9f9 Mon Sep 17 00:00:00 2001 From: Przemek Date: Tue, 9 Nov 2021 15:27:53 +0900 Subject: [PATCH 2/3] Test for `BertModel` constructor --- nn_pruning/tests/test_patch.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nn_pruning/tests/test_patch.py b/nn_pruning/tests/test_patch.py index 4541dd27..aeb3c721 100644 --- a/nn_pruning/tests/test_patch.py +++ b/nn_pruning/tests/test_patch.py @@ -1,7 +1,7 @@ import unittest from unittest import TestCase -from transformers import BertConfig, BertForQuestionAnswering +from transformers import BertConfig, BertForQuestionAnswering, BertModel from nn_pruning.model_structure import BertStructure from nn_pruning.modules.masked_nn import ( @@ -25,9 +25,9 @@ def test_base(self): # for regexp, layers in layers.items(): # print(regexp) - def test_patch_module_independent_parameters(self): + def test_patch_module_independent_parameters(self, bert_constructor=BertForQuestionAnswering): config = BertConfig.from_pretrained("bert-base-uncased") - model = BertForQuestionAnswering(config) + model = bert_constructor(config) parameters = LinearPruningArgs( method="topK", @@ -52,6 +52,9 @@ def test_patch_module_independent_parameters(self): self.assertEqual(key_sizes, {"mask": 72}) + def test_patch_module_independent_parameters_bert_model_only(self): + self.test_patch_module_independent_parameters(bert_constructor=BertModel) + def test_patch_module_ampere(self): config = BertConfig.from_pretrained("bert-base-uncased") model = BertForQuestionAnswering(config) From db2ad04521ba05eb8ef0247b7a7c51d58147fc77 Mon Sep 17 00:00:00 2001 From: Przemek Date: Tue, 9 Nov 2021 15:28:10 +0900 Subject: [PATCH 3/3] Test for `AutoModel` constructor --- nn_pruning/tests/test_patch2.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/nn_pruning/tests/test_patch2.py b/nn_pruning/tests/test_patch2.py index 7d2dfad7..9e55119e 100644 --- a/nn_pruning/tests/test_patch2.py +++ b/nn_pruning/tests/test_patch2.py @@ -6,14 +6,14 @@ BlockLinearPruningContextModule, SingleDimensionLinearPruningContextModule, ) -from transformers import AutoConfig, AutoModelForQuestionAnswering +from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoModel import copy class TestFun(TestCase): - def helper(self, sparse_args, model_name_or_path): + def helper(self, sparse_args, model_name_or_path, model_constructor=AutoModelForQuestionAnswering): config = AutoConfig.from_pretrained(model_name_or_path) - model = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path) + model = model_constructor.from_pretrained(model_name_or_path) device = "cuda" cache_dir = None @@ -24,7 +24,7 @@ def helper(self, sparse_args, model_name_or_path): return config, model, coordinator - def test_base(self): + def test_base(self, model_constructor=AutoModelForQuestionAnswering): sparse_args = SparseTrainingArguments.hybrid(20.0) sparse_args.layer_norm_patch = True sparse_args.gelu_patch = True @@ -36,7 +36,7 @@ def test_base(self): } for model_name_or_path in ref_stats.keys(): - config, model, coordinator = self.helper(sparse_args, model_name_or_path) + config, model, coordinator = self.helper(sparse_args, model_name_or_path, model_constructor) coordinator.patch_model(model) @@ -48,7 +48,10 @@ def test_base(self): self.assertEqual(stats, ref_stats[model_name_or_path]) - def test_context_module(self): + def test_base_for_auto_model(self): + self.test_base(model_constructor=AutoModel) + + def test_context_module(self, model_constructor=AutoModelForQuestionAnswering): sparse_args = SparseTrainingArguments.hybrid(20.0) sparse_args.layer_norm_patch = True sparse_args.gelu_patch = True @@ -60,7 +63,7 @@ def test_context_module(self): } for model_name_or_path in ref_context_module.keys(): - config, model, coordinator = self.helper(sparse_args, model_name_or_path) + config, model, coordinator = self.helper(sparse_args, model_name_or_path, model_constructor) coordinator.patch_model(model) @@ -76,6 +79,9 @@ def test_context_module(self): self.assertEqual(context_module, ref_context_module[model_name_or_path]) + def test_context_module_for_auto_model(self, model_constructor=AutoModelForQuestionAnswering): + self.test_context_module(model_constructor=AutoModel) + if __name__ == "__main__": unittest.main()