From cbf82c71b11c05152bf32f5b82dd7f3a2b07afc9 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 12 Sep 2025 14:01:03 -0700 Subject: [PATCH 1/3] Remove internal usage of all config functions like `int4_weight_only` **Summary:** These are now deprecated as of #2994. We should stop using them internally as well. **Test Plan:** CI [ghstack-poisoned] --- benchmarks/benchmark_aq.py | 12 +- .../quantized_training/pretrain_llama2.py | 4 +- test/dtypes/test_affine_quantized.py | 36 +++-- .../test_affine_quantized_tensor_parallel.py | 28 ++-- test/dtypes/test_floatx.py | 4 +- test/dtypes/test_uintx.py | 17 +-- test/hqq/test_hqq_affine.py | 10 +- test/integration/test_integration.py | 37 ++--- test/prototype/test_parq.py | 6 +- .../pt2e/test_x86inductor_fusion.py | 2 +- test/quantization/test_marlin_qqq.py | 6 +- test/quantization/test_quant_api.py | 129 +++++++++--------- test/sparsity/test_marlin.py | 14 +- test/sparsity/test_sparse_api.py | 18 +-- torchao/_models/llama/eval.py | 33 +++-- torchao/_models/llama/generate.py | 55 ++++---- torchao/_models/sam/eval_combo.py | 16 ++- torchao/dtypes/floatx/README.md | 4 +- torchao/prototype/autoround/README.md | 2 +- torchao/prototype/autoround/eval_autoround.py | 11 +- torchao/prototype/hqq/example.py | 4 +- .../mixed_precision/scripts/naive_intNwo.py | 4 +- torchao/prototype/quantized_training/int8.py | 2 +- torchao/quantization/README.md | 2 +- .../quantization/pt2e/inductor_passes/x86.py | 2 +- .../test_reference_representation_rewrite.py | 8 +- torchao/quantization/quant_api.py | 20 +-- torchao/sparsity/sparse_api.py | 2 +- torchao/testing/utils.py | 4 +- tutorials/quantize_vit/run_vit_b_quant.py | 4 +- 30 files changed, 254 insertions(+), 242 deletions(-) diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index 34be7f3005..8eb6ddde11 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -10,10 +10,10 @@ import torch from torchao.quantization.quant_api import ( + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, _replace_with_custom_fn_if_matches_filter, - int4_weight_only, - int8_dynamic_activation_int8_weight, - int8_weight_only, quantize_, ) from torchao.quantization.subclass import ( @@ -23,13 +23,13 @@ def _int8wo_api(mod, **kwargs): - quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False) + quantize_(mod, Int8WeightOnlyConfig(**kwargs), set_inductor_config=False) def _int8da_int8w_api(mod, **kwargs): quantize_( mod, - int8_dynamic_activation_int8_weight(**kwargs), + Int8DynamicActivationInt8WeightConfig(**kwargs), set_inductor_config=False, ) @@ -39,7 +39,7 @@ def _int4wo_api(mod, **kwargs): if "groupsize" in kwargs_copy: kwargs_copy["group_size"] = kwargs_copy["groupsize"] del kwargs_copy["groupsize"] - quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False) + quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy), set_inductor_config=False) class ToyLinearModel(torch.nn.Module): diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 6a1f4e8efb..2e1243d1d9 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -166,8 +166,8 @@ def insert_rmsnorm(module: torch.nn.Module): insert_rmsnorm(model.layers) # don't apply int8_mixed_precision to LM head, since it can cause convergence issue. - # TODO: might want to do the same for int8_weight_only to standardize. - if args.quantize == "int8_weight_only": + # TODO: might want to do the same for Int8WeightOnlyConfig to standardize. + if args.quantize == "Int8WeightOnlyConfig": quantize_( model, int8_weight_only_quantized_training(), set_inductor_config=False ) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 220b9c4455..56f42a8043 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -27,15 +27,13 @@ from torchao.float8.config import e4m3_dtype from torchao.quantization import ( FbgemmConfig, + Float8WeightOnlyConfig, GemliteUIntXWeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, - float8_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, + Int8WeightOnlyConfig, quantize_, ) from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain @@ -58,23 +56,23 @@ def get_quantization_functions( do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False ): base_functions = [ - int8_weight_only(), - int8_dynamic_activation_int4_weight(), - int8_dynamic_activation_int8_weight(), - int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC), + Int8WeightOnlyConfig(), + Int8DynamicActivationInt4WeightConfig(), + Int8DynamicActivationInt8WeightConfig(), + Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC), ] if do_int4: if check_cpu_version(device): base_functions.append( - int4_weight_only(group_size=32, layout=Int4CPULayout(), version=1) + Int4WeightOnlyConfig(group_size=32, layout=Int4CPULayout(), version=1) ) elif check_xpu_version(device): base_functions.append( - int4_weight_only(group_size=32, layout=Int4XPULayout(), version=1) + Int4WeightOnlyConfig(group_size=32, layout=Int4XPULayout(), version=1) ) if int4_zp_int: base_functions.append( - int4_weight_only( + Int4WeightOnlyConfig( group_size=32, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT, @@ -82,25 +80,25 @@ def get_quantization_functions( ) ) else: - base_functions.append(int4_weight_only(group_size=32, version=1)) + base_functions.append(Int4WeightOnlyConfig(group_size=32, version=1)) if device == "cuda" and not is_ROCM(): base_functions.append( - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=None, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, layout=CutlassInt4PackedLayout(), ) ) - base_functions.append(int4_dynamic_activation_int4_weight()) + base_functions.append(Int4DynamicActivationInt4WeightConfig()) if do_sparse and device != "xpu": base_functions.append( - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) + Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()) ) if is_sm_at_least_89(): - base_functions.append(float8_weight_only()) + base_functions.append(Float8WeightOnlyConfig()) if is_sm_at_least_90(): base_functions.append(FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16)) @@ -119,7 +117,7 @@ def test_tensor_core_layout_transpose(self): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") t = linear.weight shape = t.shape - apply_int4_weight_only_quant = int4_weight_only(group_size=32, version=1) + apply_int4_weight_only_quant = Int4WeightOnlyConfig(group_size=32, version=1) quantize_(linear, apply_int4_weight_only_quant) ql = linear aqt = ql.weight diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 49471d3ad1..983f701849 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -16,11 +16,11 @@ ) from torchao.quantization import ( - float8_dynamic_activation_float8_weight, - float8_weight_only, - int4_weight_only, - int8_dynamic_activation_int8_weight, - int8_weight_only, + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, ) from torchao.quantization.observer import PerRow, PerTensor from torchao.quantization.quant_api import quantize_ @@ -42,7 +42,7 @@ class TestAffineQuantizedTensorParallel(DTensorTestBase): """Basic test case for tensor subclasses""" - QUANT_METHOD_FN = staticmethod(int8_weight_only) + QUANT_METHOD_FN = staticmethod(Int8WeightOnlyConfig) QUANT_METHOD_KWARGS = {} @staticmethod @@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TestInt8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): - QUANT_METHOD_FN = staticmethod(int8_weight_only) + QUANT_METHOD_FN = staticmethod(Int8WeightOnlyConfig) COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -144,7 +144,7 @@ def test_tp(self, dtype): class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): - QUANT_METHOD_FN = staticmethod(int4_weight_only) + QUANT_METHOD_FN = staticmethod(Int4WeightOnlyConfig) QUANT_METHOD_KWARGS = {"version": 1} COMMON_DTYPES = [torch.bfloat16] @@ -167,12 +167,12 @@ class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not has_gemlite, "gemlite not available") def test_tp_gemlite(self, dtype): - from torchao.quantization import gemlite_uintx_weight_only + from torchao.quantization import GemliteUIntXWeightOnlyConfig for packing_bitwidth in [32, 8]: for bit_width in [4, 8]: for group_size in [64, 32, None] if bit_width == 4 else [None]: - api = lambda: gemlite_uintx_weight_only( + api = lambda: GemliteUIntXWeightOnlyConfig( group_size, bit_width, packing_bitwidth ) self.QUANT_METHOD_FN = staticmethod(api) @@ -180,7 +180,7 @@ def test_tp_gemlite(self, dtype): class TestInt8dqAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): - QUANT_METHOD_FN = staticmethod(int8_dynamic_activation_int8_weight) + QUANT_METHOD_FN = staticmethod(Int8DynamicActivationInt8WeightConfig) COMMON_DTYPES = [torch.bfloat16] @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -199,7 +199,7 @@ def test_tp(self, dtype): if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): - QUANT_METHOD_FN = staticmethod(float8_weight_only) + QUANT_METHOD_FN = staticmethod(Float8WeightOnlyConfig) COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -211,7 +211,7 @@ def test_tp(self, dtype): class TestFloat8dqTensorAffineQuantizedTensorParallel( TestAffineQuantizedTensorParallel ): - QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) + QUANT_METHOD_FN = staticmethod(Float8DynamicActivationFloat8WeightConfig) QUANT_METHOD_KWARGS = {"granularity": PerTensor()} COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] @@ -224,7 +224,7 @@ def test_tp(self, dtype): class TestFloat8dqRowAffineQuantizedTensorParallel( TestAffineQuantizedTensorParallel ): - QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) + QUANT_METHOD_FN = staticmethod(Float8DynamicActivationFloat8WeightConfig) QUANT_METHOD_KWARGS = {"granularity": PerRow()} COMMON_DTYPES = [torch.bfloat16] diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 9a99ba0802..ab4a13d24c 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -29,7 +29,7 @@ _floatx_unpacked_to_f32, ) from torchao.quantization import ( - fpx_weight_only, + FPXWeightOnlyConfig, quantize_, ) from torchao.testing.utils import skip_if_rocm @@ -118,7 +118,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias, dtype): linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype) fpx_linear = copy.deepcopy(linear) - quantize_(fpx_linear, fpx_weight_only(ebits, mbits)) + quantize_(fpx_linear, FPXWeightOnlyConfig(ebits, mbits)) x = torch.randn(N, IC, device=device, dtype=dtype) expected = fpx_linear(x) diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index dbc69b8ee9..cb0c88b21c 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -7,7 +7,7 @@ import torch from torchao.dtypes.uintx.uintx_layout import to_uintx -from torchao.quantization.quant_api import quantize_, uintx_weight_only +from torchao.quantization.quant_api import UIntXWeightOnlyConfig, quantize_ from torchao.quantization.quant_primitives import ( MappingType, choose_qparams_affine, @@ -60,7 +60,7 @@ def forward(self, x): def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): scale = 512 fp16_mod_on_cpu = Linear16(scale, "cpu") - quantize_(fp16_mod_on_cpu, uintx_weight_only(dtype, group_size=group_size)) + quantize_(fp16_mod_on_cpu, UIntXWeightOnlyConfig(dtype, group_size=group_size)) test_input_on_cpu = torch.randn(scale * 2, dtype=torch.float16, device="cpu") output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu) fp16_mod_on_cuda = fp16_mod_on_cpu.to("cuda") @@ -78,7 +78,7 @@ def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): def test_uintx_weight_only_model_quant(dtype, group_size, device): scale = 512 fp16 = Linear16(scale, device) - quantize_(fp16, uintx_weight_only(dtype, group_size=group_size)) + quantize_(fp16, UIntXWeightOnlyConfig(dtype, group_size=group_size)) uintx = torch.compile(fp16, fullgraph=True) test_input = torch.randn(scale * 2, dtype=torch.float16, device=device) output = uintx.forward(test_input) @@ -124,22 +124,18 @@ def test_uintx_weight_only_quant(dtype, group_size, device): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_uintx_target_dtype(dtype): - from torchao.quantization.quant_api import uintx_weight_only - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - quantize_(linear, uintx_weight_only(dtype)) + quantize_(linear, UIntXWeightOnlyConfig(dtype)) linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_uintx_target_dtype_compile(dtype): - from torchao.quantization.quant_api import uintx_weight_only - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - quantize_(linear, uintx_weight_only(dtype)) + quantize_(linear, UIntXWeightOnlyConfig(dtype)) linear = torch.compile(linear) linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @@ -147,7 +143,6 @@ def test_uintx_target_dtype_compile(dtype): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_uintx_model_size(dtype): - from torchao.quantization.quant_api import uintx_weight_only from torchao.utils import get_model_size_in_bytes # scale size = 1/64 * 2 bytes = 1/32 bytes @@ -167,6 +162,6 @@ def test_uintx_model_size(dtype): ) bf16_size = get_model_size_in_bytes(linear) # make sure it runs - quantize_(linear[0], uintx_weight_only(dtype)) + quantize_(linear[0], UIntXWeightOnlyConfig(dtype)) quantized_size = get_model_size_in_bytes(linear) assert bf16_size * _dtype_to_ratio[dtype] == quantized_size diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index d237eec53a..09bdfa8e61 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -8,11 +8,11 @@ import torch from torchao.quantization import ( + Int4WeightOnlyConfig, MappingType, + UIntXWeightOnlyConfig, ZeroPointDomain, - int4_weight_only, quantize_, - uintx_weight_only, ) from torchao.testing.utils import skip_if_rocm @@ -55,9 +55,11 @@ def _eval_hqq(dtype): ) dummy_linear.weight.data = W if dtype == torch.uint4: - config = int4_weight_only(group_size=max(block_size), use_hqq=True, version=1) + config = Int4WeightOnlyConfig( + group_size=max(block_size), use_hqq=True, version=1 + ) else: - config = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True) + config = UIntXWeightOnlyConfig(dtype, group_size=max(block_size), use_hqq=True) quantize_(dummy_linear, config) q_tensor_hqq = dummy_linear.weight diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 0269b8a223..f99cf4a1b4 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -39,11 +39,11 @@ # APIs to be deprecated (used for torch 2.2.2 and 2.3) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, _replace_with_custom_fn_if_matches_filter, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, quantize_, ) from torchao.quantization.quant_primitives import ( @@ -109,12 +109,14 @@ def _int8wo_api(mod): - quantize_(mod, int8_weight_only(set_inductor_config=False)) + quantize_(mod, Int8WeightOnlyConfig(set_inductor_config=False)) def _int8wo_groupwise_api(mod): group_size = 32 - quantize_(mod, int8_weight_only(group_size=group_size, set_inductor_config=False)) + quantize_( + mod, Int8WeightOnlyConfig(group_size=group_size, set_inductor_config=False) + ) def _int8da_int8w_api( @@ -123,7 +125,7 @@ def _int8da_int8w_api( ): quantize_( mod, - int8_dynamic_activation_int8_weight( + Int8DynamicActivationInt8WeightConfig( act_mapping_type=act_mapping_type, set_inductor_config=False, ), @@ -134,7 +136,7 @@ def _int4wo_api(mod, use_hqq=False): if check_cpu_version(next(mod.parameters()).device): quantize_( mod, - int4_weight_only( + Int4WeightOnlyConfig( layout=Int4CPULayout(), use_hqq=use_hqq, set_inductor_config=False, @@ -145,17 +147,17 @@ def _int4wo_api(mod, use_hqq=False): elif check_xpu_version(next(mod.parameters()).device): quantize_( mod, - int4_weight_only( + Int4WeightOnlyConfig( layout=Int4XPULayout(), set_inductor_config=False, version=1 ), ) unwrap_tensor_subclass(mod) else: - quantize_(mod, int4_weight_only(set_inductor_config=False, version=1)) + quantize_(mod, Int4WeightOnlyConfig(set_inductor_config=False, version=1)) def _int8da_int4w_api(mod): - quantize_(mod, int8_dynamic_activation_int4_weight(set_inductor_config=False)) + quantize_(mod, Int8DynamicActivationInt4WeightConfig(set_inductor_config=False)) # TODO: use this to reduce the number of tests @@ -1030,9 +1032,10 @@ def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not has_gemlite, "gemlite not available") def test_gemlite_layout(self, device, dtype): + from torchao.quantization import GemliteUIntXWeightOnlyConfig + if dtype != torch.float16: self.skipTest("gemlite only works for fp16 dtype") - from torchao.quantization import gemlite_uintx_weight_only if device == "cpu": self.skipTest(f"gemlite is for cuda, not {device}") @@ -1041,7 +1044,7 @@ def test_gemlite_layout(self, device, dtype): for group_size in [64, 32, None] if bit_width == 4 else [None]: api = lambda mod: quantize_( mod, - gemlite_uintx_weight_only( + GemliteUIntXWeightOnlyConfig( group_size, bit_width, packing_bitwidth ), ) @@ -1063,7 +1066,7 @@ def test_gemlite_layout(self, device, dtype): # test that shapes with non divisible by 128 shapes aren't causing errors self._test_lin_weight_subclass_api_impl( - lambda mod: quantize_(mod, gemlite_uintx_weight_only(None, 4, 32)), + lambda mod: quantize_(mod, GemliteUIntXWeightOnlyConfig(None, 4, 32)), device, 15, test_shape=[1, 1025, 513], @@ -1094,7 +1097,7 @@ def api(mod): kwargs_copy = kwargs.copy() kwargs_copy["group_size"] = groupsize del kwargs_copy["groupsize"] - quantize_(mod, int4_weight_only(**kwargs_copy)) + quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy)) self._test_lin_weight_subclass_api_impl( api, @@ -1112,7 +1115,7 @@ def test_dynamic_quant(self): m = nn.Sequential(nn.Linear(K, N)) y_ref = m(x) - quantize_(m, int8_dynamic_activation_int8_weight()) + quantize_(m, Int8DynamicActivationInt8WeightConfig()) y_test = m(x) sqnr = compute_error(y_ref, y_test) @@ -1152,7 +1155,7 @@ def test_weight_only_groupwise_embedding_quant(self): quantize_( m, - int8_weight_only(group_size=group_size), + Int8WeightOnlyConfig(group_size=group_size), filter_fn=lambda x, *args: isinstance(x, nn.Embedding), ) y_q = m(input) diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 50e296aa30..4f3108c2a2 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -32,10 +32,10 @@ from torchao.quantization.granularity import PerGroup from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig from torchao.quantization.quant_api import ( + Int4WeightOnlyConfig, Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, _is_linear, - int4_weight_only, quantize_, ) from torchao.quantization.quant_primitives import MappingType @@ -211,7 +211,7 @@ def test_int4_weight_only(self, group_size: int = 32): model.reset_parameters() m_ref = copy.deepcopy(model).eval().to(_DEVICE) - config = int4_weight_only(group_size=group_size, version=1) + config = Int4WeightOnlyConfig(group_size=group_size, version=1) if check_cpu_version(_DEVICE): config.layout = Int4CPULayout() quantize_(m_ref, config) @@ -244,7 +244,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32): model.reset_parameters() m_ref = copy.deepcopy(model).eval().to(_DEVICE) - config = int4_weight_only(group_size=group_size, version=1) + config = Int4WeightOnlyConfig(group_size=group_size, version=1) if check_cpu_version(_DEVICE): config.layout = Int4CPULayout() quantize_(m_ref, config) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 6e3772c76a..804b7c975a 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -2336,7 +2336,7 @@ def test_da8w8_sym_act_sym_wgt_with_int_mm( self, has_bias, dtype, dynamic, reshape_a, M, inplace_add, expand_a_scale ): r""" - This testcase check if we can match the int8_dynamic_activation_int8_weight int8 linear pattern from torchao, + This testcase check if we can match the Int8DynamicActivationInt8WeightConfig int8 linear pattern from torchao, when activation is symmetrically quantized dynamically & weights are symmetrically quantized (statically) The pattern is: (no bias) _int_mm -> convert_element_type -> ([expand_a] -> mul) -> mul diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index 56b309b948..e0733520ff 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -16,7 +16,7 @@ unpack_from_marlin_qqq, ) from torchao.quantization.quant_api import ( - int8_dynamic_activation_int4_weight, + Int8DynamicActivationInt4WeightConfig, quantize_, ) from torchao.quantization.quant_primitives import ( @@ -53,7 +53,7 @@ def test_marlin_qqq(self): modelq = copy.deepcopy(self.model) quantize_( modelq, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=group_size, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, @@ -77,7 +77,7 @@ def test_marlin_qqq_compile(self): modelq = copy.deepcopy(self.model) quantize_( modelq, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=group_size, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index cfa29154ab..6380b568be 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -39,25 +39,22 @@ PerGroup, ) from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + Float8StaticActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, Quantizer, TwoStepQuantizer, + UIntXWeightOnlyConfig, _replace_with_custom_fn_if_matches_filter, - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - gemlite_uintx_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, - uintx_weight_only, ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.subclass import ( @@ -124,7 +121,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: - quantize_(model, int8_dynamic_activation_int8_weight()) + quantize_(model, Int8DynamicActivationInt8WeightConfig()) return model @@ -184,7 +181,7 @@ class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() - quantize_(m, int8_dynamic_activation_int8_weight()) + quantize_(m, Int8DynamicActivationInt8WeightConfig()) m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) @@ -227,7 +224,7 @@ def test_int4_wo_quant_save_load(self): m = ToyLinearModel().eval().cpu() def api(model): - quantize_(model, int4_weight_only(layout=Int4XPULayout(), version=1)) + quantize_(model, Int4WeightOnlyConfig(layout=Int4XPULayout(), version=1)) unwrap_tensor_subclass(model) api(m) @@ -254,7 +251,7 @@ def test_int8_wo_quant_save_load(self): m = ToyLinearModel().eval().cpu() def api(model): - quantize_(model, int8_weight_only()) + quantize_(model, Int8WeightOnlyConfig()) unwrap_tensor_subclass(model) api(m) @@ -418,7 +415,7 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type): example_inputs = m.example_inputs() quantize_( m, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=group_size, mapping_type=mapping_type ), ) @@ -459,12 +456,12 @@ def test_quantized_tensor_subclass_int4(self): if device == "xpu": quantize_( m, - int4_weight_only( + Int4WeightOnlyConfig( group_size=group_size, layout=Int4XPULayout(), version=1 ), ) else: - quantize_(m, int4_weight_only(group_size=group_size, version=1)) + quantize_(m, Int4WeightOnlyConfig(group_size=group_size, version=1)) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -482,7 +479,7 @@ def test_quantized_tensor_subclass_int8_wo(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - quantize_(m, int8_weight_only()) + quantize_(m, Int8WeightOnlyConfig()) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -501,7 +498,7 @@ def test_quantized_tensor_subclass_save_load(self): m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16) - quantize_(m, int8_weight_only()) + quantize_(m, Int8WeightOnlyConfig()) ref = m(*example_inputs) with tempfile.NamedTemporaryFile() as f: torch.save(m.state_dict(), f) @@ -518,7 +515,7 @@ def test_int8wo_quantized_model_to_device(self): m = ToyLinearModel().eval().to(torch.bfloat16) example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu") - quantize_(m, int8_weight_only()) + quantize_(m, Int8WeightOnlyConfig()) ref = m(*example_inputs) example_inputs_cuda = (example_inputs[0].to("cuda"),) @@ -531,7 +528,7 @@ def test_quantized_tensor_subclass_save_load_map_location(self): m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda") example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") - quantize_(m, int8_weight_only()) + quantize_(m, Int8WeightOnlyConfig()) ref = m(*example_inputs) with tempfile.NamedTemporaryFile() as f: torch.save(m.state_dict(), f) @@ -556,13 +553,13 @@ def reset_memory(): reset_memory() m = ToyLinearModel() - quantize_(m.to(device="cuda"), int8_weight_only()) + quantize_(m.to(device="cuda"), Int8WeightOnlyConfig()) memory_baseline = torch.cuda.max_memory_allocated() del m reset_memory() m = ToyLinearModel() - quantize_(m, int8_weight_only(), device="cuda") + quantize_(m, Int8WeightOnlyConfig(), device="cuda") memory_streaming = torch.cuda.max_memory_allocated() for param in m.parameters(): @@ -582,7 +579,7 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): with torch.no_grad(): quantize_( m, - int4_weight_only( + Int4WeightOnlyConfig( group_size=32, layout=Int4CPULayout(), use_hqq=use_hqq, version=1 ), ) @@ -599,51 +596,51 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): @common_utils.parametrize( "config", [ - int4_weight_only(version=1), - float8_weight_only(), - float8_dynamic_activation_float8_weight(), - float8_static_activation_float8_weight(scale=torch.tensor([1.0])), - int4_dynamic_activation_int4_weight(), - int8_dynamic_activation_int8_weight(), - int8_dynamic_activation_int4_weight(), - int8_weight_only(), - fpx_weight_only(ebits=4, mbits=3), - gemlite_uintx_weight_only(), - uintx_weight_only(dtype=torch.uint4), + Int4WeightOnlyConfig(version=1), + Float8WeightOnlyConfig(), + Float8DynamicActivationFloat8WeightConfig(), + Float8StaticActivationFloat8WeightConfig(scale=torch.tensor([1.0])), + Int4DynamicActivationInt4WeightConfig(), + Int8DynamicActivationInt8WeightConfig(), + Int8DynamicActivationInt4WeightConfig(), + Int8WeightOnlyConfig(), + FPXWeightOnlyConfig(ebits=4, mbits=3), + GemliteUIntXWeightOnlyConfig(), + UIntXWeightOnlyConfig(dtype=torch.uint4), ], ) @skip_if_rocm("ROCm enablement in progress") def test_workflow_e2e_numerics(self, config): """ - Simple test of e2e int4_weight_only workflow, comparing numerics + Simple test of e2e Int4WeightOnlyConfig workflow, comparing numerics to a bfloat16 baseline. """ if ( isinstance( config, ( - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, + Float8DynamicActivationFloat8WeightConfig, + Float8StaticActivationFloat8WeightConfig, ), ) and not is_sm_at_least_89() ): return unittest.skip("requires CUDA capability 8.9 or greater") elif ( - isinstance(config, int4_dynamic_activation_int4_weight) + isinstance(config, Int4DynamicActivationInt4WeightConfig) and is_sm_at_least_90() ): return unittest.skip("only supported on CUDA capability 8.9, not greater") - elif isinstance(config, gemlite_uintx_weight_only) and not has_gemlite: + elif isinstance(config, GemliteUIntXWeightOnlyConfig) and not has_gemlite: return unittest.skip("gemlite not available") # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability - if isinstance(config, float8_static_activation_float8_weight): + if isinstance(config, Float8StaticActivationFloat8WeightConfig): config.scale = config.scale.to("cuda") dtype = torch.bfloat16 - if isinstance(config, gemlite_uintx_weight_only): + if isinstance(config, GemliteUIntXWeightOnlyConfig): dtype = torch.float16 # set up inputs @@ -755,20 +752,20 @@ def test_int4wo_cuda_serialization(self): def test_config_deprecation(self): """ - Test that old config functions like `int4_weight_only` trigger deprecation warnings. + Test that old config functions like `Int4WeightOnlyConfig` trigger deprecation warnings. """ from torchao.quantization import ( - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - gemlite_uintx_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, - uintx_weight_only, + Float8DynamicActivationFloat8WeightConfig, + Float8StaticActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + UIntXWeightOnlyConfig, ) # Reset deprecation warning state, otherwise we won't log warnings here @@ -776,17 +773,17 @@ def test_config_deprecation(self): # Map from deprecated API to the args needed to instantiate it deprecated_apis_to_args = { - float8_dynamic_activation_float8_weight: (), - float8_static_activation_float8_weight: (torch.randn(3)), - float8_weight_only: (), - fpx_weight_only: (3, 2), - gemlite_uintx_weight_only: (), - int4_dynamic_activation_int4_weight: (), - int4_weight_only: (), - int8_dynamic_activation_int4_weight: (), - int8_dynamic_activation_int8_weight: (), - int8_weight_only: (), - uintx_weight_only: (torch.uint4,), + Float8DynamicActivationFloat8WeightConfig: (), + Float8StaticActivationFloat8WeightConfig: (torch.randn(3)), + Float8WeightOnlyConfig: (), + FPXWeightOnlyConfig: (3, 2), + GemliteUIntXWeightOnlyConfig: (), + Int4DynamicActivationInt4WeightConfig: (), + Int4WeightOnlyConfig: (), + Int8DynamicActivationInt4WeightConfig: (), + Int8DynamicActivationInt8WeightConfig: (), + Int8WeightOnlyConfig: (), + UIntXWeightOnlyConfig: (torch.uint4,), } with warnings.catch_warnings(record=True) as _warnings: diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index d193ae9db2..e602210ee5 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -11,7 +11,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests from torchao.dtypes import MarlinSparseLayout -from torchao.quantization.quant_api import int4_weight_only, quantize_ +from torchao.quantization.quant_api import Int4WeightOnlyConfig, quantize_ from torchao.quantization.quant_primitives import ( MappingType, choose_qparams_affine, @@ -47,11 +47,13 @@ def test_quant_sparse_marlin_layout_eager(self): model_copy = copy.deepcopy(self.model) # Quantized - quantize_(model_copy.bfloat16(), int4_weight_only(version=1)) + quantize_(model_copy.bfloat16(), Int4WeightOnlyConfig(version=1)) dense_result = model_copy(self.input.bfloat16()).half() # Sparse + quantized - quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout(), version=1)) + quantize_( + self.model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1) + ) sparse_result = self.model(self.input) assert torch.allclose(dense_result, sparse_result, atol=3e-1), ( "Results are not close" @@ -64,12 +66,14 @@ def test_quant_sparse_marlin_layout_compile(self): model_copy = copy.deepcopy(self.model) # Quantized - quantize_(model_copy.bfloat16(), int4_weight_only(version=1)) + quantize_(model_copy.bfloat16(), Int4WeightOnlyConfig(version=1)) model_copy.foward = torch.compile(model_copy.forward, fullgraph=True) dense_result = model_copy(self.input.bfloat16()).half() # Sparse + quantized - quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout(), version=1)) + quantize_( + self.model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1) + ) self.model.forward = torch.compile(self.model.forward, fullgraph=True) sparse_result = self.model(self.input) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 0bf0fe4d8c..ab00ee5e55 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -13,8 +13,8 @@ from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout from torchao.quantization.quant_api import ( - int4_weight_only, - int8_dynamic_activation_int8_weight, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, quantize_, ) from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ @@ -76,12 +76,12 @@ def test_quant_semi_sparse(self, compile): ) apply_fake_sparsity(model) model_copy = copy.deepcopy(model) - quantize_(model_copy, int8_dynamic_activation_int8_weight()) + quantize_(model_copy, Int8DynamicActivationInt8WeightConfig()) dense_result = model_copy(input) quantize_( model, - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()), ) if compile: model = torch.compile(model) @@ -110,11 +110,11 @@ def test_sparse_marlin(self, compile): model_copy = copy.deepcopy(model) # Quantized - quantize_(model_copy.bfloat16(), int4_weight_only(version=1)) + quantize_(model_copy.bfloat16(), Int4WeightOnlyConfig(version=1)) dense_result = model_copy(input.bfloat16()).half() # Sparse + quantized - quantize_(model, int4_weight_only(layout=MarlinSparseLayout(), version=1)) + quantize_(model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1)) if compile: model = torch.compile(model) sparse_result = model(input) @@ -182,14 +182,16 @@ def test_sparse(self, compile): model_copy = copy.deepcopy(model) - quantize_(model_copy, int8_dynamic_activation_int8_weight()) + quantize_(model_copy, Int8DynamicActivationInt8WeightConfig()) reference = model_copy(input) from torchao.dtypes import BlockSparseLayout quantize_( model, - int8_dynamic_activation_int8_weight(layout=BlockSparseLayout(blocksize=64)), + Int8DynamicActivationInt8WeightConfig( + layout=BlockSparseLayout(blocksize=64) + ), ) if compile: model = torch.compile(model) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index c53cbdd5cd..fdd9792cb4 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -17,16 +17,16 @@ import torchao from torchao._models.llama.model import prepare_inputs_for_model from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + FPXWeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, PerRow, PerTensor, - float8_dynamic_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - int4_weight_only, - int8_dynamic_activation_int8_weight, - int8_weight_only, + UIntXWeightOnlyConfig, quantize_, - uintx_weight_only, ) @@ -73,11 +73,11 @@ def run_evaluation( apply_spinquant(model) if "int8wo" in quantization: - quantize_(model, int8_weight_only()) + quantize_(model, Int8WeightOnlyConfig()) if "int8dq" in quantization: - quantize_(model, int8_dynamic_activation_int8_weight()) + quantize_(model, Int8DynamicActivationInt8WeightConfig()) if "fp6" in quantization: - quantize_(model, fpx_weight_only(3, 2)) + quantize_(model, FPXWeightOnlyConfig(3, 2)) if "int4wo" in quantization and not "gptq" in quantization: if "hqq" in quantization: use_hqq = True @@ -89,7 +89,7 @@ def run_evaluation( ) quantize_( model.to(device), - int4_weight_only(group_size=groupsize, use_hqq=use_hqq, version=1), + Int4WeightOnlyConfig(group_size=groupsize, use_hqq=use_hqq, version=1), ) if "uintx" in quantization: # uintx-nbits-groupsize @@ -112,11 +112,13 @@ def run_evaluation( } dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) - quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) + quantize_(model, UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq)) if "marlin" in quantization: from torchao.dtypes import MarlinSparseLayout - quantize_(model, int4_weight_only(layout=MarlinSparseLayout(), version=1)) + quantize_( + model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1) + ) if "int4wo" in quantization and "gptq" in quantization: # avoid circular imports from torchao._models._eval import LMEvalInputRecorder @@ -151,7 +153,7 @@ def run_evaluation( quantizer.quantize(model, *inputs) model = model.to(device) if "float8wo" in quantization: - quantize_(model, float8_weight_only()) + quantize_(model, Float8WeightOnlyConfig()) if "float8dq" in quantization: granularity = str(quantization.split("-")[-1]) if granularity == "tensor": @@ -164,7 +166,8 @@ def run_evaluation( else: raise ValueError(f"Unknown granularity {granularity}") quantize_( - model, float8_dynamic_activation_float8_weight(granularity=granularity) + model, + Float8DynamicActivationFloat8WeightConfig(granularity=granularity), ) if "autoround" in quantization: from transformers import AutoTokenizer diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 68d2f98548..c9662366f3 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -340,18 +340,18 @@ def ffn_or_attn_only(mod, fqn): if quantization: from torchao.quantization import ( Float8DynamicActivationFloat8SemiSparseWeightConfig, + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + UIntXWeightOnlyConfig, autoquant, - float8_dynamic_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - gemlite_uintx_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, quantize_, - uintx_weight_only, ) from torchao.quantization.granularity import PerRow, PerTensor @@ -375,7 +375,7 @@ def ffn_or_attn_only(mod, fqn): quantize_( model, - gemlite_uintx_weight_only( + GemliteUIntXWeightOnlyConfig( bit_width=bit_width, group_size=group_size, mode=mode ), ) @@ -395,25 +395,28 @@ def ffn_or_attn_only(mod, fqn): gemlite.cache_config(config_file) if "int8wo" in quantization: - quantize_(model, int8_weight_only()) + quantize_(model, Int8WeightOnlyConfig()) if "int8dq" in quantization: if sparsity and "semi" in sparsity: from torchao.dtypes import SemiSparseLayout quantize_( model, - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()), filter_fn=ffn_only, ) quantize_( - model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only + model, + Int8DynamicActivationInt8WeightConfig(), + filter_fn=not_ffn_only, ) elif "int8dq_prefill_wo_decode" in quantization: quantize_( - model, int8_dynamic_activation_int8_weight(weight_only_decode=True) + model, + Int8DynamicActivationInt8WeightConfig(weight_only_decode=True), ) else: - quantize_(model, int8_dynamic_activation_int8_weight()) + quantize_(model, Int8DynamicActivationInt8WeightConfig()) if "int4wo" in quantization: use_hqq = False if "hqq" in quantization: @@ -429,7 +432,7 @@ def ffn_or_attn_only(mod, fqn): ) quantize_( model, - int4_weight_only(group_size=group_size, use_hqq=use_hqq, version=1), + Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, version=1), ) elif "fbgemm" in quantization and "int4" in quantization: from torchao.quantization import FbgemmConfig @@ -458,7 +461,7 @@ def ffn_or_attn_only(mod, fqn): if nbits == 4: quantize_( model, - int4_dynamic_activation_int4_weight( + Int4DynamicActivationInt4WeightConfig( mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, layout=CutlassInt4PackedLayout(), @@ -467,7 +470,7 @@ def ffn_or_attn_only(mod, fqn): elif nbits == 8: quantize_( model, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=None, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, @@ -480,7 +483,7 @@ def ffn_or_attn_only(mod, fqn): quantize_( model, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=128, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, @@ -492,15 +495,15 @@ def ffn_or_attn_only(mod, fqn): quantize_( model, - int4_weight_only(layout=MarlinSparseLayout(), version=1), + Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1), filter_fn=ffn_or_attn_only, ) if "fp6" in quantization: - quantize_(model, fpx_weight_only(3, 2)) + quantize_(model, FPXWeightOnlyConfig(3, 2)) elif "embed-int8wo" in quantization: quantize_( model, - int8_weight_only(group_size=64), + Int8WeightOnlyConfig(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding), ) elif quantization.startswith("awq"): @@ -560,7 +563,7 @@ def ffn_or_attn_only(mod, fqn): } dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) - quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) + quantize_(model, UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq)) elif "int8_dynamic_activation_intx_weight" in quantization: assert precision == torch.float32, ( "int8_dynamic_activation_intx_weight requires using precision=torch.float32" @@ -591,7 +594,7 @@ def ffn_or_attn_only(mod, fqn): ), ) elif "float8wo" in quantization: - quantize_(model, float8_weight_only()) + quantize_(model, Float8WeightOnlyConfig()) elif "float8dq" in quantization: if sparsity and "semi" in sparsity: quantize_( @@ -609,7 +612,7 @@ def ffn_or_attn_only(mod, fqn): granularity = PerTensor() quantize_( model, - float8_dynamic_activation_float8_weight(granularity=granularity), + Float8DynamicActivationFloat8WeightConfig(granularity=granularity), ) elif "autoquant_v2" in quantization: from torchao._models._eval import LMEvalInputRecorder diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index 68880d5aed..467e24a9b6 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -22,9 +22,9 @@ from torchao.dtypes import SemiSparseLayout from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 from torchao.quantization import ( + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, autoquant, - int4_weight_only, - int8_dynamic_activation_int8_weight, quantize_, ) from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ @@ -362,7 +362,9 @@ def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and "mlp" in name if compress == "int8_dynamic_quant": - quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) + quantize_( + predictor.model.image_encoder, Int8DynamicActivationInt8WeightConfig() + ) elif compress == "sparse_mlp_only": def mlp_only(mod, name): @@ -381,12 +383,12 @@ def mlp_only(mod, name): quantize_( predictor.model.image_encoder, - int8_dynamic_activation_int8_weight(), + Int8DynamicActivationInt8WeightConfig(), attn_only, ) quantize_( predictor.model.image_encoder, - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()), mlp_lin1_only, ) sparsify_(predictor.model.image_encoder, semi_sparse_weight(), mlp_lin2_only) @@ -397,12 +399,12 @@ def mlp_only(mod, name): quantize_( predictor.model.image_encoder, - int8_dynamic_activation_int8_weight(), + Int8DynamicActivationInt8WeightConfig(), attn_only, ) quantize_( predictor.model.image_encoder, - int4_weight_only(layout=MarlinSparseLayout(), version=1), + Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1), mlp_lin1_only, ) sparsify_(predictor.model.image_encoder, semi_sparse_weight(), mlp_lin2_only) diff --git a/torchao/dtypes/floatx/README.md b/torchao/dtypes/floatx/README.md index 16aec8362b..092ef01233 100644 --- a/torchao/dtypes/floatx/README.md +++ b/torchao/dtypes/floatx/README.md @@ -9,7 +9,7 @@ This kernel was originally designed for FP16, but was extended to work for BF16 ```python from torchao.quantization import ( quantize_, - fpx_weight_only, + FPXWeightOnlyConfig, ) model = ... @@ -17,7 +17,7 @@ model = ... # for generic Floatx EyMz where x = 1 + y + z # fp6 with ebits = 3 and mbits = 2 -quantize_(model, fpx_weight_only(3, 2)) +quantize_(model, FPXWeightOnlyConfig(3, 2)) # fully compatible with torch.compile() model.compile(mode="max-autotune", fullgraph=True) diff --git a/torchao/prototype/autoround/README.md b/torchao/prototype/autoround/README.md index 396cde9461..a67b3be9f0 100644 --- a/torchao/prototype/autoround/README.md +++ b/torchao/prototype/autoround/README.md @@ -114,7 +114,7 @@ quantize_(model, apply_auto_round(), is_target_module) | autoround-4bit* | 0.6338 | 0.4566 | 0.7661 | 0.6646 | 0.5688 | 0.7130 | > [!NOTE] -> - `torchao-int4wo` quantizes the model to 4 bits with a group size of 128 (`int4_weight_only(group_size=128, version=1)`) while leaving the `lm-head` unquantized.
+> - `torchao-int4wo` quantizes the model to 4 bits with a group size of 128 (`Int4WeightOnlyConfig(group_size=128, version=1)`) while leaving the `lm-head` unquantized.
> - `auto-round-4bit` uses the deafult configuration from [quick start](#quick-start).
> - `auto-round-4bit*` follows the same settings as `auto-round-4bit`, but with `gradient_accumulate_steps=2` and `batch_size=4`, which accumulating two batches(4 samples per batch) before performing the backward pass.
> - To reproduce results, run `eval_autoround.py` with `AO_USE_DETERMINISTIC_ALGORITHMS=1`. diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py index caebf85a2f..62cc9c43d5 100644 --- a/torchao/prototype/autoround/eval_autoround.py +++ b/torchao/prototype/autoround/eval_autoround.py @@ -101,25 +101,28 @@ def main(args): # Evaluate the quantized model if args.woq_int4: msg += " (int4wo)" - from torchao.quantization import int4_weight_only, quantize_ + from torchao.quantization import Int4WeightOnlyConfig, quantize_ quantize_( model, - int4_weight_only(group_size=args.group_size, version=1), + Int4WeightOnlyConfig(group_size=args.group_size, version=1), filter_fn=filter_fn, device=model_device, ) elif args.uintx: msg += f" (uintx {args.bits} bits)" from torchao.dtypes.uintx.uintx import _BIT_WIDTH_TO_DTYPE - from torchao.quantization.quant_api import quantize_, uintx_weight_only + from torchao.quantization.quant_api import ( + UIntXWeightOnlyConfig, + quantize_, + ) bits = args.bits assert bits in _BIT_WIDTH_TO_DTYPE, f"Invalid bits: {bits}" dtype = _BIT_WIDTH_TO_DTYPE[bits] quantize_( model, - uintx_weight_only(dtype=dtype, group_size=args.group_size), + UIntXWeightOnlyConfig(dtype=dtype, group_size=args.group_size), filter_fn=filter_fn, device=model_device, ) diff --git a/torchao/prototype/hqq/example.py b/torchao/prototype/hqq/example.py index cca8a42eb3..cda96f6b3c 100644 --- a/torchao/prototype/hqq/example.py +++ b/torchao/prototype/hqq/example.py @@ -108,14 +108,14 @@ print("Quant API example") print("-------------------------------------------------------------------") -from torchao.quantization.quant_api import int4_weight_only +from torchao.quantization.quant_api import Int4WeightOnlyConfig nbits = 4 target_dtype = torch.int32 inner_k_tiles = 8 _layout = TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles) -int4_weight_only_patch_fct = int4_weight_only( +int4_weight_only_patch_fct = Int4WeightOnlyConfig( group_size=group_size, inner_k_tiles=inner_k_tiles, version=1 ) linear_layer_default = torch.nn.Linear( diff --git a/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py b/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py index 016b6c9eef..2174e7683a 100644 --- a/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py +++ b/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py @@ -101,11 +101,11 @@ def apply_intN_weight_only_quant_sym(weight): assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]" if n == 8: raise AssertionError( - "Someone needs to refactor this code to handle int8_weight_only again" + "Someone needs to refactor this code to handle Int8WeightOnlyConfig again" ) elif n == 4: raise AssertionError( - "Someone needs to refactor this code to handle int4_weight_only again" + "Someone needs to refactor this code to handle Int4WeightOnlyConfig again" ) else: if symmetric: diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 6b438ca787..1eaaacd1db 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -29,7 +29,7 @@ def quantize_int8_rowwise( probability of rounding up is equal to x - ⌊x⌋, which indicates how close the value is to the next integer value. Thus, stochastic rounding also approximates the floating point value exactly. - Currently this function differs from AQT's `int8_weight_only()` in the following way: + Currently this function differs from AQT's `Int8WeightOnlyConfig()` in the following way: 1. Precision: AQT keeps original dtype when doing quantization, while this function upcasts input to FP32 before quantization. Output scale maintains the original input dtype. 2. Calculate scale: AQT uses `input.abs().amax() / 127.5`, while `input.abs().amax() / 127` is diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index e1e4c20a31..d54caa420e 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -376,7 +376,7 @@ Marlin QQQ is an optimized GPU kernel that supports W4A8 mixed precision GEMM. F | | w4a8-g128 | 187.62 | 640.32 | 4.82 | 3.41 | ### Gemlite Triton -Int4 and Int8 quantization using the [Gemlite Triton](https://github.com/mobiusml/gemlite) kernels. You can try it out with the `quantize_` api as above alongside the constructor `gemlite_uintx_weight_only`. An example can be found in `torchao/_models/llama/generate.py`. +Int4 and Int8 quantization using the [Gemlite Triton](https://github.com/mobiusml/gemlite) kernels. You can try it out with the `quantize_` api as above alongside the constructor `GemliteUIntXWeightOnlyConfig`. An example can be found in `torchao/_models/llama/generate.py`. Note: we test on gemlite 0.4.1, but should be able to use any version after that, we'd recommend to use the latest release to get the most recent performance improvements. diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 4ccb2a1f31..ccc0c64650 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -1729,7 +1729,7 @@ def _with_outer_reshape(pattern): KeywordArg("out_shape_with_bias"), ) - # The following patterns are for torchao int8_dynamic_activation_int8_weight linear, + # The following patterns are for torchao Int8DynamicActivationInt8WeightConfig linear, # when both activation and weights are symmetrically quantized. # In practice, though, they may also match smooth-quant pattern when a 2D input shape would be used. # Since add is not currently being used as a oneDNN post-op, but is unfused, we don't need these patterns with bias. diff --git a/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py b/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py index 91b13144b5..5161e130a0 100644 --- a/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py +++ b/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn -from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ +from torchao.quantization import Int8DynamicActivationInt4WeightConfig, quantize_ from torchao.quantization.pt2e.reference_representation_rewrite import ( _qdq_dynamic_quantized_linear_4bit_groupwise, _reference_dynamic_quantized_linear_4bit_groupwise, @@ -313,7 +313,7 @@ def test_export_and_rewrite_workflow(self): example_input = torch.randn(1, 64) # Apply 8da4w quantization - quantize_(model, int8_dynamic_activation_int4_weight(group_size=32)) + quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) # Unwrap tensor subclasses for export compatibility model = unwrap_tensor_subclass(model) @@ -360,7 +360,7 @@ def test_different_group_sizes_rewrite(self): # Apply quantization with specific group size quantize_( - model, int8_dynamic_activation_int4_weight(group_size=group_size) + model, Int8DynamicActivationInt4WeightConfig(group_size=group_size) ) # Unwrap tensor subclasses for export compatibility @@ -402,7 +402,7 @@ def test_model_without_bias_rewrite(self): example_input = torch.randn(1, 32) # Apply quantization - quantize_(model, int8_dynamic_activation_int4_weight(group_size=16)) + quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=16)) # Unwrap tensor subclasses for export compatibility model = unwrap_tensor_subclass(model) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a59f9c7069..5d2c8d863b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -513,14 +513,14 @@ def quantize_( # optimized execution paths or kernels (e.g. int4 tinygemm kernel) # also customizable with arguments # currently options are - # int8_dynamic_activation_int4_weight (for executorch) - # int8_dynamic_activation_int8_weight (optimized with int8 mm op and torch.compile) - # int4_weight_only (optimized with int4 tinygemm kernel and torch.compile) - # int8_weight_only (optimized with int8 mm op and torch.compile + # Int8DynamicActivationInt4WeightConfig (for executorch) + # Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) + # Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile) + # Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile from torchao.quantization.quant_api import int4_weight_only m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - quantize_(m, int4_weight_only(group_size=32, version=1)) + quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1)) """ torch._C._log_api_usage_once("torchao.quantization.quantize_") @@ -1492,7 +1492,7 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): # int8 dynamic quantization only has benefit when in_feature > 16 if in_features <= 16: logger.info( - f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}" + f"Skipping applying Int8DynamicActivationInt8WeightConfig to weight of shape {weight.shape}" f" because `in_feature` is <= 16: {in_features}" ) return weight @@ -1557,13 +1557,13 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): quantization + 2:4 sparsity to linear layers. """ warnings.warn( - """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. + """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in Int8DynamicActivationInt8WeightConfig instead. from torchao.dtypes import SemiSparseLayout - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""" + Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()""" ) - return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) + return Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()) @dataclass @@ -2047,7 +2047,7 @@ def _uintx_weight_only_transform( if use_hqq: if dtype == torch.uint4: logger.warning( - "Recommended to use `int4_weight_only(group_size, use_hqq=True, version=1)` for the best performance" + "Recommended to use `Int4WeightOnlyConfig(group_size, use_hqq=True, version=1)` for the best performance" ) quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] dtype = torch.uint8 diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index f0d3183e35..9214f8b1ef 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -129,7 +129,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: # for int8 dynamic quantization + 2:4 sparsity from torchao.dtypes import SemiSparseLayout - m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) + m = quantize_(m, Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout), filter_fn) """ torch._C._log_api_usage_once("torchao.sparsity.sparsify_") handler = _QUANTIZE_CONFIG_HANDLER[type(config)] diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 762fb31b30..bb9c2ca8dc 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -17,7 +17,7 @@ import torchao from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx -from torchao.quantization import int8_weight_only, quantize_ +from torchao.quantization import Int8WeightOnlyConfig, quantize_ from torchao.quantization.quant_primitives import MappingType from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, @@ -331,7 +331,7 @@ class TorchAOTensorParallelTestCase(DTensorTestBase): COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] TENSOR_SUBCLASS = AffineQuantizedTensor - QUANT_METHOD_FN = staticmethod(int8_weight_only) + QUANT_METHOD_FN = staticmethod(Int8WeightOnlyConfig) QUANT_METHOD_KWARGS = {} @staticmethod diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index c326828219..bc999b49d4 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -24,11 +24,11 @@ # for torch 2.4+ from torchao.quantization.quant_api import ( - int8_dynamic_activation_int8_weight, + Int8DynamicActivationInt8WeightConfig, quantize_, ) -quantize_(model, int8_dynamic_activation_int8_weight()) +quantize_(model, Int8DynamicActivationInt8WeightConfig()) ## Quantization code - end ## compilation configs From 593e3bc1c756a397dd94902cc7cfc65e402a3acf Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 12 Sep 2025 14:13:24 -0700 Subject: [PATCH 2/3] Update on "Remove internal usage of all config functions like `int4_weight_only`" **Summary:** These are now deprecated as of #2994. We should stop using them internally as well. **Test Plan:** CI [ghstack-poisoned] --- test/quantization/test_quant_api.py | 47 +++++++++++++++-------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 6380b568be..a40a93c836 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -750,22 +750,23 @@ def test_int4wo_cuda_serialization(self): # load state_dict in cuda model.load_state_dict(sd, assign=True) + def test_config_deprecation(self): """ - Test that old config functions like `Int4WeightOnlyConfig` trigger deprecation warnings. + Test that old config functions like `int4_weight_only` trigger deprecation warnings. """ from torchao.quantization import ( - Float8DynamicActivationFloat8WeightConfig, - Float8StaticActivationFloat8WeightConfig, - Float8WeightOnlyConfig, - FPXWeightOnlyConfig, - GemliteUIntXWeightOnlyConfig, - Int4DynamicActivationInt4WeightConfig, - Int4WeightOnlyConfig, - Int8DynamicActivationInt4WeightConfig, - Int8DynamicActivationInt8WeightConfig, - Int8WeightOnlyConfig, - UIntXWeightOnlyConfig, + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, + fpx_weight_only, + gemlite_uintx_weight_only, + int4_dynamic_activation_int4_weight, + int4_weight_only, + int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int8_weight, + int8_weight_only, + uintx_weight_only, ) # Reset deprecation warning state, otherwise we won't log warnings here @@ -773,17 +774,17 @@ def test_config_deprecation(self): # Map from deprecated API to the args needed to instantiate it deprecated_apis_to_args = { - Float8DynamicActivationFloat8WeightConfig: (), - Float8StaticActivationFloat8WeightConfig: (torch.randn(3)), - Float8WeightOnlyConfig: (), - FPXWeightOnlyConfig: (3, 2), - GemliteUIntXWeightOnlyConfig: (), - Int4DynamicActivationInt4WeightConfig: (), - Int4WeightOnlyConfig: (), - Int8DynamicActivationInt4WeightConfig: (), - Int8DynamicActivationInt8WeightConfig: (), - Int8WeightOnlyConfig: (), - UIntXWeightOnlyConfig: (torch.uint4,), + float8_dynamic_activation_float8_weight: (), + float8_static_activation_float8_weight: (torch.randn(3)), + float8_weight_only: (), + fpx_weight_only: (3, 2), + gemlite_uintx_weight_only: (), + int4_dynamic_activation_int4_weight: (), + int4_weight_only: (), + int8_dynamic_activation_int4_weight: (), + int8_dynamic_activation_int8_weight: (), + int8_weight_only: (), + uintx_weight_only: (torch.uint4,), } with warnings.catch_warnings(record=True) as _warnings: From b933861ba1b2b4d7d4879304476f0e12d4dd10a4 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 12 Sep 2025 14:16:18 -0700 Subject: [PATCH 3/3] Update on "Remove internal usage of all config functions like `int4_weight_only`" **Summary:** These are now deprecated as of #2994. We should stop using them internally as well. **Test Plan:** CI [ghstack-poisoned] --- test/quantization/test_quant_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index a40a93c836..1164b1cffb 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -750,7 +750,6 @@ def test_int4wo_cuda_serialization(self): # load state_dict in cuda model.load_state_dict(sd, assign=True) - def test_config_deprecation(self): """ Test that old config functions like `int4_weight_only` trigger deprecation warnings.