16
16
)
17
17
18
18
from torchao .quantization import (
19
- float8_dynamic_activation_float8_weight ,
20
- float8_weight_only ,
21
- int4_weight_only ,
22
- int8_dynamic_activation_int8_weight ,
23
- int8_weight_only ,
19
+ Float8DynamicActivationFloat8WeightConfig ,
20
+ Float8WeightOnlyConfig ,
21
+ Int4WeightOnlyConfig ,
22
+ Int8DynamicActivationInt8WeightConfig ,
23
+ Int8WeightOnlyConfig ,
24
24
)
25
25
from torchao .quantization .observer import PerRow , PerTensor
26
26
from torchao .quantization .quant_api import quantize_
42
42
class TestAffineQuantizedTensorParallel (DTensorTestBase ):
43
43
"""Basic test case for tensor subclasses"""
44
44
45
- QUANT_METHOD_FN = staticmethod (int8_weight_only )
45
+ QUANT_METHOD_FN = staticmethod (Int8WeightOnlyConfig )
46
46
QUANT_METHOD_KWARGS = {}
47
47
48
48
@staticmethod
@@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
133
133
134
134
135
135
class TestInt8woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
136
- QUANT_METHOD_FN = staticmethod (int8_weight_only )
136
+ QUANT_METHOD_FN = staticmethod (Int8WeightOnlyConfig )
137
137
COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
138
138
139
139
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
@@ -144,7 +144,7 @@ def test_tp(self, dtype):
144
144
145
145
146
146
class TestInt4woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
147
- QUANT_METHOD_FN = staticmethod (int4_weight_only )
147
+ QUANT_METHOD_FN = staticmethod (Int4WeightOnlyConfig )
148
148
QUANT_METHOD_KWARGS = {"version" : 1 }
149
149
COMMON_DTYPES = [torch .bfloat16 ]
150
150
@@ -167,20 +167,20 @@ class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel):
167
167
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
168
168
@unittest .skipIf (not has_gemlite , "gemlite not available" )
169
169
def test_tp_gemlite (self , dtype ):
170
- from torchao .quantization import gemlite_uintx_weight_only
170
+ from torchao .quantization import GemliteUIntXWeightOnlyConfig
171
171
172
172
for packing_bitwidth in [32 , 8 ]:
173
173
for bit_width in [4 , 8 ]:
174
174
for group_size in [64 , 32 , None ] if bit_width == 4 else [None ]:
175
- api = lambda : gemlite_uintx_weight_only (
175
+ api = lambda : GemliteUIntXWeightOnlyConfig (
176
176
group_size , bit_width , packing_bitwidth
177
177
)
178
178
self .QUANT_METHOD_FN = staticmethod (api )
179
179
return self ._test_tp (dtype )
180
180
181
181
182
182
class TestInt8dqAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
183
- QUANT_METHOD_FN = staticmethod (int8_dynamic_activation_int8_weight )
183
+ QUANT_METHOD_FN = staticmethod (Int8DynamicActivationInt8WeightConfig )
184
184
COMMON_DTYPES = [torch .bfloat16 ]
185
185
186
186
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
@@ -199,7 +199,7 @@ def test_tp(self, dtype):
199
199
if torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 ):
200
200
201
201
class TestFloat8woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
202
- QUANT_METHOD_FN = staticmethod (float8_weight_only )
202
+ QUANT_METHOD_FN = staticmethod (Float8WeightOnlyConfig )
203
203
COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
204
204
205
205
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
@@ -211,7 +211,7 @@ def test_tp(self, dtype):
211
211
class TestFloat8dqTensorAffineQuantizedTensorParallel (
212
212
TestAffineQuantizedTensorParallel
213
213
):
214
- QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
214
+ QUANT_METHOD_FN = staticmethod (Float8DynamicActivationFloat8WeightConfig )
215
215
QUANT_METHOD_KWARGS = {"granularity" : PerTensor ()}
216
216
COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
217
217
@@ -224,7 +224,7 @@ def test_tp(self, dtype):
224
224
class TestFloat8dqRowAffineQuantizedTensorParallel (
225
225
TestAffineQuantizedTensorParallel
226
226
):
227
- QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
227
+ QUANT_METHOD_FN = staticmethod (Float8DynamicActivationFloat8WeightConfig )
228
228
QUANT_METHOD_KWARGS = {"granularity" : PerRow ()}
229
229
COMMON_DTYPES = [torch .bfloat16 ]
230
230
0 commit comments