Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 75 additions & 6 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
dequantize_affine,
quantize_affine,
)
from torchao.quantization.quantize_.workflows import Int4PackingFormat
from torchao.quantization.unified import (
TwoStepQuantizer,
)
Expand Down Expand Up @@ -1942,15 +1943,18 @@ def test_quantize_api_fp8_int4(self):
)
@unittest.skipIf(is_fbcode(), "cutlass cannot initialize")
@parametrize("version", [1, 2])
def test_quantize_api_int4(self, version: int):
@parametrize(
"packing_format", [Int4PackingFormat.PLAIN, Int4PackingFormat.PRESHUFFLED]
)
def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat):
"""
Test the following:
quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="prepare"))
quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="convert"))
"""
self._test_quantize_api_against_ptq(
Int4WeightOnlyConfig(version=version),
target_prepare_sqnr=12,
Int4WeightOnlyConfig(version=version, int4_packing_format=packing_format),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's fine for QAT to only support version 2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although you may want to cover more int4 packing format such as TILE_PACKED_TO_4D the previous tinygemm layout

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think we can drop version 1, but it's BC breaking so we can do it separately

target_prepare_sqnr=45 if version == 2 else 12,
target_convert_sqnr=float("inf"),
)

Expand Down Expand Up @@ -2004,9 +2008,9 @@ def test_infer_int4_weight_only_config(self):
base_config = Int4WeightOnlyConfig(version=2)
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
self.assertIsNone(act_config)
self.assertEqual(weight_config.dtype, torch.int4)
self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig)
self.assertEqual(weight_config.group_size, 128)
self.assertTrue(weight_config.is_symmetric)
self.assertEqual(weight_config.activation_dtype, torch.bfloat16)

@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
def test_quantize_api_nvfp4(self):
Expand Down Expand Up @@ -2094,7 +2098,7 @@ def test_fbgemm_fp8_primitives(self):
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
)
@unittest.skipIf(is_fbcode(), "triton compilation error")
def test_fbgemm_int4_preshuffled_primitives(self):
def test_fbgemm_fp8_int4_preshuffled_primitives(self):
"""
Compare numerics between:
(1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle
Expand Down Expand Up @@ -2171,6 +2175,71 @@ def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
)
self.assertGreater(sqnr_q1_q3_preshuffle, 32)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
)
@unittest.skipIf(is_fbcode(), "triton compilation error")
def test_fbgemm_int4_weight_only_primitives(self):
"""
Compare numerics between:
(1) fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize_zp
(2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer`
"""
from fbgemm_gpu.experimental.gen_ai.quantize import (
int4_row_quantize_zp,
pack_int4,
quantize_int4_preshuffle,
)

group_size = 128
x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda()
x2 = copy.deepcopy(x1)
x3 = copy.deepcopy(x1)

# (1) Just call `quantize_int4_preshuffle` with dtype="bf16"
(q1, (scale1, _)) = quantize_int4_preshuffle(x1, group_size, dtype="bf16")

# (2) Call `int4_row_quantize_zp`, which should be the same as (1)
# but without the packing and shuffling
(q2, scale2, _) = int4_row_quantize_zp(x2, group_size)

# (3) Reference implementation for QAT without the dequantize
eps = 1e-6
qmin, qmax = 0, 15
fbgemm_symmetric_qmax = 8
w_grouped = x3.to(torch.float32).view(x3.shape[0], -1, group_size)
max_val = torch.amax(w_grouped, dim=-1, keepdim=True)
min_val = torch.amin(w_grouped, dim=-1, keepdim=True)
scale3 = torch.clamp(max_val - min_val, min=eps) / qmax
q3 = (w_grouped.sub(min_val).div(scale3)).round().clamp_(qmin, qmax)
q3 = q3 - fbgemm_symmetric_qmax
q3 = q3.view(x3.shape)

def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
t = pack_int4(t.to(torch.int8))
return torch.ops.fbgemm.preshuffle_i4(t, scale.to(torch.bfloat16))[0]

# First, sanity check that shuffle_and_pack(q2) == q1
torch.testing.assert_close(q1, shuffle_and_pack(q2, scale2), atol=0, rtol=0)

# Now check q2 vs q3 with and without shuffle
torch.testing.assert_close(q2.to(torch.float32), q3, atol=0, rtol=0)
torch.testing.assert_close(
shuffle_and_pack(q2, scale2).to(torch.float32),
shuffle_and_pack(q3, scale3).to(torch.float32),
atol=0,
rtol=0,
)

# Now check shuffle_and_pack(q3) vs q1
torch.testing.assert_close(
q1.to(torch.float32),
shuffle_and_pack(q3, scale3).to(torch.float32),
atol=0,
rtol=0,
)


instantiate_parametrized_tests(TestQAT)

Expand Down
25 changes: 19 additions & 6 deletions torchao/quantization/qat/fake_quantize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TorchAODType,
ZeroPointDomain,
)
from torchao.quantization.quantize_.workflows import Int4PackingFormat
from torchao.utils import _is_float8_type

from .utils import _log_deprecation_warning
Expand Down Expand Up @@ -77,11 +78,14 @@ def __post_init__(self):
)


# TODO: rename this config, it actually works for both plain and preshuffled
@dataclass
class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase):
"""
Config for pint4 weight fake quantization that targets the numerics in the following preshuffled kernel:
torch.ops.fbgemm.f8i4bf16_shuffled
torch.ops.fbgemm.bf16i4bf16_shuffled
torch.ops.fbgemm.bf16i4bf16_rowwise

Currently this only supports float8 input activations. It is expected to be used in conjunction with
:class:`~torchao.quantization.Float8DynamicActivationInt4WeightConfig`. In the future, we may extend
Expand All @@ -92,8 +96,10 @@ class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase):
activation_dtype: torch.dtype = e4m3_dtype

def __post_init__(self):
if self.activation_dtype != e4m3_dtype:
raise ValueError(f"Only {e4m3_dtype} activation is supported currently")
if self.activation_dtype not in [e4m3_dtype, torch.bfloat16]:
raise ValueError(
f"Only {e4m3_dtype} or torch.bfloat16 activation are supported"
)


@dataclass
Expand Down Expand Up @@ -379,10 +385,17 @@ def _infer_fake_quantize_configs(
elif isinstance(base_config, Int4WeightOnlyConfig):
act_config = None
if base_config.version == 2:
weight_config = IntxFakeQuantizeConfig(
dtype=torch.int4,
group_size=base_config.group_size,
is_symmetric=True,
supported_packing_formats = [
Int4PackingFormat.PLAIN,
Int4PackingFormat.PRESHUFFLED,
]
if base_config.int4_packing_format not in supported_packing_formats:
raise ValueError(
f"Packing format must be one of {supported_packing_formats}"
)
weight_config = Int4WeightPreshuffledFakeQuantizeConfig(
group_size=128,
activation_dtype=torch.bfloat16,
)
elif base_config.version == 1:
# For BC
Expand Down
42 changes: 37 additions & 5 deletions torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return dq


# TODO: rename this, it also works for plain Int4Tensor
class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase):
"""
Generic module for applying int4 fake quantization to a weight tensor,
targeting the following FBGEMM kernel:
targeting the following FBGEMM kernels:
torch.ops.fbgemm.f8i4bf16_shuffled
torch.ops.fbgemm.bf16i4bf16_shuffled
torch.ops.fbgemm.bf16i4bf16_rowwise
"""

def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig):
Expand All @@ -118,11 +121,18 @@ def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig):
)

def forward(self, w: torch.Tensor) -> torch.Tensor:
"""
Apply int4 fake quantization to the weight tensor, using the following as a reference:
https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L112
if self.config.activation_dtype == torch.float8_e4m3fn:
return self._fp8_activations_forward(w)
elif self.config.activation_dtype == torch.bfloat16:
return self._bf16_activations_forward(w)
else:
raise ValueError(f"Unknown activation dtype {self.config.activation_dtype}")

Currently, we expect the activations to always be rowwise float8.
def _fp8_activations_forward(self, w: torch.Tensor) -> torch.Tensor:
"""
Apply int4 fake quantization to the weight tensor where the input activations
are expected to be rowwise fp8, using the following as a reference:
https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L136
"""
assert w.dim() == 2
assert self.config.activation_dtype == torch.float8_e4m3fn
Expand Down Expand Up @@ -159,6 +169,28 @@ def forward(self, w: torch.Tensor) -> torch.Tensor:
)
return fq.to(w.dtype)

def _bf16_activations_forward(self, w: torch.Tensor) -> torch.Tensor:
"""
Apply int4 fake quantization to the weight tensor where the input activations
are expected to be bf16, using the following as a reference:
https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L152
"""
assert w.dim() == 2
assert self.config.activation_dtype == torch.bfloat16

eps = 1e-6
qmin, qmax = 0, 15
fbgemm_symmetric_qmax = 8
w_grouped = w.to(torch.float32).view(w.shape[0], -1, self.config.group_size)
max_val = torch.amax(w_grouped, dim=-1, keepdim=True)
min_val = torch.amin(w_grouped, dim=-1, keepdim=True)
scale = torch.clamp(max_val - min_val, min=eps) / qmax
zero_point = min_val + scale * fbgemm_symmetric_qmax
Comment on lines +183 to +188
Copy link
Contributor

@jerryzh168 jerryzh168 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we call int4_row_quantize_zp and get the scale/zero_point from there? is it because of performance concerns?

I guess we could ask fbgemm to add another function to just compute scale/zero_point so we can call it here in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah and also they cast the quantized values to int8, which we don't want to do here

fq = _Round.apply((w_grouped - min_val) / scale).clamp(qmin, qmax)
fq = fq - fbgemm_symmetric_qmax
fq = fq * scale + zero_point
return fq.view(w.shape).to(w.dtype)


class IntxFakeQuantizer(FakeQuantizerBase):
"""
Expand Down
Loading