Skip to content

Commit f17861a

Browse files
committed
Improve QAT int4 weight-only numerics
**Summary:** Similar to #2937, this commit improves the prepare vs convert SQNR of int4 weight-only QAT from 12 to 45. This is achieved by mimicking the numerics of the target FBGEMM bf16-int4 kernel more closely. In particular, the FBGEMM kernel: 1. Performs asymmetric [0, 15] quant first then recenters to 8 2. Uses smaller scale eps of 1e-6 instead of bf16's eps (0.0078125) 3. Quantizes the weights using min val instead of zero points **Test Plan:** ``` python test/quantization/test_qat.py -k test_quantize_api_int4 python test/quantization/test_qat.py -k test_fbgemm_int4_weight_only_primitives ``` End-to-end tests TBD
1 parent 10ba659 commit f17861a

File tree

3 files changed

+131
-17
lines changed

3 files changed

+131
-17
lines changed

test/quantization/test_qat.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
dequantize_affine,
8484
quantize_affine,
8585
)
86+
from torchao.quantization.quantize_.workflows import Int4PackingFormat
8687
from torchao.quantization.unified import (
8788
TwoStepQuantizer,
8889
)
@@ -1942,15 +1943,18 @@ def test_quantize_api_fp8_int4(self):
19421943
)
19431944
@unittest.skipIf(is_fbcode(), "cutlass cannot initialize")
19441945
@parametrize("version", [1, 2])
1945-
def test_quantize_api_int4(self, version: int):
1946+
@parametrize(
1947+
"packing_format", [Int4PackingFormat.PLAIN, Int4PackingFormat.PRESHUFFLED]
1948+
)
1949+
def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat):
19461950
"""
19471951
Test the following:
19481952
quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="prepare"))
19491953
quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="convert"))
19501954
"""
19511955
self._test_quantize_api_against_ptq(
1952-
Int4WeightOnlyConfig(version=version),
1953-
target_prepare_sqnr=12,
1956+
Int4WeightOnlyConfig(version=version, int4_packing_format=packing_format),
1957+
target_prepare_sqnr=45 if version == 2 else 12,
19541958
target_convert_sqnr=float("inf"),
19551959
)
19561960

@@ -2004,9 +2008,9 @@ def test_infer_int4_weight_only_config(self):
20042008
base_config = Int4WeightOnlyConfig(version=2)
20052009
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
20062010
self.assertIsNone(act_config)
2007-
self.assertEqual(weight_config.dtype, torch.int4)
2011+
self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig)
20082012
self.assertEqual(weight_config.group_size, 128)
2009-
self.assertTrue(weight_config.is_symmetric)
2013+
self.assertEqual(weight_config.activation_dtype, torch.bfloat16)
20102014

20112015
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
20122016
def test_quantize_api_nvfp4(self):
@@ -2094,7 +2098,7 @@ def test_fbgemm_fp8_primitives(self):
20942098
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
20952099
)
20962100
@unittest.skipIf(is_fbcode(), "triton compilation error")
2097-
def test_fbgemm_int4_preshuffled_primitives(self):
2101+
def test_fbgemm_fp8_int4_preshuffled_primitives(self):
20982102
"""
20992103
Compare numerics between:
21002104
(1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle
@@ -2171,6 +2175,71 @@ def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
21712175
)
21722176
self.assertGreater(sqnr_q1_q3_preshuffle, 32)
21732177

2178+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
2179+
@unittest.skipIf(
2180+
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
2181+
)
2182+
@unittest.skipIf(is_fbcode(), "triton compilation error")
2183+
def test_fbgemm_int4_weight_only_primitives(self):
2184+
"""
2185+
Compare numerics between:
2186+
(1) fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize_zp
2187+
(2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer`
2188+
"""
2189+
from fbgemm_gpu.experimental.gen_ai.quantize import (
2190+
int4_row_quantize_zp,
2191+
pack_int4,
2192+
quantize_int4_preshuffle,
2193+
)
2194+
2195+
group_size = 128
2196+
x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda()
2197+
x2 = copy.deepcopy(x1)
2198+
x3 = copy.deepcopy(x1)
2199+
2200+
# (1) Just call `quantize_int4_preshuffle` with dtype="bf16"
2201+
(q1, (scale1, _)) = quantize_int4_preshuffle(x1, group_size, dtype="bf16")
2202+
2203+
# (2) Call `int4_row_quantize_zp`, which should be the same as (1)
2204+
# but without the packing and shuffling
2205+
(q2, scale2, _) = int4_row_quantize_zp(x2, group_size)
2206+
2207+
# (3) Reference implementation for QAT without the dequantize
2208+
eps = 1e-6
2209+
qmin, qmax = 0, 15
2210+
fbgemm_symmetric_qmax = 8
2211+
w_grouped = x3.to(torch.float32).view(x3.shape[0], -1, group_size)
2212+
max_val = torch.amax(w_grouped, dim=-1, keepdim=True)
2213+
min_val = torch.amin(w_grouped, dim=-1, keepdim=True)
2214+
scale3 = torch.clamp(max_val - min_val, min=eps) / qmax
2215+
q3 = (w_grouped.sub(min_val).div(scale3)).round().clamp_(qmin, qmax)
2216+
q3 = q3 - fbgemm_symmetric_qmax
2217+
q3 = q3.view(x3.shape)
2218+
2219+
def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
2220+
t = pack_int4(t.to(torch.int8))
2221+
return torch.ops.fbgemm.preshuffle_i4(t, scale.to(torch.bfloat16))[0]
2222+
2223+
# First, sanity check that shuffle_and_pack(q2) == q1
2224+
torch.testing.assert_close(q1, shuffle_and_pack(q2, scale2), atol=0, rtol=0)
2225+
2226+
# Now check q2 vs q3 with and without shuffle
2227+
torch.testing.assert_close(q2.to(torch.float32), q3, atol=0, rtol=0)
2228+
torch.testing.assert_close(
2229+
shuffle_and_pack(q2, scale2).to(torch.float32),
2230+
shuffle_and_pack(q3, scale3).to(torch.float32),
2231+
atol=0,
2232+
rtol=0,
2233+
)
2234+
2235+
# Now check shuffle_and_pack(q3) vs q1
2236+
torch.testing.assert_close(
2237+
q1.to(torch.float32),
2238+
shuffle_and_pack(q3, scale3).to(torch.float32),
2239+
atol=0,
2240+
rtol=0,
2241+
)
2242+
21742243

21752244
instantiate_parametrized_tests(TestQAT)
21762245

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
TorchAODType,
3232
ZeroPointDomain,
3333
)
34+
from torchao.quantization.quantize_.workflows import Int4PackingFormat
3435
from torchao.utils import _is_float8_type
3536

3637
from .utils import _log_deprecation_warning
@@ -77,11 +78,14 @@ def __post_init__(self):
7778
)
7879

7980

81+
# TODO: rename this config, it actually works for both plain and preshuffled
8082
@dataclass
8183
class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase):
8284
"""
8385
Config for pint4 weight fake quantization that targets the numerics in the following preshuffled kernel:
8486
torch.ops.fbgemm.f8i4bf16_shuffled
87+
torch.ops.fbgemm.bf16i4bf16_shuffled
88+
torch.ops.fbgemm.bf16i4bf16_rowwise
8589
8690
Currently this only supports float8 input activations. It is expected to be used in conjunction with
8791
:class:`~torchao.quantization.Float8DynamicActivationInt4WeightConfig`. In the future, we may extend
@@ -92,8 +96,10 @@ class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase):
9296
activation_dtype: torch.dtype = e4m3_dtype
9397

9498
def __post_init__(self):
95-
if self.activation_dtype != e4m3_dtype:
96-
raise ValueError(f"Only {e4m3_dtype} activation is supported currently")
99+
if self.activation_dtype not in [e4m3_dtype, torch.bfloat16]:
100+
raise ValueError(
101+
f"Only {e4m3_dtype} or torch.bfloat16 activation are supported"
102+
)
97103

98104

99105
@dataclass
@@ -379,10 +385,17 @@ def _infer_fake_quantize_configs(
379385
elif isinstance(base_config, Int4WeightOnlyConfig):
380386
act_config = None
381387
if base_config.version == 2:
382-
weight_config = IntxFakeQuantizeConfig(
383-
dtype=torch.int4,
384-
group_size=base_config.group_size,
385-
is_symmetric=True,
388+
supported_packing_formats = [
389+
Int4PackingFormat.PLAIN,
390+
Int4PackingFormat.PRESHUFFLED,
391+
]
392+
if base_config.int4_packing_format not in supported_packing_formats:
393+
raise ValueError(
394+
f"Packing format must be one of {supported_packing_formats}"
395+
)
396+
weight_config = Int4WeightPreshuffledFakeQuantizeConfig(
397+
group_size=128,
398+
activation_dtype=torch.bfloat16,
386399
)
387400
elif base_config.version == 1:
388401
# For BC

torchao/quantization/qat/fake_quantizer.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
103103
return dq
104104

105105

106+
# TODO: rename this, it also works for plain Int4Tensor
106107
class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase):
107108
"""
108109
Generic module for applying int4 fake quantization to a weight tensor,
109-
targeting the following FBGEMM kernel:
110+
targeting the following FBGEMM kernels:
110111
torch.ops.fbgemm.f8i4bf16_shuffled
112+
torch.ops.fbgemm.bf16i4bf16_shuffled
113+
torch.ops.fbgemm.bf16i4bf16_rowwise
111114
"""
112115

113116
def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig):
@@ -118,11 +121,18 @@ def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig):
118121
)
119122

120123
def forward(self, w: torch.Tensor) -> torch.Tensor:
121-
"""
122-
Apply int4 fake quantization to the weight tensor, using the following as a reference:
123-
https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L112
124+
if self.config.activation_dtype == torch.float8_e4m3fn:
125+
return self._fp8_activations_forward(w)
126+
elif self.config.activation_dtype == torch.bfloat16:
127+
return self._bf16_activations_forward(w)
128+
else:
129+
raise ValueError(f"Unknown activation dtype {self.config.activation_dtype}")
124130

125-
Currently, we expect the activations to always be rowwise float8.
131+
def _fp8_activations_forward(self, w: torch.Tensor) -> torch.Tensor:
132+
"""
133+
Apply int4 fake quantization to the weight tensor where the input activations
134+
are expected to be rowwise fp8, using the following as a reference:
135+
https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L136
126136
"""
127137
assert w.dim() == 2
128138
assert self.config.activation_dtype == torch.float8_e4m3fn
@@ -159,6 +169,28 @@ def forward(self, w: torch.Tensor) -> torch.Tensor:
159169
)
160170
return fq.to(w.dtype)
161171

172+
def _bf16_activations_forward(self, w: torch.Tensor) -> torch.Tensor:
173+
"""
174+
Apply int4 fake quantization to the weight tensor where the input activations
175+
are expected to be bf16, using the following as a reference:
176+
https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L152
177+
"""
178+
assert w.dim() == 2
179+
assert self.config.activation_dtype == torch.bfloat16
180+
181+
eps = 1e-6
182+
qmin, qmax = 0, 15
183+
fbgemm_symmetric_qmax = 8
184+
w_grouped = w.to(torch.float32).view(w.shape[0], -1, self.config.group_size)
185+
max_val = torch.amax(w_grouped, dim=-1, keepdim=True)
186+
min_val = torch.amin(w_grouped, dim=-1, keepdim=True)
187+
scale = torch.clamp(max_val - min_val, min=eps) / qmax
188+
zero_point = min_val + scale * fbgemm_symmetric_qmax
189+
fq = _Round.apply((w_grouped - min_val) / scale).clamp(qmin, qmax)
190+
fq = fq - fbgemm_symmetric_qmax
191+
fq = fq * scale + zero_point
192+
return fq.view(w.shape).to(w.dtype)
193+
162194

163195
class IntxFakeQuantizer(FakeQuantizerBase):
164196
"""

0 commit comments

Comments
 (0)