|
83 | 83 | dequantize_affine,
|
84 | 84 | quantize_affine,
|
85 | 85 | )
|
| 86 | +from torchao.quantization.quantize_.workflows import Int4PackingFormat |
86 | 87 | from torchao.quantization.unified import (
|
87 | 88 | TwoStepQuantizer,
|
88 | 89 | )
|
@@ -1942,15 +1943,18 @@ def test_quantize_api_fp8_int4(self):
|
1942 | 1943 | )
|
1943 | 1944 | @unittest.skipIf(is_fbcode(), "cutlass cannot initialize")
|
1944 | 1945 | @parametrize("version", [1, 2])
|
1945 |
| - def test_quantize_api_int4(self, version: int): |
| 1946 | + @parametrize( |
| 1947 | + "packing_format", [Int4PackingFormat.PLAIN, Int4PackingFormat.PRESHUFFLED] |
| 1948 | + ) |
| 1949 | + def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat): |
1946 | 1950 | """
|
1947 | 1951 | Test the following:
|
1948 | 1952 | quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="prepare"))
|
1949 | 1953 | quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="convert"))
|
1950 | 1954 | """
|
1951 | 1955 | self._test_quantize_api_against_ptq(
|
1952 |
| - Int4WeightOnlyConfig(version=version), |
1953 |
| - target_prepare_sqnr=12, |
| 1956 | + Int4WeightOnlyConfig(version=version, int4_packing_format=packing_format), |
| 1957 | + target_prepare_sqnr=45 if version == 2 else 12, |
1954 | 1958 | target_convert_sqnr=float("inf"),
|
1955 | 1959 | )
|
1956 | 1960 |
|
@@ -2004,9 +2008,9 @@ def test_infer_int4_weight_only_config(self):
|
2004 | 2008 | base_config = Int4WeightOnlyConfig(version=2)
|
2005 | 2009 | (act_config, weight_config) = _infer_fake_quantize_configs(base_config)
|
2006 | 2010 | self.assertIsNone(act_config)
|
2007 |
| - self.assertEqual(weight_config.dtype, torch.int4) |
| 2011 | + self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig) |
2008 | 2012 | self.assertEqual(weight_config.group_size, 128)
|
2009 |
| - self.assertTrue(weight_config.is_symmetric) |
| 2013 | + self.assertEqual(weight_config.activation_dtype, torch.bfloat16) |
2010 | 2014 |
|
2011 | 2015 | @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
|
2012 | 2016 | def test_quantize_api_nvfp4(self):
|
@@ -2094,7 +2098,7 @@ def test_fbgemm_fp8_primitives(self):
|
2094 | 2098 | not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
|
2095 | 2099 | )
|
2096 | 2100 | @unittest.skipIf(is_fbcode(), "triton compilation error")
|
2097 |
| - def test_fbgemm_int4_preshuffled_primitives(self): |
| 2101 | + def test_fbgemm_fp8_int4_preshuffled_primitives(self): |
2098 | 2102 | """
|
2099 | 2103 | Compare numerics between:
|
2100 | 2104 | (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle
|
@@ -2171,6 +2175,71 @@ def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
2171 | 2175 | )
|
2172 | 2176 | self.assertGreater(sqnr_q1_q3_preshuffle, 32)
|
2173 | 2177 |
|
| 2178 | + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") |
| 2179 | + @unittest.skipIf( |
| 2180 | + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" |
| 2181 | + ) |
| 2182 | + @unittest.skipIf(is_fbcode(), "triton compilation error") |
| 2183 | + def test_fbgemm_int4_weight_only_primitives(self): |
| 2184 | + """ |
| 2185 | + Compare numerics between: |
| 2186 | + (1) fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize_zp |
| 2187 | + (2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer` |
| 2188 | + """ |
| 2189 | + from fbgemm_gpu.experimental.gen_ai.quantize import ( |
| 2190 | + int4_row_quantize_zp, |
| 2191 | + pack_int4, |
| 2192 | + quantize_int4_preshuffle, |
| 2193 | + ) |
| 2194 | + |
| 2195 | + group_size = 128 |
| 2196 | + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() |
| 2197 | + x2 = copy.deepcopy(x1) |
| 2198 | + x3 = copy.deepcopy(x1) |
| 2199 | + |
| 2200 | + # (1) Just call `quantize_int4_preshuffle` with dtype="bf16" |
| 2201 | + (q1, (scale1, _)) = quantize_int4_preshuffle(x1, group_size, dtype="bf16") |
| 2202 | + |
| 2203 | + # (2) Call `int4_row_quantize_zp`, which should be the same as (1) |
| 2204 | + # but without the packing and shuffling |
| 2205 | + (q2, scale2, _) = int4_row_quantize_zp(x2, group_size) |
| 2206 | + |
| 2207 | + # (3) Reference implementation for QAT without the dequantize |
| 2208 | + eps = 1e-6 |
| 2209 | + qmin, qmax = 0, 15 |
| 2210 | + fbgemm_symmetric_qmax = 8 |
| 2211 | + w_grouped = x3.to(torch.float32).view(x3.shape[0], -1, group_size) |
| 2212 | + max_val = torch.amax(w_grouped, dim=-1, keepdim=True) |
| 2213 | + min_val = torch.amin(w_grouped, dim=-1, keepdim=True) |
| 2214 | + scale3 = torch.clamp(max_val - min_val, min=eps) / qmax |
| 2215 | + q3 = (w_grouped.sub(min_val).div(scale3)).round().clamp_(qmin, qmax) |
| 2216 | + q3 = q3 - fbgemm_symmetric_qmax |
| 2217 | + q3 = q3.view(x3.shape) |
| 2218 | + |
| 2219 | + def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: |
| 2220 | + t = pack_int4(t.to(torch.int8)) |
| 2221 | + return torch.ops.fbgemm.preshuffle_i4(t, scale.to(torch.bfloat16))[0] |
| 2222 | + |
| 2223 | + # First, sanity check that shuffle_and_pack(q2) == q1 |
| 2224 | + torch.testing.assert_close(q1, shuffle_and_pack(q2, scale2), atol=0, rtol=0) |
| 2225 | + |
| 2226 | + # Now check q2 vs q3 with and without shuffle |
| 2227 | + torch.testing.assert_close(q2.to(torch.float32), q3, atol=0, rtol=0) |
| 2228 | + torch.testing.assert_close( |
| 2229 | + shuffle_and_pack(q2, scale2).to(torch.float32), |
| 2230 | + shuffle_and_pack(q3, scale3).to(torch.float32), |
| 2231 | + atol=0, |
| 2232 | + rtol=0, |
| 2233 | + ) |
| 2234 | + |
| 2235 | + # Now check shuffle_and_pack(q3) vs q1 |
| 2236 | + torch.testing.assert_close( |
| 2237 | + q1.to(torch.float32), |
| 2238 | + shuffle_and_pack(q3, scale3).to(torch.float32), |
| 2239 | + atol=0, |
| 2240 | + rtol=0, |
| 2241 | + ) |
| 2242 | + |
2174 | 2243 |
|
2175 | 2244 | instantiate_parametrized_tests(TestQAT)
|
2176 | 2245 |
|
|
0 commit comments