Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import unittest

import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
Expand Down Expand Up @@ -41,12 +42,14 @@ def forward(self, x):
x = self.linear3(x)
return x

devices = ["cpu"]
if (
torch.cuda.is_available()
and _is_fbgemm_genai_gpu_available()
and TORCH_VERSION_AT_LEAST_2_6
):
devices.append("cuda")

@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(),
reason="need to install fbgemm_gpu_genai package",
)
class TestAWQ(TestCase):
def test_awq_config(self):
base_config = Int4WeightOnlyConfig()
Expand All @@ -61,8 +64,8 @@ def test_awq_config(self):
with self.assertRaisesRegex(ValueError, "is not one of"):
AWQConfig(base_config, step="not_supported")

def test_awq_functionality(self):
device = "cuda"
@parameterized.expand([(device,) for device in devices])
def test_awq_functionality(self, device):
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand All @@ -73,7 +76,15 @@ def test_awq_functionality(self):
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)

# baseline quantization
base_config = Int4WeightOnlyConfig(group_size=group_size)
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, packing_format="opaque", version=2
)
torch.manual_seed(1234)
else:
assert False, "Unsupported device: {}".format(device)
m_baseline = copy.deepcopy(m)
quantize_(m_baseline, base_config)

Expand Down Expand Up @@ -102,10 +113,11 @@ def test_awq_functionality(self):

loss_awq = (ref_out - awq_out).pow(2).mean().item()
loss_base = (ref_out - baseline_out).pow(2).mean().item()

assert loss_awq < loss_base

def test_awq_loading(self):
device = "cuda"
@parameterized.expand([(device,) for device in devices])
def test_awq_loading(self, device):
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand All @@ -123,7 +135,14 @@ def test_awq_loading(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = Int4WeightOnlyConfig(group_size=group_size)
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, packing_format="opaque", version=2
)
else:
assert False, "Unsupported device: {}".format(device)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down Expand Up @@ -152,14 +171,14 @@ def test_awq_loading(self):
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)

def test_awq_loading_vllm(self):
@parameterized.expand([(device,) for device in devices])
def test_awq_loading_vllm(self, device):
"""Simulate weight loading in vllm:
* prepare model weight to the same format (awq weight)
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint

There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
"""
device = "cuda"
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand All @@ -177,7 +196,14 @@ def test_awq_loading_vllm(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = Int4WeightOnlyConfig(group_size=group_size)
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, packing_format="opaque", version=2
)
else:
assert False, "Unsupported device: {}".format(device)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.utils import (
torch_version_at_least,
)
Expand All @@ -28,7 +29,8 @@
def get_config(group_size):
return Int4WeightOnlyConfig(
group_size=group_size,
int4_packing_format="opaque",
packing_format="opaque",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be int4_packing_format I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated

version=2,
)


Expand Down Expand Up @@ -75,7 +77,23 @@ def test_module_path(self, dtype):
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4OpaqueTensor'>",
)
def test_activation_prescaling(self):
dtype = torch.bfloat16
input = torch.randn(1, 128, dtype=dtype)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype)
original = linear(input)
quantize_(linear, get_config(group_size=128))
qw = linear.weight
assert isinstance(qw, SupportsActivationPreScaling), (
"Expected int4 tensor supports activation prescaling"
)
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
_ACT_PRE_SCALE = 2
qw.act_pre_scale = _ACT_PRE_SCALE
quantized = linear(input)

# making sure activation pre scaling is successfully applied to the activation
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)

instantiate_parametrized_tests(TestInt4OpaqueTensor)

Expand Down
101 changes: 79 additions & 22 deletions torchao/prototype/awq/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@
)
from torchao.quantization import (
quantize_,
Int4WeightOnlyConfig
)

from torch._inductor import config as inductor_config
import lm_eval
from lm_eval.models.huggingface import HFLM
inductor_config.cpp_wrapper = True
inductor_config.max_autotune = True
inductor_config.max_autotune_gemm_backends = "CPP,ATEN"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This script will also be used for CUDA. So, I think triton is needed here. Or let's simply remove this line to use the default ones.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed


# adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255
def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512):
Expand Down Expand Up @@ -93,7 +99,9 @@ def wiki2_eval(


# adapted from Hicham Badri (@mobicham)
def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
def benchmark(
model, tokenizer, max_length, tasks=None, evaluation_limit=None, device="cuda"
):
import lm_eval
import numpy as np

Expand All @@ -103,7 +111,7 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
lm_eval.tasks.initialize_tasks()
except:
pass
model_eval = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer)
model_eval = HFLM(pretrained=model, tokenizer=tokenizer)
eval_batch_size = 1 # 8
if tasks is None:
tasks = [
Expand All @@ -126,21 +134,33 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
for task in [("truthfulqa_mc2", 0)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])
if "winogrande" in tasks:
for task in [("winogrande", 5)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])
if "arc_challenge" in tasks:
for task in [("arc_challenge", 25)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])

Expand All @@ -149,14 +169,22 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
for task in [("hellaswag", 10)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])
if "gsm8k" in tasks:
for task in [("gsm8k", 5)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])
# ############################################
Expand All @@ -167,7 +195,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
for task in [("mmlu", 5)]:
tag, fewshot = task
results_mmlu[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results_mmlu[tag])

Expand All @@ -188,7 +220,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
for task in [("leaderboard_bbh", 3)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])
results["bbh"] = results[tag]
Expand All @@ -202,7 +238,7 @@ def quantize_and_eval(
tasks: list[str],
max_seq_length: int,
calibration_limit: int,
validation_size: int,
evaluation_limit: int,
device: str,
precision: torch.dtype,
compile: bool,
Expand All @@ -223,10 +259,15 @@ def quantize_and_eval(
if quant.startswith("awq-int4wo"):
group_size = int(quant.split("-")[2])
print(f"running {quant} quantization with group size {group_size}")
# TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon
from torchao.quantization import Int4WeightOnlyConfig

base_config = Int4WeightOnlyConfig(group_size=group_size)
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, packing_format="opaque", version=2
)
else:
assert False, "Unsupported device: {}".format(device)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am not very familar with the concept here, could you explain why cpu needs opaque packing_format?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's because packing_format describes a fix format of how the quantized weight data are laid out in memory, but int4 cpu has a format that is based on specific hardwares/tensor shapes etc.:

We use AVX512 to compute TINYGEMM on CPU. We can also leverage AVX512_VNNI and AMX instructions with torch.compile and max-autotune.
For data locality, we preshuffle the data in plain layout (N, K/2) to (N/block_n, K, block_n/2), where block_n = 64/32/16.
See https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583 for more details.

print(f"running {quant} prepare and calibrate")
t0 = time.time()
quant_config = AWQConfig(base_config, step="prepare")
Expand Down Expand Up @@ -254,14 +295,21 @@ def quantize_and_eval(
quantize_(model, quant_config)
print(f"time for convert: {time.time() - t0:.02f} seconds")
quant_config = AWQConfig(base_config, step="prepare_for_loading")
model.config.quantization_config = TorchAoConfig(quant_config)
#model.config.quantization_config = TorchAoConfig(quant_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I update and remove this change.


elif quant.startswith("int4wo"):
group_size = int(quant.split("-")[1])
print(f"running {quant} quantization with group size {group_size}")
# TODO: enable after migration: https://github.com/pytorch/ao/issues/2752
# use_hqq = "hqq" in quant
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, packing_format="opaque", version=2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

version=2 can be removed now, it's the default now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I update and remove version=2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the version=2 here since it's the default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I removed version=2 in CPU

)
else:
assert False, "Unsupported device: {}".format(device)
quantize_(model, base_config)

if model_save_path is not None:
Expand All @@ -276,8 +324,14 @@ def quantize_and_eval(
if compile:
model = torch.compile(model)

return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device)

return benchmark(
model,
tokenizer,
max_seq_length,
tasks=tasks,
evaluation_limit=evaluation_limit,
device=device,
)

if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand All @@ -295,8 +349,8 @@ def quantize_and_eval(
"--tasks",
nargs="+",
type=str,
help="Task to benchmark model on. Either PPL or QA",
default=["PPL"],
help="Task to benchmark model on. Here is the list of tasks you can use: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/README.md",
default=["hellaswag"],
)
parser.add_argument(
"--calibration_limit",
Expand All @@ -305,7 +359,10 @@ def quantize_and_eval(
help="Number of samples to use for calibration. Default is 10.",
)
parser.add_argument(
"--validation_size", type=int, default=1, help="Validation size. Default is 1."
"--evaluation_limit",
type=int,
default=None,
help="Number of samples to use for evaluation. Default is None (all).",
)
parser.add_argument(
"--device",
Expand Down Expand Up @@ -353,7 +410,7 @@ def quantize_and_eval(
args.tasks,
args.max_seq_length,
args.calibration_limit,
args.validation_size,
args.evaluation_limit,
args.device,
args.precision,
args.compile,
Expand Down
Loading