diff --git a/test/test_ao_models.py b/test/test_ao_models.py index 79e4cc3ef5..9b71019f63 100644 --- a/test/test_ao_models.py +++ b/test/test_ao_models.py @@ -7,8 +7,10 @@ import torch from torchao._models.llama.model import Transformer +from torchao.utils import get_available_devices -_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + +_DEVICES = get_available_devices() def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): @@ -17,7 +19,7 @@ def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): return model.eval() -@pytest.mark.parametrize("device", _AVAILABLE_DEVICES) +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("is_training", [True, False]) def test_ao_llama_model_inference_mode(device, batch_size, is_training):