Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions .github/workflows/cloud-tpu-presubmit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Cloud TPU CI
name: Cloud TPU Presubmit
# Run on pull_request that is labeled as "optional_ci_tpu" or workflow dispatch
on:
pull_request:
branches:
- main
types: [labeled, synchronize]
workflow_dispatch:
# Cancel any previous iterations if a new commit is pushed
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
cloud-tpu-test:
# TODO: confirm final naming for optional label
if: contains(github.event.pull_request.labels.*.name, 'optional_ci_tpu')
name: "TPU v5e x 8 Presubmit"
env:
ENABLE_PJRT_COMPATIBILITY: 1
# TODO: Needs final runs-on value
runs-on: arc-linux-x86-ct5lp-224-8tpu
container:
# TODO: Needs newer, light weight image
image: index.docker.io/tensorflow/build@sha256:7fb38f0319bda36393cad7f40670aa22352b44421bb906f5cf34d543acd8e1d2 # ratchet:tensorflow/build:latest-python3.11
timeout-minutes: 45
defaults:
run:
shell: bash -ex {0}
steps:
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Install JAX test requirements
run: |
pip install -U -r build/test-requirements.txt
- name: DEBUG HALT
run: |
echo "Halting"
sleep 180m
# TODO: build jax should be done on a step prior or we should just bazel test
- name: Build JAX
run: |
pip uninstall -y jaxlib
python3 build/build.py --use_clang
pip install -e .
ls -la dist/*.whl
pip install dist/*.whl
# Note the version it installs! Should be today's date
pip install -U --no-index --pre libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Everything being built in this step downgrades numpy, reupgrade it
pip install "numpy>=2.0.0"
python3 -c 'import sys; print("python version:", sys.version)'
python3 -c 'import jax; print("jax version:", jax.__version__)'
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
# strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on'
python3 -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
PY_COLORS: 1
NUM_TESTS: 8
JAX_NUM_GENERATED_CASES: 25
run: |
# Run single-accelerator tests in parallel
mkdir results
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=$NUM_TESTS --tb=short \
--junitxml=results/singlejunit.xml --maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --junitxml=results/multijunit.xml \
--maxfail=20 -m "multiaccelerator" tests
- name: 'Upload Artifact'
if: success() || failure()
uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # ratchet:actions/upload-artifact@v4
with:
name: junit
path: |
results/singlejunit.xml
results/multijunit.xml
retention-days: 1