16
16
17
17
from torchao .utils import is_fbcode , is_sm_at_least_89
18
18
19
- _MODEL_NAME_AND_VERSIONS = [
20
- ("torchao-testing/opt-125m-float8dq-row-v1-0.13-dev" , 1 ),
19
+ _DEPRECATED_MODEL_INFO = [
20
+ (
21
+ "torchao-testing/opt-125m-float8dq-row-v1-0.13-dev" ,
22
+ 1 ,
23
+ "Float8DynamicActivationFloat8WeightConfig" ,
24
+ ),
21
25
]
22
26
27
+ _MODEL_NAMES = [
28
+ "torchao-testing/single-linear-FP8-v2-0.13-dev" ,
29
+ "torchao-testing/single-linear-INT4-preshuffled-v2-0.13-dev" ,
30
+ "torchao-testing/single-linear-INT4-v2-0.13-dev" ,
31
+ ]
32
+
33
+ _MODEL = torch .nn .Sequential (
34
+ torch .nn .Linear (32 , 256 , dtype = torch .bfloat16 , device = "cuda" )
35
+ )
36
+
23
37
24
38
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
25
39
@unittest .skipIf (not is_sm_at_least_89 (), "Nedd sm89+" )
26
40
@unittest .skipIf (
27
41
is_fbcode (),
28
42
"Skipping the test in fbcode for now, not sure how to download from transformers" ,
29
43
)
30
- class TestLoadingDeprecatedCheckpoint (TestCase ):
31
- @common_utils .parametrize ("model_name_and_version " , _MODEL_NAME_AND_VERSIONS )
32
- def test_load_model_and_run (self , model_name_and_version ):
44
+ class TestLoadAndRunCheckpoint (TestCase ):
45
+ @common_utils .parametrize ("model_info " , _DEPRECATED_MODEL_INFO )
46
+ def test_load_and_run_deprecated_checkpoints (self , model_info ):
33
47
"""Test that we print correct warning message when loading a deprecated checkpoint
34
48
and making sure the deprecated checkpoints can still be loaded
35
49
"""
36
50
# Load and quantize model
37
- model_name , version = model_name_and_version
51
+ model_name , version , config_name = model_info
38
52
with warnings .catch_warnings (record = True ) as caught_warnings :
39
53
quantized_model = AutoModelForCausalLM .from_pretrained (
40
54
model_name ,
41
55
torch_dtype = "bfloat16" ,
42
- device_map = "cuda" ,
56
+ device_map = "cuda:0 " ,
43
57
)
44
58
assert any (
45
59
"Stored version is not the same as current default version of the config"
46
60
in str (w .message )
47
61
for w in caught_warnings
48
62
), "Didn't get expected warning message for version mismatch"
49
63
50
- # TODO: generalize when we test more checkpoints
51
64
assert any (
52
- "Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
65
+ f "Models quantized with version 1 of { config_name } is deprecated"
53
66
in str (w .message )
54
67
for w in caught_warnings
55
68
), "Didn't get expected warning message for deprecation"
@@ -70,8 +83,36 @@ def test_load_model_and_run(self, model_name_and_version):
70
83
generated_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False
71
84
)
72
85
86
+ @common_utils .parametrize ("model_name" , _MODEL_NAMES )
87
+ def test_load_and_run_checkpoints (self , model_name ):
88
+ """Test that we print correct warning message when loading a deprecated checkpoint
89
+ and making sure the deprecated checkpoints can still be loaded
90
+ """
91
+ from huggingface_hub import hf_hub_download
92
+
93
+ downloaded_model = hf_hub_download (model_name , filename = "model.bin" )
94
+ # Load and quantize model
95
+ with torch .device ("meta" ):
96
+ model = torch .nn .Sequential (
97
+ torch .nn .Linear (32 , 256 , dtype = torch .bfloat16 , device = "cuda" )
98
+ )
99
+ with open (downloaded_model , "rb" ) as f :
100
+ model .load_state_dict (torch .load (f ), assign = True )
101
+
102
+ downloaded_example_inputs = hf_hub_download (
103
+ model_name , filename = "model_inputs.pt"
104
+ )
105
+ with open (downloaded_example_inputs , "rb" ) as f :
106
+ example_inputs = torch .load (f )
107
+ downloaded_output = hf_hub_download (model_name , filename = "model_output.pt" )
108
+ with open (downloaded_output , "rb" ) as f :
109
+ ref_output = torch .load (f )
110
+
111
+ output = model (* example_inputs )
112
+ self .assertTrue (torch .allclose (output , ref_output ))
113
+
73
114
74
- common_utils .instantiate_parametrized_tests (TestLoadingDeprecatedCheckpoint )
115
+ common_utils .instantiate_parametrized_tests (TestLoadAndRunCheckpoint )
75
116
76
117
if __name__ == "__main__" :
77
118
run_tests ()
0 commit comments