Skip to content

Commit 80009d9

Browse files
committed
Add load and run tests for checkpoints that we want to have BC
Summary: Added load and run tests to make sure previously saved checkpoints can continue to load and run. includes FP8, INT4 and INT4 + preshuffled checkpoints since these might reach larger audience Test Plan: python test/integration/test_load_and_run_checkpoint.py Reviewers: Subscribers: Tasks: Tags:
1 parent 751d7f6 commit 80009d9

File tree

1 file changed

+46
-10
lines changed

1 file changed

+46
-10
lines changed

test/integration/test_loading_deprecated_checkpoint.py renamed to test/integration/test_load_and_run_checkpoint.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,18 @@
1616

1717
from torchao.utils import is_fbcode, is_sm_at_least_89
1818

19-
_MODEL_NAME_AND_VERSIONS = [
20-
("torchao-testing/opt-125m-float8dq-row-v1-0.13-dev", 1),
19+
_DEPRECATED_MODEL_INFO = [
20+
(
21+
"torchao-testing/opt-125m-float8dq-row-v1-0.13-dev",
22+
1,
23+
"Float8DynamicActivationFloat8WeightConfig",
24+
),
25+
]
26+
27+
_MODEL_INFO = [
28+
("torchao-testing/opt-125m-FP8-v2-0.13-dev", 2),
29+
("torchao-testing/opt-125m-INT4-preshuffled-v2-0.13-dev", 2),
30+
("torchao-testing/opt-125m-INT4-v2-0.13-dev", 2),
2131
]
2232

2333

@@ -27,29 +37,28 @@
2737
is_fbcode(),
2838
"Skipping the test in fbcode for now, not sure how to download from transformers",
2939
)
30-
class TestLoadingDeprecatedCheckpoint(TestCase):
31-
@common_utils.parametrize("model_name_and_version", _MODEL_NAME_AND_VERSIONS)
32-
def test_load_model_and_run(self, model_name_and_version):
40+
class TestLoadAndRunCheckpoint(TestCase):
41+
@common_utils.parametrize("model_info", _DEPRECATED_MODEL_INFO)
42+
def test_load_and_run_deprecated_checkpoints(self, model_info):
3343
"""Test that we print correct warning message when loading a deprecated checkpoint
3444
and making sure the deprecated checkpoints can still be loaded
3545
"""
3646
# Load and quantize model
37-
model_name, version = model_name_and_version
47+
model_name, version, config_name = model_info
3848
with warnings.catch_warnings(record=True) as caught_warnings:
3949
quantized_model = AutoModelForCausalLM.from_pretrained(
4050
model_name,
4151
torch_dtype="bfloat16",
42-
device_map="cuda",
52+
device_map="cuda:0",
4353
)
4454
assert any(
4555
"Stored version is not the same as current default version of the config"
4656
in str(w.message)
4757
for w in caught_warnings
4858
), "Didn't get expected warning message for version mismatch"
4959

50-
# TODO: generalize when we test more checkpoints
5160
assert any(
52-
"Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
61+
f"Models quantized with version 1 of {config_name} is deprecated"
5362
in str(w.message)
5463
for w in caught_warnings
5564
), "Didn't get expected warning message for deprecation"
@@ -70,8 +79,35 @@ def test_load_model_and_run(self, model_name_and_version):
7079
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
7180
)
7281

82+
@common_utils.parametrize("model_info", _MODEL_INFO)
83+
def test_load_and_run_checkpoints(self, model_info):
84+
"""Test that we print correct warning message when loading a deprecated checkpoint
85+
and making sure the deprecated checkpoints can still be loaded
86+
"""
87+
model_name, version = model_info
88+
# Load and quantize model
89+
quantized_model = AutoModelForCausalLM.from_pretrained(
90+
model_name,
91+
torch_dtype="bfloat16",
92+
device_map="cuda:0",
93+
)
94+
assert isinstance(quantized_model.config.quantization_config, TorchAoConfig)
95+
assert quantized_model.config.quantization_config.quant_type.version == version
96+
97+
tokenizer = AutoTokenizer.from_pretrained(model_name)
98+
prompt = ("Hello, my name is",)
99+
inputs = tokenizer(
100+
prompt,
101+
return_tensors="pt",
102+
).to("cuda")
103+
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
104+
# make sure it runs
105+
_ = tokenizer.batch_decode(
106+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
107+
)
108+
73109

74-
common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint)
110+
common_utils.instantiate_parametrized_tests(TestLoadAndRunCheckpoint)
75111

76112
if __name__ == "__main__":
77113
run_tests()

0 commit comments

Comments
 (0)