From c2ffbac1e6cb0eb4803c85b429a65a4b353fda0e Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 22 Jul 2025 20:10:16 +0000 Subject: [PATCH 1/2] Speed up nvfp4 pack/unpack w/ torch.compile Signed-off-by: Fynn Schmitt-Ulms --- .../compressors/quantized_compressors/nvfp4_quantized.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5f348e91..167f5f95 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -105,6 +105,7 @@ def decompress_weight( return decompressed_weight +@torch.compile(fullgraph=True) def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: """ Packs a tensor with values in the fp4 range into uint8. @@ -127,12 +128,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: # Find closest valid FP4 value index for each element abs_x = torch.abs(x) - abs_indices = torch.zeros_like(abs_x, dtype=torch.long) - for i, val in enumerate(kE2M1): - abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices) + abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8] + abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n] # Apply sign bit (bit 3) to get final 4-bit representation - indices = abs_indices + (torch.signbit(x) << 3).to(torch.long) + indices = abs_indices + (torch.signbit(x).to(torch.long) << 3) # Reshape to prepare for packing pairs of values indices = indices.reshape(-1) @@ -155,6 +155,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: ) # reference: : https://github.com/vllm-project/vllm/pull/16362 +@torch.compile(fullgraph=True) def unpack_fp4_from_uint8( a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 ) -> torch.Tensor: From 6bb69c172c8bf2e16ea6ddfa29c6b480fe2e8b3d Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Wed, 30 Jul 2025 16:00:26 +0000 Subject: [PATCH 2/2] Add `dynamic=True` to torch.compile call in nvfp4 packing Signed-off-by: Fynn Schmitt-Ulms --- .../compressors/quantized_compressors/nvfp4_quantized.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 167f5f95..419d47c6 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -71,7 +71,6 @@ def compress_weight( zero_point: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: - quantized_weight = quantize( x=weight, scale=scale, @@ -91,7 +90,6 @@ def decompress_weight( compressed_data: Dict[str, Tensor], quantization_args: Optional[QuantizationArgs] = None, ) -> torch.Tensor: - weight = compressed_data["weight_packed"] scale = compressed_data["weight_scale"] global_scale = compressed_data["weight_global_scale"] @@ -105,7 +103,7 @@ def decompress_weight( return decompressed_weight -@torch.compile(fullgraph=True) +@torch.compile(fullgraph=True, dynamic=True) def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: """ Packs a tensor with values in the fp4 range into uint8. @@ -154,8 +152,9 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 ) + # reference: : https://github.com/vllm-project/vllm/pull/16362 -@torch.compile(fullgraph=True) +@torch.compile(fullgraph=True, dynamic=True) def unpack_fp4_from_uint8( a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 ) -> torch.Tensor: