From f784e50078561b78a11671cc88be7c5d5fb949c3 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 11 Sep 2025 11:50:12 -0700 Subject: [PATCH 1/8] up --- .../dynamic_activation_lut/api.py | 94 ++--- .../int8_dynamic_activation_lut_tensor.py | 354 +++++++++--------- 2 files changed, 203 insertions(+), 245 deletions(-) diff --git a/torchao/prototype/quantization/dynamic_activation_lut/api.py b/torchao/prototype/quantization/dynamic_activation_lut/api.py index bccbc80a1c..c6b12fc4b9 100644 --- a/torchao/prototype/quantization/dynamic_activation_lut/api.py +++ b/torchao/prototype/quantization/dynamic_activation_lut/api.py @@ -4,80 +4,30 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Callable - import torch import torch.nn as nn -from torchao.core.config import AOBaseConfig -from torchao.prototype.parq.quant.quant_api import StretchedAffineQuantizedTensor from torchao.prototype.quantization.dynamic_activation_lut.int8_dynamic_activation_lut_tensor import ( - Int8DynamicActivationLutTensor, -) -from torchao.quantization.granularity import Granularity, PerAxis, PerGroup -from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS -from torchao.quantization.transform_module import register_quantize_module_handler - - -@dataclass -class StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig( - AOBaseConfig -): - bit_width: int - granularity: Granularity - - def get_filter_fn(self) -> Callable[[nn.Module, str], bool]: - return lambda m, fqn: isinstance(m, torch.nn.Linear) and isinstance( - m.weight, StretchedAffineQuantizedTensor - ) - - -@register_quantize_module_handler( - StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig + Int8LutTensor, ) -def _( - module: nn.Module, - config: StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig, -) -> nn.Module: - weight = module.weight - bias = module.bias - assert isinstance(weight, StretchedAffineQuantizedTensor) - - b = config.bit_width - granularity = config.granularity - if isinstance(granularity, PerGroup): - group_size = granularity.group_size - elif isinstance(granularity, PerAxis): - assert granularity.axis == 0, ( - f"axis must be 0 with PerAxis, but got {granularity.axis}" - ) - group_size = weight.shape[-1] - else: - raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") - - int_data, scale, zero_point = weight.tensor_impl.get_plain() - q_min, q_max = _DTYPE_TO_QVALUE_BOUNDS[getattr(torch, f"int{b}")] - - # Construct LUT as 2 * ([q_min, q_max] - 0.5) - assert torch.all(zero_point == -0.5) - lut = torch.arange(q_min, q_max + 1) - lut = 2 * lut + 1 - - # Construct idx values - qval_idx = int_data - q_min - - # Construct scale - scale = scale.reshape(-1).to(torch.float32) - scale = 0.5 * scale # since we multiply LUT values by 2 - - weight_tensor = Int8DynamicActivationLutTensor.from_plain( - qval_idx, - lut, - scale, - group_size, - bias.to(torch.float32) if bias is not None else None, - ) - module.weight = torch.nn.Parameter(weight_tensor, requires_grad=False) - module.bias = None - return module +from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor + + +def convert_model(model): + # Iterate through modules in model and convert IntxUnpackedToInt8Tensor tensors to Int8LutTensor + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + weight = module.weight + if isinstance(weight, IntxUnpackedToInt8Tensor): + try: + new_weight = Int8LutTensor.from_intx_unpacked_to_int8_tensor( + weight, bias=module.bias + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.bias = None + except Exception as e: + print( + f"Failed to convert {name} to Int8LutTensor. Skipping. The exception was: {e}" + ) + continue + return model diff --git a/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py b/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py index a15ea944fd..4274237907 100644 --- a/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py +++ b/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py @@ -3,230 +3,238 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Optional, Tuple import torch -from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_BIT_WIDTH, + _DTYPE_TO_QVALUE_BOUNDS, +) +from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + _is_kernel_library_loaded, +) +from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import ( + IntxUnpackedToInt8Tensor, + IntxUnpackedToInt8TensorActivationQuantization, +) from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten -class Int8DynamicActivationLutTensor(TorchAOBaseTensor): +class Int8LutTensor(TorchAOBaseTensor): """ - Tensor subclass that applies int8 dynamic activation quantization with lookup table quantization - - Args: - original_weight_tensor (torch.Tensor): The weight tensor to be wrapped. - scale (torch.Tensor): The scale tensor to be applied to activation. + Tensor subclass that does int8 dynamic activation quantization with lookup table quantization """ - packed_weight: torch.Tensor + tensor_data_names = ["packed_weights"] + tensor_attribute_names = [ + "bit_width", + "block_size", + "shape", + "dtype", + "packed_weights_has_bias", + ] + + packed_weights: torch.Tensor original_shape: Tuple[int, int] weight_scale_group_size: int bit_width: int def __new__( cls, - packed_weight: torch.Tensor, - original_shape: Tuple[int, int], - weight_scale_group_size: int, - bit_width: int, + packed_weights, + bit_width, + block_size, + shape, + dtype, + packed_weights_has_bias, ): kwargs = {} - kwargs["dtype"] = torch.float32 + kwargs["device"] = packed_weights.device + kwargs["dtype"] = dtype kwargs["requires_grad"] = False - kwargs["device"] = packed_weight.device - return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) # type: ignore[attr-defined] + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, - packed_weight: torch.Tensor, - original_shape: Tuple[int, int], - weight_scale_group_size, - bit_width: int, + packed_weights, + bit_width, + block_size, + shape, + dtype, + packed_weights_has_bias, ): - self.packed_weight = packed_weight - self.original_shape = original_shape - self.weight_scale_group_size = weight_scale_group_size + super().__init__() + assert packed_weights.device == torch.device("cpu") + self.packed_weights = packed_weights self.bit_width = bit_width + self.block_size = block_size + self.shape = shape + self.dtype = dtype + self.packed_weights_has_bias = packed_weights_has_bias + + def _quantization_type(self): + return f"bit_width={self.bit_width}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}" + + def to(self, *args, **kwargs): + raise NotImplementedError("to() is not implemented for IntxOpaqueTensor") + + @classmethod + def _get_lut_params(cls, tensor: IntxUnpackedToInt8Tensor): + assert isinstance(tensor, IntxUnpackedToInt8Tensor) + assert tensor.target_dtype in [torch.int1, torch.int2, torch.int3, torch.int4] + + qdata = tensor.qdata + scale = tensor.scale + zero_point = tensor.zero_point + + if tensor._has_float_zero_point: + # Stretched tensors from PARQ should have -0.5 has zero_point + assert torch.all(zero_point == -0.5) + is_stretched_tensor = True + else: + assert torch.all(zero_point == 0) + is_stretched_tensor = False + + quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[tensor.target_dtype] + lut_indices = qdata - quant_min + lut = torch.arange(quant_min, quant_max + 1) + + # Construct LUT as 2 * ([q_min, q_max] - 0.5) + if is_stretched_tensor: + lut = 2 * lut + 1 + scale = 0.5 * scale + + # Scale must be float32 + 1D + scale = scale.reshape(-1).to(torch.float32) + + return lut, lut_indices, scale @classmethod - def from_plain( + def from_intx_unpacked_to_int8_tensor( cls, - weight_indices: torch.Tensor, - weight_luts: torch.Tensor, - weight_scale: torch.Tensor, - weight_scale_group_size: int, - bias, + tensor: IntxUnpackedToInt8Tensor, + *, + bias: Optional[torch.Tensor] = None, ): - if len(weight_luts.shape) == 1: - weight_luts = weight_luts.unsqueeze(0) - assert len(weight_luts.shape) == 2, ( - "Expected weight_luts to be 2D tensor. Each row in the tensor is an LUT" - ) - bit_width = {2**b: b for b in range(1, 5)}[weight_luts.shape[1]] - - int8_min, int8_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8] - assert torch.all(weight_luts >= int8_min) - assert torch.all(weight_luts <= int8_max) - weight_luts = weight_luts.to(torch.int8) - - n, k = weight_indices.shape - # assert n % 8 == 0, f"Expected n to be divisible by 8, but got n={n}" - assert k % 16 == 0, f"Expected k to be divisible by 16, but got k={k}" - assert torch.all(weight_indices >= 0) - assert torch.all(weight_indices < 2**bit_width) - - weight_scale = weight_scale.reshape(-1) - assert k % weight_scale_group_size == 0, ( - f"Expected k to be divisible by weight_scale_group_size, but got k={k} and weight_scale_group_size={weight_scale_group_size}" + """ + Constructs a Int8LutTensor from an IntxUnpackedToInt8Tensor. + If bias is passed, bias is packed into the tensor. + """ + + assert _is_kernel_library_loaded(), "TorchAO kernel library is not loaded" + assert ( + tensor.activation_quantization + == IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN ) - assert weight_scale.shape == (n * (k // weight_scale_group_size),) - if bias is not None: + assert len(tensor.block_size) == 2 + assert tensor.block_size[0] == 1 + scale_group_size = tensor.block_size[1] + + packed_weights_has_bias = bias is not None + if packed_weights_has_bias: + n, k = tensor.shape assert bias.shape == (n,) + bias = bias.to(torch.float32) - packed_weight = getattr( - torch.ops.torchao, f"_pack_8bit_act_{bit_width}bit_weight_with_lut" + lut, lut_indices, scale = cls._get_lut_params(tensor) + packed_weights = getattr( + torch.ops.torchao, f"_pack_8bit_act_{tensor.bit_width}bit_weight_with_lut" )( - weight_indices, - weight_luts, - weight_scale, - weight_scale_group_size, + lut_indices, + lut, + scale, + scale_group_size, bias, None, ) - return cls(packed_weight, (n, k), weight_scale_group_size, bit_width) - - def __repr__(self): - return "Int8DynamicActivationLutTensor" - def __tensor_flatten__(self): - return ["packed_weight"], [ - self.original_shape, - self.weight_scale_group_size, - self.bit_width, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight = tensor_data_dict["packed_weight"] - original_shape, weight_scale_group_size, bitwidth = tensor_attributes - return cls(packed_weight, original_shape, weight_scale_group_size, bitwidth) - - @staticmethod - def _quantized_linear_op( - input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor - ): - def _impl_2d( - input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor - ): - original_dtype = torch.float32 - if input_tensor.dtype != torch.float32: - original_dtype = input_tensor.dtype - input_tensor = input_tensor.to(torch.float32) - - assert input_tensor.dim() == 2 - m, k = input_tensor.shape - n, k_ = weight_tensor.original_shape - assert k == k_, ( - f"Incompatible input shape. Expected second dimension to be equal to {k_}, but got {k}" - ) - assert bias is None, ( - "Expected bias to be None because it should be packed with the weight tensor" - ) - out = getattr( - torch.ops.torchao, - f"_linear_8bit_act_{weight_tensor.bit_width}bit_weight", - )( - input_tensor, - weight_tensor.packed_weight, - weight_tensor.weight_scale_group_size, - n, - k, - ) - - if original_dtype != torch.float32: - out = out.to(original_dtype) - return out - - assert input_tensor.dim() >= 2 - if input_tensor.dim() == 2: - res = _impl_2d(input_tensor, weight_tensor, bias) - else: - assert input_tensor.dim() >= 3 - lead_shape = input_tensor.shape[0:-2] - m, k = input_tensor.shape[-2], input_tensor.shape[-1] - res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor, bias) - res = res.reshape(*lead_shape, m, -1) - - return res - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.packed_weight), - self.original_shape, - self.weight_scale_group_size, - self.bit_width, + bit_width = _DTYPE_TO_BIT_WIDTH[tensor.target_dtype] + return cls( + packed_weights, + bit_width, + tensor.block_size, + tensor.shape, + tensor.dtype, + packed_weights_has_bias, ) - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self.packed_weight.to(device), - self.original_shape, - self.weight_scale_group_size, - self.bit_width, - ) +implements = Int8LutTensor.implements -implements = Int8DynamicActivationLutTensor.implements +def _linear_impl_2d( + input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor +): + assert isinstance(weight_tensor, Int8LutTensor) + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + assert weight_tensor.block_size[0] == 1 + group_size = weight_tensor.block_size[1] -@implements(torch.nn.functional.linear) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - if isinstance(weight_tensor, Int8DynamicActivationLutTensor): - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k - raise NotImplementedError( - "Int8DynamicActivationLutTensor: No specialized dispatch found for linear op" - ) + packed_weights = weight_tensor.packed_weights + bit_width = weight_tensor.bit_width + if weight_tensor.dtype != torch.float32: + input_tensor = input_tensor.to(torch.float32) -@implements(aten.detach.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + res = getattr( + torch.ops.torchao, + f"_linear_8bit_act_{bit_width}bit_weight", + )( + input_tensor, + packed_weights, + group_size, + n, + k, ) + if weight_tensor.dtype != torch.float32: + res = res.to(weight_tensor.dtype) + return res -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - -@implements(aten._to_copy.default) +@implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, ) - -# Allow a model with Int8DynamicActivationLutTensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([Int8DynamicActivationLutTensor]) + # TODO: why was this added https://github.com/pytorch/ao/pull/2043 + if input_tensor.numel() == 0: + return input_tensor + + if input_tensor.dim() == 1: + k = input_tensor.shape[0] + input_tensor = input_tensor.reshape(1, k) + res = _linear_impl_2d(input_tensor, weight_tensor) + res = res.reshape(-1) + elif input_tensor.dim() == 2: + res = _linear_impl_2d(input_tensor, weight_tensor) + else: + assert input_tensor.dim() >= 3 + lead_shape = input_tensor.shape[0:-2] + m, k = input_tensor.shape[-2], input_tensor.shape[-1] + n, k_ = weight_tensor.shape + assert k_ == k + res = _linear_impl_2d(input_tensor.reshape(-1, k), weight_tensor) + res = res.reshape(*lead_shape, m, n) + + if bias is not None: + assert not weight_tensor.packed_weights_has_bias + res = res + bias + + return res + + +# Allow a model with Int8LutTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int8LutTensor]) From 05307a3aa62265efd88dfdf232134a5bad6a62cf Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 11 Sep 2025 12:38:30 -0700 Subject: [PATCH 2/8] up --- .github/workflows/regression_test_aarch64.yml | 2 +- ...ivation_lut.py => test_int8_lut_tensor.py} | 40 ++++--------------- .../dynamic_activation_lut/__init__.py | 7 ---- .../quantization/int8_lut_tensor/__init__.py | 7 ++++ .../api.py | 2 +- .../int8_lut_tensor.py} | 31 +++++++------- 6 files changed, 32 insertions(+), 57 deletions(-) rename test/prototype/{test_dynamic_activation_lut.py => test_int8_lut_tensor.py} (71%) delete mode 100644 torchao/prototype/quantization/dynamic_activation_lut/__init__.py create mode 100644 torchao/prototype/quantization/int8_lut_tensor/__init__.py rename torchao/prototype/quantization/{dynamic_activation_lut => int8_lut_tensor}/api.py (92%) rename torchao/prototype/quantization/{dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py => int8_lut_tensor/int8_lut_tensor.py} (90%) diff --git a/.github/workflows/regression_test_aarch64.yml b/.github/workflows/regression_test_aarch64.yml index a3ba86dd8b..10948fa61d 100644 --- a/.github/workflows/regression_test_aarch64.yml +++ b/.github/workflows/regression_test_aarch64.yml @@ -54,7 +54,7 @@ jobs: pytest -s test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py pytest -s test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py pytest -s test/prototype/test_embedding.py - pytest -s test/prototype/test_dynamic_activation_lut.py + pytest -s test/prototype/test_int8_lut_tensor.py pytest -s test/prototype/test_groupwise_lowbit_weight_lut_quantizer.py pytest -s test/prototype/test_parq.py - name: torchao/csrc/cpu - build and run C++ tests diff --git a/test/prototype/test_dynamic_activation_lut.py b/test/prototype/test_int8_lut_tensor.py similarity index 71% rename from test/prototype/test_dynamic_activation_lut.py rename to test/prototype/test_int8_lut_tensor.py index 497de519b5..40dfe4cdef 100644 --- a/test/prototype/test_dynamic_activation_lut.py +++ b/test/prototype/test_int8_lut_tensor.py @@ -15,12 +15,9 @@ StretchedIntxWeightConfig, StretchedUnifTorchaoQuantizer, ) -from torchao.prototype.quantization.dynamic_activation_lut import ( - StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig, -) +from torchao.prototype.quantization.int8_lut_tensor import convert_model from torchao.quantization import quantize_ from torchao.quantization.granularity import PerAxis, PerGroup -from torchao.quantization.quant_api import _is_linear from torchao.quantization.utils import compute_error is_arm64_mac = sys.platform == "darwin" and platform.machine() == "arm64" @@ -68,38 +65,22 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim): quant_min=quantizer.quant_min, quant_max=quantizer.quant_max, granularity=granularity, - activation_quantization=None, - version=1, + activation_quantization="int8_asym_per_token", ) parq_model = ToyLinearModel(128, 256, 128, 1).to(dtype) activations = parq_model.example_inputs(lead_dim=lead_dim, dtype=dtype) - parq_model_with_dyn_quant = deepcopy(parq_model) quantize_(parq_model, config) - # Apply dynamic activation to parq model. This will serve as the LUT reference - dyn_act_config = deepcopy(config) - dyn_act_config.activation_quantization = "int8_asym_per_token" - quantize_(parq_model_with_dyn_quant, dyn_act_config, filter_fn=_is_linear) - # Convert PARQ model to lowbit LUT model lut_model = deepcopy(parq_model) - conversion_config = ( - StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig( - config.b, config.granularity - ) - ) - quantize_(lut_model, conversion_config, filter_fn=conversion_config.get_filter_fn()) + convert_model(lut_model) # Run both models and compare parq_out = parq_model(activations) - parq_with_dyn_quant_out = parq_model_with_dyn_quant(activations) lut_out = lut_model(activations) - sqnr = compute_error(parq_out, parq_with_dyn_quant_out).item() - assert sqnr > 20.0, f"sqnr {sqnr} is too low" - - sqnr = compute_error(lut_out, parq_with_dyn_quant_out).item() + sqnr = compute_error(parq_out, lut_out).item() if dtype == torch.float32: assert sqnr > 40.0, f"sqnr {sqnr} is too low" elif dtype == torch.bfloat16: @@ -120,24 +101,17 @@ def test_export(dtype, granularity, bit_width, lead_dim): quant_min=quantizer.quant_min, quant_max=quantizer.quant_max, granularity=granularity, - activation_quantization=None, - version=1, + activation_quantization="int8_asym_per_token", ) parq_model = ToyLinearModel(128, 256, 128, 8).to(dtype) activations = parq_model.example_inputs(lead_dim=lead_dim) quantize_(parq_model, config) - conversion_config = ( - StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig( - config.b, config.granularity - ) - ) - quantize_( - parq_model, conversion_config, filter_fn=conversion_config.get_filter_fn() - ) + convert_model(parq_model) ep = torch.export.export(parq_model, (activations,)) + assert ( f"torch.ops.torchao._linear_8bit_act_{bit_width}bit_weight.default" in ep.graph_module.code diff --git a/torchao/prototype/quantization/dynamic_activation_lut/__init__.py b/torchao/prototype/quantization/dynamic_activation_lut/__init__.py deleted file mode 100644 index 688cb2e836..0000000000 --- a/torchao/prototype/quantization/dynamic_activation_lut/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .api import StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig -from .int8_dynamic_activation_lut_tensor import Int8DynamicActivationLutTensor - -__all__ = [ - "StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig", - "Int8DynamicActivationLutTensor", -] diff --git a/torchao/prototype/quantization/int8_lut_tensor/__init__.py b/torchao/prototype/quantization/int8_lut_tensor/__init__.py new file mode 100644 index 0000000000..aad3693bbd --- /dev/null +++ b/torchao/prototype/quantization/int8_lut_tensor/__init__.py @@ -0,0 +1,7 @@ +from .api import convert_model +from .int8_lut_tensor import Int8LutTensor + +__all__ = [ + "Int8LutTensor", + "convert_model", +] diff --git a/torchao/prototype/quantization/dynamic_activation_lut/api.py b/torchao/prototype/quantization/int8_lut_tensor/api.py similarity index 92% rename from torchao/prototype/quantization/dynamic_activation_lut/api.py rename to torchao/prototype/quantization/int8_lut_tensor/api.py index c6b12fc4b9..d49e04531d 100644 --- a/torchao/prototype/quantization/dynamic_activation_lut/api.py +++ b/torchao/prototype/quantization/int8_lut_tensor/api.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn -from torchao.prototype.quantization.dynamic_activation_lut.int8_dynamic_activation_lut_tensor import ( +from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import ( Int8LutTensor, ) from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor diff --git a/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py b/torchao/prototype/quantization/int8_lut_tensor/int8_lut_tensor.py similarity index 90% rename from torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py rename to torchao/prototype/quantization/int8_lut_tensor/int8_lut_tensor.py index 4274237907..2a01098f0f 100644 --- a/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py +++ b/torchao/prototype/quantization/int8_lut_tensor/int8_lut_tensor.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Tuple +from typing import Optional import torch @@ -37,11 +37,6 @@ class Int8LutTensor(TorchAOBaseTensor): "packed_weights_has_bias", ] - packed_weights: torch.Tensor - original_shape: Tuple[int, int] - weight_scale_group_size: int - bit_width: int - def __new__( cls, packed_weights, @@ -71,8 +66,6 @@ def __init__( self.packed_weights = packed_weights self.bit_width = bit_width self.block_size = block_size - self.shape = shape - self.dtype = dtype self.packed_weights_has_bias = packed_weights_has_bias def _quantization_type(self): @@ -107,7 +100,10 @@ def _get_lut_params(cls, tensor: IntxUnpackedToInt8Tensor): lut = 2 * lut + 1 scale = 0.5 * scale - # Scale must be float32 + 1D + # LUT must be 2D and int8 + lut = lut.reshape(1, -1).to(torch.int8) + + # Scale must be 1D and float32 scale = scale.reshape(-1).to(torch.float32) return lut, lut_indices, scale @@ -128,6 +124,8 @@ def from_intx_unpacked_to_int8_tensor( assert ( tensor.activation_quantization == IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN + ), ( + "IntxUnpackedToInt8Tensor must have INT8_ASYM_PER_TOKEN activation quantization" ) assert len(tensor.block_size) == 2 @@ -141,8 +139,9 @@ def from_intx_unpacked_to_int8_tensor( bias = bias.to(torch.float32) lut, lut_indices, scale = cls._get_lut_params(tensor) + bit_width = _DTYPE_TO_BIT_WIDTH[tensor.target_dtype] packed_weights = getattr( - torch.ops.torchao, f"_pack_8bit_act_{tensor.bit_width}bit_weight_with_lut" + torch.ops.torchao, f"_pack_8bit_act_{bit_width}bit_weight_with_lut" )( lut_indices, lut, @@ -152,12 +151,14 @@ def from_intx_unpacked_to_int8_tensor( None, ) + block_size = [b for b in tensor.block_size] + shape = tensor.shape bit_width = _DTYPE_TO_BIT_WIDTH[tensor.target_dtype] return cls( packed_weights, bit_width, - tensor.block_size, - tensor.shape, + block_size, + shape, tensor.dtype, packed_weights_has_bias, ) @@ -216,17 +217,17 @@ def _(func, types, args, kwargs): if input_tensor.dim() == 1: k = input_tensor.shape[0] input_tensor = input_tensor.reshape(1, k) - res = _linear_impl_2d(input_tensor, weight_tensor) + res = _linear_impl_2d(input_tensor, weight_tensor, bias) res = res.reshape(-1) elif input_tensor.dim() == 2: - res = _linear_impl_2d(input_tensor, weight_tensor) + res = _linear_impl_2d(input_tensor, weight_tensor, bias) else: assert input_tensor.dim() >= 3 lead_shape = input_tensor.shape[0:-2] m, k = input_tensor.shape[-2], input_tensor.shape[-1] n, k_ = weight_tensor.shape assert k_ == k - res = _linear_impl_2d(input_tensor.reshape(-1, k), weight_tensor) + res = _linear_impl_2d(input_tensor.reshape(-1, k), weight_tensor, bias) res = res.reshape(*lead_shape, m, n) if bias is not None: From c3b529cdab8721135a99938e391b6fc4d156d8c2 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:08:33 -0700 Subject: [PATCH 3/8] up --- .../quantization/int8_lut_tensor/api.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/torchao/prototype/quantization/int8_lut_tensor/api.py b/torchao/prototype/quantization/int8_lut_tensor/api.py index d49e04531d..c700b9bd68 100644 --- a/torchao/prototype/quantization/int8_lut_tensor/api.py +++ b/torchao/prototype/quantization/int8_lut_tensor/api.py @@ -18,16 +18,20 @@ def convert_model(model): for name, module in model.named_modules(): if isinstance(module, nn.Linear): weight = module.weight - if isinstance(weight, IntxUnpackedToInt8Tensor): - try: - new_weight = Int8LutTensor.from_intx_unpacked_to_int8_tensor( - weight, bias=module.bias - ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.bias = None - except Exception as e: - print( - f"Failed to convert {name} to Int8LutTensor. Skipping. The exception was: {e}" - ) - continue + if not isinstance(weight, IntxUnpackedToInt8Tensor): + print( + f"Skipping converting {name} to Int8LutTensor because its weight is not an IntxUnpackedToInt8Tensor" + ) + continue + try: + new_weight = Int8LutTensor.from_intx_unpacked_to_int8_tensor( + weight, bias=module.bias + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.bias = None + except Exception as e: + print( + f"Skipping converting {name} to Int8LutTensor because an error occurred: {e}" + ) + continue return model From 01032ed4957fed855172090e4497ba00cba1a256 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:14:58 -0700 Subject: [PATCH 4/8] up --- test/prototype/test_int8_lut_tensor.py | 21 ++++---- torchao/prototype/conversion/__init__.py | 0 torchao/prototype/conversion/api.py | 48 +++++++++++++++++++ .../prototype/parq/quant/config_torchao.py | 7 +++ .../quantization/int8_lut_tensor/__init__.py | 2 - .../quantization/int8_lut_tensor/api.py | 37 -------------- .../int8_lut_tensor/int8_lut_tensor.py | 2 +- 7 files changed, 68 insertions(+), 49 deletions(-) create mode 100644 torchao/prototype/conversion/__init__.py create mode 100644 torchao/prototype/conversion/api.py delete mode 100644 torchao/prototype/quantization/int8_lut_tensor/api.py diff --git a/test/prototype/test_int8_lut_tensor.py b/test/prototype/test_int8_lut_tensor.py index 40dfe4cdef..ea467cd0ab 100644 --- a/test/prototype/test_int8_lut_tensor.py +++ b/test/prototype/test_int8_lut_tensor.py @@ -4,24 +4,23 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import platform -import sys from copy import deepcopy import pytest import torch +from torchao.prototype.conversion.api import _convert_to_optimized_model_for_aarch64 from torchao.prototype.parq.quant import ( StretchedIntxWeightConfig, StretchedUnifTorchaoQuantizer, ) -from torchao.prototype.quantization.int8_lut_tensor import convert_model +from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import ( + _is_kernel_library_loaded, +) from torchao.quantization import quantize_ from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.utils import compute_error -is_arm64_mac = sys.platform == "darwin" and platform.machine() == "arm64" - class ToyLinearModel(torch.nn.Module): def __init__(self, d1=512, d2=256, d3=128, d4=8): @@ -56,7 +55,9 @@ def run_before_and_after_tests(): @pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)]) @pytest.mark.parametrize("bit_width", [1, 2, 3, 4]) @pytest.mark.parametrize("lead_dim", [(5,), (2, 3)]) -@pytest.mark.skipif(not is_arm64_mac, reason="requires arm64 mac") +@pytest.mark.skipif( + not _is_kernel_library_loaded(), reason="Kernel library is not loaded" +) def test_parq_conversion(dtype, granularity, bit_width, lead_dim): torch.manual_seed(0) quantizer = StretchedUnifTorchaoQuantizer(bit_width) @@ -74,7 +75,7 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim): # Convert PARQ model to lowbit LUT model lut_model = deepcopy(parq_model) - convert_model(lut_model) + _convert_to_optimized_model_for_aarch64(lut_model, tensor_type="int8_lut_tensor") # Run both models and compare parq_out = parq_model(activations) @@ -93,7 +94,9 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim): @pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)]) @pytest.mark.parametrize("bit_width", [1, 2, 3, 4]) @pytest.mark.parametrize("lead_dim", [(5,), (2, 3)]) -@pytest.mark.skipif(not is_arm64_mac, reason="requires arm64 mac") +@pytest.mark.skipif( + not _is_kernel_library_loaded(), reason="Kernel library is not loaded" +) def test_export(dtype, granularity, bit_width, lead_dim): quantizer = StretchedUnifTorchaoQuantizer(bit_width) config = StretchedIntxWeightConfig( @@ -108,7 +111,7 @@ def test_export(dtype, granularity, bit_width, lead_dim): activations = parq_model.example_inputs(lead_dim=lead_dim) quantize_(parq_model, config) - convert_model(parq_model) + _convert_to_optimized_model_for_aarch64(parq_model) ep = torch.export.export(parq_model, (activations,)) diff --git a/torchao/prototype/conversion/__init__.py b/torchao/prototype/conversion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/conversion/api.py b/torchao/prototype/conversion/api.py new file mode 100644 index 0000000000..9989fe3f58 --- /dev/null +++ b/torchao/prototype/conversion/api.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +def _convert_module_to_int8_lut_tensor(module): + from torchao.prototype.quantization.int8_lut_tensor import Int8LutTensor + + assert isinstance(module, nn.Linear) + weight = module.weight + new_weight = Int8LutTensor.from_intx_unpacked_to_int8_tensor( + weight, bias=module.bias + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.bias = None + + +def _convert_to_optimized_model_for_aarch64( + model, + *, + tensor_type="int8_lut_tensor", +): + from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor + + # Iterate through modules in model and convert IntxUnpackedToInt8Tensor tensors to Int8LutTensor + for name, module in model.named_modules(): + if not isinstance(module, nn.Linear): + print(f"Skipping converting {name} because it is not a linear layer") + continue + + weight = module.weight + if not isinstance(weight, IntxUnpackedToInt8Tensor): + print( + f"Skipping converting {name} to IntxOpaqueTensor because its weight is not an IntxUnpackedToInt8Tensor" + ) + continue + + if tensor_type == "int8_lut_tensor": + _convert_module_to_int8_lut_tensor(module) + else: + raise ValueError(f"Unexpected tensor_type={tensor_type}") + + return model diff --git a/torchao/prototype/parq/quant/config_torchao.py b/torchao/prototype/parq/quant/config_torchao.py index b2eb70b2d4..b546ecb328 100644 --- a/torchao/prototype/parq/quant/config_torchao.py +++ b/torchao/prototype/parq/quant/config_torchao.py @@ -1,3 +1,4 @@ +import types from dataclasses import dataclass from typing import Callable, Optional @@ -17,6 +18,7 @@ IntxWeightOnlyConfig, ModuleFqnToConfig, _int8_asymm_per_token_quant, + _linear_extra_repr, ) from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor from torchao.quantization.transform_module import register_quantize_module_handler @@ -117,7 +119,12 @@ def _int8_dynamic_activation_stretched_intx_transform( weight = to_linear_activation_quantized(weight, _int8_asymm_per_token_quant) elif config.activation_quantization is not None: raise ValueError(f"Unsupported {config.activation_quantization=}") + module.weight = nn.Parameter(weight, requires_grad=False) + + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module diff --git a/torchao/prototype/quantization/int8_lut_tensor/__init__.py b/torchao/prototype/quantization/int8_lut_tensor/__init__.py index aad3693bbd..dd53868182 100644 --- a/torchao/prototype/quantization/int8_lut_tensor/__init__.py +++ b/torchao/prototype/quantization/int8_lut_tensor/__init__.py @@ -1,7 +1,5 @@ -from .api import convert_model from .int8_lut_tensor import Int8LutTensor __all__ = [ "Int8LutTensor", - "convert_model", ] diff --git a/torchao/prototype/quantization/int8_lut_tensor/api.py b/torchao/prototype/quantization/int8_lut_tensor/api.py deleted file mode 100644 index c700b9bd68..0000000000 --- a/torchao/prototype/quantization/int8_lut_tensor/api.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn - -from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import ( - Int8LutTensor, -) -from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor - - -def convert_model(model): - # Iterate through modules in model and convert IntxUnpackedToInt8Tensor tensors to Int8LutTensor - for name, module in model.named_modules(): - if isinstance(module, nn.Linear): - weight = module.weight - if not isinstance(weight, IntxUnpackedToInt8Tensor): - print( - f"Skipping converting {name} to Int8LutTensor because its weight is not an IntxUnpackedToInt8Tensor" - ) - continue - try: - new_weight = Int8LutTensor.from_intx_unpacked_to_int8_tensor( - weight, bias=module.bias - ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.bias = None - except Exception as e: - print( - f"Skipping converting {name} to Int8LutTensor because an error occurred: {e}" - ) - continue - return model diff --git a/torchao/prototype/quantization/int8_lut_tensor/int8_lut_tensor.py b/torchao/prototype/quantization/int8_lut_tensor/int8_lut_tensor.py index 2a01098f0f..a4feee13aa 100644 --- a/torchao/prototype/quantization/int8_lut_tensor/int8_lut_tensor.py +++ b/torchao/prototype/quantization/int8_lut_tensor/int8_lut_tensor.py @@ -83,7 +83,7 @@ def _get_lut_params(cls, tensor: IntxUnpackedToInt8Tensor): scale = tensor.scale zero_point = tensor.zero_point - if tensor._has_float_zero_point: + if tensor._has_float_zero_point(): # Stretched tensors from PARQ should have -0.5 has zero_point assert torch.all(zero_point == -0.5) is_stretched_tensor = True From 1adb547b69bdf84e98c4065140a4afbe27613005 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 12 Sep 2025 14:45:11 -0700 Subject: [PATCH 5/8] up --- test/prototype/test_int8_lut_tensor.py | 6 +++--- .../prototype/{conversion => tensor_conversion}/__init__.py | 0 torchao/prototype/{conversion => tensor_conversion}/api.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) rename torchao/prototype/{conversion => tensor_conversion}/__init__.py (100%) rename torchao/prototype/{conversion => tensor_conversion}/api.py (90%) diff --git a/test/prototype/test_int8_lut_tensor.py b/test/prototype/test_int8_lut_tensor.py index ea467cd0ab..b5d1a6b0a1 100644 --- a/test/prototype/test_int8_lut_tensor.py +++ b/test/prototype/test_int8_lut_tensor.py @@ -9,7 +9,6 @@ import pytest import torch -from torchao.prototype.conversion.api import _convert_to_optimized_model_for_aarch64 from torchao.prototype.parq.quant import ( StretchedIntxWeightConfig, StretchedUnifTorchaoQuantizer, @@ -17,6 +16,7 @@ from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import ( _is_kernel_library_loaded, ) +from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64 from torchao.quantization import quantize_ from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.utils import compute_error @@ -75,7 +75,7 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim): # Convert PARQ model to lowbit LUT model lut_model = deepcopy(parq_model) - _convert_to_optimized_model_for_aarch64(lut_model, tensor_type="int8_lut_tensor") + _convert_model_for_aarch64(lut_model, tensor_type="int8_lut_tensor") # Run both models and compare parq_out = parq_model(activations) @@ -111,7 +111,7 @@ def test_export(dtype, granularity, bit_width, lead_dim): activations = parq_model.example_inputs(lead_dim=lead_dim) quantize_(parq_model, config) - _convert_to_optimized_model_for_aarch64(parq_model) + _convert_model_for_aarch64(parq_model) ep = torch.export.export(parq_model, (activations,)) diff --git a/torchao/prototype/conversion/__init__.py b/torchao/prototype/tensor_conversion/__init__.py similarity index 100% rename from torchao/prototype/conversion/__init__.py rename to torchao/prototype/tensor_conversion/__init__.py diff --git a/torchao/prototype/conversion/api.py b/torchao/prototype/tensor_conversion/api.py similarity index 90% rename from torchao/prototype/conversion/api.py rename to torchao/prototype/tensor_conversion/api.py index 9989fe3f58..7722c57e34 100644 --- a/torchao/prototype/conversion/api.py +++ b/torchao/prototype/tensor_conversion/api.py @@ -8,7 +8,7 @@ import torch.nn as nn -def _convert_module_to_int8_lut_tensor(module): +def _convert_linear_weight_to_int8_lut_tensor(module): from torchao.prototype.quantization.int8_lut_tensor import Int8LutTensor assert isinstance(module, nn.Linear) @@ -20,7 +20,7 @@ def _convert_module_to_int8_lut_tensor(module): module.bias = None -def _convert_to_optimized_model_for_aarch64( +def _convert_model_for_aarch64( model, *, tensor_type="int8_lut_tensor", @@ -41,7 +41,7 @@ def _convert_to_optimized_model_for_aarch64( continue if tensor_type == "int8_lut_tensor": - _convert_module_to_int8_lut_tensor(module) + _convert_linear_weight_to_int8_lut_tensor(module) else: raise ValueError(f"Unexpected tensor_type={tensor_type}") From cd551a602607ab7c4561206ffc7a015f83ebc5ba Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:05:26 -0700 Subject: [PATCH 6/8] init --- test/prototype/test_conversion.py | 180 ++++++++++++++++++ torchao/prototype/tensor_conversion/api.py | 120 +++++++++++- .../workflows/intx/intx_opaque_tensor.py | 29 +++ 3 files changed, 324 insertions(+), 5 deletions(-) create mode 100644 test/prototype/test_conversion.py diff --git a/test/prototype/test_conversion.py b/test/prototype/test_conversion.py new file mode 100644 index 0000000000..2cee9a08ef --- /dev/null +++ b/test/prototype/test_conversion.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import pytest +import torch + +from torchao.prototype.parq.quant import ( + StretchedIntxWeightConfig, + StretchedUnifTorchaoQuantizer, +) +from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import Int8LutTensor +from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64 +from torchao.quantization import MappingType +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + quantize_, +) +from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + IntxOpaqueTensor, + _is_kernel_library_loaded, +) +from torchao.quantization.utils import compute_error + + +class ToyLinearModelWithTiedEmbedding(torch.nn.Module): + def __init__(self, d0=512, d1=512, d2=256, d3=128, d4=32): + super().__init__() + self.embedding1 = torch.nn.Embedding(d0, d1) + self.embedding2 = torch.nn.Embedding(d0, d1) + self.embedding3 = torch.nn.Embedding(d0, d1) + + self.linear1 = torch.nn.Linear(d1, d2, bias=False) + self.linear2 = torch.nn.Linear(d2, d3, bias=True) + self.linear3 = torch.nn.Linear(d3, d4, bias=False) + self.linear4 = torch.nn.Linear(d4, d1, bias=False) + + self.lm_head1 = torch.nn.Linear(d1, d0, bias=False) + self.lm_head2 = torch.nn.Linear(d1, d0, bias=False) + self.lm_head3 = torch.nn.Linear(d1, d0, bias=False) + + # Tie weights + # lm_head1 / lm_head2 form one tied weight group + self.embedding2.weight = self.embedding1.weight + self.lm_head1.weight = self.embedding1.weight + self.lm_head2.weight = self.embedding1.weight + + # lm_head3 forms a separate tied weight group + self.lm_head3.weight = self.embedding3.weight + + def example_inputs( + self, + lead_dim=(1,), + dtype=torch.bfloat16, + ): + return ( + torch.randint( + 0, + self.embedding1.num_embeddings, + size=lead_dim, + dtype=torch.int64, + device="cpu", + ), + ) + + def forward(self, x): + x = self.embedding1(x) + self.embedding2(x) + self.embedding3(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.lm_head1(x) + self.lm_head2(x) + self.lm_head3(x) + return x + + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + yield + torch._dynamo.reset() # reset cache between tests + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)]) +@pytest.mark.parametrize("bit_width", [1, 2, 3, 4]) +@pytest.mark.parametrize( + "lead_dim", + [ + (1,), + (5,), + (7, 2), + ], +) +@pytest.mark.skipif( + not _is_kernel_library_loaded(), reason="Kernel library is not loaded" +) +def test_aarch64_conversion(dtype, granularity, bit_width, lead_dim): + torch.manual_seed(0) + + model = ToyLinearModelWithTiedEmbedding() + model = model.to(dtype) + example_inputs = model.example_inputs(lead_dim, dtype) + + # Quantize linear 2 and 3 with PARQ + quantizer = StretchedUnifTorchaoQuantizer(bit_width) + config = StretchedIntxWeightConfig( + b=bit_width, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=granularity, + activation_quantization="int8_asym_per_token", + ) + quantize_(model, config, filter_fn=lambda m, fqn: fqn in ["linear2", "linear3"]) + + # Quantize linear 1 and 4 with int8 dynamic activation + config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=granularity, + weight_mapping_type=MappingType.SYMMETRIC, + ) + quantize_( + model, + config, + filter_fn=lambda m, fqn: fqn + in ["linear1", "linear4", "lm_head1", "lm_head2", "lm_head3"], + ) + + # Quantize embedding 1, 2, and 3 with weight only + config = IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=granularity, + mapping_type=MappingType.SYMMETRIC, + ) + quantize_( + model, + config, + filter_fn=lambda m, fqn: fqn in ["embedding1", "embedding2", "embedding3"], + ) + model_out = model(*example_inputs) + + # Convert to optimized model + _convert_model_for_aarch64(model) + + # Check expected tensor subclass + assert isinstance(model.linear2.weight, Int8LutTensor) + assert isinstance(model.linear3.weight, Int8LutTensor) + assert isinstance(model.linear1.weight, IntxOpaqueTensor) + assert isinstance(model.linear4.weight, IntxOpaqueTensor) + + # Assert tied params + tied_group1_id = id(model.embedding1.weight) + assert id(model.embedding2.weight) == tied_group1_id + assert id(model.lm_head1.weight) == tied_group1_id + assert id(model.lm_head2.weight) == tied_group1_id + + assert id(model.lm_head3.weight) == id(model.embedding3.weight) + assert id(model.lm_head3.weight) != tied_group1_id + + # Compare converted out with original out + converted_out = model(*example_inputs) + sqnr = compute_error(model_out, converted_out) + sqnr_threshold = 30 + assert sqnr > sqnr_threshold, f"sqnr: {sqnr}" + + # Check exported graph for correct ops + ep = torch.export.export(model, example_inputs) + expected_counts = { + "torch.ops.torchao._shared_embedding_": 3, + "torch.ops.torchao._linear_8bit_act_": 7, + "torch.ops.aten.linear.default": 0, + "torch.ops.aten.embedding.default": 0, + } + for line, cnt in expected_counts.items(): + assert ep.graph_module.code.count(line) == cnt, ( + f"expected {cnt} {line} in {ep.graph_module.code}" + ) diff --git a/torchao/prototype/tensor_conversion/api.py b/torchao/prototype/tensor_conversion/api.py index 7722c57e34..63a1bcc2ef 100644 --- a/torchao/prototype/tensor_conversion/api.py +++ b/torchao/prototype/tensor_conversion/api.py @@ -7,6 +7,8 @@ import torch import torch.nn as nn +from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor + def _convert_linear_weight_to_int8_lut_tensor(module): from torchao.prototype.quantization.int8_lut_tensor import Int8LutTensor @@ -20,17 +22,116 @@ def _convert_linear_weight_to_int8_lut_tensor(module): module.bias = None +def _convert_module_weight_to_intx_opaque_tensor(module, intx_packing_format): + from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + IntxOpaqueTensor, + ) + + assert isinstance(module, nn.Linear) or isinstance(module, nn.Embedding) + weight = module.weight + new_weight = IntxOpaqueTensor.from_intx_unpacked_to_int8_tensor( + weight, + bias=module.bias if hasattr(module, "bias") else None, + intx_packing_format=intx_packing_format, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + if hasattr(module, "bias"): + module.bias = None + + +def _find_tied_module_names_for_embedding(embedding_weight, model): + assert isinstance(embedding_weight, IntxUnpackedToInt8Tensor) + tied_names = [] + for name, module in model.named_modules(): + is_linear = isinstance(module, nn.Linear) + is_embedding = isinstance(module, nn.Embedding) + if not (is_linear or is_embedding): + continue + + weight = module.weight + if not isinstance(weight, IntxUnpackedToInt8Tensor): + continue + + # We only have tied kernels for dynamically quantized linears + if is_linear and weight.activation_quantization != "int8_asym_per_token": + continue + + # We only have tied kernels for linear layers with no bias + if is_linear and module.bias is not None: + continue + + are_tied = ( + (embedding_weight.shape == weight.shape) + and (embedding_weight.block_size == weight.block_size) + and (embedding_weight.dtype == weight.dtype) + and (embedding_weight.qdata == weight.qdata).all() + and (embedding_weight.scale == weight.scale).all() + and (embedding_weight.zero_point == weight.zero_point).all() + ) + + if are_tied: + tied_names.append(name) + + return tied_names + + +def _find_tied_params(model): + from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + IntxOpaqueTensor, + ) + + module_name_to_tied_param = {} + for name, module in model.named_modules(): + if not isinstance(module, nn.Embedding): + continue + + weight = module.weight + if not isinstance(weight, IntxUnpackedToInt8Tensor): + continue + + tied_module_names = _find_tied_module_names_for_embedding(weight, model) + if not tied_module_names: + continue + + if name in module_name_to_tied_param: + tied_param = module_name_to_tied_param[name] + else: + # Construct a new tied param + # IntxOpaqueTensor requires activation_quantization = int8_asym_per_token + prev = weight.activation_quantization + weight.activation_quantization = "int8_asym_per_token" + tied_param = IntxOpaqueTensor.from_intx_unpacked_to_int8_tensor( + weight, + bias=None, + intx_packing_format="opaque_torchao_lowbit", + ) + weight.activation_quantization = prev + tied_param = nn.Parameter(tied_param, requires_grad=False) + module_name_to_tied_param[name] = tied_param + + for t in tied_module_names: + if t not in module_name_to_tied_param: + module_name_to_tied_param[t] = tied_param + + return module_name_to_tied_param + + def _convert_model_for_aarch64( - model, - *, - tensor_type="int8_lut_tensor", + model, *, tensor_type="auto", intx_packing_format="opaque_torchao_auto" ): - from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor + module_name_to_tied_param = _find_tied_params(model) # Iterate through modules in model and convert IntxUnpackedToInt8Tensor tensors to Int8LutTensor for name, module in model.named_modules(): + if name in module_name_to_tied_param: + module.weight = module_name_to_tied_param[name] + continue + + if isinstance(module, nn.Embedding): + print("Skipping converting nn.Embedding {name} because it is not tied") + continue + if not isinstance(module, nn.Linear): - print(f"Skipping converting {name} because it is not a linear layer") continue weight = module.weight @@ -42,6 +143,15 @@ def _convert_model_for_aarch64( if tensor_type == "int8_lut_tensor": _convert_linear_weight_to_int8_lut_tensor(module) + elif tensor_type == "intx_opaque_tensor": + _convert_module_weight_to_intx_opaque_tensor(module, intx_packing_format) + elif tensor_type == "auto": + if weight._has_float_zero_point() and isinstance(module, nn.Linear): + _convert_linear_weight_to_int8_lut_tensor(module) + else: + _convert_module_weight_to_intx_opaque_tensor( + module, intx_packing_format + ) else: raise ValueError(f"Unexpected tensor_type={tensor_type}") diff --git a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py index 50bee6df25..2c32732b74 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py @@ -335,6 +335,35 @@ def _(func, types, args, kwargs): return res +@implements([torch.nn.functional.embedding, aten.embedding.default]) +def _(func, types, args, kwargs): + assert len(args) == 2 + indices, weight_tensor = ( + args[0], + args[1], + ) + assert isinstance(weight_tensor, IntxOpaqueTensor) + assert weight_tensor.intx_packing_format == IntxPackingFormat.OPAQUE_TORCHAO_LOWBIT + packed_weights = weight_tensor.packed_weights + + assert len(weight_tensor.block_size) == 2 + assert weight_tensor.block_size[0] == 1 + group_size = weight_tensor.block_size[1] + + n, k = weight_tensor.shape + bit_width = weight_tensor.bit_width + + shape = indices.shape + out = getattr(torch.ops.torchao, f"_shared_embedding_{bit_width}bit")( + packed_weights, + group_size, + n, + k, + indices.reshape(-1), + ).reshape(*shape, -1) + return out + + IntxOpaqueTensor.__module__ = "torchao.quantization" torch.serialization.add_safe_globals([IntxOpaqueTensor]) From abacce39ea9b3d89dc8ec1a13b9b919327ef73e5 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:08:16 -0700 Subject: [PATCH 7/8] up --- .github/workflows/regression_test_aarch64.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/regression_test_aarch64.yml b/.github/workflows/regression_test_aarch64.yml index 10948fa61d..e7cdea55f0 100644 --- a/.github/workflows/regression_test_aarch64.yml +++ b/.github/workflows/regression_test_aarch64.yml @@ -55,6 +55,7 @@ jobs: pytest -s test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py pytest -s test/prototype/test_embedding.py pytest -s test/prototype/test_int8_lut_tensor.py + pytest -s test/prototype/test_conversion.py pytest -s test/prototype/test_groupwise_lowbit_weight_lut_quantizer.py pytest -s test/prototype/test_parq.py - name: torchao/csrc/cpu - build and run C++ tests From 422bc35ecb1a4b5624ea5aa2ddad27d5c92ebc2d Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 12 Sep 2025 18:12:49 -0700 Subject: [PATCH 8/8] up --- .github/workflows/regression_test_aarch64.yml | 2 +- .../prototype/{test_conversion.py => test_tensor_conversion.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename test/prototype/{test_conversion.py => test_tensor_conversion.py} (100%) diff --git a/.github/workflows/regression_test_aarch64.yml b/.github/workflows/regression_test_aarch64.yml index e7cdea55f0..ff10b661a5 100644 --- a/.github/workflows/regression_test_aarch64.yml +++ b/.github/workflows/regression_test_aarch64.yml @@ -55,7 +55,7 @@ jobs: pytest -s test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py pytest -s test/prototype/test_embedding.py pytest -s test/prototype/test_int8_lut_tensor.py - pytest -s test/prototype/test_conversion.py + pytest -s test/prototype/test_tensor_conversion.py pytest -s test/prototype/test_groupwise_lowbit_weight_lut_quantizer.py pytest -s test/prototype/test_parq.py - name: torchao/csrc/cpu - build and run C++ tests diff --git a/test/prototype/test_conversion.py b/test/prototype/test_tensor_conversion.py similarity index 100% rename from test/prototype/test_conversion.py rename to test/prototype/test_tensor_conversion.py