Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _is_auto_round_available():
from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401
from auto_round.mllm import lmms_eval, mllm_eval
from auto_round.mllm.template import Template, get_template
from auto_round.schemes import QuantizationScheme

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import get_accelerator, logger
Expand All @@ -53,7 +54,7 @@ def __init__(
enable_full_range: bool = False, ##for symmetric, TODO support later
batch_size: int = 8,
amp: bool = True,
device: str = None,
device_map: str = None,
lr_scheduler=None,
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
enable_quanted_input: bool = True,
Expand Down Expand Up @@ -91,6 +92,8 @@ def __init__(
processor=None,
template: Union[str, Template] = None,
truncation: bool = False,
# 0.7
scheme: Union[str, dict, QuantizationScheme] = "W4A16",
**kwargs,
):
"""Init a AutQRoundQuantizer object.
Expand Down Expand Up @@ -122,7 +125,7 @@ def __init__(
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True).
device: The device to be used for tuning (default is "auto").
device_map: The device to be used for tuning (default is None).
lr_scheduler: The learning rate scheduler to be used.
dataset (str): The default dataset name (default is "NeelNanda/pile-10k").
enable_quanted_input (bool): Whether to use the output of the previous quantized block as
Expand Down Expand Up @@ -161,6 +164,7 @@ def __init__(
image_processor (Processor): Image processor for special model like llava.
template (Template): The template to specify process for different mllms.
truncation (bool): Activates truncation to cut input sequences longer than `max_length` to `max_length`.
scheme (str| dict | QuantizationScheme ): A preset scheme that defines the quantization configurations.

Returns:
The quantized model.
Expand Down Expand Up @@ -205,6 +209,8 @@ def __init__(
self.image_processor = image_processor
self.template = template
self.truncation = truncation
self.scheme = scheme
self.device_map = device_map
self.enable_w4afp8 = self._is_w4afp8()

def _is_w4afp8(self):
Expand Down Expand Up @@ -237,12 +243,13 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
rounder = AutoRoundMLLM(
model,
tokenizer=self.tokenizer,
scheme=self.scheme,
processor=self.processor,
image_processor=self.image_processor,
layer_config=self.quant_config,
batch_size=self.batch_size,
amp=self.amp,
device=self.device,
device_map=self.device_map,
lr_scheduler=self.lr_scheduler,
dataset=dataloader,
extra_data_dir=self.extra_data_dir,
Expand Down Expand Up @@ -278,12 +285,13 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
rounder = AutoRound(
model=model,
tokenizer=self.tokenizer,
scheme=self.scheme,
dataset=dataloader,
layer_config=self.quant_config or {},
enable_full_range=self.enable_full_range,
batch_size=self.batch_size,
amp=self.amp,
device=self.device,
device_map=self.device_map,
lr_scheduler=self.lr_scheduler,
enable_quanted_input=self.enable_quanted_input,
enable_minmax_tuning=self.enable_minmax_tuning,
Expand Down Expand Up @@ -317,7 +325,7 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
elif "itrex" in self.export_format:
model = pack_model(model, weight_config, device=self.device, inplace=True)
else: # pragma: no cover
model = rounder.save_quantized(output_dir=None, format=self.export_format, device=self.device, inplace=True)
model = rounder.save_quantized(output_dir="temp_auto_round", format=self.export_format, inplace=True)

return model

Expand All @@ -341,9 +349,7 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42
"""
from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401

dataloader = get_dataloader(
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, nsamples=nsamples
)
dataloader = get_dataloader(tokenizer, seqlen, dataset_name=dataset_name, seed=seed, bs=bs, nsamples=nsamples)
return dataloader


Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,8 @@ def autoround_quantize_entry(
image_processor = quant_config.image_processor
template = quant_config.template
truncation = quant_config.truncation
scheme = quant_config.scheme
device_map = quant_config.device_map

kwargs.pop("example_inputs")
quantizer = get_quantizer(
Expand Down Expand Up @@ -666,6 +668,8 @@ def autoround_quantize_entry(
image_processor=image_processor,
template=template,
truncation=truncation,
scheme=scheme,
device_map=device_map,
)
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
model.qconfig = configs_mapping
Expand Down
7 changes: 7 additions & 0 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,9 @@ def __init__(
# v0.4
enable_norm_bias_tuning: bool = False,
enable_torch_compile: bool = None,
# v0.7
scheme: str | dict = "W4A16",
device_map: str = None,
# mllm
is_mllm: bool = False,
quant_nontext_module: bool = False,
Expand Down Expand Up @@ -1029,6 +1032,8 @@ def __init__(
image_processor (Processor): Image processor for special model like llava.
template (Template): The template to specify process for different mllms.
truncation (bool): Activates truncation to cut input sequences longer than `max_length` to `max_length`.
device_map: The device to be used for tuning.
scheme (str| dict | QuantizationScheme ): A preset scheme that defines the quantization configurations.
white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): White list of operator names or module types.
Default is DEFAULT_WHITE_LIST.
"""
Expand Down Expand Up @@ -1073,6 +1078,8 @@ def __init__(
self.image_processor = image_processor
self.template = template
self.truncation = truncation
self.scheme = scheme
self.device_map = device_map
self._post_init()

@classmethod
Expand Down
38 changes: 35 additions & 3 deletions test/3x/torch/quantization/weight_only/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch
import transformers
from packaging.version import Version
from packaging.version import Version, parse
import os
from functools import lru_cache

Expand Down Expand Up @@ -57,6 +57,13 @@ def set_hpu_torch_compile_envs():
auto_round_installed = True
except ImportError:
auto_round_installed = False

try:
import compressed_tensors

ct_installed = True
except ImportError:
ct_installed = False


@torch.no_grad()
Expand Down Expand Up @@ -247,7 +254,7 @@ def test_mllm(self):
seed=42,
nsamples=1,
gradient_accumulate_steps=1,
quant_nontext_module=False,
quant_nontext_module=True,
processor=processor,
)
quant_config = AutoRoundConfig(
Expand All @@ -258,7 +265,7 @@ def test_mllm(self):
batch_size=batch_size,
iters=1,
seqlen=seqlen,
quant_nontext_module=False,
quant_nontext_module=True,
truncation=truncation,
gradient_accumulate_steps=gradient_accumulate_steps,
)
Expand All @@ -283,6 +290,31 @@ def test_mllm(self):
# q_model.save(output_dir="saved_results_tiny-random-GPTJForCausalLM", format="huggingface")
# loaded_model = load("saved_results_tiny-random-GPTJForCausalLM", format="huggingface", trust_remote_code=True)

@pytest.mark.skipif(parse(auto_round.__version__) <= parse("0.7.0"),
reason="Export with llm_compressor format does not return a model.")
@pytest.mark.skipif(not ct_installed, reason="The compressed-tensors module is not installed.")
@pytest.mark.parametrize("scheme", ["MXFP4", "NVFP4"])
def test_scheme(self, scheme):
fp32_model = copy.deepcopy(self.gptj)
quant_config = AutoRoundConfig(
nsamples=32,
seqlen=10,
iters=10,
amp=False,
scale_dtype="fp16",
scheme=scheme,
export_format="llm_compressor",
)
logger.info(f"Test AutoRound with config {quant_config}")

# quantizer execute
model = prepare(model=fp32_model, quant_config=quant_config)
run_fn(model, self.dataloader)
q_model = convert(model)
out = q_model(self.inp)[0]
assert q_model is not None, "Quantization failed!"
assert q_model.transformer.h[0].attn.k_proj.bits is 4
assert torch.allclose(out, self.label, atol=1e-1)

@pytest.mark.skipif(not is_habana_framework_installed(), reason="Habana framework is not installed")
@pytest.mark.skipif(os.getenv("PT_HPU_LAZY_MODE", "0") == "1", reason="Lazy mode is enabled")
Expand Down
1 change: 1 addition & 0 deletions test/3x/torch/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
auto_round
compressed-tensors
datasets
deepspeed @ git+https://github.com/HabanaAI/[email protected]
expecttest
Expand Down
Loading