diff --git a/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py b/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py index e8e855232c..b4428d94d8 100644 --- a/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py +++ b/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py @@ -23,8 +23,13 @@ torch_blockwise_scale_act_quant_rhs, torch_blockwise_scale_weight_quant, ) -from torchao.testing.utils import skip_if_rocm -from torchao.utils import is_sm_at_least_90 +from torchao.testing.utils import skip_if_rocm, skip_if_xpu +from torchao.utils import ( + is_sm_at_least_90, + auto_detect_device, +) + +_DEVICE = [auto_detect_device()] BLOCKWISE_SIZE_MNK = [ (128, 128, 128), @@ -37,19 +42,19 @@ (67, 6656, 1408), ] - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.parametrize("device", _DEVICE) +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.skipif( version.parse(triton.__version__) < version.parse("3.3.0"), reason="Triton version < 3.3.0, test skipped", ) @pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype): +@skip_if_xpu("XPU enablement in progress") +def test_blockwise_fp8_gemm_1x128_128x128(device, M, N, K, dtype): # Simulate output = input @ weight.T - A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") - B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) C = A @ B.T A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=dtype) B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype) @@ -60,19 +65,19 @@ def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype): min_sqnr = 28.0 assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}" - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.parametrize("device", _DEVICE) +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.skipif( version.parse(triton.__version__) < version.parse("3.3.0"), reason="Triton version < 3.3.0, test skipped", ) @pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -def test_blockwise_fp8_gemm_1x128_128x1(M, N, K, dtype): +@skip_if_xpu("XPU enablement in progress") +def test_blockwise_fp8_gemm_1x128_128x1(device, M, N, K, dtype): # Simulate grad_weight = grad_output_t @ input - A = torch.randn(K, M, dtype=torch.bfloat16, device="cuda") - B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda") + A = torch.randn(K, M, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) C = A.T @ B A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype) B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=dtype) @@ -86,12 +91,11 @@ def test_blockwise_fp8_gemm_1x128_128x1(M, N, K, dtype): min_sqnr = 28.0 assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}" - +@pytest.mark.parametrize("device", _DEVICE) @skip_if_rocm("ROCm not supported") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.skipif(torch.cuda.is_available and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.parametrize("block_size", [128, 256]) -def test_triton_quantize_fp8_act_quant_lhs(block_size): - device = "cuda" +def test_triton_quantize_fp8_act_quant_lhs(device, block_size): M, K = 4096, 1024 x = torch.randn(M, K, device=device) @@ -133,12 +137,12 @@ def test_triton_quantize_fp8_act_quant_lhs(block_size): msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", ) - +@pytest.mark.parametrize("device", _DEVICE) @skip_if_rocm("ROCm not supported") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.parametrize("block_size", [128, 256]) -def test_triton_quantize_fp8_act_quant_rhs(block_size: int): - device = "cuda" +@skip_if_xpu("XPU enablement in progress") +def test_triton_quantize_fp8_act_quant_rhs(device, block_size: int): M, K = 4096, 1024 x = torch.randn(M, K, device=device) @@ -180,13 +184,13 @@ def test_triton_quantize_fp8_act_quant_rhs(block_size: int): msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", ) - +@pytest.mark.parametrize("device", _DEVICE) @skip_if_rocm("ROCm not supported") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@skip_if_xpu("XPU enablement in progress") +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.parametrize("block_size", [128, 256]) @pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)]) -def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int): - device = "cuda" +def test_triton_quantize_fp8_act_quant_transposed_lhs(device, M, K, block_size: int): x = torch.randn(M, K, device=device) # Set one scaling block to 0s, so if nan guards/EPS are not applied, the @@ -229,13 +233,13 @@ def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int): msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", ) - +@pytest.mark.parametrize("device", _DEVICE) @skip_if_rocm("ROCm not supported") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@skip_if_xpu("XPU enablement in progress") +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.parametrize("block_size", [128, 256]) @pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)]) -def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int): - device = "cuda" +def test_triton_quantize_fp8_weight_quant_rhs(device, M, K, block_size: int): x = torch.randn(M, K, device=device) # Set one scaling block to 0s, so if nan guards/EPS are not applied, the @@ -275,12 +279,12 @@ def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int): msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", ) - +@pytest.mark.parametrize("device", _DEVICE) @skip_if_rocm("ROCm not supported") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@skip_if_xpu("XPU enablement in progress") +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.parametrize("block_size", [128, 256]) -def test_triton_quantize_fp8_weight_quant_transposed_rhs(block_size: int): - device = "cuda" +def test_triton_quantize_fp8_weight_quant_transposed_rhs(device, block_size: int): M = 512 K = 2048 x = torch.randn(M, K, device=device) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 181445470e..f44a4e623c 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -15,8 +15,9 @@ from torchao.prototype.awq import AWQConfig, AWQStep from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_ -from torchao.utils import _is_fbgemm_genai_gpu_available +from torchao.utils import _is_fbgemm_genai_gpu_available, auto_detect_device +_DEVICE = auto_detect_device() class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): @@ -26,7 +27,7 @@ def __init__(self, m=512, n=256, k=128): self.linear3 = torch.nn.Linear(k, 64, bias=False) def example_inputs( - self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda" + self, batch_size, sequence_length=10, dtype=torch.bfloat16, device=_DEVICE ): return [ torch.randn( @@ -42,7 +43,6 @@ def forward(self, x): return x -@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") @unittest.skipIf( not _is_fbgemm_genai_gpu_available(), reason="need to install fbgemm_gpu_genai package", @@ -62,7 +62,7 @@ def test_awq_config(self): AWQConfig(base_config, step="not_supported") def test_awq_functionality(self): - device = "cuda" + device = _DEVICE dataset_size = 100 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs @@ -111,7 +111,7 @@ def test_awq_functionality(self): assert loss_awq < loss_base def test_awq_loading(self): - device = "cuda" + device = _DEVICE dataset_size = 100 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs @@ -171,7 +171,7 @@ def test_awq_loading_vllm(self): There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo """ - device = "cuda" + device = _DEVICE dataset_size = 100 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs diff --git a/test/prototype/test_blockwise_triton.py b/test/prototype/test_blockwise_triton.py index 1c79ed9b23..19c237c0ca 100644 --- a/test/prototype/test_blockwise_triton.py +++ b/test/prototype/test_blockwise_triton.py @@ -9,6 +9,10 @@ from packaging import version +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() + triton = pytest.importorskip("triton", reason="Triton required to run this test") from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import ( @@ -18,6 +22,7 @@ fp8_blockwise_weight_quant, ) from torchao.utils import is_sm_at_least_89 +from torchao.testing.utils import skip_if_xpu BLOCKWISE_SIZE_MNK = [ (2, 512, 128), @@ -29,7 +34,6 @@ ] -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK) @pytest.mark.parametrize( "dtype", @@ -38,7 +42,7 @@ else [torch.float8_e5m2], ) def test_blockwise_quant_dequant(_, N, K, dtype): - x = torch.randn(N, K).cuda() + x = torch.randn(N, K).to(_DEVICE) qx, s = fp8_blockwise_weight_quant(x, dtype=dtype) x_reconstructed = fp8_blockwise_weight_dequant(qx, s) error = torch.norm(x - x_reconstructed) / torch.norm(x) @@ -47,7 +51,6 @@ def test_blockwise_quant_dequant(_, N, K, dtype): assert error < 0.1, "Quant-Dequant error is too high" -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( version.parse(triton.__version__) < version.parse("3.3.0"), reason="Triton version < 3.3.0, test skipped", @@ -59,9 +62,10 @@ def test_blockwise_quant_dequant(_, N, K, dtype): if is_sm_at_least_89() else [torch.float8_e5m2], ) +@skip_if_xpu("XPU Enablement in Progress") def test_blockwise_fp8_gemm(M, N, K, dtype): - A = torch.randn(M, K).cuda() - B = torch.randn(N, K).cuda() + A = torch.randn(M, K).to(_DEVICE) + B = torch.randn(N, K).to(_DEVICE) C = A @ B.T A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype) B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype) diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index a25ce2301d..4507dddbf1 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -39,9 +39,11 @@ quantize_, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import check_cpu_version +from torchao.utils import check_cpu_version, auto_detect_device -_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from torchao.testing.utils import skip_if_xpu + +_DEVICE = auto_detect_device() def split_param_groups(model): @@ -206,6 +208,7 @@ def setUp(self): torch.manual_seed(123) @common_utils.parametrize("group_size", [32, 256]) + @skip_if_xpu("XPU Enablement in Progress") def test_int4_weight_only(self, group_size: int = 32): model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16) model.reset_parameters() diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 836e2c302e..ab2c975449 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -34,11 +34,14 @@ quantize_int8_rowwise, ) from torchao.quantization.quant_api import quantize_ +from torchao.testing.utils import skip_if_xpu if common_utils.SEED is None: common_utils.SEED = 1234 -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +from torchao.utils import get_available_devices + +_DEVICES = get_available_devices() def _reset(): @@ -182,12 +185,13 @@ def test_int8_weight_only_training(self, compile, device): ], ) @parametrize("module_swap", [False, True]) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_int8_mixed_precision_training(self, compile, config, module_swap): + @parametrize("device", _DEVICES) + @skip_if_xpu("XPU enablement in progress") + def test_int8_mixed_precision_training(self, compile, config, module_swap, device): _reset() bsize = 64 embed_dim = 64 - device = "cuda" + device = device linear = nn.Linear(embed_dim, embed_dim, device=device) linear_int8mp = copy.deepcopy(linear) @@ -219,7 +223,6 @@ def snr(ref, actual): @pytest.mark.skip("Flaky on CI") @parametrize("compile", [False, True]) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_bitnet_training(self, compile): # reference implementation # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf @@ -244,7 +247,7 @@ def forward(self, x): _reset() bsize = 4 embed_dim = 32 - device = "cuda" + device = _DEVICE # only use 1 matmul shape to reduce triton autotune time model_ref = nn.Sequential( @@ -339,7 +342,7 @@ def _run_subtest(self, args): dropout_p=0, ) torch.manual_seed(42) - base_model = Transformer(model_args).cuda() + base_model = Transformer(model_args).to(_DEVICE) fsdp_model = copy.deepcopy(base_model) quantize_(base_model.layers, quantize_fn) @@ -359,7 +362,7 @@ def _run_subtest(self, args): torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): - inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device=_DEVICE) fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) fsdp_loss = fsdp_model(inp).sum() fsdp_loss.backward() @@ -390,7 +393,7 @@ def test_precompute_bitnet_scale(self): precompute_bitnet_scale_for_fsdp, ) - model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda() + model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).to(_DEVICE) model_fsdp = copy.deepcopy(model) quantize_(model_fsdp, bitnet_training()) fully_shard(model_fsdp)