Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 38 additions & 34 deletions test/prototype/blockwise_fp8_training/test_blockwise_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@
torch_blockwise_scale_act_quant_rhs,
torch_blockwise_scale_weight_quant,
)
from torchao.testing.utils import skip_if_rocm
from torchao.utils import is_sm_at_least_90
from torchao.testing.utils import skip_if_rocm, skip_if_xpu
from torchao.utils import (
is_sm_at_least_90,
auto_detect_device,
)

_DEVICE = [auto_detect_device()]

BLOCKWISE_SIZE_MNK = [
(128, 128, 128),
Expand All @@ -37,19 +42,19 @@
(67, 6656, 1408),
]


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@pytest.mark.parametrize("device", _DEVICE)
@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@pytest.mark.skipif(
version.parse(triton.__version__) < version.parse("3.3.0"),
reason="Triton version < 3.3.0, test skipped",
)
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype):
@skip_if_xpu("XPU enablement in progress")
def test_blockwise_fp8_gemm_1x128_128x128(device, M, N, K, dtype):
# Simulate output = input @ weight.T
A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
A = torch.randn(M, K, dtype=torch.bfloat16, device=device)
B = torch.randn(N, K, dtype=torch.bfloat16, device=device)
C = A @ B.T
A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=dtype)
B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype)
Expand All @@ -60,19 +65,19 @@ def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype):
min_sqnr = 28.0
assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@pytest.mark.parametrize("device", _DEVICE)
@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@pytest.mark.skipif(
version.parse(triton.__version__) < version.parse("3.3.0"),
reason="Triton version < 3.3.0, test skipped",
)
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
def test_blockwise_fp8_gemm_1x128_128x1(M, N, K, dtype):
@skip_if_xpu("XPU enablement in progress")
def test_blockwise_fp8_gemm_1x128_128x1(device, M, N, K, dtype):
# Simulate grad_weight = grad_output_t @ input
A = torch.randn(K, M, dtype=torch.bfloat16, device="cuda")
B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda")
A = torch.randn(K, M, dtype=torch.bfloat16, device=device)
B = torch.randn(K, N, dtype=torch.bfloat16, device=device)
C = A.T @ B
A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype)
B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=dtype)
Expand All @@ -86,12 +91,11 @@ def test_blockwise_fp8_gemm_1x128_128x1(M, N, K, dtype):
min_sqnr = 28.0
assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}"


@pytest.mark.parametrize("device", _DEVICE)
@skip_if_rocm("ROCm not supported")
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@pytest.mark.skipif(torch.cuda.is_available and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@pytest.mark.parametrize("block_size", [128, 256])
def test_triton_quantize_fp8_act_quant_lhs(block_size):
device = "cuda"
def test_triton_quantize_fp8_act_quant_lhs(device, block_size):
M, K = 4096, 1024
x = torch.randn(M, K, device=device)

Expand Down Expand Up @@ -133,12 +137,12 @@ def test_triton_quantize_fp8_act_quant_lhs(block_size):
msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}",
)


@pytest.mark.parametrize("device", _DEVICE)
@skip_if_rocm("ROCm not supported")
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@pytest.mark.parametrize("block_size", [128, 256])
def test_triton_quantize_fp8_act_quant_rhs(block_size: int):
device = "cuda"
@skip_if_xpu("XPU enablement in progress")
def test_triton_quantize_fp8_act_quant_rhs(device, block_size: int):
M, K = 4096, 1024
x = torch.randn(M, K, device=device)

Expand Down Expand Up @@ -180,13 +184,13 @@ def test_triton_quantize_fp8_act_quant_rhs(block_size: int):
msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}",
)


@pytest.mark.parametrize("device", _DEVICE)
@skip_if_rocm("ROCm not supported")
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@skip_if_xpu("XPU enablement in progress")
@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@pytest.mark.parametrize("block_size", [128, 256])
@pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)])
def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int):
device = "cuda"
def test_triton_quantize_fp8_act_quant_transposed_lhs(device, M, K, block_size: int):
x = torch.randn(M, K, device=device)

# Set one scaling block to 0s, so if nan guards/EPS are not applied, the
Expand Down Expand Up @@ -229,13 +233,13 @@ def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int):
msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}",
)


@pytest.mark.parametrize("device", _DEVICE)
@skip_if_rocm("ROCm not supported")
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@skip_if_xpu("XPU enablement in progress")
@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@pytest.mark.parametrize("block_size", [128, 256])
@pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)])
def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int):
device = "cuda"
def test_triton_quantize_fp8_weight_quant_rhs(device, M, K, block_size: int):
x = torch.randn(M, K, device=device)

# Set one scaling block to 0s, so if nan guards/EPS are not applied, the
Expand Down Expand Up @@ -275,12 +279,12 @@ def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int):
msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}",
)


@pytest.mark.parametrize("device", _DEVICE)
@skip_if_rocm("ROCm not supported")
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@skip_if_xpu("XPU enablement in progress")
@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
@pytest.mark.parametrize("block_size", [128, 256])
def test_triton_quantize_fp8_weight_quant_transposed_rhs(block_size: int):
device = "cuda"
def test_triton_quantize_fp8_weight_quant_transposed_rhs(device, block_size: int):
M = 512
K = 2048
x = torch.randn(M, K, device=device)
Expand Down
12 changes: 6 additions & 6 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

from torchao.prototype.awq import AWQConfig, AWQStep
from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_
from torchao.utils import _is_fbgemm_genai_gpu_available
from torchao.utils import _is_fbgemm_genai_gpu_available, auto_detect_device

_DEVICE = auto_detect_device()

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
Expand All @@ -26,7 +27,7 @@ def __init__(self, m=512, n=256, k=128):
self.linear3 = torch.nn.Linear(k, 64, bias=False)

def example_inputs(
self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"
self, batch_size, sequence_length=10, dtype=torch.bfloat16, device=_DEVICE
):
return [
torch.randn(
Expand All @@ -42,7 +43,6 @@ def forward(self, x):
return x


@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(),
reason="need to install fbgemm_gpu_genai package",
Expand All @@ -62,7 +62,7 @@ def test_awq_config(self):
AWQConfig(base_config, step="not_supported")

def test_awq_functionality(self):
device = "cuda"
device = _DEVICE
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_awq_functionality(self):
assert loss_awq < loss_base

def test_awq_loading(self):
device = "cuda"
device = _DEVICE
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_awq_loading_vllm(self):

There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
"""
device = "cuda"
device = _DEVICE
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand Down
14 changes: 9 additions & 5 deletions test/prototype/test_blockwise_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from packaging import version

from torchao.utils import auto_detect_device

_DEVICE = auto_detect_device()

triton = pytest.importorskip("triton", reason="Triton required to run this test")

from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
Expand All @@ -18,6 +22,7 @@
fp8_blockwise_weight_quant,
)
from torchao.utils import is_sm_at_least_89
from torchao.testing.utils import skip_if_xpu

BLOCKWISE_SIZE_MNK = [
(2, 512, 128),
Expand All @@ -29,7 +34,6 @@
]


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK)
@pytest.mark.parametrize(
"dtype",
Expand All @@ -38,7 +42,7 @@
else [torch.float8_e5m2],
)
def test_blockwise_quant_dequant(_, N, K, dtype):
x = torch.randn(N, K).cuda()
x = torch.randn(N, K).to(_DEVICE)
qx, s = fp8_blockwise_weight_quant(x, dtype=dtype)
x_reconstructed = fp8_blockwise_weight_dequant(qx, s)
error = torch.norm(x - x_reconstructed) / torch.norm(x)
Expand All @@ -47,7 +51,6 @@ def test_blockwise_quant_dequant(_, N, K, dtype):
assert error < 0.1, "Quant-Dequant error is too high"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
version.parse(triton.__version__) < version.parse("3.3.0"),
reason="Triton version < 3.3.0, test skipped",
Expand All @@ -59,9 +62,10 @@ def test_blockwise_quant_dequant(_, N, K, dtype):
if is_sm_at_least_89()
else [torch.float8_e5m2],
)
@skip_if_xpu("XPU Enablement in Progress")
def test_blockwise_fp8_gemm(M, N, K, dtype):
A = torch.randn(M, K).cuda()
B = torch.randn(N, K).cuda()
A = torch.randn(M, K).to(_DEVICE)
B = torch.randn(N, K).to(_DEVICE)
C = A @ B.T
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype)
Expand Down
7 changes: 5 additions & 2 deletions test/prototype/test_parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
quantize_,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import check_cpu_version
from torchao.utils import check_cpu_version, auto_detect_device

_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from torchao.testing.utils import skip_if_xpu

_DEVICE = auto_detect_device()


def split_param_groups(model):
Expand Down Expand Up @@ -206,6 +208,7 @@ def setUp(self):
torch.manual_seed(123)

@common_utils.parametrize("group_size", [32, 256])
@skip_if_xpu("XPU Enablement in Progress")
def test_int4_weight_only(self, group_size: int = 32):
model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16)
model.reset_parameters()
Expand Down
21 changes: 12 additions & 9 deletions test/prototype/test_quantized_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,14 @@
quantize_int8_rowwise,
)
from torchao.quantization.quant_api import quantize_
from torchao.testing.utils import skip_if_xpu

if common_utils.SEED is None:
common_utils.SEED = 1234

_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
from torchao.utils import get_available_devices

_DEVICES = get_available_devices()


def _reset():
Expand Down Expand Up @@ -182,12 +185,13 @@ def test_int8_weight_only_training(self, compile, device):
],
)
@parametrize("module_swap", [False, True])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_int8_mixed_precision_training(self, compile, config, module_swap):
@parametrize("device", _DEVICES)
@skip_if_xpu("XPU enablement in progress")
def test_int8_mixed_precision_training(self, compile, config, module_swap, device):
_reset()
bsize = 64
embed_dim = 64
device = "cuda"
device = device

linear = nn.Linear(embed_dim, embed_dim, device=device)
linear_int8mp = copy.deepcopy(linear)
Expand Down Expand Up @@ -219,7 +223,6 @@ def snr(ref, actual):

@pytest.mark.skip("Flaky on CI")
@parametrize("compile", [False, True])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_bitnet_training(self, compile):
# reference implementation
# https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
Expand All @@ -244,7 +247,7 @@ def forward(self, x):
_reset()
bsize = 4
embed_dim = 32
device = "cuda"
device = _DEVICE

# only use 1 matmul shape to reduce triton autotune time
model_ref = nn.Sequential(
Expand Down Expand Up @@ -339,7 +342,7 @@ def _run_subtest(self, args):
dropout_p=0,
)
torch.manual_seed(42)
base_model = Transformer(model_args).cuda()
base_model = Transformer(model_args).to(_DEVICE)
fsdp_model = copy.deepcopy(base_model)

quantize_(base_model.layers, quantize_fn)
Expand All @@ -359,7 +362,7 @@ def _run_subtest(self, args):

torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(5):
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device=_DEVICE)
fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = fsdp_model(inp).sum()
fsdp_loss.backward()
Expand Down Expand Up @@ -390,7 +393,7 @@ def test_precompute_bitnet_scale(self):
precompute_bitnet_scale_for_fsdp,
)

model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda()
model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).to(_DEVICE)
model_fsdp = copy.deepcopy(model)
quantize_(model_fsdp, bitnet_training())
fully_shard(model_fsdp)
Expand Down
Loading