Skip to content

Commit ce3e167

Browse files
committed
Add load and run tests for checkpoints that we want to have BC
Summary: Added load and run tests to make sure previously saved checkpoints can continue to load and run. includes FP8, INT4 and INT4 + preshuffled checkpoints since these might reach larger audience Test Plan: python test/integration/test_load_and_run_checkpoint.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2792, branch: jerryzh168/stack/28
1 parent 751d7f6 commit ce3e167

File tree

1 file changed

+51
-10
lines changed

1 file changed

+51
-10
lines changed

test/integration/test_loading_deprecated_checkpoint.py renamed to test/integration/test_load_and_run_checkpoint.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,40 +16,53 @@
1616

1717
from torchao.utils import is_fbcode, is_sm_at_least_89
1818

19-
_MODEL_NAME_AND_VERSIONS = [
20-
("torchao-testing/opt-125m-float8dq-row-v1-0.13-dev", 1),
19+
_DEPRECATED_MODEL_INFO = [
20+
(
21+
"torchao-testing/opt-125m-float8dq-row-v1-0.13-dev",
22+
1,
23+
"Float8DynamicActivationFloat8WeightConfig",
24+
),
2125
]
2226

27+
_MODEL_NAMES = [
28+
"torchao-testing/single-linear-FP8-v2-0.13-dev",
29+
"torchao-testing/single-linear-INT4-preshuffled-v2-0.13-dev",
30+
"torchao-testing/single-linear-INT4-v2-0.13-dev",
31+
]
32+
33+
_MODEL = torch.nn.Sequential(
34+
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
35+
)
36+
2337

2438
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2539
@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+")
2640
@unittest.skipIf(
2741
is_fbcode(),
2842
"Skipping the test in fbcode for now, not sure how to download from transformers",
2943
)
30-
class TestLoadingDeprecatedCheckpoint(TestCase):
31-
@common_utils.parametrize("model_name_and_version", _MODEL_NAME_AND_VERSIONS)
32-
def test_load_model_and_run(self, model_name_and_version):
44+
class TestLoadAndRunCheckpoint(TestCase):
45+
@common_utils.parametrize("model_info", _DEPRECATED_MODEL_INFO)
46+
def test_load_and_run_deprecated_checkpoints(self, model_info):
3347
"""Test that we print correct warning message when loading a deprecated checkpoint
3448
and making sure the deprecated checkpoints can still be loaded
3549
"""
3650
# Load and quantize model
37-
model_name, version = model_name_and_version
51+
model_name, version, config_name = model_info
3852
with warnings.catch_warnings(record=True) as caught_warnings:
3953
quantized_model = AutoModelForCausalLM.from_pretrained(
4054
model_name,
4155
torch_dtype="bfloat16",
42-
device_map="cuda",
56+
device_map="cuda:0",
4357
)
4458
assert any(
4559
"Stored version is not the same as current default version of the config"
4660
in str(w.message)
4761
for w in caught_warnings
4862
), "Didn't get expected warning message for version mismatch"
4963

50-
# TODO: generalize when we test more checkpoints
5164
assert any(
52-
"Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
65+
f"Models quantized with version 1 of {config_name} is deprecated"
5366
in str(w.message)
5467
for w in caught_warnings
5568
), "Didn't get expected warning message for deprecation"
@@ -70,8 +83,36 @@ def test_load_model_and_run(self, model_name_and_version):
7083
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
7184
)
7285

86+
@common_utils.parametrize("model_name", _MODEL_NAMES)
87+
def test_load_and_run_checkpoints(self, model_name):
88+
"""Test that we print correct warning message when loading a deprecated checkpoint
89+
and making sure the deprecated checkpoints can still be loaded
90+
"""
91+
from huggingface_hub import hf_hub_download
92+
93+
downloaded_model = hf_hub_download(model_name, filename="model.bin")
94+
# Load and quantize model
95+
with torch.device("meta"):
96+
model = torch.nn.Sequential(
97+
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
98+
)
99+
with open(downloaded_model, "rb") as f:
100+
model.load_state_dict(torch.load(f), assign=True)
101+
102+
downloaded_example_inputs = hf_hub_download(
103+
model_name, filename="model_inputs.pt"
104+
)
105+
with open(downloaded_example_inputs, "rb") as f:
106+
example_inputs = torch.load(f)
107+
downloaded_output = hf_hub_download(model_name, filename="model_output.pt")
108+
with open(downloaded_output, "rb") as f:
109+
ref_output = torch.load(f)
110+
111+
output = model(*example_inputs)
112+
self.assertTrue(torch.allclose(output, ref_output))
113+
73114

74-
common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint)
115+
common_utils.instantiate_parametrized_tests(TestLoadAndRunCheckpoint)
75116

76117
if __name__ == "__main__":
77118
run_tests()

0 commit comments

Comments
 (0)