Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
from torch.testing._internal import common_utils

from torchao.prototype.awq import AWQConfig, AWQStep
from torchao.quantization import Int4WeightOnlyConfig, quantize_
Expand Down Expand Up @@ -47,7 +44,7 @@ def forward(self, x):
not _is_fbgemm_genai_gpu_available(),
reason="need to install fbgemm_gpu_genai package",
)
class TestAWQ(TestCase):
class TestAWQ(common_utils.TestCase):
def test_awq_config(self):
base_config = Int4WeightOnlyConfig()
AWQConfig(base_config, step=AWQStep.PREPARE)
Expand All @@ -61,19 +58,18 @@ def test_awq_config(self):
with self.assertRaisesRegex(ValueError, "is not one of"):
AWQConfig(base_config, step="not_supported")

def test_awq_functionality(self):
@common_utils.parametrize("base_config", [Int4WeightOnlyConfig(group_size=128)])
def test_awq_functionality(self, base_config):
device = "cuda"
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
group_size = 128
n_calibration_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)

# baseline quantization
base_config = Int4WeightOnlyConfig(group_size=group_size)
m_baseline = copy.deepcopy(m)
quantize_(m_baseline, base_config)

Expand Down Expand Up @@ -104,12 +100,12 @@ def test_awq_functionality(self):
loss_base = (ref_out - baseline_out).pow(2).mean().item()
assert loss_awq < loss_base

def test_awq_loading(self):
@common_utils.parametrize("base_config", [Int4WeightOnlyConfig(group_size=128)])
def test_awq_loading(self, base_config):
device = "cuda"
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
group_size = 128
n_calibration_examples = 10
sequence_length = 5

Expand All @@ -123,7 +119,6 @@ def test_awq_loading(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = Int4WeightOnlyConfig(group_size=group_size)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down Expand Up @@ -152,7 +147,8 @@ def test_awq_loading(self):
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)

def test_awq_loading_vllm(self):
@common_utils.parametrize("base_config", [Int4WeightOnlyConfig(group_size=128)])
def test_awq_loading_vllm(self, base_config):
"""Simulate weight loading in vllm:
* prepare model weight to the same format (awq weight)
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
Expand All @@ -163,7 +159,6 @@ def test_awq_loading_vllm(self):
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
group_size = 128
n_calibration_examples = 10
sequence_length = 5

Expand All @@ -177,7 +172,6 @@ def test_awq_loading_vllm(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = Int4WeightOnlyConfig(group_size=group_size)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down Expand Up @@ -212,5 +206,7 @@ def test_awq_loading_vllm(self):
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)


common_utils.instantiate_parametrized_tests(TestAWQ)

if __name__ == "__main__":
run_tests()
unittest.main()