-
Notifications
You must be signed in to change notification settings - Fork 338
Support Int4OpaqueTensor for AWQ #2997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1a722a3
c8ca7e1
9cc0c26
a441eda
b5f8874
8b042db
e14ed2a
a5675fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
quantize_, | ||
) | ||
from torchao.quantization.utils import compute_error | ||
from torchao.quantization.quantize_.common import SupportsActivationPreScaling | ||
from torchao.utils import ( | ||
torch_version_at_least, | ||
) | ||
|
@@ -28,7 +29,8 @@ | |
def get_config(group_size): | ||
return Int4WeightOnlyConfig( | ||
group_size=group_size, | ||
int4_packing_format="opaque", | ||
packing_format="opaque", | ||
|
||
version=2, | ||
) | ||
|
||
|
||
|
@@ -75,7 +77,23 @@ def test_module_path(self, dtype): | |
str(type(state_dict["weight"])), | ||
"<class 'torchao.quantization.Int4OpaqueTensor'>", | ||
) | ||
def test_activation_prescaling(self): | ||
dtype = torch.bfloat16 | ||
input = torch.randn(1, 128, dtype=dtype) | ||
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype) | ||
original = linear(input) | ||
quantize_(linear, get_config(group_size=128)) | ||
qw = linear.weight | ||
assert isinstance(qw, SupportsActivationPreScaling), ( | ||
"Expected int4 tensor supports activation prescaling" | ||
) | ||
Xia-Weiwen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" | ||
_ACT_PRE_SCALE = 2 | ||
qw.act_pre_scale = _ACT_PRE_SCALE | ||
quantized = linear(input) | ||
|
||
# making sure activation pre scaling is successfully applied to the activation | ||
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20) | ||
Xia-Weiwen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
instantiate_parametrized_tests(TestInt4OpaqueTensor) | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -16,8 +16,14 @@ | |||||||
) | ||||||||
from torchao.quantization import ( | ||||||||
quantize_, | ||||||||
Int4WeightOnlyConfig | ||||||||
) | ||||||||
|
||||||||
from torch._inductor import config as inductor_config | ||||||||
import lm_eval | ||||||||
from lm_eval.models.huggingface import HFLM | ||||||||
inductor_config.cpp_wrapper = True | ||||||||
inductor_config.max_autotune = True | ||||||||
inductor_config.max_autotune_gemm_backends = "CPP,ATEN" | ||||||||
|
||||||||
|
||||||||
# adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255 | ||||||||
def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): | ||||||||
|
@@ -93,7 +99,9 @@ def wiki2_eval( | |||||||
|
||||||||
|
||||||||
# adapted from Hicham Badri (@mobicham) | ||||||||
def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): | ||||||||
def benchmark( | ||||||||
model, tokenizer, max_length, tasks=None, evaluation_limit=None, device="cuda" | ||||||||
): | ||||||||
import lm_eval | ||||||||
import numpy as np | ||||||||
|
||||||||
|
@@ -103,7 +111,7 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): | |||||||
lm_eval.tasks.initialize_tasks() | ||||||||
except: | ||||||||
pass | ||||||||
model_eval = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer) | ||||||||
model_eval = HFLM(pretrained=model, tokenizer=tokenizer) | ||||||||
eval_batch_size = 1 # 8 | ||||||||
if tasks is None: | ||||||||
tasks = [ | ||||||||
|
@@ -126,21 +134,33 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): | |||||||
for task in [("truthfulqa_mc2", 0)]: | ||||||||
tag, fewshot = task | ||||||||
results[tag] = lm_eval.evaluator.simple_evaluate( | ||||||||
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size | ||||||||
model_eval, | ||||||||
tasks=[tag], | ||||||||
num_fewshot=fewshot, | ||||||||
batch_size=eval_batch_size, | ||||||||
limit=evaluation_limit, | ||||||||
)["results"] | ||||||||
print(tag, results[tag]) | ||||||||
if "winogrande" in tasks: | ||||||||
for task in [("winogrande", 5)]: | ||||||||
tag, fewshot = task | ||||||||
results[tag] = lm_eval.evaluator.simple_evaluate( | ||||||||
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size | ||||||||
model_eval, | ||||||||
tasks=[tag], | ||||||||
num_fewshot=fewshot, | ||||||||
batch_size=eval_batch_size, | ||||||||
limit=evaluation_limit, | ||||||||
)["results"] | ||||||||
print(tag, results[tag]) | ||||||||
if "arc_challenge" in tasks: | ||||||||
for task in [("arc_challenge", 25)]: | ||||||||
tag, fewshot = task | ||||||||
results[tag] = lm_eval.evaluator.simple_evaluate( | ||||||||
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size | ||||||||
model_eval, | ||||||||
tasks=[tag], | ||||||||
num_fewshot=fewshot, | ||||||||
batch_size=eval_batch_size, | ||||||||
limit=evaluation_limit, | ||||||||
)["results"] | ||||||||
print(tag, results[tag]) | ||||||||
|
||||||||
|
@@ -149,14 +169,22 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): | |||||||
for task in [("hellaswag", 10)]: | ||||||||
tag, fewshot = task | ||||||||
results[tag] = lm_eval.evaluator.simple_evaluate( | ||||||||
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size | ||||||||
model_eval, | ||||||||
tasks=[tag], | ||||||||
num_fewshot=fewshot, | ||||||||
batch_size=eval_batch_size, | ||||||||
limit=evaluation_limit, | ||||||||
)["results"] | ||||||||
print(tag, results[tag]) | ||||||||
if "gsm8k" in tasks: | ||||||||
for task in [("gsm8k", 5)]: | ||||||||
tag, fewshot = task | ||||||||
results[tag] = lm_eval.evaluator.simple_evaluate( | ||||||||
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size | ||||||||
model_eval, | ||||||||
tasks=[tag], | ||||||||
num_fewshot=fewshot, | ||||||||
batch_size=eval_batch_size, | ||||||||
limit=evaluation_limit, | ||||||||
)["results"] | ||||||||
print(tag, results[tag]) | ||||||||
# ############################################ | ||||||||
|
@@ -167,7 +195,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): | |||||||
for task in [("mmlu", 5)]: | ||||||||
tag, fewshot = task | ||||||||
results_mmlu[tag] = lm_eval.evaluator.simple_evaluate( | ||||||||
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size | ||||||||
model_eval, | ||||||||
tasks=[tag], | ||||||||
num_fewshot=fewshot, | ||||||||
batch_size=eval_batch_size, | ||||||||
limit=evaluation_limit, | ||||||||
)["results"] | ||||||||
print(tag, results_mmlu[tag]) | ||||||||
|
||||||||
|
@@ -188,7 +220,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): | |||||||
for task in [("leaderboard_bbh", 3)]: | ||||||||
tag, fewshot = task | ||||||||
results[tag] = lm_eval.evaluator.simple_evaluate( | ||||||||
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size | ||||||||
model_eval, | ||||||||
tasks=[tag], | ||||||||
num_fewshot=fewshot, | ||||||||
batch_size=eval_batch_size, | ||||||||
limit=evaluation_limit, | ||||||||
)["results"] | ||||||||
print(tag, results[tag]) | ||||||||
results["bbh"] = results[tag] | ||||||||
|
@@ -202,7 +238,7 @@ def quantize_and_eval( | |||||||
tasks: list[str], | ||||||||
max_seq_length: int, | ||||||||
calibration_limit: int, | ||||||||
validation_size: int, | ||||||||
evaluation_limit: int, | ||||||||
device: str, | ||||||||
precision: torch.dtype, | ||||||||
compile: bool, | ||||||||
|
@@ -223,10 +259,15 @@ def quantize_and_eval( | |||||||
if quant.startswith("awq-int4wo"): | ||||||||
group_size = int(quant.split("-")[2]) | ||||||||
print(f"running {quant} quantization with group size {group_size}") | ||||||||
# TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon | ||||||||
from torchao.quantization import Int4WeightOnlyConfig | ||||||||
|
||||||||
base_config = Int4WeightOnlyConfig(group_size=group_size) | ||||||||
if device == "cuda": | ||||||||
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2) | ||||||||
elif device == "cpu": | ||||||||
base_config = Int4WeightOnlyConfig( | ||||||||
group_size=group_size, packing_format="opaque", version=2 | ||||||||
) | ||||||||
else: | ||||||||
assert False, "Unsupported device: {}".format(device) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i am not very familar with the concept here, could you explain why cpu needs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's because packing_format describes a fix format of how the quantized weight data are laid out in memory, but int4 cpu has a format that is based on specific hardwares/tensor shapes etc.:
|
||||||||
print(f"running {quant} prepare and calibrate") | ||||||||
t0 = time.time() | ||||||||
quant_config = AWQConfig(base_config, step="prepare") | ||||||||
|
@@ -254,14 +295,21 @@ def quantize_and_eval( | |||||||
quantize_(model, quant_config) | ||||||||
print(f"time for convert: {time.time() - t0:.02f} seconds") | ||||||||
quant_config = AWQConfig(base_config, step="prepare_for_loading") | ||||||||
model.config.quantization_config = TorchAoConfig(quant_config) | ||||||||
#model.config.quantization_config = TorchAoConfig(quant_config) | ||||||||
|
||||||||
|
||||||||
elif quant.startswith("int4wo"): | ||||||||
group_size = int(quant.split("-")[1]) | ||||||||
print(f"running {quant} quantization with group size {group_size}") | ||||||||
# TODO: enable after migration: https://github.com/pytorch/ao/issues/2752 | ||||||||
# use_hqq = "hqq" in quant | ||||||||
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2) | ||||||||
if device == "cuda": | ||||||||
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2) | ||||||||
elif device == "cpu": | ||||||||
base_config = Int4WeightOnlyConfig( | ||||||||
group_size=group_size, packing_format="opaque", version=2 | ||||||||
|
||||||||
) | ||||||||
else: | ||||||||
assert False, "Unsupported device: {}".format(device) | ||||||||
quantize_(model, base_config) | ||||||||
|
||||||||
if model_save_path is not None: | ||||||||
|
@@ -276,8 +324,14 @@ def quantize_and_eval( | |||||||
if compile: | ||||||||
model = torch.compile(model) | ||||||||
|
||||||||
return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device) | ||||||||
|
||||||||
return benchmark( | ||||||||
model, | ||||||||
tokenizer, | ||||||||
max_seq_length, | ||||||||
tasks=tasks, | ||||||||
evaluation_limit=evaluation_limit, | ||||||||
device=device, | ||||||||
) | ||||||||
|
||||||||
if __name__ == "__main__": | ||||||||
parser = argparse.ArgumentParser( | ||||||||
|
@@ -295,8 +349,8 @@ def quantize_and_eval( | |||||||
"--tasks", | ||||||||
nargs="+", | ||||||||
type=str, | ||||||||
help="Task to benchmark model on. Either PPL or QA", | ||||||||
default=["PPL"], | ||||||||
help="Task to benchmark model on. Here is the list of tasks you can use: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/README.md", | ||||||||
default=["hellaswag"], | ||||||||
) | ||||||||
parser.add_argument( | ||||||||
"--calibration_limit", | ||||||||
|
@@ -305,7 +359,10 @@ def quantize_and_eval( | |||||||
help="Number of samples to use for calibration. Default is 10.", | ||||||||
) | ||||||||
parser.add_argument( | ||||||||
"--validation_size", type=int, default=1, help="Validation size. Default is 1." | ||||||||
"--evaluation_limit", | ||||||||
type=int, | ||||||||
default=None, | ||||||||
help="Number of samples to use for evaluation. Default is None (all).", | ||||||||
) | ||||||||
parser.add_argument( | ||||||||
"--device", | ||||||||
|
@@ -353,7 +410,7 @@ def quantize_and_eval( | |||||||
args.tasks, | ||||||||
args.max_seq_length, | ||||||||
args.calibration_limit, | ||||||||
args.validation_size, | ||||||||
args.evaluation_limit, | ||||||||
args.device, | ||||||||
args.precision, | ||||||||
args.compile, | ||||||||
|
Uh oh!
There was an error while loading. Please reload this page.