From acd1d9c7c1aa1485dba1f8a1197c9cc109243e5e Mon Sep 17 00:00:00 2001 From: "Sun, Diwei" Date: Tue, 19 Aug 2025 08:14:36 +0000 Subject: [PATCH] xpu ut enabling: test_ao_models --- test/test_ao_models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_ao_models.py b/test/test_ao_models.py index 79e4cc3ef5..9b71019f63 100644 --- a/test/test_ao_models.py +++ b/test/test_ao_models.py @@ -7,8 +7,10 @@ import torch from torchao._models.llama.model import Transformer +from torchao.utils import get_available_devices -_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + +_DEVICES = get_available_devices() def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): @@ -17,7 +19,7 @@ def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): return model.eval() -@pytest.mark.parametrize("device", _AVAILABLE_DEVICES) +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("is_training", [True, False]) def test_ao_llama_model_inference_mode(device, batch_size, is_training):