-
Notifications
You must be signed in to change notification settings - Fork 338
Improve QAT int4 weight-only numerics #2986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,11 +103,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return dq | ||
|
||
|
||
# TODO: rename this, it also works for plain Int4Tensor | ||
class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase): | ||
""" | ||
Generic module for applying int4 fake quantization to a weight tensor, | ||
targeting the following FBGEMM kernel: | ||
targeting the following FBGEMM kernels: | ||
torch.ops.fbgemm.f8i4bf16_shuffled | ||
torch.ops.fbgemm.bf16i4bf16_shuffled | ||
torch.ops.fbgemm.bf16i4bf16_rowwise | ||
""" | ||
|
||
def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig): | ||
|
@@ -118,11 +121,18 @@ def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig): | |
) | ||
|
||
def forward(self, w: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Apply int4 fake quantization to the weight tensor, using the following as a reference: | ||
https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L112 | ||
if self.config.activation_dtype == torch.float8_e4m3fn: | ||
return self._fp8_activations_forward(w) | ||
elif self.config.activation_dtype == torch.bfloat16: | ||
return self._bf16_activations_forward(w) | ||
else: | ||
raise ValueError(f"Unknown activation dtype {self.config.activation_dtype}") | ||
|
||
Currently, we expect the activations to always be rowwise float8. | ||
def _fp8_activations_forward(self, w: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Apply int4 fake quantization to the weight tensor where the input activations | ||
are expected to be rowwise fp8, using the following as a reference: | ||
https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L136 | ||
""" | ||
assert w.dim() == 2 | ||
assert self.config.activation_dtype == torch.float8_e4m3fn | ||
|
@@ -159,6 +169,28 @@ def forward(self, w: torch.Tensor) -> torch.Tensor: | |
) | ||
return fq.to(w.dtype) | ||
|
||
def _bf16_activations_forward(self, w: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Apply int4 fake quantization to the weight tensor where the input activations | ||
are expected to be bf16, using the following as a reference: | ||
https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L152 | ||
""" | ||
assert w.dim() == 2 | ||
assert self.config.activation_dtype == torch.bfloat16 | ||
|
||
eps = 1e-6 | ||
qmin, qmax = 0, 15 | ||
fbgemm_symmetric_qmax = 8 | ||
w_grouped = w.to(torch.float32).view(w.shape[0], -1, self.config.group_size) | ||
max_val = torch.amax(w_grouped, dim=-1, keepdim=True) | ||
min_val = torch.amin(w_grouped, dim=-1, keepdim=True) | ||
scale = torch.clamp(max_val - min_val, min=eps) / qmax | ||
zero_point = min_val + scale * fbgemm_symmetric_qmax | ||
Comment on lines
+183
to
+188
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why don't we call I guess we could ask fbgemm to add another function to just compute scale/zero_point so we can call it here in the future There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah and also they cast the quantized values to int8, which we don't want to do here |
||
fq = _Round.apply((w_grouped - min_val) / scale).clamp(qmin, qmax) | ||
fq = fq - fbgemm_symmetric_qmax | ||
fq = fq * scale + zero_point | ||
return fq.view(w.shape).to(w.dtype) | ||
|
||
|
||
class IntxFakeQuantizer(FakeQuantizerBase): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it's fine for QAT to only support version 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
although you may want to cover more int4 packing format such as
TILE_PACKED_TO_4D
the previous tinygemm layoutThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think we can drop version 1, but it's BC breaking so we can do it separately