Skip to content
Closed
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
from torch.testing._internal import common_utils

from torchao.prototype.awq import AWQConfig, AWQStep
from torchao.quantization import Int4WeightOnlyConfig, quantize_
Expand Down Expand Up @@ -47,9 +44,9 @@ def forward(self, x):
not _is_fbgemm_genai_gpu_available(),
reason="need to install fbgemm_gpu_genai package",
)
class TestAWQ(TestCase):
def test_awq_config(self):
base_config = Int4WeightOnlyConfig()
class TestAWQ(common_utils.TestCase):
@common_utils.parametrize("base_config", [Int4WeightOnlyConfig()])
def test_awq_config(self, base_config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is fine not to parametrize I think

AWQConfig(base_config, step=AWQStep.PREPARE)
AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
AWQConfig(base_config, step=AWQStep.CONVERT)
Expand All @@ -61,19 +58,20 @@ def test_awq_config(self):
with self.assertRaisesRegex(ValueError, "is not one of"):
AWQConfig(base_config, step="not_supported")

def test_awq_functionality(self):
@common_utils.parametrize(
"base_config", [Int4WeightOnlyConfig(group_size=128, version=2)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: version=2 can be removed now

)
def test_awq_functionality(self, base_config):
device = "cuda"
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
group_size = 128
n_calibration_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)

# baseline quantization
base_config = Int4WeightOnlyConfig(group_size=group_size)
m_baseline = copy.deepcopy(m)
quantize_(m_baseline, base_config)

Expand Down Expand Up @@ -104,12 +102,14 @@ def test_awq_functionality(self):
loss_base = (ref_out - baseline_out).pow(2).mean().item()
assert loss_awq < loss_base

def test_awq_loading(self):
@common_utils.parametrize(
"base_config", [Int4WeightOnlyConfig(group_size=128, version=2)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, version=2 can be removed now

)
def test_awq_loading(self, base_config):
device = "cuda"
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
group_size = 128
n_calibration_examples = 10
sequence_length = 5

Expand All @@ -123,7 +123,6 @@ def test_awq_loading(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = Int4WeightOnlyConfig(group_size=group_size)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down Expand Up @@ -152,7 +151,10 @@ def test_awq_loading(self):
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)

def test_awq_loading_vllm(self):
@common_utils.parametrize(
"base_config", [Int4WeightOnlyConfig(group_size=128, version=2)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

)
def test_awq_loading_vllm(self, base_config):
"""Simulate weight loading in vllm:
* prepare model weight to the same format (awq weight)
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
Expand All @@ -163,7 +165,6 @@ def test_awq_loading_vllm(self):
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
group_size = 128
n_calibration_examples = 10
sequence_length = 5

Expand All @@ -177,7 +178,6 @@ def test_awq_loading_vllm(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = Int4WeightOnlyConfig(group_size=group_size)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down Expand Up @@ -212,5 +212,7 @@ def test_awq_loading_vllm(self):
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)


common_utils.instantiate_parametrized_tests(TestAWQ)

if __name__ == "__main__":
run_tests()
unittest.main()