Skip to content

Commit db19984

Browse files
committed
Remove internal usage of all config functions like int4_weight_only
**Summary:** These are now deprecated as of #2994. We should stop using them internally as well. **Test Plan:** CI ghstack-source-id: e493e74 Pull Request resolved: #2995
1 parent 64dccf3 commit db19984

File tree

30 files changed

+231
-219
lines changed

30 files changed

+231
-219
lines changed

benchmarks/benchmark_aq.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
import torch
1111

1212
from torchao.quantization.quant_api import (
13+
Int4WeightOnlyConfig,
14+
Int8DynamicActivationInt8WeightConfig,
15+
Int8WeightOnlyConfig,
1316
_replace_with_custom_fn_if_matches_filter,
14-
int4_weight_only,
15-
int8_dynamic_activation_int8_weight,
16-
int8_weight_only,
1717
quantize_,
1818
)
1919
from torchao.quantization.subclass import (
@@ -23,13 +23,13 @@
2323

2424

2525
def _int8wo_api(mod, **kwargs):
26-
quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False)
26+
quantize_(mod, Int8WeightOnlyConfig(**kwargs), set_inductor_config=False)
2727

2828

2929
def _int8da_int8w_api(mod, **kwargs):
3030
quantize_(
3131
mod,
32-
int8_dynamic_activation_int8_weight(**kwargs),
32+
Int8DynamicActivationInt8WeightConfig(**kwargs),
3333
set_inductor_config=False,
3434
)
3535

@@ -39,7 +39,7 @@ def _int4wo_api(mod, **kwargs):
3939
if "groupsize" in kwargs_copy:
4040
kwargs_copy["group_size"] = kwargs_copy["groupsize"]
4141
del kwargs_copy["groupsize"]
42-
quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False)
42+
quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy), set_inductor_config=False)
4343

4444

4545
class ToyLinearModel(torch.nn.Module):

benchmarks/quantized_training/pretrain_llama2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def insert_rmsnorm(module: torch.nn.Module):
166166
insert_rmsnorm(model.layers)
167167

168168
# don't apply int8_mixed_precision to LM head, since it can cause convergence issue.
169-
# TODO: might want to do the same for int8_weight_only to standardize.
170-
if args.quantize == "int8_weight_only":
169+
# TODO: might want to do the same for Int8WeightOnlyConfig to standardize.
170+
if args.quantize == "Int8WeightOnlyConfig":
171171
quantize_(
172172
model, int8_weight_only_quantized_training(), set_inductor_config=False
173173
)

test/dtypes/test_affine_quantized.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@
2727
from torchao.float8.config import e4m3_dtype
2828
from torchao.quantization import (
2929
FbgemmConfig,
30+
Float8WeightOnlyConfig,
3031
GemliteUIntXWeightOnlyConfig,
32+
Int4DynamicActivationInt4WeightConfig,
3133
Int4WeightOnlyConfig,
34+
Int8DynamicActivationInt4WeightConfig,
3235
Int8DynamicActivationInt8WeightConfig,
33-
float8_weight_only,
34-
int4_dynamic_activation_int4_weight,
35-
int4_weight_only,
36-
int8_dynamic_activation_int4_weight,
37-
int8_dynamic_activation_int8_weight,
38-
int8_weight_only,
36+
Int8WeightOnlyConfig,
3937
quantize_,
4038
)
4139
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
@@ -58,49 +56,49 @@ def get_quantization_functions(
5856
do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False
5957
):
6058
base_functions = [
61-
int8_weight_only(),
62-
int8_dynamic_activation_int4_weight(),
63-
int8_dynamic_activation_int8_weight(),
64-
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
59+
Int8WeightOnlyConfig(),
60+
Int8DynamicActivationInt4WeightConfig(),
61+
Int8DynamicActivationInt8WeightConfig(),
62+
Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC),
6563
]
6664
if do_int4:
6765
if check_cpu_version(device):
6866
base_functions.append(
69-
int4_weight_only(group_size=32, layout=Int4CPULayout(), version=1)
67+
Int4WeightOnlyConfig(group_size=32, layout=Int4CPULayout(), version=1)
7068
)
7169
elif check_xpu_version(device):
7270
base_functions.append(
73-
int4_weight_only(group_size=32, layout=Int4XPULayout(), version=1)
71+
Int4WeightOnlyConfig(group_size=32, layout=Int4XPULayout(), version=1)
7472
)
7573
if int4_zp_int:
7674
base_functions.append(
77-
int4_weight_only(
75+
Int4WeightOnlyConfig(
7876
group_size=32,
7977
layout=Int4XPULayout(),
8078
zero_point_domain=ZeroPointDomain.INT,
8179
version=1,
8280
)
8381
)
8482
else:
85-
base_functions.append(int4_weight_only(group_size=32, version=1))
83+
base_functions.append(Int4WeightOnlyConfig(group_size=32, version=1))
8684
if device == "cuda" and not is_ROCM():
8785
base_functions.append(
88-
int8_dynamic_activation_int4_weight(
86+
Int8DynamicActivationInt4WeightConfig(
8987
group_size=None,
9088
mapping_type=MappingType.SYMMETRIC,
9189
act_mapping_type=MappingType.SYMMETRIC,
9290
layout=CutlassInt4PackedLayout(),
9391
)
9492
)
95-
base_functions.append(int4_dynamic_activation_int4_weight())
93+
base_functions.append(Int4DynamicActivationInt4WeightConfig())
9694

9795
if do_sparse and device != "xpu":
9896
base_functions.append(
99-
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
97+
Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout())
10098
)
10199

102100
if is_sm_at_least_89():
103-
base_functions.append(float8_weight_only())
101+
base_functions.append(Float8WeightOnlyConfig())
104102

105103
if is_sm_at_least_90():
106104
base_functions.append(FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16))
@@ -119,7 +117,7 @@ def test_tensor_core_layout_transpose(self):
119117
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
120118
t = linear.weight
121119
shape = t.shape
122-
apply_int4_weight_only_quant = int4_weight_only(group_size=32, version=1)
120+
apply_int4_weight_only_quant = Int4WeightOnlyConfig(group_size=32, version=1)
123121
quantize_(linear, apply_int4_weight_only_quant)
124122
ql = linear
125123
aqt = ql.weight

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
)
1717

1818
from torchao.quantization import (
19-
float8_dynamic_activation_float8_weight,
20-
float8_weight_only,
21-
int4_weight_only,
22-
int8_dynamic_activation_int8_weight,
23-
int8_weight_only,
19+
Float8DynamicActivationFloat8WeightConfig,
20+
Float8WeightOnlyConfig,
21+
Int4WeightOnlyConfig,
22+
Int8DynamicActivationInt8WeightConfig,
23+
Int8WeightOnlyConfig,
2424
)
2525
from torchao.quantization.observer import PerRow, PerTensor
2626
from torchao.quantization.quant_api import quantize_
@@ -42,7 +42,7 @@
4242
class TestAffineQuantizedTensorParallel(DTensorTestBase):
4343
"""Basic test case for tensor subclasses"""
4444

45-
QUANT_METHOD_FN = staticmethod(int8_weight_only)
45+
QUANT_METHOD_FN = staticmethod(Int8WeightOnlyConfig)
4646
QUANT_METHOD_KWARGS = {}
4747

4848
@staticmethod
@@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
133133

134134

135135
class TestInt8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
136-
QUANT_METHOD_FN = staticmethod(int8_weight_only)
136+
QUANT_METHOD_FN = staticmethod(Int8WeightOnlyConfig)
137137
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
138138

139139
@common_utils.parametrize("dtype", COMMON_DTYPES)
@@ -144,7 +144,7 @@ def test_tp(self, dtype):
144144

145145

146146
class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
147-
QUANT_METHOD_FN = staticmethod(int4_weight_only)
147+
QUANT_METHOD_FN = staticmethod(Int4WeightOnlyConfig)
148148
QUANT_METHOD_KWARGS = {"version": 1}
149149
COMMON_DTYPES = [torch.bfloat16]
150150

@@ -167,20 +167,20 @@ class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel):
167167
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
168168
@unittest.skipIf(not has_gemlite, "gemlite not available")
169169
def test_tp_gemlite(self, dtype):
170-
from torchao.quantization import gemlite_uintx_weight_only
170+
from torchao.quantization import GemliteUIntXWeightOnlyConfig
171171

172172
for packing_bitwidth in [32, 8]:
173173
for bit_width in [4, 8]:
174174
for group_size in [64, 32, None] if bit_width == 4 else [None]:
175-
api = lambda: gemlite_uintx_weight_only(
175+
api = lambda: GemliteUIntXWeightOnlyConfig(
176176
group_size, bit_width, packing_bitwidth
177177
)
178178
self.QUANT_METHOD_FN = staticmethod(api)
179179
return self._test_tp(dtype)
180180

181181

182182
class TestInt8dqAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
183-
QUANT_METHOD_FN = staticmethod(int8_dynamic_activation_int8_weight)
183+
QUANT_METHOD_FN = staticmethod(Int8DynamicActivationInt8WeightConfig)
184184
COMMON_DTYPES = [torch.bfloat16]
185185

186186
@common_utils.parametrize("dtype", COMMON_DTYPES)
@@ -199,7 +199,7 @@ def test_tp(self, dtype):
199199
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
200200

201201
class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
202-
QUANT_METHOD_FN = staticmethod(float8_weight_only)
202+
QUANT_METHOD_FN = staticmethod(Float8WeightOnlyConfig)
203203
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
204204

205205
@common_utils.parametrize("dtype", COMMON_DTYPES)
@@ -211,7 +211,7 @@ def test_tp(self, dtype):
211211
class TestFloat8dqTensorAffineQuantizedTensorParallel(
212212
TestAffineQuantizedTensorParallel
213213
):
214-
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
214+
QUANT_METHOD_FN = staticmethod(Float8DynamicActivationFloat8WeightConfig)
215215
QUANT_METHOD_KWARGS = {"granularity": PerTensor()}
216216
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
217217

@@ -224,7 +224,7 @@ def test_tp(self, dtype):
224224
class TestFloat8dqRowAffineQuantizedTensorParallel(
225225
TestAffineQuantizedTensorParallel
226226
):
227-
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
227+
QUANT_METHOD_FN = staticmethod(Float8DynamicActivationFloat8WeightConfig)
228228
QUANT_METHOD_KWARGS = {"granularity": PerRow()}
229229
COMMON_DTYPES = [torch.bfloat16]
230230

test/dtypes/test_floatx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
_floatx_unpacked_to_f32,
3030
)
3131
from torchao.quantization import (
32-
fpx_weight_only,
32+
FPXWeightOnlyConfig,
3333
quantize_,
3434
)
3535
from torchao.testing.utils import skip_if_rocm
@@ -118,7 +118,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
118118

119119
linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype)
120120
fpx_linear = copy.deepcopy(linear)
121-
quantize_(fpx_linear, fpx_weight_only(ebits, mbits))
121+
quantize_(fpx_linear, FPXWeightOnlyConfig(ebits, mbits))
122122

123123
x = torch.randn(N, IC, device=device, dtype=dtype)
124124
expected = fpx_linear(x)

test/dtypes/test_uintx.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88

99
from torchao.dtypes.uintx.uintx_layout import to_uintx
10-
from torchao.quantization.quant_api import quantize_, uintx_weight_only
10+
from torchao.quantization.quant_api import UIntXWeightOnlyConfig, quantize_
1111
from torchao.quantization.quant_primitives import (
1212
MappingType,
1313
choose_qparams_affine,
@@ -60,7 +60,7 @@ def forward(self, x):
6060
def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size):
6161
scale = 512
6262
fp16_mod_on_cpu = Linear16(scale, "cpu")
63-
quantize_(fp16_mod_on_cpu, uintx_weight_only(dtype, group_size=group_size))
63+
quantize_(fp16_mod_on_cpu, UIntXWeightOnlyConfig(dtype, group_size=group_size))
6464
test_input_on_cpu = torch.randn(scale * 2, dtype=torch.float16, device="cpu")
6565
output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu)
6666
fp16_mod_on_cuda = fp16_mod_on_cpu.to("cuda")
@@ -78,7 +78,7 @@ def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size):
7878
def test_uintx_weight_only_model_quant(dtype, group_size, device):
7979
scale = 512
8080
fp16 = Linear16(scale, device)
81-
quantize_(fp16, uintx_weight_only(dtype, group_size=group_size))
81+
quantize_(fp16, UIntXWeightOnlyConfig(dtype, group_size=group_size))
8282
uintx = torch.compile(fp16, fullgraph=True)
8383
test_input = torch.randn(scale * 2, dtype=torch.float16, device=device)
8484
output = uintx.forward(test_input)
@@ -124,30 +124,25 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
124124
@pytest.mark.parametrize("dtype", dtypes)
125125
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
126126
def test_uintx_target_dtype(dtype):
127-
from torchao.quantization.quant_api import uintx_weight_only
128-
129127
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
130128
# make sure it runs
131-
quantize_(linear, uintx_weight_only(dtype))
129+
quantize_(linear, UIntXWeightOnlyConfig(dtype))
132130
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))
133131

134132

135133
@pytest.mark.parametrize("dtype", dtypes)
136134
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
137135
def test_uintx_target_dtype_compile(dtype):
138-
from torchao.quantization.quant_api import uintx_weight_only
139-
140136
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
141137
# make sure it runs
142-
quantize_(linear, uintx_weight_only(dtype))
138+
quantize_(linear, UIntXWeightOnlyConfig(dtype))
143139
linear = torch.compile(linear)
144140
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))
145141

146142

147143
@pytest.mark.parametrize("dtype", dtypes)
148144
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
149145
def test_uintx_model_size(dtype):
150-
from torchao.quantization.quant_api import uintx_weight_only
151146
from torchao.utils import get_model_size_in_bytes
152147

153148
# scale size = 1/64 * 2 bytes = 1/32 bytes
@@ -167,6 +162,6 @@ def test_uintx_model_size(dtype):
167162
)
168163
bf16_size = get_model_size_in_bytes(linear)
169164
# make sure it runs
170-
quantize_(linear[0], uintx_weight_only(dtype))
165+
quantize_(linear[0], UIntXWeightOnlyConfig(dtype))
171166
quantized_size = get_model_size_in_bytes(linear)
172167
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size

test/hqq/test_hqq_affine.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
import torch
99

1010
from torchao.quantization import (
11+
Int4WeightOnlyConfig,
1112
MappingType,
13+
UIntXWeightOnlyConfig,
1214
ZeroPointDomain,
13-
int4_weight_only,
1415
quantize_,
15-
uintx_weight_only,
1616
)
1717
from torchao.testing.utils import skip_if_rocm
1818

@@ -55,9 +55,11 @@ def _eval_hqq(dtype):
5555
)
5656
dummy_linear.weight.data = W
5757
if dtype == torch.uint4:
58-
config = int4_weight_only(group_size=max(block_size), use_hqq=True, version=1)
58+
config = Int4WeightOnlyConfig(
59+
group_size=max(block_size), use_hqq=True, version=1
60+
)
5961
else:
60-
config = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)
62+
config = UIntXWeightOnlyConfig(dtype, group_size=max(block_size), use_hqq=True)
6163
quantize_(dummy_linear, config)
6264
q_tensor_hqq = dummy_linear.weight
6365

0 commit comments

Comments
 (0)