diff --git a/pyproject.toml b/pyproject.toml index c4a06e532..cf74cce98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ dependencies = [ "fsspec[gcs]>=2024.6.0", "gym-aloha>=0.1.1", "imageio>=2.36.1", - "jax[cuda12]==0.5.3", + 'jax[cuda12]==0.5.3; sys_platform == "linux"', + 'jax==0.5.3; sys_platform == "darwin"', "jaxtyping==0.2.36", "lerobot", "ml_collections==1.0.0",