Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import torch

from torchao.quantization.quant_api import (
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
_replace_with_custom_fn_if_matches_filter,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
from torchao.quantization.subclass import (
Expand All @@ -23,13 +23,13 @@


def _int8wo_api(mod, **kwargs):
quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False)
quantize_(mod, Int8WeightOnlyConfig(**kwargs), set_inductor_config=False)


def _int8da_int8w_api(mod, **kwargs):
quantize_(
mod,
int8_dynamic_activation_int8_weight(**kwargs),
Int8DynamicActivationInt8WeightConfig(**kwargs),
set_inductor_config=False,
)

Expand All @@ -39,7 +39,7 @@ def _int4wo_api(mod, **kwargs):
if "groupsize" in kwargs_copy:
kwargs_copy["group_size"] = kwargs_copy["groupsize"]
del kwargs_copy["groupsize"]
quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False)
quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy), set_inductor_config=False)


class ToyLinearModel(torch.nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/quantized_training/pretrain_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def insert_rmsnorm(module: torch.nn.Module):
insert_rmsnorm(model.layers)

# don't apply int8_mixed_precision to LM head, since it can cause convergence issue.
# TODO: might want to do the same for int8_weight_only to standardize.
if args.quantize == "int8_weight_only":
# TODO: might want to do the same for Int8WeightOnlyConfig to standardize.
if args.quantize == "Int8WeightOnlyConfig":
quantize_(
model, int8_weight_only_quantized_training(), set_inductor_config=False
)
Expand Down
36 changes: 17 additions & 19 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@
from torchao.float8.config import e4m3_dtype
from torchao.quantization import (
FbgemmConfig,
Float8WeightOnlyConfig,
GemliteUIntXWeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
float8_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
Int8WeightOnlyConfig,
quantize_,
)
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
Expand All @@ -58,49 +56,49 @@ def get_quantization_functions(
do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False
):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(),
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
Int8WeightOnlyConfig(),
Int8DynamicActivationInt4WeightConfig(),
Int8DynamicActivationInt8WeightConfig(),
Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC),
]
if do_int4:
if check_cpu_version(device):
base_functions.append(
int4_weight_only(group_size=32, layout=Int4CPULayout(), version=1)
Int4WeightOnlyConfig(group_size=32, layout=Int4CPULayout(), version=1)
)
elif check_xpu_version(device):
base_functions.append(
int4_weight_only(group_size=32, layout=Int4XPULayout(), version=1)
Int4WeightOnlyConfig(group_size=32, layout=Int4XPULayout(), version=1)
)
if int4_zp_int:
base_functions.append(
int4_weight_only(
Int4WeightOnlyConfig(
group_size=32,
layout=Int4XPULayout(),
zero_point_domain=ZeroPointDomain.INT,
version=1,
)
)
else:
base_functions.append(int4_weight_only(group_size=32, version=1))
base_functions.append(Int4WeightOnlyConfig(group_size=32, version=1))
if device == "cuda" and not is_ROCM():
base_functions.append(
int8_dynamic_activation_int4_weight(
Int8DynamicActivationInt4WeightConfig(
group_size=None,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=CutlassInt4PackedLayout(),
)
)
base_functions.append(int4_dynamic_activation_int4_weight())
base_functions.append(Int4DynamicActivationInt4WeightConfig())

if do_sparse and device != "xpu":
base_functions.append(
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout())
)

if is_sm_at_least_89():
base_functions.append(float8_weight_only())
base_functions.append(Float8WeightOnlyConfig())

if is_sm_at_least_90():
base_functions.append(FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16))
Expand All @@ -119,7 +117,7 @@ def test_tensor_core_layout_transpose(self):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = linear.weight
shape = t.shape
apply_int4_weight_only_quant = int4_weight_only(group_size=32, version=1)
apply_int4_weight_only_quant = Int4WeightOnlyConfig(group_size=32, version=1)
quantize_(linear, apply_int4_weight_only_quant)
ql = linear
aqt = ql.weight
Expand Down
28 changes: 14 additions & 14 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
)

from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao.quantization.quant_api import quantize_
Expand All @@ -42,7 +42,7 @@
class TestAffineQuantizedTensorParallel(DTensorTestBase):
"""Basic test case for tensor subclasses"""

QUANT_METHOD_FN = staticmethod(int8_weight_only)
QUANT_METHOD_FN = staticmethod(Int8WeightOnlyConfig)
QUANT_METHOD_KWARGS = {}

@staticmethod
Expand Down Expand Up @@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class TestInt8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(int8_weight_only)
QUANT_METHOD_FN = staticmethod(Int8WeightOnlyConfig)
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]

@common_utils.parametrize("dtype", COMMON_DTYPES)
Expand All @@ -144,7 +144,7 @@ def test_tp(self, dtype):


class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(int4_weight_only)
QUANT_METHOD_FN = staticmethod(Int4WeightOnlyConfig)
QUANT_METHOD_KWARGS = {"version": 1}
COMMON_DTYPES = [torch.bfloat16]

Expand All @@ -167,20 +167,20 @@ class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not has_gemlite, "gemlite not available")
def test_tp_gemlite(self, dtype):
from torchao.quantization import gemlite_uintx_weight_only
from torchao.quantization import GemliteUIntXWeightOnlyConfig

for packing_bitwidth in [32, 8]:
for bit_width in [4, 8]:
for group_size in [64, 32, None] if bit_width == 4 else [None]:
api = lambda: gemlite_uintx_weight_only(
api = lambda: GemliteUIntXWeightOnlyConfig(
group_size, bit_width, packing_bitwidth
)
self.QUANT_METHOD_FN = staticmethod(api)
return self._test_tp(dtype)


class TestInt8dqAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(int8_dynamic_activation_int8_weight)
QUANT_METHOD_FN = staticmethod(Int8DynamicActivationInt8WeightConfig)
COMMON_DTYPES = [torch.bfloat16]

@common_utils.parametrize("dtype", COMMON_DTYPES)
Expand All @@ -199,7 +199,7 @@ def test_tp(self, dtype):
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):

class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(float8_weight_only)
QUANT_METHOD_FN = staticmethod(Float8WeightOnlyConfig)
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]

@common_utils.parametrize("dtype", COMMON_DTYPES)
Expand All @@ -211,7 +211,7 @@ def test_tp(self, dtype):
class TestFloat8dqTensorAffineQuantizedTensorParallel(
TestAffineQuantizedTensorParallel
):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_FN = staticmethod(Float8DynamicActivationFloat8WeightConfig)
QUANT_METHOD_KWARGS = {"granularity": PerTensor()}
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]

Expand All @@ -224,7 +224,7 @@ def test_tp(self, dtype):
class TestFloat8dqRowAffineQuantizedTensorParallel(
TestAffineQuantizedTensorParallel
):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_FN = staticmethod(Float8DynamicActivationFloat8WeightConfig)
QUANT_METHOD_KWARGS = {"granularity": PerRow()}
COMMON_DTYPES = [torch.bfloat16]

Expand Down
4 changes: 2 additions & 2 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
_floatx_unpacked_to_f32,
)
from torchao.quantization import (
fpx_weight_only,
FPXWeightOnlyConfig,
quantize_,
)
from torchao.testing.utils import skip_if_rocm
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias, dtype):

linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype)
fpx_linear = copy.deepcopy(linear)
quantize_(fpx_linear, fpx_weight_only(ebits, mbits))
quantize_(fpx_linear, FPXWeightOnlyConfig(ebits, mbits))

x = torch.randn(N, IC, device=device, dtype=dtype)
expected = fpx_linear(x)
Expand Down
17 changes: 6 additions & 11 deletions test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from torchao.dtypes.uintx.uintx_layout import to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.quantization.quant_api import UIntXWeightOnlyConfig, quantize_
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
Expand Down Expand Up @@ -60,7 +60,7 @@ def forward(self, x):
def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size):
scale = 512
fp16_mod_on_cpu = Linear16(scale, "cpu")
quantize_(fp16_mod_on_cpu, uintx_weight_only(dtype, group_size=group_size))
quantize_(fp16_mod_on_cpu, UIntXWeightOnlyConfig(dtype, group_size=group_size))
test_input_on_cpu = torch.randn(scale * 2, dtype=torch.float16, device="cpu")
output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu)
fp16_mod_on_cuda = fp16_mod_on_cpu.to("cuda")
Expand All @@ -78,7 +78,7 @@ def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size):
def test_uintx_weight_only_model_quant(dtype, group_size, device):
scale = 512
fp16 = Linear16(scale, device)
quantize_(fp16, uintx_weight_only(dtype, group_size=group_size))
quantize_(fp16, UIntXWeightOnlyConfig(dtype, group_size=group_size))
uintx = torch.compile(fp16, fullgraph=True)
test_input = torch.randn(scale * 2, dtype=torch.float16, device=device)
output = uintx.forward(test_input)
Expand Down Expand Up @@ -124,30 +124,25 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_uintx_target_dtype(dtype):
from torchao.quantization.quant_api import uintx_weight_only

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
# make sure it runs
quantize_(linear, uintx_weight_only(dtype))
quantize_(linear, UIntXWeightOnlyConfig(dtype))
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_uintx_target_dtype_compile(dtype):
from torchao.quantization.quant_api import uintx_weight_only

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
# make sure it runs
quantize_(linear, uintx_weight_only(dtype))
quantize_(linear, UIntXWeightOnlyConfig(dtype))
linear = torch.compile(linear)
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_uintx_model_size(dtype):
from torchao.quantization.quant_api import uintx_weight_only
from torchao.utils import get_model_size_in_bytes

# scale size = 1/64 * 2 bytes = 1/32 bytes
Expand All @@ -167,6 +162,6 @@ def test_uintx_model_size(dtype):
)
bf16_size = get_model_size_in_bytes(linear)
# make sure it runs
quantize_(linear[0], uintx_weight_only(dtype))
quantize_(linear[0], UIntXWeightOnlyConfig(dtype))
quantized_size = get_model_size_in_bytes(linear)
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size
10 changes: 6 additions & 4 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import torch

from torchao.quantization import (
Int4WeightOnlyConfig,
MappingType,
UIntXWeightOnlyConfig,
ZeroPointDomain,
int4_weight_only,
quantize_,
uintx_weight_only,
)
from torchao.testing.utils import skip_if_rocm

Expand Down Expand Up @@ -55,9 +55,11 @@ def _eval_hqq(dtype):
)
dummy_linear.weight.data = W
if dtype == torch.uint4:
config = int4_weight_only(group_size=max(block_size), use_hqq=True, version=1)
config = Int4WeightOnlyConfig(
group_size=max(block_size), use_hqq=True, version=1
)
else:
config = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)
config = UIntXWeightOnlyConfig(dtype, group_size=max(block_size), use_hqq=True)
quantize_(dummy_linear, config)
q_tensor_hqq = dummy_linear.weight

Expand Down
Loading
Loading