Skip to content

Commit ecb6c4b

Browse files
authored
Remove compute target from intx_opaque_tensor (#2960)
* Remove compute target from intx_opaque_tensor * up * up * up * up
1 parent b10876b commit ecb6c4b

File tree

5 files changed

+112
-112
lines changed

5 files changed

+112
-112
lines changed

test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def _get_accuracy_test_cases():
3535
]
3636

3737
PACKING_FORMATS = [
38-
(IntxPackingFormat.UNPACKED_TO_INT8, None),
39-
(IntxPackingFormat.OPAQUE, "aten"),
40-
(IntxPackingFormat.OPAQUE, "torchao_auto"),
41-
(IntxPackingFormat.OPAQUE, "torchao_lowbit"),
42-
(IntxPackingFormat.OPAQUE, "torchao_kleidiai"),
38+
IntxPackingFormat.UNPACKED_TO_INT8,
39+
IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI,
40+
IntxPackingFormat.OPAQUE_TORCHAO_AUTO,
41+
IntxPackingFormat.OPAQUE_TORCHAO_KLEIDIAI,
42+
IntxPackingFormat.OPAQUE_TORCHAO_LOWBIT,
4343
]
4444

4545
WEIGHT_DTYPES = [
@@ -64,13 +64,12 @@ def _get_accuracy_test_cases():
6464
def _is_valid_test_combination(
6565
model_dtype,
6666
packing_format,
67-
compute_target,
6867
weight_dtype,
6968
weight_mapping_type,
7069
weight_granularity,
7170
):
7271
# ATEN restrictions
73-
if (packing_format == IntxPackingFormat.OPAQUE) and (compute_target == "aten"):
72+
if packing_format == IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI:
7473
if weight_dtype != torch.int4:
7574
return False
7675
if weight_mapping_type == MappingType.ASYMMETRIC:
@@ -79,9 +78,7 @@ def _is_valid_test_combination(
7978
return False
8079

8180
# TORCHAO_KLEIDIAI restrictions
82-
if (packing_format == IntxPackingFormat.OPAQUE) and (
83-
compute_target == "torchao_kleidiai"
84-
):
81+
if packing_format == IntxPackingFormat.OPAQUE_TORCHAO_KLEIDIAI:
8582
if weight_dtype != torch.int4:
8683
return False
8784
if weight_mapping_type == MappingType.ASYMMETRIC:
@@ -100,17 +97,16 @@ def _is_valid_test_combination(
10097
param(
10198
model_dtype=mdt,
10299
packing_format=pf,
103-
compute_target=ct,
104100
weight_dtype=dt,
105101
weight_mapping_type=mt,
106102
weight_granularity=gr,
107103
)
108104
for mdt in MODEL_DTYPES
109-
for pf, ct in PACKING_FORMATS
105+
for pf in PACKING_FORMATS
110106
for dt in WEIGHT_DTYPES
111107
for mt in MAPPING_TYPES
112108
for gr in GRANULARITIES
113-
if _is_valid_test_combination(dt, pf, ct, dt, mt, gr)
109+
if _is_valid_test_combination(dt, pf, dt, mt, gr)
114110
]
115111

116112
return test_cases
@@ -126,7 +122,6 @@ def test_accuracy(
126122
self,
127123
model_dtype,
128124
packing_format,
129-
compute_target,
130125
weight_dtype,
131126
weight_mapping_type,
132127
weight_granularity,
@@ -149,8 +144,7 @@ def test_accuracy(
149144
weight_dtype=weight_dtype,
150145
weight_granularity=weight_granularity,
151146
weight_mapping_type=weight_mapping_type,
152-
packing_format=packing_format,
153-
compute_target=compute_target,
147+
intx_packing_format=packing_format,
154148
version=2,
155149
),
156150
)
@@ -162,8 +156,7 @@ def test_accuracy(
162156
weight_dtype=weight_dtype,
163157
weight_granularity=weight_granularity,
164158
weight_mapping_type=weight_mapping_type,
165-
packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
166-
compute_target=None,
159+
intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
167160
version=2,
168161
),
169162
)
@@ -209,8 +202,7 @@ def test_export_compile_aoti(
209202
weight_dtype=weight_dtype,
210203
weight_granularity=weight_granularity,
211204
weight_mapping_type=weight_mapping_type,
212-
packing_format=IntxPackingFormat.OPAQUE,
213-
compute_target="torchao_auto",
205+
intx_packing_format=IntxPackingFormat.OPAQUE_TORCHAO_AUTO,
214206
version=2,
215207
),
216208
)
@@ -241,15 +233,15 @@ def test_export_compile_aoti(
241233

242234
@parameterized.expand(
243235
[
244-
param(packing_format=pf, compute_target=ct)
245-
for (pf, ct) in [
246-
(IntxPackingFormat.OPAQUE, "torchao_auto"),
247-
(IntxPackingFormat.OPAQUE, "aten"),
236+
param(packing_format=pf)
237+
for pf in [
238+
IntxPackingFormat.OPAQUE_TORCHAO_AUTO,
239+
IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI,
248240
]
249241
],
250242
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
251243
)
252-
def test_serialization(self, packing_format, compute_target):
244+
def test_serialization(self, packing_format):
253245
layers = [
254246
torch.nn.Linear(512, 256),
255247
]
@@ -262,8 +254,7 @@ def test_serialization(self, packing_format, compute_target):
262254
Int8DynamicActivationIntxWeightConfig(
263255
weight_dtype=torch.int4,
264256
weight_granularity=PerGroup(64),
265-
packing_format=packing_format,
266-
compute_target=compute_target,
257+
intx_packing_format=packing_format,
267258
version=2,
268259
),
269260
)
@@ -305,8 +296,7 @@ def test_moe_quant_intx(self):
305296
out = model(x).clone()
306297

307298
base_config = Int8DynamicActivationIntxWeightConfig(
308-
packing_format=IntxPackingFormat.OPAQUE,
309-
compute_target="torchao_auto",
299+
intx_packing_format=IntxPackingFormat.OPAQUE_TORCHAO_AUTO,
310300
version=2,
311301
)
312302
moe_config = MoEQuantConfig(

test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def test_export_int8_dyn_act_intx_weight_config(self):
158158
weight_dtype=torch.int4,
159159
weight_granularity=PerAxis(0),
160160
weight_mapping_type=MappingType.SYMMETRIC,
161-
packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
161+
intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
162162
version=2,
163163
),
164164
)
@@ -194,7 +194,7 @@ def test_export_int8_dyn_act_intx_weight_config_with_unwrap(self):
194194
weight_dtype=torch.int4,
195195
weight_granularity=PerGroup(64),
196196
weight_mapping_type=MappingType.SYMMETRIC,
197-
packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
197+
intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
198198
version=2,
199199
),
200200
)
@@ -232,7 +232,7 @@ def test_serialization_int8_dyn_act_intx_weight_config(self):
232232
Int8DynamicActivationIntxWeightConfig(
233233
weight_dtype=torch.int4,
234234
weight_granularity=PerGroup(64),
235-
packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
235+
intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
236236
version=2,
237237
),
238238
)
@@ -262,7 +262,7 @@ def test_serialization_intx_weight_only_config(self):
262262
IntxWeightOnlyConfig(
263263
weight_dtype=torch.int4,
264264
granularity=PerGroup(64),
265-
packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
265+
intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
266266
version=2,
267267
),
268268
)
@@ -321,7 +321,7 @@ def test_qat_int8_dyn_act_intx_weight_config(
321321
weight_granularity=PerGroup(group_size),
322322
weight_mapping_type=mapping_type,
323323
weight_scale_dtype=scale_dtype,
324-
packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
324+
intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
325325
version=2,
326326
)
327327

@@ -429,7 +429,7 @@ def test_intx_unpacked_v2_is_close_to_qdq_v1(
429429
weight_mapping_type=mapping_type,
430430
weight_scale_dtype=scale_dtype,
431431
act_mapping_type=act_mapping_type,
432-
packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
432+
intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8,
433433
version=2,
434434
),
435435
)

torchao/quantization/quant_api.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -745,10 +745,8 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
745745
weight_scale_dtype: Optional[torch.dtype] = None
746746
act_mapping_type: MappingType = MappingType.ASYMMETRIC
747747
layout: Layout = QDQLayout()
748-
packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8
748+
intx_packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8
749749

750-
# Used with IntxPackingFormat.OPAQUE
751-
compute_target: Optional[str] = None
752750
version: int = 1
753751

754752
def __post_init__(self):
@@ -804,8 +802,7 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
804802
weight_scale_dtype = config.weight_scale_dtype
805803
act_mapping_type = config.act_mapping_type
806804
layout = config.layout
807-
packing_format = config.packing_format
808-
compute_target = config.compute_target
805+
intx_packing_format = config.intx_packing_format
809806

810807
assert weight.dim() == 2, (
811808
f"Int8DynamicActivationIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}"
@@ -826,10 +823,16 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
826823

827824
if config.version == 2:
828825
assert act_mapping_type == MappingType.ASYMMETRIC
829-
assert packing_format in [
830-
IntxPackingFormat.UNPACKED_TO_INT8,
831-
IntxPackingFormat.OPAQUE,
832-
], f"Unsupported packing format: {packing_format}"
826+
opaque_formats = [
827+
IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI,
828+
IntxPackingFormat.OPAQUE_TORCHAO_AUTO,
829+
IntxPackingFormat.OPAQUE_TORCHAO_KLEIDIAI,
830+
IntxPackingFormat.OPAQUE_TORCHAO_LOWBIT,
831+
]
832+
assert (
833+
intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8
834+
or intx_packing_format in opaque_formats
835+
), f"Unsupported packing format: {intx_packing_format}"
833836
new_weight = IntxUnpackedToInt8Tensor.from_hp(
834837
weight,
835838
block_size,
@@ -845,12 +848,9 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
845848
new_bias = bias
846849

847850
# Create packed tensor
848-
if packing_format == IntxPackingFormat.OPAQUE:
849-
assert compute_target is not None, (
850-
"Must specify a compute target for IntxPackingFormat.OPAQUE"
851-
)
851+
if intx_packing_format in opaque_formats:
852852
new_weight = IntxOpaqueTensor.from_intx_unpacked_to_int8_tensor(
853-
new_weight, bias=new_bias, compute_target=compute_target
853+
new_weight, bias=new_bias, intx_packing_format=intx_packing_format
854854
)
855855
new_bias = None # bias is packed with weights
856856

@@ -2113,7 +2113,7 @@ class IntxWeightOnlyConfig(AOBaseConfig):
21132113
mapping_type: MappingType = MappingType.SYMMETRIC
21142114
scale_dtype: Optional[torch.dtype] = None
21152115
layout: Layout = QDQLayout()
2116-
packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8
2116+
intx_packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8
21172117
version: int = 1
21182118

21192119
def __post_init__(self):
@@ -2142,7 +2142,7 @@ def _intx_weight_only_quantize_tensor(weight, config):
21422142
mapping_type = config.mapping_type
21432143
scale_dtype = config.scale_dtype
21442144
layout = config.layout
2145-
packing_format = config.packing_format
2145+
intx_packing_format = config.intx_packing_format
21462146

21472147
assert weight.dim() == 2, (
21482148
f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}"
@@ -2160,7 +2160,7 @@ def _intx_weight_only_quantize_tensor(weight, config):
21602160
block_size = (1, group_size)
21612161

21622162
if config.version == 2:
2163-
if config.packing_format == IntxPackingFormat.UNPACKED_TO_INT8:
2163+
if config.intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8:
21642164
new_weight = IntxUnpackedToInt8Tensor.from_hp(
21652165
weight,
21662166
block_size,
@@ -2174,7 +2174,7 @@ def _intx_weight_only_quantize_tensor(weight, config):
21742174

21752175
return new_weight
21762176
else:
2177-
raise ValueError(f"Unsupported packing format: {packing_format}")
2177+
raise ValueError(f"Unsupported packing format: {intx_packing_format}")
21782178

21792179
# Version 1
21802180
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype]

0 commit comments

Comments
 (0)