16
16
17
17
from torchao .utils import is_fbcode , is_sm_at_least_89
18
18
19
- _MODEL_NAME_AND_VERSIONS = [
20
- ("torchao-testing/opt-125m-float8dq-row-v1-0.13-dev" , 1 ),
19
+ _DEPRECATED_MODEL_INFO = [
20
+ (
21
+ "torchao-testing/opt-125m-float8dq-row-v1-0.13-dev" ,
22
+ 1 ,
23
+ "Float8DynamicActivationFloat8WeightConfig" ,
24
+ ),
25
+ ]
26
+
27
+ _MODEL_INFO = [
28
+ ("torchao-testing/opt-125m-FP8-v2-0.13-dev" , 2 ),
29
+ ("torchao-testing/opt-125m-INT4-preshuffled-v2-0.13-dev" , 2 ),
30
+ ("torchao-testing/opt-125m-INT4-v2-0.13-dev" , 2 ),
21
31
]
22
32
23
33
27
37
is_fbcode (),
28
38
"Skipping the test in fbcode for now, not sure how to download from transformers" ,
29
39
)
30
- class TestLoadingDeprecatedCheckpoint (TestCase ):
31
- @common_utils .parametrize ("model_name_and_version " , _MODEL_NAME_AND_VERSIONS )
32
- def test_load_model_and_run (self , model_name_and_version ):
40
+ class TestLoadAndRunCheckpoint (TestCase ):
41
+ @common_utils .parametrize ("model_info " , _DEPRECATED_MODEL_INFO )
42
+ def test_load_and_run_deprecated_checkpoints (self , model_info ):
33
43
"""Test that we print correct warning message when loading a deprecated checkpoint
34
44
and making sure the deprecated checkpoints can still be loaded
35
45
"""
36
46
# Load and quantize model
37
- model_name , version = model_name_and_version
47
+ model_name , version , config_name = model_info
38
48
with warnings .catch_warnings (record = True ) as caught_warnings :
39
49
quantized_model = AutoModelForCausalLM .from_pretrained (
40
50
model_name ,
41
51
torch_dtype = "bfloat16" ,
42
- device_map = "cuda" ,
52
+ device_map = "cuda:0 " ,
43
53
)
44
54
assert any (
45
55
"Stored version is not the same as current default version of the config"
46
56
in str (w .message )
47
57
for w in caught_warnings
48
58
), "Didn't get expected warning message for version mismatch"
49
59
50
- # TODO: generalize when we test more checkpoints
51
60
assert any (
52
- "Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
61
+ f "Models quantized with version 1 of { config_name } is deprecated"
53
62
in str (w .message )
54
63
for w in caught_warnings
55
64
), "Didn't get expected warning message for deprecation"
@@ -70,8 +79,35 @@ def test_load_model_and_run(self, model_name_and_version):
70
79
generated_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False
71
80
)
72
81
82
+ @common_utils .parametrize ("model_info" , _MODEL_INFO )
83
+ def test_load_and_run_checkpoints (self , model_info ):
84
+ """Test that we print correct warning message when loading a deprecated checkpoint
85
+ and making sure the deprecated checkpoints can still be loaded
86
+ """
87
+ model_name , version = model_info
88
+ # Load and quantize model
89
+ quantized_model = AutoModelForCausalLM .from_pretrained (
90
+ model_name ,
91
+ torch_dtype = "bfloat16" ,
92
+ device_map = "cuda:0" ,
93
+ )
94
+ assert isinstance (quantized_model .config .quantization_config , TorchAoConfig )
95
+ assert quantized_model .config .quantization_config .quant_type .version == version
96
+
97
+ tokenizer = AutoTokenizer .from_pretrained (model_name )
98
+ prompt = ("Hello, my name is" ,)
99
+ inputs = tokenizer (
100
+ prompt ,
101
+ return_tensors = "pt" ,
102
+ ).to ("cuda" )
103
+ generated_ids = quantized_model .generate (** inputs , max_new_tokens = 128 )
104
+ # make sure it runs
105
+ _ = tokenizer .batch_decode (
106
+ generated_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False
107
+ )
108
+
73
109
74
- common_utils .instantiate_parametrized_tests (TestLoadingDeprecatedCheckpoint )
110
+ common_utils .instantiate_parametrized_tests (TestLoadAndRunCheckpoint )
75
111
76
112
if __name__ == "__main__" :
77
113
run_tests ()
0 commit comments