Skip to content

Commit 1a4a232

Browse files
committed
Use uv to increase speed of pip installs
1 parent 72fede0 commit 1a4a232

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

.github/workflows/pytorch-version-tests.yml

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ on:
66
# Run at 00:00 UTC Every Day
77
- cron: "0 0 * * *"
88
workflow_dispatch:
9+
push:
10+
branches:
11+
- master
12+
pull_request:
13+
branches:
14+
- master
915

1016
jobs:
1117
build:
@@ -15,14 +21,13 @@ jobs:
1521
max-parallel: 5
1622
fail-fast: false
1723
matrix:
18-
python-version: [3.9, "3.10", "3.11"]
24+
python-version: ["3.9", "3.10", "3.11"]
1925
pytorch-version: [2.5.1, 2.4.1, 2.3.1, 2.2.2, 1.13.1, 1.12.1, 1.10.0]
2026
exclude:
2127
- pytorch-version: 1.10.0
2228
python-version: "3.10"
2329
- pytorch-version: 1.10.0
2430
python-version: "3.11"
25-
2631
- pytorch-version: 1.11.0
2732
python-version: "3.10"
2833
- pytorch-version: 1.11.0
@@ -68,22 +73,24 @@ jobs:
6873
- name: Install dependencies
6974
shell: bash -l {0}
7075
run: |
71-
conda install pytorch=${{ matrix.pytorch-version }} torchvision cpuonly python=${{ matrix.python-version }} -c pytorch
7276
77+
conda install pytorch=${{ matrix.pytorch-version }} torchvision cpuonly python=${{ matrix.python-version }} -c pytorch -y
78+
pip install uv
79+
7380
# We should install numpy<2.0 for pytorch<2.3
7481
numpy_one_pth_version=$(python -c "import torch; print(float('.'.join(torch.__version__.split('.')[:2])) < 2.3)")
7582
if [ "${numpy_one_pth_version}" == "True" ]; then
76-
pip install -U "numpy<2.0"
83+
uv pip install "numpy<2.0"
7784
fi
7885

79-
pip install -r requirements-dev.txt
80-
pip install .
86+
uv pip install -r requirements-dev.txt
87+
uv pip install .
8188

8289
# pytorch>=1.9.0,<1.11.0 is using "from setuptools import distutils; distutils.version.LooseVersion" anti-pattern
8390
# which raises the error: AttributeError: module 'distutils' has no attribute 'version' for setuptools>59
8491
bad_pth_version=$(python -c "import torch; print('.'.join(torch.__version__.split('.')[:2]) in ['1.9', '1.10'])")
8592
if [ "${bad_pth_version}" == "True" ]; then
86-
pip install --upgrade "setuptools<59"
93+
uv pip install "setuptools<59"
8794
python -c "from setuptools import distutils; distutils.version.LooseVersion"
8895
fi
8996

0 commit comments

Comments
 (0)