-
Notifications
You must be signed in to change notification settings - Fork 338
Parametrize quantization APIs for AWQ unittest #2930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
4a17ec1
5a63f40
86cf8e4
f4d66c3
e9640bc
05b534a
ec269da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,10 +8,7 @@ | |
import unittest | ||
|
||
import torch | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
run_tests, | ||
) | ||
from torch.testing._internal import common_utils | ||
|
||
from torchao.prototype.awq import AWQConfig, AWQStep | ||
from torchao.quantization import Int4WeightOnlyConfig, quantize_ | ||
|
@@ -47,9 +44,9 @@ def forward(self, x): | |
not _is_fbgemm_genai_gpu_available(), | ||
reason="need to install fbgemm_gpu_genai package", | ||
) | ||
class TestAWQ(TestCase): | ||
def test_awq_config(self): | ||
base_config = Int4WeightOnlyConfig() | ||
class TestAWQ(common_utils.TestCase): | ||
@common_utils.parametrize("base_config", [Int4WeightOnlyConfig()]) | ||
def test_awq_config(self, base_config): | ||
AWQConfig(base_config, step=AWQStep.PREPARE) | ||
AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING) | ||
AWQConfig(base_config, step=AWQStep.CONVERT) | ||
|
@@ -61,19 +58,20 @@ def test_awq_config(self): | |
with self.assertRaisesRegex(ValueError, "is not one of"): | ||
AWQConfig(base_config, step="not_supported") | ||
|
||
def test_awq_functionality(self): | ||
@common_utils.parametrize( | ||
"base_config", [Int4WeightOnlyConfig(group_size=128, version=2)] | ||
|
||
) | ||
def test_awq_functionality(self, base_config): | ||
device = "cuda" | ||
dataset_size = 100 | ||
l1, l2, l3 = 512, 256, 128 | ||
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs | ||
group_size = 128 | ||
n_calibration_examples = 10 | ||
sequence_length = 5 | ||
|
||
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) | ||
|
||
# baseline quantization | ||
base_config = Int4WeightOnlyConfig(group_size=group_size) | ||
m_baseline = copy.deepcopy(m) | ||
quantize_(m_baseline, base_config) | ||
|
||
|
@@ -104,12 +102,14 @@ def test_awq_functionality(self): | |
loss_base = (ref_out - baseline_out).pow(2).mean().item() | ||
assert loss_awq < loss_base | ||
|
||
def test_awq_loading(self): | ||
@common_utils.parametrize( | ||
"base_config", [Int4WeightOnlyConfig(group_size=128, version=2)] | ||
|
||
) | ||
def test_awq_loading(self, base_config): | ||
device = "cuda" | ||
dataset_size = 100 | ||
l1, l2, l3 = 512, 256, 128 | ||
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs | ||
group_size = 128 | ||
n_calibration_examples = 10 | ||
sequence_length = 5 | ||
|
||
|
@@ -123,7 +123,6 @@ def test_awq_loading(self): | |
calibration_data = dataset[:n_calibration_examples] | ||
|
||
# calibrate | ||
base_config = Int4WeightOnlyConfig(group_size=group_size) | ||
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) | ||
quantize_(m, quant_config) | ||
|
||
|
@@ -152,7 +151,10 @@ def test_awq_loading(self): | |
assert awq_save_load_out is not None | ||
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2) | ||
|
||
def test_awq_loading_vllm(self): | ||
@common_utils.parametrize( | ||
"base_config", [Int4WeightOnlyConfig(group_size=128, version=2)] | ||
|
||
) | ||
def test_awq_loading_vllm(self, base_config): | ||
"""Simulate weight loading in vllm: | ||
* prepare model weight to the same format (awq weight) | ||
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint | ||
|
@@ -163,7 +165,6 @@ def test_awq_loading_vllm(self): | |
dataset_size = 100 | ||
l1, l2, l3 = 512, 256, 128 | ||
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs | ||
group_size = 128 | ||
n_calibration_examples = 10 | ||
sequence_length = 5 | ||
|
||
|
@@ -177,7 +178,6 @@ def test_awq_loading_vllm(self): | |
calibration_data = dataset[:n_calibration_examples] | ||
|
||
# calibrate | ||
base_config = Int4WeightOnlyConfig(group_size=group_size) | ||
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) | ||
quantize_(m, quant_config) | ||
|
||
|
@@ -212,5 +212,7 @@ def test_awq_loading_vllm(self): | |
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2) | ||
|
||
|
||
common_utils.instantiate_parametrized_tests(TestAWQ) | ||
|
||
if __name__ == "__main__": | ||
run_tests() | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one is fine not to parametrize I think