diff --git a/.github/workflows/regression_test_aarch64.yml b/.github/workflows/regression_test_aarch64.yml index 10948fa61d..ff10b661a5 100644 --- a/.github/workflows/regression_test_aarch64.yml +++ b/.github/workflows/regression_test_aarch64.yml @@ -55,6 +55,7 @@ jobs: pytest -s test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py pytest -s test/prototype/test_embedding.py pytest -s test/prototype/test_int8_lut_tensor.py + pytest -s test/prototype/test_tensor_conversion.py pytest -s test/prototype/test_groupwise_lowbit_weight_lut_quantizer.py pytest -s test/prototype/test_parq.py - name: torchao/csrc/cpu - build and run C++ tests diff --git a/test/prototype/test_tensor_conversion.py b/test/prototype/test_tensor_conversion.py new file mode 100644 index 0000000000..2cee9a08ef --- /dev/null +++ b/test/prototype/test_tensor_conversion.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import pytest +import torch + +from torchao.prototype.parq.quant import ( + StretchedIntxWeightConfig, + StretchedUnifTorchaoQuantizer, +) +from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import Int8LutTensor +from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64 +from torchao.quantization import MappingType +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + quantize_, +) +from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + IntxOpaqueTensor, + _is_kernel_library_loaded, +) +from torchao.quantization.utils import compute_error + + +class ToyLinearModelWithTiedEmbedding(torch.nn.Module): + def __init__(self, d0=512, d1=512, d2=256, d3=128, d4=32): + super().__init__() + self.embedding1 = torch.nn.Embedding(d0, d1) + self.embedding2 = torch.nn.Embedding(d0, d1) + self.embedding3 = torch.nn.Embedding(d0, d1) + + self.linear1 = torch.nn.Linear(d1, d2, bias=False) + self.linear2 = torch.nn.Linear(d2, d3, bias=True) + self.linear3 = torch.nn.Linear(d3, d4, bias=False) + self.linear4 = torch.nn.Linear(d4, d1, bias=False) + + self.lm_head1 = torch.nn.Linear(d1, d0, bias=False) + self.lm_head2 = torch.nn.Linear(d1, d0, bias=False) + self.lm_head3 = torch.nn.Linear(d1, d0, bias=False) + + # Tie weights + # lm_head1 / lm_head2 form one tied weight group + self.embedding2.weight = self.embedding1.weight + self.lm_head1.weight = self.embedding1.weight + self.lm_head2.weight = self.embedding1.weight + + # lm_head3 forms a separate tied weight group + self.lm_head3.weight = self.embedding3.weight + + def example_inputs( + self, + lead_dim=(1,), + dtype=torch.bfloat16, + ): + return ( + torch.randint( + 0, + self.embedding1.num_embeddings, + size=lead_dim, + dtype=torch.int64, + device="cpu", + ), + ) + + def forward(self, x): + x = self.embedding1(x) + self.embedding2(x) + self.embedding3(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.lm_head1(x) + self.lm_head2(x) + self.lm_head3(x) + return x + + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + yield + torch._dynamo.reset() # reset cache between tests + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)]) +@pytest.mark.parametrize("bit_width", [1, 2, 3, 4]) +@pytest.mark.parametrize( + "lead_dim", + [ + (1,), + (5,), + (7, 2), + ], +) +@pytest.mark.skipif( + not _is_kernel_library_loaded(), reason="Kernel library is not loaded" +) +def test_aarch64_conversion(dtype, granularity, bit_width, lead_dim): + torch.manual_seed(0) + + model = ToyLinearModelWithTiedEmbedding() + model = model.to(dtype) + example_inputs = model.example_inputs(lead_dim, dtype) + + # Quantize linear 2 and 3 with PARQ + quantizer = StretchedUnifTorchaoQuantizer(bit_width) + config = StretchedIntxWeightConfig( + b=bit_width, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=granularity, + activation_quantization="int8_asym_per_token", + ) + quantize_(model, config, filter_fn=lambda m, fqn: fqn in ["linear2", "linear3"]) + + # Quantize linear 1 and 4 with int8 dynamic activation + config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=granularity, + weight_mapping_type=MappingType.SYMMETRIC, + ) + quantize_( + model, + config, + filter_fn=lambda m, fqn: fqn + in ["linear1", "linear4", "lm_head1", "lm_head2", "lm_head3"], + ) + + # Quantize embedding 1, 2, and 3 with weight only + config = IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=granularity, + mapping_type=MappingType.SYMMETRIC, + ) + quantize_( + model, + config, + filter_fn=lambda m, fqn: fqn in ["embedding1", "embedding2", "embedding3"], + ) + model_out = model(*example_inputs) + + # Convert to optimized model + _convert_model_for_aarch64(model) + + # Check expected tensor subclass + assert isinstance(model.linear2.weight, Int8LutTensor) + assert isinstance(model.linear3.weight, Int8LutTensor) + assert isinstance(model.linear1.weight, IntxOpaqueTensor) + assert isinstance(model.linear4.weight, IntxOpaqueTensor) + + # Assert tied params + tied_group1_id = id(model.embedding1.weight) + assert id(model.embedding2.weight) == tied_group1_id + assert id(model.lm_head1.weight) == tied_group1_id + assert id(model.lm_head2.weight) == tied_group1_id + + assert id(model.lm_head3.weight) == id(model.embedding3.weight) + assert id(model.lm_head3.weight) != tied_group1_id + + # Compare converted out with original out + converted_out = model(*example_inputs) + sqnr = compute_error(model_out, converted_out) + sqnr_threshold = 30 + assert sqnr > sqnr_threshold, f"sqnr: {sqnr}" + + # Check exported graph for correct ops + ep = torch.export.export(model, example_inputs) + expected_counts = { + "torch.ops.torchao._shared_embedding_": 3, + "torch.ops.torchao._linear_8bit_act_": 7, + "torch.ops.aten.linear.default": 0, + "torch.ops.aten.embedding.default": 0, + } + for line, cnt in expected_counts.items(): + assert ep.graph_module.code.count(line) == cnt, ( + f"expected {cnt} {line} in {ep.graph_module.code}" + ) diff --git a/torchao/prototype/tensor_conversion/api.py b/torchao/prototype/tensor_conversion/api.py index 7722c57e34..63a1bcc2ef 100644 --- a/torchao/prototype/tensor_conversion/api.py +++ b/torchao/prototype/tensor_conversion/api.py @@ -7,6 +7,8 @@ import torch import torch.nn as nn +from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor + def _convert_linear_weight_to_int8_lut_tensor(module): from torchao.prototype.quantization.int8_lut_tensor import Int8LutTensor @@ -20,17 +22,116 @@ def _convert_linear_weight_to_int8_lut_tensor(module): module.bias = None +def _convert_module_weight_to_intx_opaque_tensor(module, intx_packing_format): + from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + IntxOpaqueTensor, + ) + + assert isinstance(module, nn.Linear) or isinstance(module, nn.Embedding) + weight = module.weight + new_weight = IntxOpaqueTensor.from_intx_unpacked_to_int8_tensor( + weight, + bias=module.bias if hasattr(module, "bias") else None, + intx_packing_format=intx_packing_format, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + if hasattr(module, "bias"): + module.bias = None + + +def _find_tied_module_names_for_embedding(embedding_weight, model): + assert isinstance(embedding_weight, IntxUnpackedToInt8Tensor) + tied_names = [] + for name, module in model.named_modules(): + is_linear = isinstance(module, nn.Linear) + is_embedding = isinstance(module, nn.Embedding) + if not (is_linear or is_embedding): + continue + + weight = module.weight + if not isinstance(weight, IntxUnpackedToInt8Tensor): + continue + + # We only have tied kernels for dynamically quantized linears + if is_linear and weight.activation_quantization != "int8_asym_per_token": + continue + + # We only have tied kernels for linear layers with no bias + if is_linear and module.bias is not None: + continue + + are_tied = ( + (embedding_weight.shape == weight.shape) + and (embedding_weight.block_size == weight.block_size) + and (embedding_weight.dtype == weight.dtype) + and (embedding_weight.qdata == weight.qdata).all() + and (embedding_weight.scale == weight.scale).all() + and (embedding_weight.zero_point == weight.zero_point).all() + ) + + if are_tied: + tied_names.append(name) + + return tied_names + + +def _find_tied_params(model): + from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + IntxOpaqueTensor, + ) + + module_name_to_tied_param = {} + for name, module in model.named_modules(): + if not isinstance(module, nn.Embedding): + continue + + weight = module.weight + if not isinstance(weight, IntxUnpackedToInt8Tensor): + continue + + tied_module_names = _find_tied_module_names_for_embedding(weight, model) + if not tied_module_names: + continue + + if name in module_name_to_tied_param: + tied_param = module_name_to_tied_param[name] + else: + # Construct a new tied param + # IntxOpaqueTensor requires activation_quantization = int8_asym_per_token + prev = weight.activation_quantization + weight.activation_quantization = "int8_asym_per_token" + tied_param = IntxOpaqueTensor.from_intx_unpacked_to_int8_tensor( + weight, + bias=None, + intx_packing_format="opaque_torchao_lowbit", + ) + weight.activation_quantization = prev + tied_param = nn.Parameter(tied_param, requires_grad=False) + module_name_to_tied_param[name] = tied_param + + for t in tied_module_names: + if t not in module_name_to_tied_param: + module_name_to_tied_param[t] = tied_param + + return module_name_to_tied_param + + def _convert_model_for_aarch64( - model, - *, - tensor_type="int8_lut_tensor", + model, *, tensor_type="auto", intx_packing_format="opaque_torchao_auto" ): - from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor + module_name_to_tied_param = _find_tied_params(model) # Iterate through modules in model and convert IntxUnpackedToInt8Tensor tensors to Int8LutTensor for name, module in model.named_modules(): + if name in module_name_to_tied_param: + module.weight = module_name_to_tied_param[name] + continue + + if isinstance(module, nn.Embedding): + print("Skipping converting nn.Embedding {name} because it is not tied") + continue + if not isinstance(module, nn.Linear): - print(f"Skipping converting {name} because it is not a linear layer") continue weight = module.weight @@ -42,6 +143,15 @@ def _convert_model_for_aarch64( if tensor_type == "int8_lut_tensor": _convert_linear_weight_to_int8_lut_tensor(module) + elif tensor_type == "intx_opaque_tensor": + _convert_module_weight_to_intx_opaque_tensor(module, intx_packing_format) + elif tensor_type == "auto": + if weight._has_float_zero_point() and isinstance(module, nn.Linear): + _convert_linear_weight_to_int8_lut_tensor(module) + else: + _convert_module_weight_to_intx_opaque_tensor( + module, intx_packing_format + ) else: raise ValueError(f"Unexpected tensor_type={tensor_type}") diff --git a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py index 50bee6df25..2c32732b74 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py @@ -335,6 +335,35 @@ def _(func, types, args, kwargs): return res +@implements([torch.nn.functional.embedding, aten.embedding.default]) +def _(func, types, args, kwargs): + assert len(args) == 2 + indices, weight_tensor = ( + args[0], + args[1], + ) + assert isinstance(weight_tensor, IntxOpaqueTensor) + assert weight_tensor.intx_packing_format == IntxPackingFormat.OPAQUE_TORCHAO_LOWBIT + packed_weights = weight_tensor.packed_weights + + assert len(weight_tensor.block_size) == 2 + assert weight_tensor.block_size[0] == 1 + group_size = weight_tensor.block_size[1] + + n, k = weight_tensor.shape + bit_width = weight_tensor.bit_width + + shape = indices.shape + out = getattr(torch.ops.torchao, f"_shared_embedding_{bit_width}bit")( + packed_weights, + group_size, + n, + k, + indices.reshape(-1), + ).reshape(*shape, -1) + return out + + IntxOpaqueTensor.__module__ = "torchao.quantization" torch.serialization.add_safe_globals([IntxOpaqueTensor])