diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 004860e329..0a7a94af1c 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -83,6 +83,7 @@ dequantize_affine, quantize_affine, ) +from torchao.quantization.quantize_.workflows import Int4PackingFormat from torchao.quantization.unified import ( TwoStepQuantizer, ) @@ -1942,15 +1943,18 @@ def test_quantize_api_fp8_int4(self): ) @unittest.skipIf(is_fbcode(), "cutlass cannot initialize") @parametrize("version", [1, 2]) - def test_quantize_api_int4(self, version: int): + @parametrize( + "packing_format", [Int4PackingFormat.PLAIN, Int4PackingFormat.PRESHUFFLED] + ) + def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat): """ Test the following: quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="prepare")) quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="convert")) """ self._test_quantize_api_against_ptq( - Int4WeightOnlyConfig(version=version), - target_prepare_sqnr=12, + Int4WeightOnlyConfig(version=version, int4_packing_format=packing_format), + target_prepare_sqnr=45 if version == 2 else 12, target_convert_sqnr=float("inf"), ) @@ -2004,9 +2008,9 @@ def test_infer_int4_weight_only_config(self): base_config = Int4WeightOnlyConfig(version=2) (act_config, weight_config) = _infer_fake_quantize_configs(base_config) self.assertIsNone(act_config) - self.assertEqual(weight_config.dtype, torch.int4) + self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig) self.assertEqual(weight_config.group_size, 128) - self.assertTrue(weight_config.is_symmetric) + self.assertEqual(weight_config.activation_dtype, torch.bfloat16) @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") def test_quantize_api_nvfp4(self): @@ -2094,7 +2098,7 @@ def test_fbgemm_fp8_primitives(self): not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" ) @unittest.skipIf(is_fbcode(), "triton compilation error") - def test_fbgemm_int4_preshuffled_primitives(self): + def test_fbgemm_fp8_int4_preshuffled_primitives(self): """ Compare numerics between: (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle @@ -2171,6 +2175,71 @@ def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: ) self.assertGreater(sqnr_q1_q3_preshuffle, 32) + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) + @unittest.skipIf(is_fbcode(), "triton compilation error") + def test_fbgemm_int4_weight_only_primitives(self): + """ + Compare numerics between: + (1) fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize_zp + (2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer` + """ + from fbgemm_gpu.experimental.gen_ai.quantize import ( + int4_row_quantize_zp, + pack_int4, + quantize_int4_preshuffle, + ) + + group_size = 128 + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() + x2 = copy.deepcopy(x1) + x3 = copy.deepcopy(x1) + + # (1) Just call `quantize_int4_preshuffle` with dtype="bf16" + (q1, (scale1, _)) = quantize_int4_preshuffle(x1, group_size, dtype="bf16") + + # (2) Call `int4_row_quantize_zp`, which should be the same as (1) + # but without the packing and shuffling + (q2, scale2, _) = int4_row_quantize_zp(x2, group_size) + + # (3) Reference implementation for QAT without the dequantize + eps = 1e-6 + qmin, qmax = 0, 15 + fbgemm_symmetric_qmax = 8 + w_grouped = x3.to(torch.float32).view(x3.shape[0], -1, group_size) + max_val = torch.amax(w_grouped, dim=-1, keepdim=True) + min_val = torch.amin(w_grouped, dim=-1, keepdim=True) + scale3 = torch.clamp(max_val - min_val, min=eps) / qmax + q3 = (w_grouped.sub(min_val).div(scale3)).round().clamp_(qmin, qmax) + q3 = q3 - fbgemm_symmetric_qmax + q3 = q3.view(x3.shape) + + def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + t = pack_int4(t.to(torch.int8)) + return torch.ops.fbgemm.preshuffle_i4(t, scale.to(torch.bfloat16))[0] + + # First, sanity check that shuffle_and_pack(q2) == q1 + torch.testing.assert_close(q1, shuffle_and_pack(q2, scale2), atol=0, rtol=0) + + # Now check q2 vs q3 with and without shuffle + torch.testing.assert_close(q2.to(torch.float32), q3, atol=0, rtol=0) + torch.testing.assert_close( + shuffle_and_pack(q2, scale2).to(torch.float32), + shuffle_and_pack(q3, scale3).to(torch.float32), + atol=0, + rtol=0, + ) + + # Now check shuffle_and_pack(q3) vs q1 + torch.testing.assert_close( + q1.to(torch.float32), + shuffle_and_pack(q3, scale3).to(torch.float32), + atol=0, + rtol=0, + ) + instantiate_parametrized_tests(TestQAT) diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index dc86aa919f..892fcd8d8b 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -31,6 +31,7 @@ TorchAODType, ZeroPointDomain, ) +from torchao.quantization.quantize_.workflows import Int4PackingFormat from torchao.utils import _is_float8_type from .utils import _log_deprecation_warning @@ -77,11 +78,14 @@ def __post_init__(self): ) +# TODO: rename this config, it actually works for both plain and preshuffled @dataclass class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase): """ Config for pint4 weight fake quantization that targets the numerics in the following preshuffled kernel: torch.ops.fbgemm.f8i4bf16_shuffled + torch.ops.fbgemm.bf16i4bf16_shuffled + torch.ops.fbgemm.bf16i4bf16_rowwise Currently this only supports float8 input activations. It is expected to be used in conjunction with :class:`~torchao.quantization.Float8DynamicActivationInt4WeightConfig`. In the future, we may extend @@ -92,8 +96,10 @@ class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase): activation_dtype: torch.dtype = e4m3_dtype def __post_init__(self): - if self.activation_dtype != e4m3_dtype: - raise ValueError(f"Only {e4m3_dtype} activation is supported currently") + if self.activation_dtype not in [e4m3_dtype, torch.bfloat16]: + raise ValueError( + f"Only {e4m3_dtype} or torch.bfloat16 activation are supported" + ) @dataclass @@ -379,10 +385,17 @@ def _infer_fake_quantize_configs( elif isinstance(base_config, Int4WeightOnlyConfig): act_config = None if base_config.version == 2: - weight_config = IntxFakeQuantizeConfig( - dtype=torch.int4, - group_size=base_config.group_size, - is_symmetric=True, + supported_packing_formats = [ + Int4PackingFormat.PLAIN, + Int4PackingFormat.PRESHUFFLED, + ] + if base_config.int4_packing_format not in supported_packing_formats: + raise ValueError( + f"Packing format must be one of {supported_packing_formats}" + ) + weight_config = Int4WeightPreshuffledFakeQuantizeConfig( + group_size=128, + activation_dtype=torch.bfloat16, ) elif base_config.version == 1: # For BC diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 8a63a0d0ad..8c21ecf5cc 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -103,11 +103,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return dq +# TODO: rename this, it also works for plain Int4Tensor class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase): """ Generic module for applying int4 fake quantization to a weight tensor, - targeting the following FBGEMM kernel: + targeting the following FBGEMM kernels: torch.ops.fbgemm.f8i4bf16_shuffled + torch.ops.fbgemm.bf16i4bf16_shuffled + torch.ops.fbgemm.bf16i4bf16_rowwise """ def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig): @@ -118,11 +121,18 @@ def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig): ) def forward(self, w: torch.Tensor) -> torch.Tensor: - """ - Apply int4 fake quantization to the weight tensor, using the following as a reference: - https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L112 + if self.config.activation_dtype == torch.float8_e4m3fn: + return self._fp8_activations_forward(w) + elif self.config.activation_dtype == torch.bfloat16: + return self._bf16_activations_forward(w) + else: + raise ValueError(f"Unknown activation dtype {self.config.activation_dtype}") - Currently, we expect the activations to always be rowwise float8. + def _fp8_activations_forward(self, w: torch.Tensor) -> torch.Tensor: + """ + Apply int4 fake quantization to the weight tensor where the input activations + are expected to be rowwise fp8, using the following as a reference: + https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L136 """ assert w.dim() == 2 assert self.config.activation_dtype == torch.float8_e4m3fn @@ -159,6 +169,28 @@ def forward(self, w: torch.Tensor) -> torch.Tensor: ) return fq.to(w.dtype) + def _bf16_activations_forward(self, w: torch.Tensor) -> torch.Tensor: + """ + Apply int4 fake quantization to the weight tensor where the input activations + are expected to be bf16, using the following as a reference: + https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L152 + """ + assert w.dim() == 2 + assert self.config.activation_dtype == torch.bfloat16 + + eps = 1e-6 + qmin, qmax = 0, 15 + fbgemm_symmetric_qmax = 8 + w_grouped = w.to(torch.float32).view(w.shape[0], -1, self.config.group_size) + max_val = torch.amax(w_grouped, dim=-1, keepdim=True) + min_val = torch.amin(w_grouped, dim=-1, keepdim=True) + scale = torch.clamp(max_val - min_val, min=eps) / qmax + zero_point = min_val + scale * fbgemm_symmetric_qmax + fq = _Round.apply((w_grouped - min_val) / scale).clamp(qmin, qmax) + fq = fq - fbgemm_symmetric_qmax + fq = fq * scale + zero_point + return fq.view(w.shape).to(w.dtype) + class IntxFakeQuantizer(FakeQuantizerBase): """