diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 115e6784fb..35870a5e6b 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -733,6 +733,48 @@ def test_preprocess_scale_3d_reshape(self): expected_shape = (8, 1) # Flattened (2*2*2, 1) self.assertEqual(result.shape, expected_shape) + @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) + def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype): + quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8_non_decomposed + dequantize_affine_float8 = ( + torch.ops.torchao.dequantize_affine_float8_non_decomposed + ) + input = torch.randn(10, 10) + with torch.no_grad(): + torch._dynamo.reset() + expected_scale = torch.tensor(2.0) + expected_quantized = quantize_affine_float8( + input, + expected_scale, + float8_dtype=float8_dtype, + ) + expected_dequantized = dequantize_affine_float8( + expected_quantized, + expected_scale, + output_dtype=hp_dtype, + ) + test_q, (code_q,) = torch._inductor.utils.run_and_get_code( + torch.compile(quantize_affine_float8), + input, + expected_scale, + float8_dtype=float8_dtype, + ) + torch.testing.FileCheck().check(f"{quantize_affine_float8}.default").run( + code_q + ) + test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code( + torch.compile(dequantize_affine_float8), + test_q, + expected_scale, + hp_dtype, + ) + torch.testing.FileCheck().check(f"{dequantize_affine_float8}.default").run( + code_dq + ) + torch.testing.assert_close(expected_quantized, test_q) + torch.testing.assert_close(expected_dequantized, test_dq) + @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 54ad472219..cdfbc00c3a 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2310,8 +2310,6 @@ def _quantize_affine_float8( return _RoundToFloat8.apply(tensor_clamped, float8_dtype) -# TODO: don't register as custom op? -@_register_custom_op(quant_lib, False) def _dequantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, @@ -2329,7 +2327,48 @@ def _dequantize_affine_float8( return hp_tensor.to(output_dtype) -@_register_meta_op(quant_lib, "dequantize_affine_float8") +@_register_custom_op(quant_lib, False) +def _quantize_affine_float8_non_decomposed( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, +) -> torch.Tensor: + """ + Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. + """ + return _quantize_affine_float8( + tensor=tensor, + scale=scale, + float8_dtype=float8_dtype, + ) + + +@_register_meta_op(quant_lib, "quantize_affine_float8_non_decomposed") +def _quantize_affine_float8_meta( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, +) -> torch.Tensor: + return torch.empty_like(tensor, dtype=float8_dtype) + + +@_register_custom_op(quant_lib, False) +def _dequantize_affine_float8_non_decomposed( + tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Dequantizes the float8 tensor to high precision tensor. + """ + return _dequantize_affine_float8( + tensor=tensor, + scale=scale, + output_dtype=output_dtype, + ) + + +@_register_meta_op(quant_lib, "dequantize_affine_float8_non_decomposed") def _dequantize_affine_float8_meta( tensor: torch.Tensor, scale: torch.Tensor,