From f2372d0c428ae29bc4c4fefa69a6d71667f18790 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 13 Nov 2024 19:10:58 +0000 Subject: [PATCH 001/205] debug bazel cpu rbe presubmit for windows --- .github/workflows/bazel_cpu_rbe.yml | 8 ++++---- ci/run_bazel_test_cpu_rbe.sh | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index bae03def7d53..29a6b2086384 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -1,9 +1,9 @@ name: Run Bazel CPU tests (RBE) on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -24,7 +24,7 @@ jobs: shell: bash strategy: matrix: - runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] + runner: ["windows-x86-n2-64"] #, "linux-x86-n2-16", "linux-arm64-t2a-16"] runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12') || diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index 2748e82ec60e..a2048bea467f 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -36,6 +36,13 @@ source "ci/utilities/setup_build_environment.sh" os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) +# Adjust the values when running on Windows x86 to match the config in +# .bazelrc +if [[ $os =~ "msys_nt" ]] && [[ $arch == "x86_64" ]]; then + os="windows" + arch="amd64" +fi + # When running on Mac or Linux Aarch64, we only build the test targets and # not run them. These platforms do not have native RBE support so we # RBE cross-compile them on remote Linux x86 machines. As the tests still From fc82f13d90b2026f0b7500bc224b4b856a777200 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 13 Nov 2024 20:29:09 +0000 Subject: [PATCH 002/205] enable workflows --- .github/workflows/bazel_gpu_non_rbe.yml | 6 +++--- .github/workflows/pytest_cpu.yml | 6 +++--- .github/workflows/pytest_gpu.yml | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index ba11ac486001..346af78459a3 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -1,9 +1,9 @@ name: Run Bazel GPU tests (non RBE) on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index ec2dbfa7686b..760c8882423a 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -1,9 +1,9 @@ name: Run Pytest CPU tests on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index ac6d9e39e168..94ef41723d8c 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -1,9 +1,9 @@ name: Run Pytest GPU tests on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: From 08ef29ab16fa7c52334957541fb4a5e7c307e850 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 14 Nov 2024 20:24:00 +0000 Subject: [PATCH 003/205] update linux arm64 container to ml-build --- .github/workflows/build_artifacts.yml | 8 ++++---- .github/workflows/pytest_cpu.yml | 2 +- ci/envs/docker.env | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index fc4e2df2bda4..a428c3e79512 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -1,9 +1,9 @@ name: Build JAX Artifacts on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -54,7 +54,7 @@ jobs: runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(matrix.runner, 'windows-x86') && null) }} env: diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 760c8882423a..79153889a391 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -29,7 +29,7 @@ jobs: runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(matrix.runner, 'windows-x86') && null) }} env: diff --git a/ci/envs/docker.env b/ci/envs/docker.env index 77376e4c6578..3832c095c85d 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -29,7 +29,7 @@ fi # Linux Aarch64 specifc settings if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest" + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest" fi # Windows specific settings From b77573d1495148efefc972e23d1b1f855fb73add Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 14 Nov 2024 20:24:29 +0000 Subject: [PATCH 004/205] upload and download artifacts from gcs bucket --- .github/workflows/pytest_gpu.yml | 70 +++++++++++++++----------------- 1 file changed, 32 insertions(+), 38 deletions(-) diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index 94ef41723d8c..cd0b2ccfeb5c 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -42,6 +42,38 @@ jobs: run: ./ci/build_artifacts.sh "jax-cuda-plugin" - name: Build jax-cuda-pjrt run: ./ci/build_artifacts.sh "jax-cuda-pjrt" + - name: Upload artifacts to GCS bucket + run: gsutil -m rsync -d -r gs://general-ml-ci-transient/jax-github-actions/${{ github.workflow }}/${{ github.run_id }}/${{ github.run_attempt }} + + run_tests: + needs: build_artifacts + strategy: + matrix: + test_env: [ + {cuda_version: "12.3", runner: "linux-x86-g2-48-l4-4gpu", + image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, + {cuda_version: "12.1", runner: "linux-x86-g2-48-l4-4gpu", + image: "gcr.io/tensorflow-testing/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, + ] + python: ["3.10"] + + runs-on: ${{ matrix.test_env.runner }} + container: + image: ${{ matrix.test_env.image }} + + name: "Pytest GPU (Test on CUDA ${{ matrix.test_env.cuda_version }})" + env: + JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Download the artifacts built in the "build_artifacts" job + run: gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/${{ github.workflow }}/${{ github.run_id }}/${{ github.run_attempt }} $(pwd)/dist - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} @@ -52,41 +84,3 @@ jobs: run: $JAXCI_PYTHON -m pip install -r build/requirements.in - name: Run Pytest GPU tests run: ./ci/run_pytest_gpu.sh - - # run_tests: - # needs: build_artifacts - # strategy: - # matrix: - # test_env: [ - # {cuda_version: "12.3", runner: "linux-x86-g2-48-l4-4gpu", - # image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, - # {cuda_version: "12.1", runner: "linux-x86-g2-48-l4-4gpu", - # image: "gcr.io/tensorflow-testing/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, - # ] - # python: ["3.10"] - - # runs-on: ${{ matrix.test_env.runner }} - # container: - # image: ${{ matrix.test_env.image }} - - # name: "Pytest GPU (Test on CUDA ${{ matrix.test_env.cuda_version }})" - # env: - # JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} - - # steps: - # - uses: actions/checkout@v3 - # # Halt for testing - # - name: Wait For Connection - # uses: google-ml-infra/actions/ci_connection@main - # with: - # halt-dispatch-input: ${{ inputs.halt-for-connection }} - # - name: Install pytest - # env: - # JAXCI_PYTHON: python${{ matrix.python }} - # run: $JAXCI_PYTHON -m pip install pytest - # - name: Install dependencies - # env: - # JAXCI_PYTHON: python${{ matrix.python }} - # run: $JAXCI_PYTHON -m pip install -r build/requirements.in - # - name: Run Pytest GPU tests - # run: ./ci/run_pytest_gpu.sh From 26744b68f797d2de65fc5f3a33dc2b2d52bd8b86 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 14 Nov 2024 21:48:35 +0000 Subject: [PATCH 005/205] update container images to one that has gcloud tools --- .github/workflows/pytest_gpu.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index cd0b2ccfeb5c..efd72185d257 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -22,7 +22,7 @@ jobs: python: ["3.10"] runs-on: "linux-x86-g2-48-l4-4gpu" - container: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" name: "Pytest GPU (Build wheels on CUDA 12.3)" env: @@ -51,9 +51,9 @@ jobs: matrix: test_env: [ {cuda_version: "12.3", runner: "linux-x86-g2-48-l4-4gpu", - image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, + image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, {cuda_version: "12.1", runner: "linux-x86-g2-48-l4-4gpu", - image: "gcr.io/tensorflow-testing/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, + image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, ] python: ["3.10"] From afcee123deb134bf8552cc7c4859e2cdc0ca319e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 14 Nov 2024 22:01:36 +0000 Subject: [PATCH 006/205] update image and runner type --- .github/workflows/pytest_gpu.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index efd72185d257..cc626ba1912d 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -21,8 +21,8 @@ jobs: matrix: python: ["3.10"] - runs-on: "linux-x86-g2-48-l4-4gpu" - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + runs-on: "linux-x86-n2-16" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" name: "Pytest GPU (Build wheels on CUDA 12.3)" env: @@ -43,7 +43,7 @@ jobs: - name: Build jax-cuda-pjrt run: ./ci/build_artifacts.sh "jax-cuda-pjrt" - name: Upload artifacts to GCS bucket - run: gsutil -m rsync -d -r gs://general-ml-ci-transient/jax-github-actions/${{ github.workflow }}/${{ github.run_id }}/${{ github.run_attempt }} + run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m rsync -d -r gs://general-ml-ci-transient/jax-github-actions/${{ github.workflow }}/${{ github.run_id }}/${{ github.run_attempt }} run_tests: needs: build_artifacts From a2c78a1dfb5d601b25af93e801271e2b9e788f4c Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 14 Nov 2024 22:15:02 +0000 Subject: [PATCH 007/205] change upload path --- .github/workflows/pytest_gpu.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index cc626ba1912d..2f9a6cfbc343 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -43,7 +43,7 @@ jobs: - name: Build jax-cuda-pjrt run: ./ci/build_artifacts.sh "jax-cuda-pjrt" - name: Upload artifacts to GCS bucket - run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m rsync -d -r gs://general-ml-ci-transient/jax-github-actions/${{ github.workflow }}/${{ github.run_id }}/${{ github.run_attempt }} + run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m rsync -d -r gs://general-ml-ci-transient/jax-github-actions/${{ github.run_number }}/${{ github.run_attempt }} run_tests: needs: build_artifacts @@ -73,7 +73,7 @@ jobs: with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Download the artifacts built in the "build_artifacts" job - run: gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/${{ github.workflow }}/${{ github.run_id }}/${{ github.run_attempt }} $(pwd)/dist + run: gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/${{ github.run_number }}/${{ github.run_attempt }} $(pwd)/dist - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} From 1c27adb897f62c5b1986a22f6d8fc61d471cbb40 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 14 Nov 2024 22:25:12 +0000 Subject: [PATCH 008/205] adjust upload and download command --- .github/workflows/pytest_gpu.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index 2f9a6cfbc343..b96966522f57 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -43,7 +43,7 @@ jobs: - name: Build jax-cuda-pjrt run: ./ci/build_artifacts.sh "jax-cuda-pjrt" - name: Upload artifacts to GCS bucket - run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m rsync -d -r gs://general-ml-ci-transient/jax-github-actions/${{ github.run_number }}/${{ github.run_attempt }} + run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }} run_tests: needs: build_artifacts @@ -73,7 +73,7 @@ jobs: with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Download the artifacts built in the "build_artifacts" job - run: gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/${{ github.run_number }}/${{ github.run_attempt }} $(pwd)/dist + run: gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }} $(pwd)/dist - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} From 84f072ebab75d2d8152ac2a13829b2f4584592a8 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 14 Nov 2024 22:42:54 +0000 Subject: [PATCH 009/205] create dist folder before downloading artifacts --- .github/workflows/pytest_gpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index b96966522f57..679c34e84f37 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -73,7 +73,7 @@ jobs: with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Download the artifacts built in the "build_artifacts" job - run: gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }} $(pwd)/dist + run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }} $(pwd)/dist - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} From b4f7c8e55a3447697c5baabf8369fe87e51e8180 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 04:01:00 +0000 Subject: [PATCH 010/205] experiment with reusable workflows --- .github/workflows/build_artifacts.yml | 62 +++++++++++++++++++++++--- .github/workflows/pytest_cpu_reuse.yml | 28 ++++++++---- 2 files changed, 74 insertions(+), 16 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index a428c3e79512..c5639b52cded 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -15,6 +15,42 @@ on: - 'yes' - 'no' workflow_call: + inputs: + build_jax: + description: "Should the jax artifact be built?" + required: true + default: true + type: boolean + build_jaxlib: + description: "Should the jaxlib artifact be built?" + required: true + default: true + type: boolean + build_jax_cuda_plugin: + description: "Should the jax-cuda-plugin artifact be built?" + required: true + default: true + type: boolean + build_jax_cuda_pjrt: + description: "Should the jax-cuda-pjrt artifact be built?" + required: true + default: true + type: boolean + clone_main_xla: + description: "Should latest XLA be used? (1 to enable, 0 to disable)" + type: string + required: true + default: "0" + upload_artifacts: + description: "Should the artifacts be uploaded to a GCS bucket?" + required: true + default: false + type: boolean + upload_destination: + description: "GCS location to where the artifacts should be uploaded" + required: true + default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string jobs: build: @@ -27,7 +63,7 @@ jobs: matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] artifact: ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] - python: ["3.10", "3.11", "3.12"] + python: ["3.10"] #, "3.11", "3.12"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. exclude: @@ -58,8 +94,8 @@ jobs: (contains(matrix.runner, 'windows-x86') && null) }} env: - # Do not run Docker container for Linux runners. Linux runners already run in a Docker container. - JAXCI_RUN_DOCKER_CONTAINER: 0 + JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" + JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" steps: - uses: actions/checkout@v3 @@ -68,7 +104,19 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build ${{ matrix.artifact }} - env: - JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" - run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" + - name: Build jax + if: inputs.build_jax && matrix.artifact == 'jax' + run: ./ci/build_artifacts.sh "jax" + - name: Build jaxlib + if: inputs.build_jaxlib && matrix.artifact == 'jaxlib' + run: ./ci/build_artifacts.sh "jaxlib" + - name: Build jax-cuda-plugin + if: inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin' + run: ./ci/build_artifacts.sh "jax-cuda-plugin" + - name: Build jax-cuda-pjrt + if: inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt' + run: ./ci/build_artifacts.sh "jax-cuda-pjrt" + - name: Upload artifacts to GCS bucket + if: inputs.upload_artifacts + run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}" + diff --git a/.github/workflows/pytest_cpu_reuse.yml b/.github/workflows/pytest_cpu_reuse.yml index f09d53d0a96a..d7d86454c224 100644 --- a/.github/workflows/pytest_cpu_reuse.yml +++ b/.github/workflows/pytest_cpu_reuse.yml @@ -1,9 +1,9 @@ name: Run Pytest CPU tests (resuable workflow) on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -15,11 +15,21 @@ on: - 'yes' jobs: - build_jaxlib_artifacts: + build_jaxlib_artifact: + name: "Build the jaxlib aritfact using latest XLA" uses: ./.github/workflows/build_artifacts.yml + with: + build_jax: false + build_jaxlib: true + build_jax_cuda_plugin: false + build_jax_cuda_pjrt: false + clone_main_xla: 1 + upload_artifacts: true + upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' run_pytest: - needs: build_jaxlib_artifacts + name: "Run CPU tests with Pytest" + needs: build_jaxlib_artifact continue-on-error: true defaults: run: @@ -27,16 +37,14 @@ jobs: shell: bash strategy: matrix: - runner: ["windows-x86-n2-64", "linux-x86-n2-64", "linux-arm64-t2a-48"] + runner: ["linux-x86-n2-64", "linux-arm64-t2a-48"] python: ["3.10"] runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') || - (contains(matrix.runner, 'windows-x86') && null) }} + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} env: - JAXCI_CLONE_MAIN_XLA: 1 JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} steps: @@ -46,6 +54,8 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Download the artifacts built in the "build_artifacts" job + run: mkdir -p $(pwd)/dist && ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }} $(pwd)/dist - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} From e123c8e1d38a4f8a9fcc15a99e36ccd3183dce78 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 04:48:29 +0000 Subject: [PATCH 011/205] Refine if conditions and update the upload destination --- .github/workflows/build_artifacts.yml | 31 +++++++++++++------------- .github/workflows/pytest_cpu_reuse.yml | 7 +++++- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index c5639b52cded..7e0343a10d1d 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -53,7 +53,7 @@ on: type: string jobs: - build: + build_artifacts: continue-on-error: true defaults: run: @@ -104,19 +104,20 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build jax - if: inputs.build_jax && matrix.artifact == 'jax' - run: ./ci/build_artifacts.sh "jax" - - name: Build jaxlib - if: inputs.build_jaxlib && matrix.artifact == 'jaxlib' - run: ./ci/build_artifacts.sh "jaxlib" - - name: Build jax-cuda-plugin - if: inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin' - run: ./ci/build_artifacts.sh "jax-cuda-plugin" - - name: Build jax-cuda-pjrt - if: inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt' - run: ./ci/build_artifacts.sh "jax-cuda-pjrt" + - name: Build ${{ matrix.artifact }} + if: >- + (inputs.build_jax && matrix.artifact == 'jax') || + (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || + (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || + (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') + run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" + - name: Set Platform + run: | + echo "PLATFORM=$(uname)_$(uname -m)" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket - if: inputs.upload_artifacts - run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}" + # Upload if requested and one of the artifacts was built + if: >- + inputs.upload_artifacts && + (inputs.build_jax || inputs.build_jaxlib || inputs.build_jax_cuda_plugin || inputs.build_jax_cuda_pjrt) + run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM diff --git a/.github/workflows/pytest_cpu_reuse.yml b/.github/workflows/pytest_cpu_reuse.yml index d7d86454c224..2bc15721498d 100644 --- a/.github/workflows/pytest_cpu_reuse.yml +++ b/.github/workflows/pytest_cpu_reuse.yml @@ -54,8 +54,13 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Set Platform + run: | + echo "PLATFORM=$(uname)_$(uname -m)" >> $GITHUB_ENV - name: Download the artifacts built in the "build_artifacts" job - run: mkdir -p $(pwd)/dist && ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }} $(pwd)/dist + run: >- + mkdir -p $(pwd)/dist && + ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM $(pwd)/dist - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} From 0a54f9787b0be1a19af0ab3cfd01f28990ca1cb7 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 04:53:49 +0000 Subject: [PATCH 012/205] refine if condition --- .github/workflows/build_artifacts.yml | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 7e0343a10d1d..a1d401388694 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -97,6 +97,7 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" + if: inputs.build_jax || inputs.build_jaxlib || inputs.build_jax_cuda_plugin || inputs.build_jax_cuda_pjrt steps: - uses: actions/checkout@v3 # Halt for testing @@ -105,19 +106,12 @@ jobs: with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Build ${{ matrix.artifact }} - if: >- - (inputs.build_jax && matrix.artifact == 'jax') || - (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || - (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || - (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" - name: Set Platform run: | echo "PLATFORM=$(uname)_$(uname -m)" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket # Upload if requested and one of the artifacts was built - if: >- - inputs.upload_artifacts && - (inputs.build_jax || inputs.build_jaxlib || inputs.build_jax_cuda_plugin || inputs.build_jax_cuda_pjrt) + if: inputs.upload_artifacts run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM From 2f122fd131d124a72d2ae675d3bf9165bf55658f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 05:13:18 +0000 Subject: [PATCH 013/205] refine if condition --- .github/workflows/build_artifacts.yml | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index a1d401388694..fd154a8d6c85 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -100,18 +100,43 @@ jobs: if: inputs.build_jax || inputs.build_jaxlib || inputs.build_jax_cuda_plugin || inputs.build_jax_cuda_pjrt steps: - uses: actions/checkout@v3 + if: >- + (inputs.build_jax && matrix.artifact == 'jax') || + (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || + (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || + (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} + if: >- + (inputs.build_jax && matrix.artifact == 'jax') || + (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || + (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || + (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') - name: Build ${{ matrix.artifact }} + if: >- + (inputs.build_jax && matrix.artifact == 'jax') || + (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || + (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || + (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" - name: Set Platform + if: >- + (inputs.build_jax && matrix.artifact == 'jax') || + (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || + (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || + (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') run: | echo "PLATFORM=$(uname)_$(uname -m)" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket # Upload if requested and one of the artifacts was built - if: inputs.upload_artifacts + if: >- + inputs.upload_artifacts && + (inputs.build_jax && matrix.artifact == 'jax') || + (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || + (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || + (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM From 866d149ce16430ab72347c284f5ffb7e54d1538b Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 05:14:06 +0000 Subject: [PATCH 014/205] fix file name --- .github/workflows/pytest_cpu_reuse.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest_cpu_reuse.yml b/.github/workflows/pytest_cpu_reuse.yml index 2bc15721498d..d731e0706932 100644 --- a/.github/workflows/pytest_cpu_reuse.yml +++ b/.github/workflows/pytest_cpu_reuse.yml @@ -70,4 +70,4 @@ jobs: JAXCI_PYTHON: python${{ matrix.python }} run: $JAXCI_PYTHON -m pip install -r build/requirements.in - name: Run Pytest CPU tests - run: ./ci/run_pytest.sh "ci/envs/run_tests/pytest_cpu" + run: ./ci/run_pytest_cpu.sh From 9ecfc007ffed739d2e6b1b297c41a59cfe406804 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 05:51:33 +0000 Subject: [PATCH 015/205] create artifact list dynamically --- .github/workflows/build_artifacts.yml | 34 +++++---------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index fd154a8d6c85..2bc2e65539e9 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -62,7 +62,11 @@ jobs: strategy: matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] - artifact: ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] + artifact: >- + ${{ fromJSON((inputs.build_jax && '["jax"]') || '[]') }} + + ${{ fromJSON((inputs.build_jaxlib && '["jaxlib"]') || '[]') }} + + ${{ fromJSON((inputs.build_jax_cuda_pjrt && '["jax-cuda-pjrt"]') || '[]') }} + + ${{ fromJSON((inputs.build_jax_cuda_plugin && '["jax-cuda-plugin"]') || '[]') }} python: ["3.10"] #, "3.11", "3.12"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. @@ -97,46 +101,20 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" - if: inputs.build_jax || inputs.build_jaxlib || inputs.build_jax_cuda_plugin || inputs.build_jax_cuda_pjrt steps: - uses: actions/checkout@v3 - if: >- - (inputs.build_jax && matrix.artifact == 'jax') || - (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || - (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || - (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - if: >- - (inputs.build_jax && matrix.artifact == 'jax') || - (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || - (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || - (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') - name: Build ${{ matrix.artifact }} - if: >- - (inputs.build_jax && matrix.artifact == 'jax') || - (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || - (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || - (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" - name: Set Platform - if: >- - (inputs.build_jax && matrix.artifact == 'jax') || - (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || - (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || - (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') run: | echo "PLATFORM=$(uname)_$(uname -m)" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket # Upload if requested and one of the artifacts was built - if: >- - inputs.upload_artifacts && - (inputs.build_jax && matrix.artifact == 'jax') || - (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || - (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || - (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') + if: inputs.upload_artifacts run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM From 0f5783fba7b9bf9da16d66feef1e9ad439b83804 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 06:25:04 +0000 Subject: [PATCH 016/205] try a fix for the dynamic artifact list --- .github/workflows/build_artifacts.yml | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 2bc2e65539e9..1687621e3b72 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -63,10 +63,13 @@ jobs: matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] artifact: >- - ${{ fromJSON((inputs.build_jax && '["jax"]') || '[]') }} + - ${{ fromJSON((inputs.build_jaxlib && '["jaxlib"]') || '[]') }} + - ${{ fromJSON((inputs.build_jax_cuda_pjrt && '["jax-cuda-pjrt"]') || '[]') }} + - ${{ fromJSON((inputs.build_jax_cuda_plugin && '["jax-cuda-plugin"]') || '[]') }} + ${{ fromJSON( + inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && inputs.build_jax_cuda_plugin && '["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]' || + inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && '["jax", "jaxlib", "jax-cuda-pjrt"]' || + inputs.build_jax && inputs.build_jaxlib && '["jax", "jaxlib"]' || + inputs.build_jax && '["jax"]' || + '[]' + ) }} python: ["3.10"] #, "3.11", "3.12"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. From e210fd23249e743be3de6f3e555cb31b29efa059 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 06:30:38 +0000 Subject: [PATCH 017/205] try a fix for the dynamic artifact list (2) --- .github/workflows/build_artifacts.yml | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 1687621e3b72..fbdf489e78e8 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -63,13 +63,11 @@ jobs: matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] artifact: >- - ${{ fromJSON( - inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && inputs.build_jax_cuda_plugin && '["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]' || - inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && '["jax", "jaxlib", "jax-cuda-pjrt"]' || - inputs.build_jax && inputs.build_jaxlib && '["jax", "jaxlib"]' || - inputs.build_jax && '["jax"]' || - '[]' - ) }} + ${{ inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && inputs.build_jax_cuda_plugin && '["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]' }} || + ${{ inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && '["jax", "jaxlib", "jax-cuda-pjrt"]' }} || + ${{ inputs.build_jax && inputs.build_jaxlib && '["jax", "jaxlib"]' }} || + ${{ inputs.build_jax && '["jax"]' }} || + ${{ '[]' }} python: ["3.10"] #, "3.11", "3.12"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. From f9a90b9b0b352f90e8cca8924b347289015c21b1 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 06:32:16 +0000 Subject: [PATCH 018/205] try a fix for the dynamic artifact list (3) --- .github/workflows/build_artifacts.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index fbdf489e78e8..8a0c46f9472e 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -63,10 +63,10 @@ jobs: matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] artifact: >- - ${{ inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && inputs.build_jax_cuda_plugin && '["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]' }} || - ${{ inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && '["jax", "jaxlib", "jax-cuda-pjrt"]' }} || - ${{ inputs.build_jax && inputs.build_jaxlib && '["jax", "jaxlib"]' }} || - ${{ inputs.build_jax && '["jax"]' }} || + ${{ inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && inputs.build_jax_cuda_plugin && fromJSON('["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]') }} || + ${{ inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && fromJSON('["jax", "jaxlib", "jax-cuda-pjrt"]') }} || + ${{ inputs.build_jax && inputs.build_jaxlib && fromJSON('["jax", "jaxlib"]') }} || + ${{ inputs.build_jax && fromJSON('["jax"]') }} || ${{ '[]' }} python: ["3.10"] #, "3.11", "3.12"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each From 22b90763eea135ea8a6c2a03d8814e37c798b6d2 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 06:33:03 +0000 Subject: [PATCH 019/205] disable workflows --- .github/workflows/bazel_cpu_rbe.yml | 6 +++--- .github/workflows/bazel_gpu_non_rbe.yml | 6 +++--- .github/workflows/pytest_cpu.yml | 6 +++--- .github/workflows/pytest_gpu.yml | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 29a6b2086384..6d1f4580f931 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -1,9 +1,9 @@ name: Run Bazel CPU tests (RBE) on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 346af78459a3..ba11ac486001 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -1,9 +1,9 @@ name: Run Bazel GPU tests (non RBE) on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 79153889a391..71463f57d974 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -1,9 +1,9 @@ name: Run Pytest CPU tests on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index 679c34e84f37..f08601dd1a85 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -1,9 +1,9 @@ name: Run Pytest GPU tests on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: From 09c576bf142fe114be7074134dd80fa33031256e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 06:34:01 +0000 Subject: [PATCH 020/205] try a fix for the dynamic artifact list (4) --- .github/workflows/build_artifacts.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 8a0c46f9472e..eaec833b470a 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -67,7 +67,7 @@ jobs: ${{ inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && fromJSON('["jax", "jaxlib", "jax-cuda-pjrt"]') }} || ${{ inputs.build_jax && inputs.build_jaxlib && fromJSON('["jax", "jaxlib"]') }} || ${{ inputs.build_jax && fromJSON('["jax"]') }} || - ${{ '[]' }} + ${{ fromJSON('[]') }} python: ["3.10"] #, "3.11", "3.12"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. From 117186bce70c34bec9b5eeb7033fb1d5d3fce5fe Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 15:48:29 +0000 Subject: [PATCH 021/205] switch to 2 job strategy --- .github/workflows/build_artifacts.yml | 30 +++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index eaec833b470a..e0020c4dbdc8 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -53,6 +53,29 @@ on: type: string jobs: + determine_matrix: + runs-on: linux-x86-n2-16 + container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') + outputs: + build_matrix: ${{ steps.set-matrix.outputs.build_matrix }} + steps: + - id: set-matrix + run: | + matrix='[]' + if ${{ inputs.build_jax }}; then + matrix='["jax"]' + if ${{ inputs.build_jaxlib }}; then + matrix='["jax", "jaxlib"]' + if ${{ inputs.build_jax_cuda_pjrt }}; then + matrix='["jax", "jaxlib", "jax-cuda-pjrt"]' + if ${{ inputs.build_jax_cuda_plugin }}; then + matrix='["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]' + fi + fi + fi + fi + echo "build_matrix=${matrix}" >> $GITHUB_OUTPUT + build_artifacts: continue-on-error: true defaults: @@ -62,12 +85,7 @@ jobs: strategy: matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] - artifact: >- - ${{ inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && inputs.build_jax_cuda_plugin && fromJSON('["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]') }} || - ${{ inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && fromJSON('["jax", "jaxlib", "jax-cuda-pjrt"]') }} || - ${{ inputs.build_jax && inputs.build_jaxlib && fromJSON('["jax", "jaxlib"]') }} || - ${{ inputs.build_jax && fromJSON('["jax"]') }} || - ${{ fromJSON('[]') }} + artifact: ${{ fromJSON(needs.determine_matrix.outputs.build_matrix) }} python: ["3.10"] #, "3.11", "3.12"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. From 6aa801f6f1908dfdb95f8049506874291e6e2a89 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 15:49:38 +0000 Subject: [PATCH 022/205] fix syntax error --- .github/workflows/build_artifacts.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index e0020c4dbdc8..b36d86eb55ce 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -54,8 +54,8 @@ on: jobs: determine_matrix: - runs-on: linux-x86-n2-16 - container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') + runs-on: "linux-x86-n2-16" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" outputs: build_matrix: ${{ steps.set-matrix.outputs.build_matrix }} steps: From 98742b286d632a419ddd301922036c2817c319e7 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 15:50:24 +0000 Subject: [PATCH 023/205] add missing job dep --- .github/workflows/build_artifacts.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index b36d86eb55ce..f8296d5a27d6 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -77,6 +77,7 @@ jobs: echo "build_matrix=${matrix}" >> $GITHUB_OUTPUT build_artifacts: + needs: determine_matrix continue-on-error: true defaults: run: From e4b45d11e687528c9f5ae8ea557596bdfebcd5d1 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 16:02:10 +0000 Subject: [PATCH 024/205] Change input type to be able to parse when generating matrix --- .github/workflows/build_artifacts.yml | 32 +++++++++++++------------- .github/workflows/pytest_cpu_reuse.yml | 8 +++---- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index f8296d5a27d6..4377c24dcdd6 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -17,25 +17,25 @@ on: workflow_call: inputs: build_jax: - description: "Should the jax artifact be built?" + description: "Should the jax artifact be built? (1 to enable, 0 to disable)" + type: string required: true - default: true - type: boolean + default: "1" build_jaxlib: - description: "Should the jaxlib artifact be built?" + description: "Should the jaxlib artifact be built? (1 to enable, 0 to disable)" + type: string required: true - default: true - type: boolean + default: "1" build_jax_cuda_plugin: - description: "Should the jax-cuda-plugin artifact be built?" + description: "Should the jax-cuda-plugin artifact be built? (1 to enable, 0 to disable)" + type: string required: true - default: true - type: boolean + default: "1" build_jax_cuda_pjrt: - description: "Should the jax-cuda-pjrt artifact be built?" + description: "Should the jax-cuda-pjrt artifact be built? (1 to enable, 0 to disable)" + type: string required: true - default: true - type: boolean + default: "1" clone_main_xla: description: "Should latest XLA be used? (1 to enable, 0 to disable)" type: string @@ -62,13 +62,13 @@ jobs: - id: set-matrix run: | matrix='[]' - if ${{ inputs.build_jax }}; then + if [[ ${{ inputs.build_jax }} == "1" ]]; then matrix='["jax"]' - if ${{ inputs.build_jaxlib }}; then + if [[ ${{ inputs.build_jaxlib }} == "1" ]]; then matrix='["jax", "jaxlib"]' - if ${{ inputs.build_jax_cuda_pjrt }}; then + if [[ ${{ inputs.build_jax_cuda_pjrt }} == "1" ]]; then matrix='["jax", "jaxlib", "jax-cuda-pjrt"]' - if ${{ inputs.build_jax_cuda_plugin }}; then + if [[ ${{ inputs.build_jax_cuda_plugin }} == "1" ]]; then matrix='["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]' fi fi diff --git a/.github/workflows/pytest_cpu_reuse.yml b/.github/workflows/pytest_cpu_reuse.yml index d731e0706932..6254548d76f4 100644 --- a/.github/workflows/pytest_cpu_reuse.yml +++ b/.github/workflows/pytest_cpu_reuse.yml @@ -19,10 +19,10 @@ jobs: name: "Build the jaxlib aritfact using latest XLA" uses: ./.github/workflows/build_artifacts.yml with: - build_jax: false - build_jaxlib: true - build_jax_cuda_plugin: false - build_jax_cuda_pjrt: false + build_jax: 0 + build_jaxlib: 1 + build_jax_cuda_plugin: 0 + build_jax_cuda_pjrt: 0 clone_main_xla: 1 upload_artifacts: true upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' From 6eca13c3b45df03060cdc59dd7dfab6bc374e074 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 17:01:52 +0000 Subject: [PATCH 025/205] change how artifact matrix is constructed --- .github/workflows/build_artifacts.yml | 33 ++++++++++++++++----------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 4377c24dcdd6..fe9baab3c9d6 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -57,24 +57,31 @@ jobs: runs-on: "linux-x86-n2-16" container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" outputs: - build_matrix: ${{ steps.set-matrix.outputs.build_matrix }} + artifact_matrix: ${{ steps.set-matrix.outputs.artifact_matrix }} steps: - id: set-matrix run: | - matrix='[]' - if [[ ${{ inputs.build_jax }} == "1" ]]; then - matrix='["jax"]' - if [[ ${{ inputs.build_jaxlib }} == "1" ]]; then - matrix='["jax", "jaxlib"]' + artifacts=() + if [[ ${{ github.event }} == "pull_request" ]]; + artifacts = ("jax" ", jaxlib" ", jax-cuda-pjrt" ", jax-cuda-plugin") + else + if [[ ${{ inputs.build_jax }} == "1" ]]; then + artifacts+="jax" + fi + + if [[ ${{ inputs.build_jaxlib }} == "1" ]]; then + artifacts+=", jaxlib" + fi + if [[ ${{ inputs.build_jax_cuda_pjrt }} == "1" ]]; then - matrix='["jax", "jaxlib", "jax-cuda-pjrt"]' - if [[ ${{ inputs.build_jax_cuda_plugin }} == "1" ]]; then - matrix='["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]' - fi + artifacts+=", jax-cuda-pjrt" + fi + + if [[ ${{ inputs.build_jax_cuda_plugin }} == "1" ]]; then + artifacts+=", jax-cuda-plugin" fi fi - fi - echo "build_matrix=${matrix}" >> $GITHUB_OUTPUT + echo "artifact_matrix='[${artifacts[@]}]'" >> $GITHUB_OUTPUT build_artifacts: needs: determine_matrix @@ -86,7 +93,7 @@ jobs: strategy: matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] - artifact: ${{ fromJSON(needs.determine_matrix.outputs.build_matrix) }} + artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} python: ["3.10"] #, "3.11", "3.12"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. From c35651395acd0f516c9131e4c6d87b293b325719 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 17:09:00 +0000 Subject: [PATCH 026/205] fix syntax issue --- .github/workflows/build_artifacts.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index fe9baab3c9d6..793f50b15b57 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -62,7 +62,7 @@ jobs: - id: set-matrix run: | artifacts=() - if [[ ${{ github.event }} == "pull_request" ]]; + if [[ ${{ github.event_name }} == "pull_request" ]]; artifacts = ("jax" ", jaxlib" ", jax-cuda-pjrt" ", jax-cuda-plugin") else if [[ ${{ inputs.build_jax }} == "1" ]]; then From 196c3deabaa884d7e1ea4b6dda7e8278fe4ec9b6 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 15 Nov 2024 17:10:24 +0000 Subject: [PATCH 027/205] fix syntax issue --- .github/workflows/build_artifacts.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 793f50b15b57..bd924f8d694e 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -63,7 +63,7 @@ jobs: run: | artifacts=() if [[ ${{ github.event_name }} == "pull_request" ]]; - artifacts = ("jax" ", jaxlib" ", jax-cuda-pjrt" ", jax-cuda-plugin") + artifacts=("jax" ", jaxlib" ", jax-cuda-pjrt" ", jax-cuda-plugin") else if [[ ${{ inputs.build_jax }} == "1" ]]; then artifacts+="jax" From 42253cc8988832141de2db7b9df2a08c45beff81 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 18:55:23 +0000 Subject: [PATCH 028/205] update build.py --- build/build.py | 49 +++++++++++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/build/build.py b/build/build.py index dd53332613ef..95469eb124b3 100755 --- a/build/build.py +++ b/build/build.py @@ -54,8 +54,8 @@ `python build/build.py requirements_update` """ -# Define the build target for each artifact. -ARTIFACT_BUILD_TARGET_DICT = { +# Define the build target for each wheel. +WHEEL_BUILD_TARGET_DICT = { "jaxlib": "//jaxlib/tools:build_wheel", "jax-cuda-plugin": "//jaxlib/tools:build_gpu_kernels_wheel", "jax-cuda-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", @@ -143,10 +143,11 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): type=str, default="jaxlib", help= - f""" - A comma separated list of JAX artifacts to build. E.g: --wheels="jaxlib", + """ + A comma separated list of JAX wheels to build. E.g: --wheels="jaxlib", --wheels="jaxlib,jax-cuda-plugin", etc. - Valid options are: {','.join(ARTIFACT_BUILD_TARGET_DICT.keys())} + Valid options are: jaxlib, jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt, + jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt """, ) @@ -362,8 +363,7 @@ async def main(): f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}" ) - # Enable color in the Bazel output and verbose failures. - bazel_command_base.append("--color=yes") + # Enable verbose failures. bazel_command_base.append("--verbose_failures=true") # Requirements update subcommand execution @@ -377,7 +377,7 @@ async def main(): requirements_command.append(option) if args.nightly_update: - logging.debug( + logging.info( "--nightly_update is set. Bazel will run" " //build:requirements_nightly.update" ) @@ -414,8 +414,16 @@ async def main(): # Wheel build command execution for wheel in args.wheels.split(","): - if wheel not in ARTIFACT_BUILD_TARGET_DICT.keys(): - logging.error("Incorrect wheel name provided: %s, valid choices are: %s", wheel, ",".join(ARTIFACT_BUILD_TARGET_DICT.keys())) + # Allow CUDA/ROCm wheels without the "jax-" prefix. + if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: + wheel = "jax-" + wheel + + if wheel not in WHEEL_BUILD_TARGET_DICT.keys(): + logging.error( + "Incorrect wheel name provided, valid choices are jaxlib," + " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," + " jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt" + ) sys.exit(1) wheel_build_command = copy.deepcopy(bazel_command_base) @@ -518,18 +526,19 @@ async def main(): for option in args.bazel_options: wheel_build_command.append(option) + with open(".jax_configure.bazelrc", "w") as f: + jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list()) + if not jax_configure_options: + logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") + sys.exit(1) + f.write(jax_configure_options) + logging.info("Bazel options written to .jax_configure.bazelrc") + if args.configure_only: - with open(".jax_configure.bazelrc", "w") as f: - jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list()) - if not jax_configure_options: - logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") - sys.exit(1) - f.write(jax_configure_options) - logging.info("Bazel options written to .jax_configure.bazelrc") - logging.info("--configure_only is set so not running any Bazel commands.") + logging.info("--configure_only is set so not running any Bazel commands.") else: # Append the build target to the Bazel command. - build_target = ARTIFACT_BUILD_TARGET_DICT[wheel] + build_target = WHEEL_BUILD_TARGET_DICT[wheel] wheel_build_command.append(build_target) wheel_build_command.append("--") @@ -538,7 +547,7 @@ async def main(): logger.debug("Artifacts output directory: %s", output_path) if args.editable: - logger.debug("Building an editable build") + logger.info("Building an editable build") output_path = os.path.join(output_path, wheel) wheel_build_command.append("--editable") From 979366118ba61c77a213a39b63db37a9cf86fbf2 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 19:01:01 +0000 Subject: [PATCH 029/205] set default values if inputs context is unavailable --- .github/workflows/build_artifacts.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index bd924f8d694e..1eacfe0f61b9 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -65,19 +65,19 @@ jobs: if [[ ${{ github.event_name }} == "pull_request" ]]; artifacts=("jax" ", jaxlib" ", jax-cuda-pjrt" ", jax-cuda-plugin") else - if [[ ${{ inputs.build_jax }} == "1" ]]; then + if [[ ${${{ inputs.build_jax }}:-0} == "1" ]]; then artifacts+="jax" fi - if [[ ${{ inputs.build_jaxlib }} == "1" ]]; then + if [[ ${${{ inputs.build_jaxlib }}:-0} == "1" ]]; then artifacts+=", jaxlib" fi - if [[ ${{ inputs.build_jax_cuda_pjrt }} == "1" ]]; then + if [[ ${${{ inputs.build_jax_cuda_pjrt }}:-0} == "1" ]]; then artifacts+=", jax-cuda-pjrt" fi - if [[ ${{ inputs.build_jax_cuda_plugin }} == "1" ]]; then + if [[ ${${{ inputs.build_jax_cuda_plugin }}:-0} == "1" ]]; then artifacts+=", jax-cuda-plugin" fi fi From cce45ca08609103001988f1104d423cea51e1995 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 19:06:38 +0000 Subject: [PATCH 030/205] set default shell to bash when determining the artifact matrix --- .github/workflows/build_artifacts.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 1eacfe0f61b9..1adca1c5142b 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -58,6 +58,9 @@ jobs: container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" outputs: artifact_matrix: ${{ steps.set-matrix.outputs.artifact_matrix }} + defaults: + run: + shell: bash steps: - id: set-matrix run: | From 6471a9663621ca47321093a64276f696cb02dd3f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 19:15:25 +0000 Subject: [PATCH 031/205] add wait for connect step to debug issue --- .github/workflows/build_artifacts.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 1adca1c5142b..9b6cd3deb4b7 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -62,6 +62,11 @@ jobs: run: shell: bash steps: + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} - id: set-matrix run: | artifacts=() From 03f8d24a9f550cd066242e3a8064288a27619f04 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 19:33:14 +0000 Subject: [PATCH 032/205] fix syntax error --- .github/workflows/build_artifacts.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 9b6cd3deb4b7..34ea7fc494b7 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -70,7 +70,7 @@ jobs: - id: set-matrix run: | artifacts=() - if [[ ${{ github.event_name }} == "pull_request" ]]; + if [[ ${{ github.event_name }} == "pull_request" ]]; then artifacts=("jax" ", jaxlib" ", jax-cuda-pjrt" ", jax-cuda-plugin") else if [[ ${${{ inputs.build_jax }}:-0} == "1" ]]; then From ab1ae97d6d3e05512f2d7604e55d471e652150a9 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 19:58:13 +0000 Subject: [PATCH 033/205] try fix --- .github/workflows/build_artifacts.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 34ea7fc494b7..33e449e0ad2f 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -74,22 +74,22 @@ jobs: artifacts=("jax" ", jaxlib" ", jax-cuda-pjrt" ", jax-cuda-plugin") else if [[ ${${{ inputs.build_jax }}:-0} == "1" ]]; then - artifacts+="jax" + artifacts+="'jax'" fi if [[ ${${{ inputs.build_jaxlib }}:-0} == "1" ]]; then - artifacts+=", jaxlib" + artifacts+=", 'jaxlib'" fi if [[ ${${{ inputs.build_jax_cuda_pjrt }}:-0} == "1" ]]; then - artifacts+=", jax-cuda-pjrt" + artifacts+=", 'jax-cuda-pjrt'" fi if [[ ${${{ inputs.build_jax_cuda_plugin }}:-0} == "1" ]]; then - artifacts+=", jax-cuda-plugin" + artifacts+=", 'jax-cuda-plugin'" fi fi - echo "artifact_matrix='[${artifacts[@]}]'" >> $GITHUB_OUTPUT + echo "artifact_matrix=[${artifacts[@]}]" >> $GITHUB_OUTPUT build_artifacts: needs: determine_matrix @@ -101,7 +101,7 @@ jobs: strategy: matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] - artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} + artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} python: ["3.10"] #, "3.11", "3.12"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. From 8b434828b8c0d66cb40172a211a0e1e26d4d357b Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 20:01:57 +0000 Subject: [PATCH 034/205] try fix --- .github/workflows/build_artifacts.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 33e449e0ad2f..1183d12d1bd8 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -71,7 +71,7 @@ jobs: run: | artifacts=() if [[ ${{ github.event_name }} == "pull_request" ]]; then - artifacts=("jax" ", jaxlib" ", jax-cuda-pjrt" ", jax-cuda-plugin") + artifacts=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") else if [[ ${${{ inputs.build_jax }}:-0} == "1" ]]; then artifacts+="'jax'" From 354bc831cfb66b2d972e7d91ebfd40d0018d9985 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 20:26:32 +0000 Subject: [PATCH 035/205] Add variable to know if a workflow call was made, enable all package builds, change gsutil path --- .github/workflows/build_artifacts.yml | 12 +++++++++--- .github/workflows/pytest_cpu_reuse.yml | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 1183d12d1bd8..edcda9e85410 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -51,6 +51,11 @@ on: required: true default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string + is_workflow_call: + description: "Metadata variable to know whether a workflow call was made" + type: string + required: true + default: "1" jobs: determine_matrix: @@ -70,7 +75,8 @@ jobs: - id: set-matrix run: | artifacts=() - if [[ ${{ github.event_name }} == "pull_request" ]]; then + # Build every package if not a workflow call + if [[ ${${{ inputs.is_workflow_call }}:-0} == "0" ]]; then artifacts=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") else if [[ ${${{ inputs.build_jax }}:-0} == "1" ]]; then @@ -102,7 +108,7 @@ jobs: matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} - python: ["3.10"] #, "3.11", "3.12"] + python: ["3.10", "3.11", "3.12"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. exclude: @@ -151,5 +157,5 @@ jobs: - name: Upload artifacts to GCS bucket # Upload if requested and one of the artifacts was built if: inputs.upload_artifacts - run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM + run: gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM diff --git a/.github/workflows/pytest_cpu_reuse.yml b/.github/workflows/pytest_cpu_reuse.yml index 6254548d76f4..6f69add49e6b 100644 --- a/.github/workflows/pytest_cpu_reuse.yml +++ b/.github/workflows/pytest_cpu_reuse.yml @@ -60,7 +60,7 @@ jobs: - name: Download the artifacts built in the "build_artifacts" job run: >- mkdir -p $(pwd)/dist && - ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM $(pwd)/dist + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM $(pwd)/dist - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} From d1ca99764ec122de7cd2a289e98ab034f051571e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 20:54:04 +0000 Subject: [PATCH 036/205] try fix --- .github/workflows/build_artifacts.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index edcda9e85410..7f17b5d1e03d 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -75,8 +75,9 @@ jobs: - id: set-matrix run: | artifacts=() + is_workflow_call=${{ inputs.is_workflow_call }} # Build every package if not a workflow call - if [[ ${${{ inputs.is_workflow_call }}:-0} == "0" ]]; then + if [[ ${is_workflow_call:-0} == "0" ]]; then artifacts=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") else if [[ ${${{ inputs.build_jax }}:-0} == "1" ]]; then From aaa2a70ebeebbf98755e960873079eff34257518 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 20:58:20 +0000 Subject: [PATCH 037/205] make is_workflow_call metadata input as not required --- .github/workflows/build_artifacts.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 7f17b5d1e03d..3faa41a56ae1 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -54,7 +54,7 @@ on: is_workflow_call: description: "Metadata variable to know whether a workflow call was made" type: string - required: true + required: false default: "1" jobs: From 55a5d0369ca1b32a5ee5468397271260114198e9 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 21:06:05 +0000 Subject: [PATCH 038/205] try fix --- .github/workflows/build_artifacts.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 3faa41a56ae1..93345462ae65 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -81,19 +81,19 @@ jobs: artifacts=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") else if [[ ${${{ inputs.build_jax }}:-0} == "1" ]]; then - artifacts+="'jax'" + artifacts+="'jax'," fi if [[ ${${{ inputs.build_jaxlib }}:-0} == "1" ]]; then - artifacts+=", 'jaxlib'" + artifacts+="'jaxlib'," fi if [[ ${${{ inputs.build_jax_cuda_pjrt }}:-0} == "1" ]]; then - artifacts+=", 'jax-cuda-pjrt'" + artifacts+="'jax-cuda-pjrt'," fi if [[ ${${{ inputs.build_jax_cuda_plugin }}:-0} == "1" ]]; then - artifacts+=", 'jax-cuda-plugin'" + artifacts+="'jax-cuda-plugin'" fi fi echo "artifact_matrix=[${artifacts[@]}]" >> $GITHUB_OUTPUT From bf48bb3bc6ed3339fca671404ad61721c73710b9 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 21:10:02 +0000 Subject: [PATCH 039/205] define inputs as bash variables --- .github/workflows/build_artifacts.yml | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 93345462ae65..13e35d10f677 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -75,24 +75,31 @@ jobs: - id: set-matrix run: | artifacts=() + # Define inputs as bash variables to be able to parse them in + # if conditions is_workflow_call=${{ inputs.is_workflow_call }} + build_jax=${{ inputs.build_jax }} + build_jaxlib=${{ inputs.build_jaxlib }} + build_jax_cuda_pjrt=${{ inputs.build_jax_cuda_pjrt }} + build_jax_cuda_plugin=${{ inputs.build_jax_cuda_plugin }} + # Build every package if not a workflow call - if [[ ${is_workflow_call:-0} == "0" ]]; then + if [[ ${is_workflow_call:-"0"} == "0" ]]; then artifacts=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") else - if [[ ${${{ inputs.build_jax }}:-0} == "1" ]]; then + if [[ ${build_jax:-"0"} == "1" ]]; then artifacts+="'jax'," fi - if [[ ${${{ inputs.build_jaxlib }}:-0} == "1" ]]; then + if [[ ${build_jaxlib:-"0"} == "1" ]]; then artifacts+="'jaxlib'," fi - if [[ ${${{ inputs.build_jax_cuda_pjrt }}:-0} == "1" ]]; then + if [[ ${build_jax_cuda_pjrt:-"0"} == "1" ]]; then artifacts+="'jax-cuda-pjrt'," fi - if [[ ${${{ inputs.build_jax_cuda_plugin }}:-0} == "1" ]]; then + if [[ ${build_jax_cuda_plugin:-"0"} == "1" ]]; then artifacts+="'jax-cuda-plugin'" fi fi From c8714df377b342badf75fb541f228fd6028dfd58 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 22:45:29 +0000 Subject: [PATCH 040/205] Change determine_matrix job name --- .github/workflows/build_artifacts.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 13e35d10f677..5059c58c410d 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -58,7 +58,7 @@ on: default: "1" jobs: - determine_matrix: + determine_artifact_matrix: runs-on: "linux-x86-n2-16" container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" outputs: @@ -106,7 +106,7 @@ jobs: echo "artifact_matrix=[${artifacts[@]}]" >> $GITHUB_OUTPUT build_artifacts: - needs: determine_matrix + needs: determine_artifact_matrix continue-on-error: true defaults: run: @@ -114,8 +114,8 @@ jobs: shell: bash strategy: matrix: - runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] - artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} + runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-48"] + artifact: ${{ fromJSON(needs.determine_artifact_matrix.outputs.artifact_matrix) }} python: ["3.10", "3.11", "3.12"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. From ddd84f7a3a752b70b32bf1f5b21d19394d83788f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 22:45:41 +0000 Subject: [PATCH 041/205] update build.py --- build/build.py | 55 +++++++++++++++++++++++++++----------------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/build/build.py b/build/build.py index 95469eb124b3..00624f15fb9d 100755 --- a/build/build.py +++ b/build/build.py @@ -64,18 +64,6 @@ } -def add_requirements_nightly_update_argument(parser: argparse.ArgumentParser): - parser.add_argument( - "--nightly_update", - action="store_true", - help=""" - If true, updates requirements_lock.txt for a corresponding version of - Python and will consider dev, nightly and pre-release versions of - packages. - """, - ) - - def add_global_arguments(parser: argparse.ArgumentParser): """Adds all the global arguments that applies to all the CLI subcommands.""" parser.add_argument( @@ -136,8 +124,8 @@ def add_global_arguments(parser: argparse.ArgumentParser): ) -def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): - """Adds all the global arguments that applies to the artifact subcommands.""" +def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): + """Adds all the arguments that applies to the artifact subcommands.""" parser.add_argument( "--wheels", type=str, @@ -178,26 +166,32 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): cuda_group.add_argument( "--cuda_version", type=str, - # LINT.IfChange(cuda_version) - default="12.3.2", - # LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc) help= """ Hermetic CUDA version to use. Default is to use the version specified - in the .bazelrc (12.3.2). + in the .bazelrc. + """, + ) + + cuda_group.add_argument( + "--cuda_major_version", + type=str, + default="12", + help= + """ + Which CUDA major version should the wheel be tagged as? Auto-detected if + --cuda_version is set. When --cuda_version is not set, the default is to + set the major version to 12 to match the default in .bazelrc. """, ) cuda_group.add_argument( "--cudnn_version", type=str, - # LINT.IfChange(cudnn_version) - default="9.1.1", - # LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc) help= """ Hermetic cuDNN version to use. Default is to use the version specified - in the .bazelrc (9.1.1). + in the .bazelrc. """, ) @@ -317,14 +311,22 @@ async def main(): requirements_update_parser = subparsers.add_parser( "requirements_update", help="Updates the requirements_lock.txt files" ) - add_requirements_nightly_update_argument(requirements_update_parser) + requirements_update_parser.add_argument( + "--nightly_update", + action="store_true", + help=""" + If true, updates requirements_lock.txt for a corresponding version of + Python and will consider dev, nightly and pre-release versions of + packages. + """, + ) add_global_arguments(requirements_update_parser) # Artifact build subcommand build_artifact_parser = subparsers.add_parser( "build", help="Builds the jaxlib, plugin, and pjrt artifact" ) - add_artifact_subcommand_global_arguments(build_artifact_parser) + add_artifact_subcommand_arguments(build_artifact_parser) add_global_arguments(build_artifact_parser) arch = platform.machine() @@ -556,7 +558,10 @@ async def main(): if "cuda" in wheel: wheel_build_command.append("--enable-cuda=True") - cuda_major_version = args.cuda_version.split(".")[0] + if args.cuda_version: + cuda_major_version = args.cuda_version.split(".")[0] + else: + cuda_major_version = args.cuda_major_version wheel_build_command.append(f"--platform_version={cuda_major_version}") if "rocm" in wheel: From 1767756241d0433ae7052b79cc2c82b064861657 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 23:39:53 +0000 Subject: [PATCH 042/205] fix exclude filters --- .github/workflows/build_artifacts.yml | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 5059c58c410d..f882a6fbe347 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -116,24 +116,31 @@ jobs: matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-48"] artifact: ${{ fromJSON(needs.determine_artifact_matrix.outputs.artifact_matrix) }} - python: ["3.10", "3.11", "3.12"] + python: ["3.10", "3.11", "3.12", "3.13"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. exclude: - # Pure Python packages do not need to be built for each Python version. + # jax-cuda-pjrt and jax are Python version independent and do not need to be built for + # each Python version + # Exclude jax-cuda-pjrt - artifact: "jax-cuda-pjrt" - python: "3.10" + python: "3.11" + - artifact: "jax-cuda-pjrt" + python: "3.12" - artifact: "jax-cuda-pjrt" + python: "3.13" + # Exclude jax + - artifact: "jax" python: "3.11" - artifact: "jax" - python: "3.10" + python: "3.12" - artifact: "jax" - python: "3.11" - # jax is a pure Python package so it does not need to be built on multiple platforms. + python: "3.13" + # jax also only needs to be built once per runner - artifact: "jax" runner: "windows-x86-n2-64" - artifact: "jax" - runner: "linux-arm64-t2a-16" + runner: "linux-arm64-t2a-48" # jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows. - artifact: "jax-cuda-plugin" runner: "windows-x86-n2-64" From 6c0e0fb26f92f62d6ccb4191e603674e4c134eba Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 23:53:07 +0000 Subject: [PATCH 043/205] introduce rbe env var to decide when to use rbe flags --- .github/workflows/build_artifacts.yml | 8 ++++++++ ci/build_artifacts.sh | 23 +++++++---------------- ci/envs/default.env | 3 +++ 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index f882a6fbe347..96ce1cc614ae 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -159,6 +159,14 @@ jobs: steps: - uses: actions/checkout@v3 + - name: Enable RBE on platforms where its supported + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + if [[ ($os == "linux" || $os =~ "msys_nt" ) && $arch == "x86_64" ]]; then + echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV + fi # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index a0d3a63d1338..c72d78f6537b 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -40,7 +40,7 @@ arch=$(uname -m) # Adjust the values when running on Windows x86 to match the config in # .bazelrc -if [[ $os =~ "msys_nt" ]] && [[ $arch == "x86_64" ]]; then +if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then os="windows" arch="amd64" fi @@ -54,26 +54,17 @@ if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then # For bazel builds, use the "rbe_" config for Linux x86/Windows and "ci_" for other platforms bazelrc_config="${os}_${arch}" - if ( [[ "$os" == "linux" ]] && [[ "$arch" == "x86_64" ]] ) || [[ "$os" == "windows" ]]; then - bazelrc_config="rbe_$bazelrc_config" + if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then + bazelrc_config="rbe_${bazelrc_config}" else - bazelrc_config="ci_$bazelrc_config" + bazelrc_config="ci_${bazelrc_config}" fi - # Build the jaxlib CPU artifact - if [[ "$artifact" == "jaxlib" ]]; then - python build/build.py build --wheels="jaxlib" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose + if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then + bazelrc_config="${bazelrc_config}_cuda" fi - # Build the jax-cuda-plugin artifact - if [[ "$artifact" == "jax-cuda-plugin" ]]; then - python build/build.py build --wheels="jax-cuda-plugin" --bazel_options=--config="${bazelrc_config}_cuda" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose - fi - - # Build the jax-cuda-pjrt artifact - if [[ "$artifact" == "jax-cuda-pjrt" ]]; then - python build/build.py build --wheels="jax-cuda-pjrt" --bazel_options=--config="${bazelrc_config}_cuda" --verbose - fi + python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. diff --git a/ci/envs/default.env b/ci/envs/default.env index d6514a132c0e..d5f011fe1c6f 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -36,6 +36,9 @@ export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} # Controls the location where the artifacts are stored. export JAXCI_OUTPUT_DIR="$(pwd)/dist" +# When enabled, artifacts will be built with RBE. +export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} + # ############################################################################# # Docker specific environment variables. # ############################################################################# From ea7b99ffef59bd48fb9635cb095f24054e8e22d6 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 19 Nov 2024 03:28:59 +0000 Subject: [PATCH 044/205] update scripts to match upstream --- build/build.py | 2 +- ci/build_artifacts.sh | 10 +++++++--- ci/envs/default.env | 4 +++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/build/build.py b/build/build.py index 00624f15fb9d..10bb96aecc12 100755 --- a/build/build.py +++ b/build/build.py @@ -245,7 +245,7 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): ) # Compile Options - compile_group = parser.add_argument_group('Compile Options') + compile_group = parser.add_argument_group('Compile Options', ) compile_group.add_argument( "--clang_path", diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index c72d78f6537b..6f246f890b30 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -45,26 +45,30 @@ if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then arch="amd64" fi -if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then +if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then # Build the jax artifact if [[ "$artifact" == "jax" ]]; then python -m build --outdir $JAXCI_OUTPUT_DIR else - # For bazel builds, use the "rbe_" config for Linux x86/Windows and "ci_" for other platforms + # Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_" + # flags in the .bazelrc depending upon the platform we are building for. bazelrc_config="${os}_${arch}" + if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then bazelrc_config="rbe_${bazelrc_config}" else bazelrc_config="ci_${bazelrc_config}" fi + # Use the "_cuda" configs when building the CUDA artifacts. if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then bazelrc_config="${bazelrc_config}_cuda" fi - python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose + # Build the artifact. + python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. diff --git a/ci/envs/default.env b/ci/envs/default.env index d5f011fe1c6f..0b48e86935e3 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -36,7 +36,9 @@ export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} # Controls the location where the artifacts are stored. export JAXCI_OUTPUT_DIR="$(pwd)/dist" -# When enabled, artifacts will be built with RBE. +# When enabled, artifacts will be built with RBE. Requires gcloud authentication +# and only certain platforms support RBE. Therefore, this flag is enabled only +# for CI builds where RBE is supported. export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} # ############################################################################# From dd464ffe0f8259f030a9747af70fddf6f9cfd50e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 19 Nov 2024 03:35:31 +0000 Subject: [PATCH 045/205] change upload destination and only enable linux x86 builds --- .github/workflows/build_artifacts.yml | 10 ++++++---- .github/workflows/pytest_cpu_reuse.yml | 8 +++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 96ce1cc614ae..fa17e5821eae 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -114,7 +114,7 @@ jobs: shell: bash strategy: matrix: - runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-48"] + runner: ["linux-x86-n2-16"] #, "linux-arm64-t2a-48", "windows-x86-n2-64"] artifact: ${{ fromJSON(needs.determine_artifact_matrix.outputs.artifact_matrix) }} python: ["3.10", "3.11", "3.12", "3.13"] # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each @@ -174,11 +174,13 @@ jobs: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Build ${{ matrix.artifact }} run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" - - name: Set Platform + - name: Set PLATFORM env var for use in upload destination run: | - echo "PLATFORM=$(uname)_$(uname -m)" >> $GITHUB_ENV + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket # Upload if requested and one of the artifacts was built if: inputs.upload_artifacts - run: gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM + run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ diff --git a/.github/workflows/pytest_cpu_reuse.yml b/.github/workflows/pytest_cpu_reuse.yml index 6f69add49e6b..d5330e2185e2 100644 --- a/.github/workflows/pytest_cpu_reuse.yml +++ b/.github/workflows/pytest_cpu_reuse.yml @@ -37,7 +37,7 @@ jobs: shell: bash strategy: matrix: - runner: ["linux-x86-n2-64", "linux-arm64-t2a-48"] + runner: ["linux-x86-n2-64"] #, "linux-arm64-t2a-48"] python: ["3.10"] runs-on: ${{ matrix.runner }} @@ -56,11 +56,13 @@ jobs: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Set Platform run: | - echo "PLATFORM=$(uname)_$(uname -m)" >> $GITHUB_ENV + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Download the artifacts built in the "build_artifacts" job run: >- mkdir -p $(pwd)/dist && - gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM $(pwd)/dist + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/ - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} From 6f1995d4ce0b905bc78fdc13fc44c2e8a80759d2 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 19 Nov 2024 16:50:44 +0000 Subject: [PATCH 046/205] make workflow calls simpler --- .github/workflows/build_artifacts.yml | 22 +++++++++++----------- .github/workflows/pytest_cpu_reuse.yml | 3 --- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index fa17e5821eae..2a7c191e088c 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -19,36 +19,36 @@ on: build_jax: description: "Should the jax artifact be built? (1 to enable, 0 to disable)" type: string - required: true - default: "1" + required: false + default: "0" build_jaxlib: description: "Should the jaxlib artifact be built? (1 to enable, 0 to disable)" type: string - required: true - default: "1" + required: false + default: "0" build_jax_cuda_plugin: description: "Should the jax-cuda-plugin artifact be built? (1 to enable, 0 to disable)" type: string - required: true - default: "1" + required: false + default: "0" build_jax_cuda_pjrt: description: "Should the jax-cuda-pjrt artifact be built? (1 to enable, 0 to disable)" type: string - required: true - default: "1" + required: false + default: "0" clone_main_xla: description: "Should latest XLA be used? (1 to enable, 0 to disable)" type: string - required: true + required: false default: "0" upload_artifacts: description: "Should the artifacts be uploaded to a GCS bucket?" - required: true + required: false default: false type: boolean upload_destination: description: "GCS location to where the artifacts should be uploaded" - required: true + required: false default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string is_workflow_call: diff --git a/.github/workflows/pytest_cpu_reuse.yml b/.github/workflows/pytest_cpu_reuse.yml index d5330e2185e2..5c7ebfee80c9 100644 --- a/.github/workflows/pytest_cpu_reuse.yml +++ b/.github/workflows/pytest_cpu_reuse.yml @@ -19,10 +19,7 @@ jobs: name: "Build the jaxlib aritfact using latest XLA" uses: ./.github/workflows/build_artifacts.yml with: - build_jax: 0 build_jaxlib: 1 - build_jax_cuda_plugin: 0 - build_jax_cuda_pjrt: 0 clone_main_xla: 1 upload_artifacts: true upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' From 28a26524882694417be700776237bdf2a4480104 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 00:26:37 +0000 Subject: [PATCH 047/205] simplify workflow calls --- .github/workflows/build_artifacts.yml | 100 +++++++++++++++---------- .github/workflows/pytest_cpu_reuse.yml | 6 +- 2 files changed, 65 insertions(+), 41 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 2a7c191e088c..3140bf60c866 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -16,26 +16,21 @@ on: - 'no' workflow_call: inputs: - build_jax: - description: "Should the jax artifact be built? (1 to enable, 0 to disable)" + wheel_list: + description: "A comma separated list of JAX wheels to build. E.g: jaxlib or jaxlib,jax-cuda-pjrt" type: string required: false - default: "0" - build_jaxlib: - description: "Should the jaxlib artifact be built? (1 to enable, 0 to disable)" + default: "" + python_list: + description: "A comma separated list of Python versions to build for. E.g: 3.10 or 3.11,3.12" type: string required: false - default: "0" - build_jax_cuda_plugin: - description: "Should the jax-cuda-plugin artifact be built? (1 to enable, 0 to disable)" + default: "" + platform_list: + description: "A comma separated list of platforms to build for. E.g: linux_x86 or linux_x86,linux_arm64,windows_x86" type: string required: false - default: "0" - build_jax_cuda_pjrt: - description: "Should the jax-cuda-pjrt artifact be built? (1 to enable, 0 to disable)" - type: string - required: false - default: "0" + default: "" clone_main_xla: description: "Should latest XLA be used? (1 to enable, 0 to disable)" type: string @@ -58,11 +53,13 @@ on: default: "1" jobs: - determine_artifact_matrix: + determine_matrix: runs-on: "linux-x86-n2-16" container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" outputs: artifact_matrix: ${{ steps.set-matrix.outputs.artifact_matrix }} + python_matrix: ${{ steps.set-matrix.outputs.python_matrix }} + platform_matrix: ${{ steps.set-matrix.outputs.platform_matrix }} defaults: run: shell: bash @@ -74,39 +71,64 @@ jobs: halt-dispatch-input: ${{ inputs.halt-for-connection }} - id: set-matrix run: | - artifacts=() # Define inputs as bash variables to be able to parse them in # if conditions is_workflow_call=${{ inputs.is_workflow_call }} - build_jax=${{ inputs.build_jax }} - build_jaxlib=${{ inputs.build_jaxlib }} - build_jax_cuda_pjrt=${{ inputs.build_jax_cuda_pjrt }} - build_jax_cuda_plugin=${{ inputs.build_jax_cuda_plugin }} + wheel_list=${{ inputs.wheel_list }} + python_list=${{ inputs.python_list }} + platform_list=${{ inputs.platform_list }} + + # Initialize the arrays + wheels=() + python_versions=() + platforms=() - # Build every package if not a workflow call + # Build every package for every Python version on every platform if not a workflow call + # Packages that are not supported on a platform won't be built. E.g. CUDA packages won't be + # built for Windows if [[ ${is_workflow_call:-"0"} == "0" ]]; then - artifacts=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") + wheels=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") + python_versions=("'3.10'" ", '3.11'" ", '3.12'", ", '3.13'") + platforms=("'linux-x86-n2-16'" ", 'linux-arm64-t2a-48'" ", 'windows-x86-n2-64'") else - if [[ ${build_jax:-"0"} == "1" ]]; then - artifacts+="'jax'," - fi + # Set the Internal Field Separator to be comma + IFS=, - if [[ ${build_jaxlib:-"0"} == "1" ]]; then - artifacts+="'jaxlib'," - fi + # Wheels + for wheel in $wheel_list; do + wheels+="'$wheel'," + done - if [[ ${build_jax_cuda_pjrt:-"0"} == "1" ]]; then - artifacts+="'jax-cuda-pjrt'," - fi + # Python versions + for python_version in $python_list; do + python_versions+="'$python_version'," + done - if [[ ${build_jax_cuda_plugin:-"0"} == "1" ]]; then - artifacts+="'jax-cuda-plugin'" - fi + # Platforms + for platform in $platform_list; do + if [[ $platform == "linux_x86" ]]; then + platforms+="'linux-x86-n2-16'," + elif [[ $platform == "linux_arm64" ]]; then + platforms+="'linux-arm64-t2a-48'," + elif [[ $platform == "windows_x86" ]]; then + platforms+="'windows-x86-n2-64'," + else + echo "Incorrect platform provided. Valid options are: linux_x86, linux_arm64, windows_x86" + exit 1 + fi + done fi - echo "artifact_matrix=[${artifacts[@]}]" >> $GITHUB_OUTPUT + + echo "artifact_matrix=[${wheels[@]}]" >> $GITHUB_OUTPUT + echo "python_matrix=[${python_versions[@]}]" >> $GITHUB_OUTPUT + echo "platform_matrix=[${platforms[@]}]" >> $GITHUB_OUTPUT + + echo "Artifacts: $artifact_matrix" + echo "Python versions: $python_matrix" + echo "Platforms: $platform_matrix" build_artifacts: - needs: determine_artifact_matrix + needs: determine_matrix continue-on-error: true defaults: run: @@ -114,9 +136,9 @@ jobs: shell: bash strategy: matrix: - runner: ["linux-x86-n2-16"] #, "linux-arm64-t2a-48", "windows-x86-n2-64"] - artifact: ${{ fromJSON(needs.determine_artifact_matrix.outputs.artifact_matrix) }} - python: ["3.10", "3.11", "3.12", "3.13"] + runner: ${{ fromJSON(needs.determine_matrix.outputs.platform_matrix) }} + artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} + python: ${{ fromJSON(needs.determine_matrix.outputs.python_matrix) }} # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. exclude: diff --git a/.github/workflows/pytest_cpu_reuse.yml b/.github/workflows/pytest_cpu_reuse.yml index 5c7ebfee80c9..e6ebfb9a6b48 100644 --- a/.github/workflows/pytest_cpu_reuse.yml +++ b/.github/workflows/pytest_cpu_reuse.yml @@ -19,7 +19,9 @@ jobs: name: "Build the jaxlib aritfact using latest XLA" uses: ./.github/workflows/build_artifacts.yml with: - build_jaxlib: 1 + wheel_list: "jaxlib" + python_list: "3.10" + platform_list: "linux_x86,linux_arm64" clone_main_xla: 1 upload_artifacts: true upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' @@ -34,7 +36,7 @@ jobs: shell: bash strategy: matrix: - runner: ["linux-x86-n2-64"] #, "linux-arm64-t2a-48"] + runner: ["linux-x86-n2-64", "linux-arm64-t2a-48"] python: ["3.10"] runs-on: ${{ matrix.runner }} From e72f19a2e6b12deedea8555076644dea9bc78de5 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 00:37:26 +0000 Subject: [PATCH 048/205] fix print statemetns and fix comments --- .github/workflows/build_artifacts.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 3140bf60c866..d3a1f71abf98 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -69,7 +69,8 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - id: set-matrix + - name: "Determine the matrix" + id: set-matrix run: | # Define inputs as bash variables to be able to parse them in # if conditions @@ -123,9 +124,9 @@ jobs: echo "python_matrix=[${python_versions[@]}]" >> $GITHUB_OUTPUT echo "platform_matrix=[${platforms[@]}]" >> $GITHUB_OUTPUT - echo "Artifacts: $artifact_matrix" - echo "Python versions: $python_matrix" - echo "Platforms: $platform_matrix" + echo "Artifacts: ${wheels[@]}" + echo "Python versions:${python_versions[@]}" + echo "Platforms: ${platforms[@]}" build_artifacts: needs: determine_matrix @@ -202,7 +203,6 @@ jobs: arch=$(uname -m) echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket - # Upload if requested and one of the artifacts was built if: inputs.upload_artifacts run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ From 4498912851f24bdf4c5ceaa2f26a98526c98ab11 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 00:42:30 +0000 Subject: [PATCH 049/205] enable resuable workflow for pytest gpu and disable build artifact workflows for pr triggers --- .github/workflows/build_artifacts.yml | 6 ++-- .github/workflows/pytest_gpu.yml | 49 ++++++++++----------------- 2 files changed, 21 insertions(+), 34 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index d3a1f71abf98..f61c78805ed5 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -1,9 +1,9 @@ name: Build JAX Artifacts on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index f08601dd1a85..b779f2009300 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -1,9 +1,9 @@ name: Run Pytest GPU tests on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -17,34 +17,16 @@ on: jobs: build_artifacts: - strategy: - matrix: - python: ["3.10"] - - runs-on: "linux-x86-n2-16" - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" - - name: "Pytest GPU (Build wheels on CUDA 12.3)" - env: - JAXCI_CLONE_MAIN_XLA: 1 - JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} + name: "Build the jaxlib and CUDA plugins using latest XLA" + uses: ./.github/workflows/build_artifacts.yml + with: + wheel_list: "jaxlib,jax-cuda-plugin,jax-cuda-pjrt" + python_list: "3.10" + platform_list: "linux_x86" + clone_main_xla: 1 + upload_artifacts: true + upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build jaxlib - run: ./ci/build_artifacts.sh "jaxlib" - - name: Build jax-cuda-plugin - run: ./ci/build_artifacts.sh "jax-cuda-plugin" - - name: Build jax-cuda-pjrt - run: ./ci/build_artifacts.sh "jax-cuda-pjrt" - - name: Upload artifacts to GCS bucket - run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }} - run_tests: needs: build_artifacts strategy: @@ -72,8 +54,13 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Set Platform + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Download the artifacts built in the "build_artifacts" job - run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }} $(pwd)/dist + run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/ - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} From 960796a082e3340f5f1ba55958dd931c4d4b068b Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 01:02:21 +0000 Subject: [PATCH 050/205] Set number of processes to be a multiple of the gpu count --- ci/run_pytest_gpu.sh | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/ci/run_pytest_gpu.sh b/ci/run_pytest_gpu.sh index 0e2bb5f55b84..36925de1e2da 100755 --- a/ci/run_pytest_gpu.sh +++ b/ci/run_pytest_gpu.sh @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# Runs Pyest CPU tests. Requires all jaxlib, jax-cuda-plugin, and jax-cuda-pjrt +# Runs Pyest CPU tests. Requires the jaxlib, jax-cuda-plugin, and jax-cuda-pjrt # wheels to be present inside $JAXCI_OUTPUT_DIR (../dist) # # -e: abort script if one command fails @@ -23,10 +23,11 @@ # -o allexport: export all functions and variables to be available to subscripts set -exu -o history -o allexport -# Inherit default JAXCI environment variables. +# Source default JAXCI environment variables. source ci/envs/default.env -# Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels on the system. +# Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels inside the +# $JAXCI_OUTPUT_DIR directory on the system. echo "Installing wheels locally..." source ./ci/utilities/install_wheels_locally.sh @@ -42,10 +43,14 @@ nvidia-smi export NCCL_DEBUG=WARN export TF_CPP_MIN_LOG_LEVEL=0 +# Set the number of processes to run to be 4x the number of GPUs. +export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +export num_processes=`expr 4 \* $gpu_count` + echo "Running GPU tests..." export XLA_PYTHON_CLIENT_ALLOCATOR=platform export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 -"$JAXCI_PYTHON" -m pytest -n 8 --tb=short --maxfail=20 \ +"$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \ tests examples \ --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ --deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \ From d14a23697922772f7239b4939f39b5dc2fc704c9 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 01:16:11 +0000 Subject: [PATCH 051/205] Set test environment variables in a single section --- ci/run_pytest_cpu.sh | 13 ++++++++----- ci/run_pytest_gpu.sh | 11 +++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh index 2173271ad806..2b19ca5ddaa5 100755 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== # Runs Pyest CPU tests. Requires a jaxlib wheel to be present -# inside $JAXCI_OUTPUT_DIR (../dist) +# inside the $JAXCI_OUTPUT_DIR (../dist) # # -e: abort script if one command fails # -u: error if undefined variable used @@ -23,20 +23,23 @@ # -o allexport: export all functions and variables to be available to subscripts set -exu -o history -o allexport -# Inherit default JAXCI environment variables. +# Source default JAXCI environment variables. source ci/envs/default.env +# Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system. echo "Installing wheels locally..." source ./ci/utilities/install_wheels_locally.sh # Set up the build environment. source "ci/utilities/setup_build_environment.sh" -export PY_COLORS=1 -export JAX_SKIP_SLOW_TESTS=true - "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" +# Set up all test environment variables +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true export TF_CPP_MIN_LOG_LEVEL=0 +# End of test environment variable setup + echo "Running CPU tests..." "$JAXCI_PYTHON" -m pytest -n auto --tb=short --maxfail=20 tests examples \ No newline at end of file diff --git a/ci/run_pytest_gpu.sh b/ci/run_pytest_gpu.sh index 36925de1e2da..7bc2492781b2 100755 --- a/ci/run_pytest_gpu.sh +++ b/ci/run_pytest_gpu.sh @@ -34,12 +34,13 @@ source ./ci/utilities/install_wheels_locally.sh # Set up the build environment. source "ci/utilities/setup_build_environment.sh" -export PY_COLORS=1 -export JAX_SKIP_SLOW_TESTS=true - "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" nvidia-smi + +# Set up all test environment variables +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true export NCCL_DEBUG=WARN export TF_CPP_MIN_LOG_LEVEL=0 @@ -47,9 +48,11 @@ export TF_CPP_MIN_LOG_LEVEL=0 export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) export num_processes=`expr 4 \* $gpu_count` -echo "Running GPU tests..." export XLA_PYTHON_CLIENT_ALLOCATOR=platform export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 +# End of test environment variable setup + +echo "Running GPU tests..." "$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \ tests examples \ --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ From aa01d868665d09e593156372b6b1c579b5113373 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 01:33:17 +0000 Subject: [PATCH 052/205] rework tpu job and scripts to match upstream --- .github/workflows/pytest_tpu.yml | 67 ++++++++++++++++++++------------ ci/run_pytest_tpu.sh | 23 ++++------- 2 files changed, 50 insertions(+), 40 deletions(-) diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 2f97aec4196d..4eb840114bd0 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -1,9 +1,9 @@ name: Run Pytest TPU tests on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -16,19 +16,26 @@ on: - 'no' jobs: - run_tests: + run_tpu_tests: strategy: + fail-fast: false matrix: - runner: ["linux-x86-ct5lp-224-8tpu"] - tpu_cores: ["8"] + jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] + tpu: [ + # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available + {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + ] python: ["3.10"] - runs-on: ${{ matrix.runner }} - container: - image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + runs-on: ${{ matrix.tpu.runner }} + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + + name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: - JAXCI_CLONE_MAIN_XLA: 1 + LIBTPU_OLDEST_VERSION_DATE: 20240722 + ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }} JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} steps: @@ -38,23 +45,33 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build jaxlib - run: ./ci/build_artifacts.sh "jaxlib" - - name: Install pytest - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install pytest - - name: Install Test requirements - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: | - $JAXCI_PYTHON -m pip install -r build/test-requirements.txt - $JAXCI_PYTHON -m pip install -r build/collect-profile-requirements.txt - - name: Install Libtpu + - name: Install JAX test requirements env: JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install --pre libtpu-nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html" + run: | + pip install -U -r build/test-requirements.txt + pip install -U -r build/collect-profile-requirements.txt + - name: Install JAX + run: | + if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then + pip install .[tpu] \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then + pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + pip install --pre libtpu \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip install requests + elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then + pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release. + pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip install requests + else + echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}" + exit 1 + fi - name: Run Pytest TPU tests env: - JAXCI_TPU_CORES: ${{ matrix.tpu_cores }} + JAXCI_TPU_CORES: ${{ matrix.tpu.cores }} run: ./ci/run_pytest_tpu.sh diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index 92a4baa0fbb1..16ca469c3ea4 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -26,37 +26,30 @@ set -exu -o history -o allexport # Inherit default JAXCI environment variables. source ci/envs/default.env -# Install jaxlib wheel on the system. Requires a jaxlib wheel to be present -# inside $JAXCI_OUTPUT_DIR (../dist) -echo "Installing wheels locally..." -source ./ci/utilities/install_wheels_locally.sh - # Set up the build environment. source "ci/utilities/setup_build_environment.sh" -export PY_COLORS=1 -export JAX_SKIP_SLOW_TESTS=true - "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" - "$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)' "$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)' "$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' strings /usr/local/lib/"$JAXCI_PYTHON"/site-packages/libtpu/libtpu.so | grep 'Built on' "$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)' -echo "Running TPU tests..." +# Set up common test environment variables +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true export JAX_PLATFORMS=tpu,cpu -# Run single-accelerator tests in parallel -export JAX_ENABLE_TPU_XDIST=true +# End of common test environment variable setup -"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ +echo "Running TPU tests..." +# Run single-accelerator tests in parallel +JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ --maxfail=20 -m "not multiaccelerator" tests examples # Run Pallas printing tests, which need to run with I/O capturing disabled. -export TPU_STDERR_LOG_LEVEL=0 -"$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest +TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest # Run multi-accelerator across all chips "$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests \ No newline at end of file From 6b7720f4e86a1fffc8e261015cf4880010756020 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 01:41:21 +0000 Subject: [PATCH 053/205] Make the resuable pytest cpu workflow as default --- .github/workflows/pytest_cpu.yml | 39 ++++++++++---- .github/workflows/pytest_cpu_reuse.yml | 74 -------------------------- 2 files changed, 28 insertions(+), 85 deletions(-) delete mode 100644 .github/workflows/pytest_cpu_reuse.yml diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 71463f57d974..6384f6d49714 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -1,9 +1,9 @@ name: Run Pytest CPU tests on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -13,10 +13,22 @@ on: default: 'no' options: - 'yes' - - 'no' jobs: - build: + build_jaxlib_artifact: + name: "Build the jaxlib aritfact using latest XLA" + uses: ./.github/workflows/build_artifacts.yml + with: + wheel_list: "jaxlib" + python_list: "3.10" + platform_list: "linux_x86,linux_arm64" + clone_main_xla: 1 + upload_artifacts: true + upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + + run_pytest: + name: "Run CPU tests with Pytest" + needs: build_jaxlib_artifact continue-on-error: true defaults: run: @@ -24,16 +36,14 @@ jobs: shell: bash strategy: matrix: - runner: ["windows-x86-n2-64", "linux-x86-n2-64", "linux-arm64-t2a-48"] + runner: ["linux-x86-n2-64", "linux-arm64-t2a-48"] python: ["3.10"] runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || - (contains(matrix.runner, 'windows-x86') && null) }} + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} env: - JAXCI_CLONE_MAIN_XLA: 1 JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} steps: @@ -43,8 +53,15 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build jaxlib - run: ./ci/build_artifacts.sh "jaxlib" + - name: Set Platform + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + - name: Download the artifacts built in the "build_artifacts" job + run: >- + mkdir -p $(pwd)/dist && + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/ - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} diff --git a/.github/workflows/pytest_cpu_reuse.yml b/.github/workflows/pytest_cpu_reuse.yml deleted file mode 100644 index e6ebfb9a6b48..000000000000 --- a/.github/workflows/pytest_cpu_reuse.yml +++ /dev/null @@ -1,74 +0,0 @@ -name: Run Pytest CPU tests (resuable workflow) - -on: - pull_request: - branches: - - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - -jobs: - build_jaxlib_artifact: - name: "Build the jaxlib aritfact using latest XLA" - uses: ./.github/workflows/build_artifacts.yml - with: - wheel_list: "jaxlib" - python_list: "3.10" - platform_list: "linux_x86,linux_arm64" - clone_main_xla: 1 - upload_artifacts: true - upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - - run_pytest: - name: "Run CPU tests with Pytest" - needs: build_jaxlib_artifact - continue-on-error: true - defaults: - run: - # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. - shell: bash - strategy: - matrix: - runner: ["linux-x86-n2-64", "linux-arm64-t2a-48"] - python: ["3.10"] - - runs-on: ${{ matrix.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} - - env: - JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Set Platform - run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) - echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Download the artifacts built in the "build_artifacts" job - run: >- - mkdir -p $(pwd)/dist && - gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/ - - name: Install pytest - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install pytest - - name: Install dependencies - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install -r build/requirements.in - - name: Run Pytest CPU tests - run: ./ci/run_pytest_cpu.sh From e43a85e5d93102371c8fd132bd57fe341d024c6c Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 01:41:55 +0000 Subject: [PATCH 054/205] Reuse the Build artifact workflow --- .github/workflows/bazel_gpu_non_rbe.yml | 49 +++++++++++++------------ 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index ba11ac486001..500c512d889c 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -1,9 +1,9 @@ name: Run Bazel GPU tests (non RBE) on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -16,14 +16,22 @@ on: - 'no' jobs: - build: - strategy: - matrix: - runner: ["linux-x86-g2-48-l4-4gpu"] - - runs-on: ${{ matrix.runner }} - container: - image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + build_artifacts: + name: "Build the jaxlib and CUDA plugins using latest XLA" + uses: ./.github/workflows/build_artifacts.yml + with: + wheel_list: "jaxlib,jax-cuda-plugin,jax-cuda-pjrt" + python_list: "3.11" + platform_list: "linux_x86" + clone_main_xla: 1 + upload_artifacts: true + upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + + run_bazel_tests: + name: "Run Bazel GPU tests (single accelerator and multi-accelerator tests, non-RBE)" + needs: build_artifacts + runs-on: "linux-x86-g2-48-l4-4gpu" + container: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" env: JAXCI_HERMETIC_PYTHON_VERSION: 3.11 @@ -35,17 +43,12 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build jaxlib - env: - JAXCI_CLONE_MAIN_XLA: 1 - run: ./ci/build_artifacts.sh "jaxlib" - - name: Build jax-cuda-plugin - env: - JAXCI_CLONE_MAIN_XLA: 1 - run: ./ci/build_artifacts.sh "jax-cuda-plugin" - - name: Build jax-cuda-pjrt - env: - JAXCI_CLONE_MAIN_XLA: 1 - run: ./ci/build_artifacts.sh "jax-cuda-pjrt" + - name: Set Platform + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + - name: Download the artifacts built in the "build_artifacts" job + run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/ - name: Run Bazel GPU tests locally run: ./ci/run_bazel_test_gpu_non_rbe.sh From b61691c5573ef7f7b663c635e4edefff76436edb Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 01:55:37 +0000 Subject: [PATCH 055/205] stop building jax and remove python exclude filter for pjrt --- .github/workflows/build_artifacts.yml | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index f61c78805ed5..2bff2e1c3ab1 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -88,7 +88,7 @@ jobs: # Packages that are not supported on a platform won't be built. E.g. CUDA packages won't be # built for Windows if [[ ${is_workflow_call:-"0"} == "0" ]]; then - wheels=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") + wheels=("'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") python_versions=("'3.10'" ", '3.11'" ", '3.12'", ", '3.13'") platforms=("'linux-x86-n2-16'" ", 'linux-arm64-t2a-48'" ", 'windows-x86-n2-64'") else @@ -143,27 +143,10 @@ jobs: # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each # Python version. exclude: - # jax-cuda-pjrt and jax are Python version independent and do not need to be built for - # each Python version - # Exclude jax-cuda-pjrt - - artifact: "jax-cuda-pjrt" - python: "3.11" - - artifact: "jax-cuda-pjrt" - python: "3.12" - - artifact: "jax-cuda-pjrt" - python: "3.13" - # Exclude jax - - artifact: "jax" - python: "3.11" - - artifact: "jax" - python: "3.12" - - artifact: "jax" - python: "3.13" - # jax also only needs to be built once per runner - - artifact: "jax" - runner: "windows-x86-n2-64" - - artifact: "jax" - runner: "linux-arm64-t2a-48" + # jax-cuda-pjrt does not need to be built for every Python but excluding it here for + # every but one Python version causes issues when a workflow call is made to this file + # requesting a build for an exlcuded Python version (see pytest_gpu.yaml) + # # jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows. - artifact: "jax-cuda-plugin" runner: "windows-x86-n2-64" From 2af5f0cbc0cf685890c6312a2eb1bf2db0e5246a Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 02:09:18 +0000 Subject: [PATCH 056/205] enable windows runners and run for all python versions --- .github/workflows/pytest_cpu.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 6384f6d49714..2b7208e74a3e 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -20,8 +20,8 @@ jobs: uses: ./.github/workflows/build_artifacts.yml with: wheel_list: "jaxlib" - python_list: "3.10" - platform_list: "linux_x86,linux_arm64" + python_list: "3.10,3.11,3.12,3.13" + platform_list: "linux_x86,linux_arm64,windows_x86" clone_main_xla: 1 upload_artifacts: true upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' @@ -36,12 +36,13 @@ jobs: shell: bash strategy: matrix: - runner: ["linux-x86-n2-64", "linux-arm64-t2a-48"] + runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] python: ["3.10"] runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(matrix.runner, 'windows-x86') && null) }} env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} From 75695a0b1adc1b6078cd8a3bd1ae2bcaf8bfd1ce Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 02:15:26 +0000 Subject: [PATCH 057/205] Add all Python versions to run_pytest job --- .github/workflows/pytest_cpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 2b7208e74a3e..2c8aae496139 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -37,7 +37,7 @@ jobs: strategy: matrix: runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] - python: ["3.10"] + python: ["3.10", "3.11", "3.12", "3.13"] runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || From bd34770fef4044095d3a9e6991926589d56b1343 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 02:17:09 +0000 Subject: [PATCH 058/205] update docker image to one with gsutil --- .github/workflows/bazel_gpu_non_rbe.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 500c512d889c..daa093435316 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -31,7 +31,7 @@ jobs: name: "Run Bazel GPU tests (single accelerator and multi-accelerator tests, non-RBE)" needs: build_artifacts runs-on: "linux-x86-g2-48-l4-4gpu" - container: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" env: JAXCI_HERMETIC_PYTHON_VERSION: 3.11 From 68ea9eb508cfee4647198fd1b759c438b00c2055 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 16:52:52 +0000 Subject: [PATCH 059/205] Use the correct Python binary to install deps --- .github/workflows/pytest_tpu.yml | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 4eb840114bd0..6c98487f227c 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -37,40 +37,39 @@ jobs: LIBTPU_OLDEST_VERSION_DATE: 20240722 ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }} JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} + JAXCI_PYTHON: python${{ matrix.python }} steps: - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Install JAX test requirements - env: - JAXCI_PYTHON: python${{ matrix.python }} run: | - pip install -U -r build/test-requirements.txt - pip install -U -r build/collect-profile-requirements.txt + $JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt + $JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt - name: Install JAX run: | if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then - pip install .[tpu] \ + $JAXCI_PYTHON -m pip install .[tpu] \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then - pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - pip install --pre libtpu \ + $JAXCI_PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + $JAXCI_PYTHON -m pip install --pre libtpu \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install requests + $JAXCI_PYTHON -m pip install requests elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then - pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + $JAXCI_PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release. - pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ + $JAXCI_PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install requests + $JAXCI_PYTHON -m pip install requests else echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}" exit 1 fi + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest TPU tests env: JAXCI_TPU_CORES: ${{ matrix.tpu.cores }} From bb1130c57c5345e8f9b85dc703ffc90db0512e81 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 17:26:27 +0000 Subject: [PATCH 060/205] update site-packages to dist-packages --- ci/run_pytest_tpu.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index 16ca469c3ea4..8642c662043b 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -33,7 +33,7 @@ source "ci/utilities/setup_build_environment.sh" "$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)' "$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)' "$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' -strings /usr/local/lib/"$JAXCI_PYTHON"/site-packages/libtpu/libtpu.so | grep 'Built on' +strings /usr/local/lib/"$JAXCI_PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on' "$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)' # Set up common test environment variables From 3ad80f19dabf508428816972ca2df39f5fa3744d Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 17:41:44 +0000 Subject: [PATCH 061/205] change input name to upload_destination_prefix --- .github/workflows/bazel_gpu_non_rbe.yml | 2 +- .github/workflows/build_artifacts.yml | 6 +++--- .github/workflows/pytest_cpu.yml | 2 +- .github/workflows/pytest_gpu.yml | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index daa093435316..3757daf693f8 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -25,7 +25,7 @@ jobs: platform_list: "linux_x86" clone_main_xla: 1 upload_artifacts: true - upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' run_bazel_tests: name: "Run Bazel GPU tests (single accelerator and multi-accelerator tests, non-RBE)" diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 2bff2e1c3ab1..a2d0f8f3828f 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -41,8 +41,8 @@ on: required: false default: false type: boolean - upload_destination: - description: "GCS location to where the artifacts should be uploaded" + upload_destination_prefix: + description: "GCS location prefix to where the artifacts should be uploaded" required: false default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string @@ -187,5 +187,5 @@ jobs: echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket if: inputs.upload_artifacts - run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ + run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 2c8aae496139..7c2e734d3c8a 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -24,7 +24,7 @@ jobs: platform_list: "linux_x86,linux_arm64,windows_x86" clone_main_xla: 1 upload_artifacts: true - upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' run_pytest: name: "Run CPU tests with Pytest" diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index b779f2009300..25118637eb39 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -25,7 +25,7 @@ jobs: platform_list: "linux_x86" clone_main_xla: 1 upload_artifacts: true - upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' run_tests: needs: build_artifacts From 58ca848eb622855122b67c57c132fbd5b21c4dde Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 17:45:48 +0000 Subject: [PATCH 062/205] remove unused env var --- ci/envs/default.env | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ci/envs/default.env b/ci/envs/default.env index 0b48e86935e3..b62773266c37 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -44,7 +44,6 @@ export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} # ############################################################################# # Docker specific environment variables. # ############################################################################# - # Docker specifc environment variables. Used by `run_docker_container.sh` export JAXCI_DOCKER_WORK_DIR="/jax" export JAXCI_DOCKER_IMAGE="" @@ -57,9 +56,6 @@ export JAXCI_DOCKER_ARGS="" # defined in the TPU GitHub Actions workflow. export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} -# Set when running tests that depend on wheels installed on the system. -export JAXCI_INSTALL_WHEELS_LOCALLY=0 - # JAXCI_PYTHON is used to install the wheels locally. It needs to match the # version of the hermetic Python used by Bazel. export JAXCI_PYTHON=python${JAXCI_HERMETIC_PYTHON_VERSION} From 6fef417c1135b00e31f83a7dc47e8dc2a469afbc Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 18:42:55 +0000 Subject: [PATCH 063/205] adjust os name when running on windows runners --- .github/workflows/bazel_gpu_non_rbe.yml | 6 ++++++ .github/workflows/build_artifacts.yml | 9 +++++++-- .github/workflows/pytest_cpu.yml | 6 ++++++ .github/workflows/pytest_gpu.yml | 7 +++++++ 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 3757daf693f8..8b7635aef3aa 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -47,6 +47,12 @@ jobs: run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) + + # Adjust name for Windows + if [[ $os =~ "msys_nt" ]]; then + os="windows" + fi + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Download the artifacts built in the "build_artifacts" job run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/ diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index a2d0f8f3828f..bee081100fbf 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -140,8 +140,6 @@ jobs: runner: ${{ fromJSON(needs.determine_matrix.outputs.platform_matrix) }} artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} python: ${{ fromJSON(needs.determine_matrix.outputs.python_matrix) }} - # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each - # Python version. exclude: # jax-cuda-pjrt does not need to be built for every Python but excluding it here for # every but one Python version causes issues when a workflow call is made to this file @@ -184,6 +182,13 @@ jobs: run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) + + # Adjust name for Windows + if [[ $os =~ "msys_nt" ]]; then + os="windows" + + fi + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket if: inputs.upload_artifacts diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 7c2e734d3c8a..4a52e68c9f2a 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -58,6 +58,12 @@ jobs: run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) + + # Adjust name for Windows + if [[ $os =~ "msys_nt" ]]; then + os="windows" + fi + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Download the artifacts built in the "build_artifacts" job run: >- diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index 25118637eb39..99a270910d80 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -58,6 +58,13 @@ jobs: run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) + + # Adjust name for Windows + if [[ $os =~ "msys_nt" ]]; then + os="windows" + + fi + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Download the artifacts built in the "build_artifacts" job run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/ From 191389b9b60f80b3a71e2fc5faf5d8cadf813bd7 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 20:24:48 +0000 Subject: [PATCH 064/205] add concurrency settings and remove condition that disables running on forks --- .github/workflows/bazel_cpu_rbe.yml | 13 ++++++++----- .github/workflows/bazel_gpu_non_rbe.yml | 4 ++++ .github/workflows/bazel_gpu_rbe.yml | 7 +++---- .github/workflows/pytest_cpu.yml | 4 ++++ .github/workflows/pytest_gpu.yml | 4 ++++ .github/workflows/pytest_tpu.yml | 4 ++++ 6 files changed, 27 insertions(+), 9 deletions(-) diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 877798797754..c5f8fbf07d7d 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -1,9 +1,9 @@ name: CI - Bazel CPU tests (RBE) on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -21,10 +21,13 @@ concurrency: jobs: run_tests: - if: github.event.repository.fork == false + defaults: + run: + # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. + shell: bash strategy: matrix: - runner: ["windows-x86-n2-64"] #, "linux-x86-n2-16", "linux-arm64-t2a-16"] + runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 8b7635aef3aa..ce432ffadb6a 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -15,6 +15,10 @@ on: - 'yes' - 'no' +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + jobs: build_artifacts: name: "Build the jaxlib and CUDA plugins using latest XLA" diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml index 93ce8492f245..b8bc20e24e97 100644 --- a/.github/workflows/bazel_gpu_rbe.yml +++ b/.github/workflows/bazel_gpu_rbe.yml @@ -1,9 +1,9 @@ name: CI - Bazel GPU tests (RBE) on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -21,7 +21,6 @@ concurrency: jobs: run_tests: - if: github.event.repository.fork == false strategy: matrix: runner: ["linux-x86-n2-16"] diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 4a52e68c9f2a..c37a34676307 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -14,6 +14,10 @@ on: options: - 'yes' +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + jobs: build_jaxlib_artifact: name: "Build the jaxlib aritfact using latest XLA" diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index 99a270910d80..a5152045d654 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -15,6 +15,10 @@ on: - 'yes' - 'no' +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + jobs: build_artifacts: name: "Build the jaxlib and CUDA plugins using latest XLA" diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 6c98487f227c..7162e3eefae5 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -15,6 +15,10 @@ on: - 'yes' - 'no' +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + jobs: run_tpu_tests: strategy: From f9165b1347119d50135fd22b98324fb3b897c41c Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 20 Nov 2024 20:49:14 +0000 Subject: [PATCH 065/205] replace continue-on-error with fail-fast --- .github/workflows/bazel_cpu_rbe.yml | 1 + .github/workflows/build_artifacts.yml | 2 +- .github/workflows/pytest_cpu.yml | 2 +- .github/workflows/pytest_gpu.yml | 1 + .github/workflows/pytest_tpu.yml | 2 +- 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index c5f8fbf07d7d..2b7648745c5c 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -26,6 +26,7 @@ jobs: # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. shell: bash strategy: + fail-fast: false # don't cancel all jobs on failure matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index bee081100fbf..182c74d427dd 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -130,12 +130,12 @@ jobs: build_artifacts: needs: determine_matrix - continue-on-error: true defaults: run: # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. shell: bash strategy: + fail-fast: false # don't cancel all jobs on failure matrix: runner: ${{ fromJSON(needs.determine_matrix.outputs.platform_matrix) }} artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index c37a34676307..d730c50d1008 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -33,12 +33,12 @@ jobs: run_pytest: name: "Run CPU tests with Pytest" needs: build_jaxlib_artifact - continue-on-error: true defaults: run: # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. shell: bash strategy: + fail-fast: false # don't cancel all jobs on failure matrix: runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] python: ["3.10", "3.11", "3.12", "3.13"] diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index a5152045d654..3e909e4d380b 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -34,6 +34,7 @@ jobs: run_tests: needs: build_artifacts strategy: + fail-fast: false # don't cancel all jobs on failure matrix: test_env: [ {cuda_version: "12.3", runner: "linux-x86-g2-48-l4-4gpu", diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 7162e3eefae5..88825d52aaef 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -22,7 +22,7 @@ concurrency: jobs: run_tpu_tests: strategy: - fail-fast: false + fail-fast: false # don't cancel all jobs on failure matrix: jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ From dc9dec4ba26b2453998bb073b34b56e4359ec51f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 22 Nov 2024 21:14:28 +0000 Subject: [PATCH 066/205] move log_stream outside of class --- build/tools/command.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/build/tools/command.py b/build/tools/command.py index 22136a23a54b..48a9bfc1c0d6 100644 --- a/build/tools/command.py +++ b/build/tools/command.py @@ -50,6 +50,18 @@ class CommandResult: ) end_time: Optional[datetime.datetime] = None + +async def _process_log_stream(stream, result: CommandResult): + """Logs the output of a subprocess stream.""" + while True: + line_bytes = await stream.readline() + if not line_bytes: + break + line = line_bytes.decode().rstrip() + result.logs += line + logger.info("%s", line) + + class SubprocessExecutor: """ Manages execution of subprocess commands with reusable environment and logging. @@ -89,17 +101,8 @@ async def run(self, cmd: str, dry_run: bool = False) -> CommandResult: env=self.environment, ) - async def log_stream(stream, result: CommandResult): - while True: - line_bytes = await stream.readline() - if not line_bytes: - break - line = line_bytes.decode().rstrip() - result.logs += line - logger.info("%s", line) - await asyncio.gather( - log_stream(process.stdout, result), log_stream(process.stderr, result) + _process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result) ) result.return_code = await process.wait() From 8857b65ce82110f0219cc0a4b65b775c1e924ca4 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 25 Nov 2024 19:11:20 +0000 Subject: [PATCH 067/205] add the use_clang flag --- build/build.py | 50 ++++++++++++++++++++++++++++++++------------ build/tools/utils.py | 23 ++++++++++++++++++++ 2 files changed, 60 insertions(+), 13 deletions(-) diff --git a/build/build.py b/build/build.py index 10bb96aecc12..e637c2223c56 100755 --- a/build/build.py +++ b/build/build.py @@ -245,7 +245,19 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): ) # Compile Options - compile_group = parser.add_argument_group('Compile Options', ) + compile_group = parser.add_argument_group('Compile Options') + + compile_group.add_argument( + "--use_clang", + type=utils._parse_string_as_bool, + default="true", + const=True, + nargs="?", + help=""" + Whether to use Clang as the compiler. Not recommended to set this to + False as JAX uses Clang as the default compiler. + """, + ) compile_group.add_argument( "--clang_path", @@ -304,7 +316,7 @@ async def main(): formatter_class=argparse.RawDescriptionHelpFormatter ) - # Create subparsers for build_artifacts and requirements_update + # Create subparsers for build and requirements_update subparsers = parser.add_subparsers(dest="command", required=True) # requirements_update subcommand @@ -437,13 +449,22 @@ async def main(): arch, ) - clang_path = args.clang_path or utils.get_clang_path_or_exit() - logging.debug("Using Clang as the compiler, clang path: %s", clang_path) + clang_path = "" + if args.use_clang: + clang_path = args.clang_path or utils.get_clang_path_or_exit() + clang_major_version = utils.get_clang_major_version(clang_path) + logging.debug( + "Using Clang as the compiler, clang path: %s, clang version: %s", + clang_path, + clang_major_version, + ) - # Use double quotes around clang path to avoid path issues on Windows. - wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") - wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"") - wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") + # Use double quotes around clang path to avoid path issues on Windows. + wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") + else: + logging.debug("Use Clang: False") # Do not apply --config=clang on Mac as these settings do not apply to # Apple Clang. @@ -455,11 +476,11 @@ async def main(): wheel_build_command.append("--config=mkl_open_source_only") if args.target_cpu_features == "release": - logging.debug( - "Using release cpu features: --config=avx_%s", - "windows" if os_name == "windows" else "posix", - ) if arch in ["x86_64", "AMD64"]: + logging.debug( + "Using release cpu features: --config=avx_%s", + "windows" if os_name == "windows" else "posix", + ) wheel_build_command.append( "--config=avx_windows" if os_name == "windows" @@ -509,7 +530,10 @@ async def main(): ) if "rocm" in wheel: - wheel_build_command.append("--config=rocm") + wheel_build_command.append("--config=rocm_base") + if args.use_clang: + wheel_build_command.append("--config=rocm") + wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") if args.rocm_path: logging.debug("ROCm tookit path: %s", args.rocm_path) wheel_build_command.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") diff --git a/build/tools/utils.py b/build/tools/utils.py index 8fa29e8d5c7c..db4541d0ee96 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -189,6 +189,19 @@ def get_clang_path_or_exit(): sys.exit(-1) +def get_clang_major_version(clang_path): + clang_version_proc = subprocess.run( + [clang_path, "-E", "-P", "-"], + input="__clang_major__", + check=True, + capture_output=True, + text=True, + ) + major_version = int(clang_version_proc.stdout) + + return major_version + + def get_jax_configure_bazel_options(bazel_command: list[str]): """Returns the bazel options to be written to .jax_configure.bazelrc.""" # Get the index of the "run" parameter. Build options will come after "run" so @@ -219,3 +232,13 @@ def get_githash(): ).stdout.strip() except OSError: return "" + +def _parse_string_as_bool(s): + """Parses a string as a boolean value.""" + lower = s.lower() + if lower == "true": + return True + elif lower == "false": + return False + else: + raise ValueError(f"Expected either 'true' or 'false'; got {s}") \ No newline at end of file From a2632fd9803e8f3bcf7986e7c77874d7f2999f14 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 25 Nov 2024 19:28:51 +0000 Subject: [PATCH 068/205] enable the build artifacts workflow --- .github/workflows/build_artifacts.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 182c74d427dd..d891e1ca7e73 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -1,9 +1,9 @@ name: Build JAX Artifacts on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: From 482635c64238c96109d467f174c7a7c24376bfcc Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 17:15:11 +0000 Subject: [PATCH 069/205] debug asan workflow --- .github/workflows/asan.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index ea87d4e29e40..698022f98177 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -59,13 +59,17 @@ jobs: source ${GITHUB_WORKSPACE}/venv/bin/activate cd jax pip install -r build/test-requirements.txt + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main - name: Build and install JAX env: ASAN_OPTIONS: detect_leaks=0 run: | source ${GITHUB_WORKSPACE}/venv/bin/activate cd jax - python build/build.py \ + git config --global --add safe.directory jax + python build/build.py build --wheels=jaxlib --verbose \ --bazel_options=--color=yes \ --bazel_options=--copt=-fsanitize=address \ --clang_path=/usr/bin/clang-18 From a7fa98f13ff8a8ec98769371ee7bb128f47473fa Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 20:51:10 +0000 Subject: [PATCH 070/205] add a check for return code --- build/build.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/build/build.py b/build/build.py index e637c2223c56..799b172fc7cd 100755 --- a/build/build.py +++ b/build/build.py @@ -293,7 +293,7 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): compile_group.add_argument( "--target_cpu", default=None, - help="CPU platform to target. Default is the same as the host machine. ", + help="CPU platform to target. Default is the same as the host machine.", ) compile_group.add_argument( @@ -399,8 +399,11 @@ async def main(): else: requirements_command.append("//build:requirements.update") - await executor.run(requirements_command.get_command_as_string(), args.dry_run) - sys.exit(0) + result = await executor.run(requirements_command.get_command_as_string(), args.dry_run) + if result.return_code != 0: + raise RuntimeError(f"Command failed with return code {result.return_code}") + else: + sys.exit(0) wheel_cpus = { "darwin_arm64": "arm64", @@ -594,8 +597,12 @@ async def main(): wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") - await executor.run(wheel_build_command.get_command_as_string(), args.dry_run) + result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run) + if result.return_code != 0: + raise RuntimeError(f"Command failed with return code {result.return_code}") + else: + sys.exit(0) if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(main()) \ No newline at end of file From c41f1c26b06d8db981a683d42e7f4c60c683bb04 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 21:27:45 +0000 Subject: [PATCH 071/205] add a ci workflow to debug --- .github/workflows/ci_duplicate.yml | 41 ++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 .github/workflows/ci_duplicate.yml diff --git a/.github/workflows/ci_duplicate.yml b/.github/workflows/ci_duplicate.yml new file mode 100644 index 000000000000..b71b120ebcdc --- /dev/null +++ b/.github/workflows/ci_duplicate.yml @@ -0,0 +1,41 @@ +name: CI (duplicate) + +on: + pull_request: + branches: + - main + +permissions: + contents: read # to fetch code + actions: write # to cancel previous workflows + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + lint_and_typecheck: + runs-on: linux-x86-n2-16 + container: + image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 + timeout-minutes: 5 + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Install dependencies for setting up Python + env: + DEBIAN_FRONTEND: noninteractive + run: | + apt update && apt install libssl-dev libsqlite3-dev git -y + - name: Set up Python 3.11 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: 3.11 + - run: python -m pip install pre-commit + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + with: + path: ~/.cache/pre-commit + key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} + - run: pre-commit run --show-diff-on-failure --color=always --all-files \ No newline at end of file From 0603e2524c474f62e6d05305fab40594fec41fe8 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 21:47:07 +0000 Subject: [PATCH 072/205] add a python3 binary --- .github/workflows/asan.yaml | 42 +++++++++++++++++------------- .github/workflows/ci_duplicate.yml | 2 ++ 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 698022f98177..33727c3ab881 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -2,7 +2,7 @@ name: CI - Address Sanitizer (nightly) concurrency: group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + #cancel-in-progress: true on: schedule: @@ -42,23 +42,29 @@ jobs: zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ libffi-dev liblzma-dev - - name: Build CPython with ASAN enabled - env: - ASAN_OPTIONS: detect_leaks=0 - run: | - cd cpython - mkdir ${GITHUB_WORKSPACE}/cpythonasan - CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpythonasan --with-address-sanitizer --without-pymalloc - make -j64 - make install - ${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv - - name: Install JAX test requirements - env: - ASAN_OPTIONS: detect_leaks=0 - run: | - source ${GITHUB_WORKSPACE}/venv/bin/activate - cd jax - pip install -r build/test-requirements.txt + # - name: Build CPython with ASAN enabled + # env: + # ASAN_OPTIONS: detect_leaks=0 + # run: | + # cd cpython + # mkdir ${GITHUB_WORKSPACE}/cpythonasan + # CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpythonasan --with-address-sanitizer --without-pymalloc + # make -j64 + # make install + # ${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv + # - name: Install JAX test requirements + # env: + # ASAN_OPTIONS: detect_leaks=0 + # run: | + # source ${GITHUB_WORKSPACE}/venv/bin/activate + # cd jax + # pip install -r build/test-requirements.txt + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Create python3 symlink + run: ln -s /__w/_tool/Python/3.11.10/x64/bin/python3 /usr/local/bin/python3 # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main diff --git a/.github/workflows/ci_duplicate.yml b/.github/workflows/ci_duplicate.yml index b71b120ebcdc..ae619855f5d6 100644 --- a/.github/workflows/ci_duplicate.yml +++ b/.github/workflows/ci_duplicate.yml @@ -30,6 +30,8 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: 3.11 + - name: Create python3 symlink + run: ln -s /__w/_tool/Python/3.11.10/x64/bin/python3 /usr/local/bin/python3 - run: python -m pip install pre-commit # Halt for testing - name: Wait For Connection From c9f5b862e10c09383d451840ab16c532a939f1fa Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 21:57:30 +0000 Subject: [PATCH 073/205] add python binary to path --- .github/workflows/asan.yaml | 4 ++-- .github/workflows/ci_duplicate.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 33727c3ab881..e98e0be04e67 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -63,8 +63,8 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.11" - - name: Create python3 symlink - run: ln -s /__w/_tool/Python/3.11.10/x64/bin/python3 /usr/local/bin/python3 + - name: Add to PATH + run: export PATH=$PATH:/__w/_tool/Python/3.11.10/x64/bin/ # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main diff --git a/.github/workflows/ci_duplicate.yml b/.github/workflows/ci_duplicate.yml index ae619855f5d6..a052da737cc6 100644 --- a/.github/workflows/ci_duplicate.yml +++ b/.github/workflows/ci_duplicate.yml @@ -30,9 +30,9 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: 3.11 - - name: Create python3 symlink - run: ln -s /__w/_tool/Python/3.11.10/x64/bin/python3 /usr/local/bin/python3 - run: python -m pip install pre-commit + - name: Add to PATH + run: export PATH=$PATH:/__w/_tool/Python/3.11.10/x64/bin/ # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main From eae1124aeb5828eb0210af24e302c213d8e354bc Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 22:13:27 +0000 Subject: [PATCH 074/205] change to v4 checkout --- .github/workflows/asan.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index e98e0be04e67..38ef63ddacdb 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -25,7 +25,7 @@ jobs: run: shell: bash -l {0} steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@v4 # v4.2.2 with: path: jax - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 From 9625f1c2223a6bf2234ec955ebe7e14d94d5e49f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 22:20:37 +0000 Subject: [PATCH 075/205] change the fetch depth --- .github/workflows/asan.yaml | 5 +---- .github/workflows/ci_duplicate.yml | 4 ---- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 38ef63ddacdb..2d855c04183e 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -1,9 +1,5 @@ name: CI - Address Sanitizer (nightly) -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - #cancel-in-progress: true - on: schedule: - cron: "0 12 * * *" # Daily at 12:00 UTC @@ -27,6 +23,7 @@ jobs: steps: - uses: actions/checkout@v4 # v4.2.2 with: + fetch-depth: 0 path: jax - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: diff --git a/.github/workflows/ci_duplicate.yml b/.github/workflows/ci_duplicate.yml index a052da737cc6..3015408c8b5f 100644 --- a/.github/workflows/ci_duplicate.yml +++ b/.github/workflows/ci_duplicate.yml @@ -9,10 +9,6 @@ permissions: contents: read # to fetch code actions: write # to cancel previous workflows -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - jobs: lint_and_typecheck: runs-on: linux-x86-n2-16 From 93b724a5c1f3e076f4310cfa9860f02930f01f9a Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 22:31:34 +0000 Subject: [PATCH 076/205] debug --- .github/workflows/asan.yaml | 5 +---- .github/workflows/ci_duplicate.yml | 9 ++------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 2d855c04183e..60882b8f965a 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -21,10 +21,7 @@ jobs: run: shell: bash -l {0} steps: - - uses: actions/checkout@v4 # v4.2.2 - with: - fetch-depth: 0 - path: jax + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: python/cpython diff --git a/.github/workflows/ci_duplicate.yml b/.github/workflows/ci_duplicate.yml index 3015408c8b5f..4fba0fd0d435 100644 --- a/.github/workflows/ci_duplicate.yml +++ b/.github/workflows/ci_duplicate.yml @@ -17,15 +17,10 @@ jobs: timeout-minutes: 5 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install dependencies for setting up Python - env: - DEBIAN_FRONTEND: noninteractive - run: | - apt update && apt install libssl-dev libsqlite3-dev git -y - name: Set up Python 3.11 - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + uses: actions/setup-python@v5 with: - python-version: 3.11 + python-version: "3.11" - run: python -m pip install pre-commit - name: Add to PATH run: export PATH=$PATH:/__w/_tool/Python/3.11.10/x64/bin/ From 69d963eae7f2b8d27038e792074ce46741649179 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 22:32:29 +0000 Subject: [PATCH 077/205] disable workflows --- .github/workflows/bazel_cpu_rbe.yml | 6 +++--- .github/workflows/bazel_gpu_non_rbe.yml | 6 +++--- .github/workflows/bazel_gpu_rbe.yml | 6 +++--- .github/workflows/pytest_cpu.yml | 6 +++--- .github/workflows/pytest_gpu.yml | 6 +++--- .github/workflows/pytest_tpu.yml | 6 +++--- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 2b7648745c5c..3186997d40cd 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -1,9 +1,9 @@ name: CI - Bazel CPU tests (RBE) on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index ce432ffadb6a..89b69483626b 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -1,9 +1,9 @@ name: Run Bazel GPU tests (non RBE) on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml index b8bc20e24e97..8c14c4eab7da 100644 --- a/.github/workflows/bazel_gpu_rbe.yml +++ b/.github/workflows/bazel_gpu_rbe.yml @@ -1,9 +1,9 @@ name: CI - Bazel GPU tests (RBE) on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index d730c50d1008..2dd5e0ed2f21 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -1,9 +1,9 @@ name: Run Pytest CPU tests on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index 3e909e4d380b..15f82fd97c8e 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -1,9 +1,9 @@ name: Run Pytest GPU tests on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 88825d52aaef..d5b4816bc5f8 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -1,9 +1,9 @@ name: Run Pytest TPU tests on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: From 3a885465f9cd1cfc4a2b02dc5cc27af73a40e707 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 22:41:44 +0000 Subject: [PATCH 078/205] use ml build image --- .github/workflows/asan.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 60882b8f965a..4280a78b06f3 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -14,7 +14,7 @@ jobs: asan: runs-on: linux-x86-n2-64 container: - image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 + image: us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest # ratchet:ubuntu:24.04 strategy: fail-fast: false defaults: @@ -22,6 +22,8 @@ jobs: shell: bash -l {0} steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: jax - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: python/cpython From dc3070a4eb7b62572c1ee87babbbbec2d1fe03d3 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 22:42:28 +0000 Subject: [PATCH 079/205] add deps --- .github/workflows/ci_duplicate.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/ci_duplicate.yml b/.github/workflows/ci_duplicate.yml index 4fba0fd0d435..60d74fde376b 100644 --- a/.github/workflows/ci_duplicate.yml +++ b/.github/workflows/ci_duplicate.yml @@ -17,6 +17,11 @@ jobs: timeout-minutes: 5 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Install dependencies for setting up Python + env: + DEBIAN_FRONTEND: noninteractive + run: | + apt update && apt install libssl-dev libsqlite3-dev git -y - name: Set up Python 3.11 uses: actions/setup-python@v5 with: From 643996ab1713ca8afce9b442b6607f24e865eb1f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 22:45:34 +0000 Subject: [PATCH 080/205] install git before actions/checkout --- .github/workflows/asan.yaml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 4280a78b06f3..aefc37fd10bb 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -14,21 +14,13 @@ jobs: asan: runs-on: linux-x86-n2-64 container: - image: us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest # ratchet:ubuntu:24.04 + image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 strategy: fail-fast: false defaults: run: shell: bash -l {0} steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - path: jax - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: python/cpython - path: cpython - ref: v3.13.0 - name: Install clang 18 env: DEBIAN_FRONTEND: noninteractive @@ -38,6 +30,14 @@ jobs: zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ libffi-dev liblzma-dev + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: jax + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: python/cpython + path: cpython + ref: v3.13.0 # - name: Build CPython with ASAN enabled # env: # ASAN_OPTIONS: detect_leaks=0 From 1b81bc380f86d0ce590c11aaa46c202e2b0679d2 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 23:08:41 +0000 Subject: [PATCH 081/205] revert other changes --- .github/workflows/asan.yaml | 43 +++++++++++++++---------------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index aefc37fd10bb..8b4451c4c417 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -38,32 +38,23 @@ jobs: repository: python/cpython path: cpython ref: v3.13.0 - # - name: Build CPython with ASAN enabled - # env: - # ASAN_OPTIONS: detect_leaks=0 - # run: | - # cd cpython - # mkdir ${GITHUB_WORKSPACE}/cpythonasan - # CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpythonasan --with-address-sanitizer --without-pymalloc - # make -j64 - # make install - # ${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv - # - name: Install JAX test requirements - # env: - # ASAN_OPTIONS: detect_leaks=0 - # run: | - # source ${GITHUB_WORKSPACE}/venv/bin/activate - # cd jax - # pip install -r build/test-requirements.txt - - name: Set up Python 3.11 - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - name: Add to PATH - run: export PATH=$PATH:/__w/_tool/Python/3.11.10/x64/bin/ - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + - name: Build CPython with ASAN enabled + env: + ASAN_OPTIONS: detect_leaks=0 + run: | + cd cpython + mkdir ${GITHUB_WORKSPACE}/cpythonasan + CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpythonasan --with-address-sanitizer --without-pymalloc + make -j64 + make install + ${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv + - name: Install JAX test requirements + env: + ASAN_OPTIONS: detect_leaks=0 + run: | + source ${GITHUB_WORKSPACE}/venv/bin/activate + cd jax + pip install -r build/test-requirements.txt - name: Build and install JAX env: ASAN_OPTIONS: detect_leaks=0 From d0a2664c8c0fa785f81cba45b44ade3ea048507c Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 23:13:53 +0000 Subject: [PATCH 082/205] try a fix for path --- .github/workflows/ci_duplicate.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_duplicate.yml b/.github/workflows/ci_duplicate.yml index 60d74fde376b..bc8329f14ecf 100644 --- a/.github/workflows/ci_duplicate.yml +++ b/.github/workflows/ci_duplicate.yml @@ -28,7 +28,7 @@ jobs: python-version: "3.11" - run: python -m pip install pre-commit - name: Add to PATH - run: export PATH=$PATH:/__w/_tool/Python/3.11.10/x64/bin/ + run: export PATH=$PATH:/__w/_tool/Python/3.11.10/x64/bin # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main From 204c7cf641bffa6d69819092c8ea2296e5889320 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 23:21:18 +0000 Subject: [PATCH 083/205] try a fix for path --- .github/workflows/ci_duplicate.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci_duplicate.yml b/.github/workflows/ci_duplicate.yml index bc8329f14ecf..9b704470c019 100644 --- a/.github/workflows/ci_duplicate.yml +++ b/.github/workflows/ci_duplicate.yml @@ -29,8 +29,12 @@ jobs: - run: python -m pip install pre-commit - name: Add to PATH run: export PATH=$PATH:/__w/_tool/Python/3.11.10/x64/bin + - name: echo PATH + run: echo $PATH # Halt for testing - name: Wait For Connection + env: + PATH: '$PATH:/__w/_tool/Python/3.11.10/x64/bin' uses: google-ml-infra/actions/ci_connection@main - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: From 630a1dfe6f2b6a8c697ca9d375a5ea8d8b9be41a Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 23:25:49 +0000 Subject: [PATCH 084/205] try a fix for path --- .github/workflows/ci_duplicate.yml | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci_duplicate.yml b/.github/workflows/ci_duplicate.yml index 9b704470c019..aca9d0aecce2 100644 --- a/.github/workflows/ci_duplicate.yml +++ b/.github/workflows/ci_duplicate.yml @@ -16,12 +16,12 @@ jobs: image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 timeout-minutes: 5 steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install dependencies for setting up Python env: - DEBIAN_FRONTEND: noninteractive + DEBIAN_FRONTEND: noninteractive run: | - apt update && apt install libssl-dev libsqlite3-dev git -y + apt update && apt install libssl-dev libsqlite3-dev git -y + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 uses: actions/setup-python@v5 with: @@ -29,12 +29,10 @@ jobs: - run: python -m pip install pre-commit - name: Add to PATH run: export PATH=$PATH:/__w/_tool/Python/3.11.10/x64/bin - - name: echo PATH - run: echo $PATH + - name: Create python3 symlink + run: ln -s /__w/_tool/Python/3.11.10/x64/bin/python3 /usr/local/bin/python3 # Halt for testing - name: Wait For Connection - env: - PATH: '$PATH:/__w/_tool/Python/3.11.10/x64/bin' uses: google-ml-infra/actions/ci_connection@main - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: From 68bf70f67467afe4c2514d3e1642ee4a502402cf Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 00:23:55 +0000 Subject: [PATCH 085/205] try a fix for path --- .github/workflows/asan.yaml | 16 ++++++++-------- .github/workflows/ci_duplicate.yml | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 8b4451c4c417..88202277eb32 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -1,14 +1,14 @@ name: CI - Address Sanitizer (nightly) on: - schedule: - - cron: "0 12 * * *" # Daily at 12:00 UTC - workflow_dispatch: # allows triggering the workflow run manually - pull_request: # Automatically trigger on pull requests affecting this file - branches: - - main - paths: - - '**/workflows/asan.yaml' +# schedule: +# - cron: "0 12 * * *" # Daily at 12:00 UTC +# workflow_dispatch: # allows triggering the workflow run manually +# pull_request: # Automatically trigger on pull requests affecting this file +# branches: +# - main +# paths: +# - '**/workflows/asan.yaml' jobs: asan: diff --git a/.github/workflows/ci_duplicate.yml b/.github/workflows/ci_duplicate.yml index aca9d0aecce2..87c8992cff7e 100644 --- a/.github/workflows/ci_duplicate.yml +++ b/.github/workflows/ci_duplicate.yml @@ -29,8 +29,8 @@ jobs: - run: python -m pip install pre-commit - name: Add to PATH run: export PATH=$PATH:/__w/_tool/Python/3.11.10/x64/bin - - name: Create python3 symlink - run: ln -s /__w/_tool/Python/3.11.10/x64/bin/python3 /usr/local/bin/python3 + - name: Set python3 alias + run: echo "alias python3=/__w/_tool/Python/3.11.10/x64/bin/python3" >> ~/.bashrc # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main From 53b7d02ff3a2bcd3b417e51f840d502f4f1cf333 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 00:27:27 +0000 Subject: [PATCH 086/205] try a fix for path --- .github/workflows/ci_duplicate.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci_duplicate.yml b/.github/workflows/ci_duplicate.yml index 87c8992cff7e..e808753fcfec 100644 --- a/.github/workflows/ci_duplicate.yml +++ b/.github/workflows/ci_duplicate.yml @@ -13,7 +13,7 @@ jobs: lint_and_typecheck: runs-on: linux-x86-n2-16 container: - image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 + image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" # ratchet:ubuntu:20.04 timeout-minutes: 5 steps: - name: Install dependencies for setting up Python @@ -27,10 +27,10 @@ jobs: with: python-version: "3.11" - run: python -m pip install pre-commit - - name: Add to PATH - run: export PATH=$PATH:/__w/_tool/Python/3.11.10/x64/bin - - name: Set python3 alias - run: echo "alias python3=/__w/_tool/Python/3.11.10/x64/bin/python3" >> ~/.bashrc + # - name: Add to PATH + # run: export PATH=$PATH:/__w/_tool/Python/3.11.10/x64/bin + # - name: Set python3 alias + # run: echo "alias python3=/__w/_tool/Python/3.11.10/x64/bin/python3" >> ~/.bashrc # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main From 2afda3dc2c01677c055196b33149a3841c6aa91d Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 03:29:38 +0000 Subject: [PATCH 087/205] update workflows after sync to upstream --- .github/workflows/asan.yaml | 3 +++ .github/workflows/bazel_cpu_rbe.yml | 25 ++++++++++++++----------- .github/workflows/bazel_gpu_non_rbe.yml | 9 +++++++-- .github/workflows/bazel_gpu_rbe.yml | 12 ++++++++---- .github/workflows/build_artifacts.yml | 1 + 5 files changed, 33 insertions(+), 17 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index c43279fcfdfb..b1b1cc8d5bdc 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -50,6 +50,9 @@ jobs: make -j64 make install ${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main - name: Install JAX test requirements env: ASAN_OPTIONS: detect_leaks=0 diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 3186997d40cd..34b3a057aa0e 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -1,9 +1,9 @@ name: CI - Bazel CPU tests (RBE) on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -29,21 +29,24 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] + enable-x_64: [1, 0] runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(matrix.runner, 'windows-x86') && null) }} - env: JAXCI_HERMETIC_PYTHON_VERSION: "3.12" + JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} + name: "Run Bazel CPU tests (x64=${{ matrix.enable-x_64 }}" + steps: - - uses: actions/checkout@v3 - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Run Bazel CPU Tests with RBE - run: ./ci/run_bazel_test_cpu_rbe.sh + - uses: actions/checkout@v3 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel CPU Tests with RBE + run: ./ci/run_bazel_test_cpu_rbe.sh diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 89b69483626b..c5595959adab 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -32,13 +32,18 @@ jobs: upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' run_bazel_tests: - name: "Run Bazel GPU tests (single accelerator and multi-accelerator tests, non-RBE)" needs: build_artifacts runs-on: "linux-x86-g2-48-l4-4gpu" container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" - + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + enable-x_64: [1, 0] env: JAXCI_HERMETIC_PYTHON_VERSION: 3.11 + JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} + + name: "Run Bazel GPU tests (single accelerator and multi-accelerator tests, non-RBE), x64=${{ matrix.enable-x_64 }}" steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml index 8c14c4eab7da..91d2d65b40d0 100644 --- a/.github/workflows/bazel_gpu_rbe.yml +++ b/.github/workflows/bazel_gpu_rbe.yml @@ -1,9 +1,9 @@ name: CI - Bazel GPU tests (RBE) on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -24,13 +24,17 @@ jobs: strategy: matrix: runner: ["linux-x86-n2-16"] + enable-x_64: [1, 0] runs-on: ${{ matrix.runner }} container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest' env: JAXCI_HERMETIC_PYTHON_VERSION: "3.12" - + JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} + + name: "Run Bazel GPU tests (x64=${{ matrix.enable-x_64 }}" + steps: - uses: actions/checkout@v3 - name: Wait For Connection diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index d891e1ca7e73..1a7b8cf2bab9 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -159,6 +159,7 @@ jobs: env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" + JAXCI_BUILD_ARTIFACT_WITH_RBE: 1 JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" steps: From 6a528990ae1e916c6bf832674659448b6bcef138 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 03:38:08 +0000 Subject: [PATCH 088/205] fix workflow names --- .github/workflows/asan.yaml | 76 ++++++++++++------------- .github/workflows/bazel_cpu_rbe.yml | 2 +- .github/workflows/bazel_gpu_non_rbe.yml | 2 +- .github/workflows/bazel_gpu_rbe.yml | 2 +- 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index b1b1cc8d5bdc..1a9520a52161 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -1,55 +1,55 @@ name: CI - Address Sanitizer (nightly) on: -# schedule: -# - cron: "0 12 * * *" # Daily at 12:00 UTC -# workflow_dispatch: # allows triggering the workflow run manually -# pull_request: # Automatically trigger on pull requests affecting this file -# branches: -# - main -# paths: -# - '**/workflows/asan.yaml' + schedule: + - cron: "0 12 * * *" # Daily at 12:00 UTC + workflow_dispatch: # allows triggering the workflow run manually + pull_request: # Automatically trigger on pull requests affecting this file + branches: + - main + paths: + - '**/workflows/asan.yaml' jobs: asan: runs-on: linux-x86-n2-64 container: - image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 + image: us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest # ratchet:ubuntu:24.04 strategy: fail-fast: false defaults: run: shell: bash -l {0} steps: - # Install git before actions/checkout as otherwise it will download the code with the GitHub - # REST API and therefore any subsequent git commands will fail. - - name: Install clang 18 - env: - DEBIAN_FRONTEND: noninteractive - run: | - apt update - apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ - zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ - libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ - libffi-dev liblzma-dev - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - path: jax - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: python/cpython - path: cpython - ref: v3.13.0 - - name: Build CPython with ASAN enabled - env: - ASAN_OPTIONS: detect_leaks=0 - run: | - cd cpython - mkdir ${GITHUB_WORKSPACE}/cpythonasan - CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpythonasan --with-address-sanitizer --without-pymalloc - make -j64 - make install - ${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv + # # Install git before actions/checkout as otherwise it will download the code with the GitHub + # # REST API and therefore any subsequent git commands will fail. + # - name: Install clang 18 + # env: + # DEBIAN_FRONTEND: noninteractive + # run: | + # apt update + # apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ + # zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ + # libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ + # libffi-dev liblzma-dev + # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # with: + # path: jax + # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # with: + # repository: python/cpython + # path: cpython + # ref: v3.13.0 + # - name: Build CPython with ASAN enabled + # env: + # ASAN_OPTIONS: detect_leaks=0 + # run: | + # cd cpython + # mkdir ${GITHUB_WORKSPACE}/cpythonasan + # CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpythonasan --with-address-sanitizer --without-pymalloc + # make -j64 + # make install + # ${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 34b3a057aa0e..429e9fcaa987 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -40,7 +40,7 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "3.12" JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} - name: "Run Bazel CPU tests (x64=${{ matrix.enable-x_64 }}" + name: "Bazel CPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})" steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index c5595959adab..70970e977096 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -43,7 +43,7 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: 3.11 JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} - name: "Run Bazel GPU tests (single accelerator and multi-accelerator tests, non-RBE), x64=${{ matrix.enable-x_64 }}" + name: "Bazel single accelerator and multi-accelerator GPU tests (Non RBE, linux-x86-g2-48-l4-4gpu, Python 3.11, x64=${{ matrix.enable-x_64 }})" steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml index 91d2d65b40d0..b0a8148d9484 100644 --- a/.github/workflows/bazel_gpu_rbe.yml +++ b/.github/workflows/bazel_gpu_rbe.yml @@ -33,7 +33,7 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "3.12" JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} - name: "Run Bazel GPU tests (x64=${{ matrix.enable-x_64 }}" + name: "Bazel single accelerator GPU tests (RBE, ${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})" steps: - uses: actions/checkout@v3 From 1611e3055253ca36a72f6ed9e2d2cbece8b68392 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 17:47:57 +0000 Subject: [PATCH 089/205] sync to upstream --- .github/workflows/asan.yaml | 61 ++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 1a9520a52161..5adff2c5186e 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -21,38 +21,35 @@ jobs: run: shell: bash -l {0} steps: - # # Install git before actions/checkout as otherwise it will download the code with the GitHub - # # REST API and therefore any subsequent git commands will fail. - # - name: Install clang 18 - # env: - # DEBIAN_FRONTEND: noninteractive - # run: | - # apt update - # apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ - # zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ - # libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ - # libffi-dev liblzma-dev - # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - # with: - # path: jax - # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - # with: - # repository: python/cpython - # path: cpython - # ref: v3.13.0 - # - name: Build CPython with ASAN enabled - # env: - # ASAN_OPTIONS: detect_leaks=0 - # run: | - # cd cpython - # mkdir ${GITHUB_WORKSPACE}/cpythonasan - # CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpythonasan --with-address-sanitizer --without-pymalloc - # make -j64 - # make install - # ${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + # Install git before actions/checkout as otherwise it will download the code with the GitHub + # REST API and therefore any subsequent git commands will fail. + - name: Install clang 18 + env: + DEBIAN_FRONTEND: noninteractive + run: | + apt update + apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ + zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ + libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ + libffi-dev liblzma-dev + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: jax + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: python/cpython + path: cpython + ref: v3.13.0 + - name: Build CPython with ASAN enabled + env: + ASAN_OPTIONS: detect_leaks=0 + run: | + cd cpython + mkdir ${GITHUB_WORKSPACE}/cpythonasan + CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpythonasan --with-address-sanitizer --without-pymalloc + make -j64 + make install + ${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv - name: Install JAX test requirements env: ASAN_OPTIONS: detect_leaks=0 From 5788853bea579a95f12d48ff3ae3ff30299a2fea Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 18:32:54 +0000 Subject: [PATCH 090/205] fix rbe setting and change msys conversion script --- .github/workflows/build_artifacts.yml | 2 +- .../convert_msys_paths_to_win_paths.py | 28 ++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 1a7b8cf2bab9..521ae48ca373 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -159,7 +159,6 @@ jobs: env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" - JAXCI_BUILD_ARTIFACT_WITH_RBE: 1 JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" steps: @@ -169,6 +168,7 @@ jobs: os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) + # Enable RBE if building on Linux x86 or Windows x86 if [[ ($os == "linux" || $os =~ "msys_nt" ) && $arch == "x86_64" ]]; then echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV fi diff --git a/ci/utilities/convert_msys_paths_to_win_paths.py b/ci/utilities/convert_msys_paths_to_win_paths.py index 78f551fff029..a09a78e1f8bb 100644 --- a/ci/utilities/convert_msys_paths_to_win_paths.py +++ b/ci/utilities/convert_msys_paths_to_win_paths.py @@ -43,15 +43,22 @@ def msys_to_windows_path(msys_path): print(f"Error converting path: {e}") return None +def should_convert(var_name: str, + exclude: list[str] | None): + """Check the variable name against allow/deny lists.""" + if exclude and var_name in exclude: + return False + return True + def main(parsed_args: argparse.Namespace): converted_paths = {} for var, value in os.environ.items(): - if (parsed_args.blacklist and var in parsed_args.blacklist) or not value: + if not value or not should_convert(var, + parsed_args.exclude): continue - if "_DIR" in var or (args.whitelist and var in parsed_args.whitelist): - converted_path = msys_to_windows_path(value) - converted_paths[var] = converted_path + converted_path = msys_to_windows_path(value) + converted_paths[var] = converted_path var_str = '\n'.join(f'export {k}="{v}"' for k, v in converted_paths.items()) @@ -63,12 +70,13 @@ def main(parsed_args: argparse.Namespace): if __name__ == '__main__': parser = argparse.ArgumentParser(description=( 'Convert MSYS paths in environment variables to Windows paths.')) - parser.add_argument('--blacklist', - nargs='*', - help='List of variables to ignore') - parser.add_argument('--whitelist', - nargs='*', - help='List of variables to include') + parser.add_argument('--convert', + nargs='+', + required=True, + help='List of variables to convert') + parser.add_argument('--exclude', + nargs='*', + help='Optional list of variables to exclude') args = parser.parse_args() main(args) From 2135bb7ae46c0d0f5aaaa15148d3395f3e4172c3 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 18:34:22 +0000 Subject: [PATCH 091/205] change msys conversion script call to match its new usage --- ci/utilities/setup_build_environment.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index b4e8c5c6114e..fbf4a8a5d116 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -74,5 +74,5 @@ fi if [[ $(uname -s) =~ "MSYS_NT" ]]; then echo 'Converting MSYS Linux-like paths to Windows paths (for Docker, Python, etc.)' # Convert all "_DIR" variables to Windows paths. - source <(python3 ./ci/utilities/convert_msys_paths_to_win_paths.py) + source <(python3 ./ci/utilities/convert_msys_paths_to_win_paths.py --convert $(env | grep "JAXCI.*DIR")) fi \ No newline at end of file From 6328485352d5ca43d0274b4cfb37aafb2d49fd7f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 18:34:58 +0000 Subject: [PATCH 092/205] fix comment --- ci/utilities/convert_msys_paths_to_win_paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/utilities/convert_msys_paths_to_win_paths.py b/ci/utilities/convert_msys_paths_to_win_paths.py index a09a78e1f8bb..7695567930d6 100644 --- a/ci/utilities/convert_msys_paths_to_win_paths.py +++ b/ci/utilities/convert_msys_paths_to_win_paths.py @@ -45,7 +45,7 @@ def msys_to_windows_path(msys_path): def should_convert(var_name: str, exclude: list[str] | None): - """Check the variable name against allow/deny lists.""" + """Check the variable name against exclude list""" if exclude and var_name in exclude: return False return True From 140173d45a42dc650a639622438084c247d270ec Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 18:49:19 +0000 Subject: [PATCH 093/205] fix conversion script --- ci/utilities/convert_msys_paths_to_win_paths.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ci/utilities/convert_msys_paths_to_win_paths.py b/ci/utilities/convert_msys_paths_to_win_paths.py index 7695567930d6..35b610cef95c 100644 --- a/ci/utilities/convert_msys_paths_to_win_paths.py +++ b/ci/utilities/convert_msys_paths_to_win_paths.py @@ -48,12 +48,12 @@ def should_convert(var_name: str, """Check the variable name against exclude list""" if exclude and var_name in exclude: return False - return True def main(parsed_args: argparse.Namespace): converted_paths = {} - for var, value in os.environ.items(): + for var_value in parsed_args.convert: + var, value = var_value.split("=") if not value or not should_convert(var, parsed_args.exclude): continue @@ -73,7 +73,7 @@ def main(parsed_args: argparse.Namespace): parser.add_argument('--convert', nargs='+', required=True, - help='List of variables to convert') + help='Space separated list of variables and values to convert. E.g: --convert foo=/path/to/bar') parser.add_argument('--exclude', nargs='*', help='Optional list of variables to exclude') From e49d60d58ce4d4c0503c38a56aa6dbf7644cbdc3 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 18:55:25 +0000 Subject: [PATCH 094/205] fix conversion script --- ci/utilities/convert_msys_paths_to_win_paths.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/ci/utilities/convert_msys_paths_to_win_paths.py b/ci/utilities/convert_msys_paths_to_win_paths.py index 35b610cef95c..a52bf63e027a 100644 --- a/ci/utilities/convert_msys_paths_to_win_paths.py +++ b/ci/utilities/convert_msys_paths_to_win_paths.py @@ -43,18 +43,24 @@ def msys_to_windows_path(msys_path): print(f"Error converting path: {e}") return None -def should_convert(var_name: str, +def should_convert(var: str, + convert: list[str] | None, exclude: list[str] | None): """Check the variable name against exclude list""" - if exclude and var_name in exclude: + if exclude and var in exclude: return False + + if var in convert: + return True + + return False def main(parsed_args: argparse.Namespace): converted_paths = {} - for var_value in parsed_args.convert: - var, value = var_value.split("=") + for var, value in os.environ.items(): if not value or not should_convert(var, + parsed_args.convert, parsed_args.exclude): continue converted_path = msys_to_windows_path(value) From 2a50a63643e840ea0f47b7736174729fbfa74f5c Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 18:55:53 +0000 Subject: [PATCH 095/205] fix comment --- ci/utilities/convert_msys_paths_to_win_paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/utilities/convert_msys_paths_to_win_paths.py b/ci/utilities/convert_msys_paths_to_win_paths.py index a52bf63e027a..edc83a175179 100644 --- a/ci/utilities/convert_msys_paths_to_win_paths.py +++ b/ci/utilities/convert_msys_paths_to_win_paths.py @@ -46,7 +46,7 @@ def msys_to_windows_path(msys_path): def should_convert(var: str, convert: list[str] | None, exclude: list[str] | None): - """Check the variable name against exclude list""" + """Check the variable name against convert/exclude list""" if exclude and var in exclude: return False From 645d3ccbeb11dbab1fe6d44132bb4c89f84720e3 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 22:03:05 +0000 Subject: [PATCH 096/205] fix conversion script --- ci/utilities/convert_msys_paths_to_win_paths.py | 4 ++-- ci/utilities/setup_build_environment.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/utilities/convert_msys_paths_to_win_paths.py b/ci/utilities/convert_msys_paths_to_win_paths.py index edc83a175179..bee5c3d8d9cc 100644 --- a/ci/utilities/convert_msys_paths_to_win_paths.py +++ b/ci/utilities/convert_msys_paths_to_win_paths.py @@ -79,10 +79,10 @@ def main(parsed_args: argparse.Namespace): parser.add_argument('--convert', nargs='+', required=True, - help='Space separated list of variables and values to convert. E.g: --convert foo=/path/to/bar') + help='Space separated list of environment variables to convert. E.g: --convert env_var1 env_var2') parser.add_argument('--exclude', nargs='*', - help='Optional list of variables to exclude') + help='Optional space separated list of environment variables to exclude') args = parser.parse_args() main(args) diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index fbf4a8a5d116..6a761506e392 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -73,6 +73,6 @@ fi # For Windows, convert MSYS Linux-like paths to Windows paths. if [[ $(uname -s) =~ "MSYS_NT" ]]; then echo 'Converting MSYS Linux-like paths to Windows paths (for Docker, Python, etc.)' - # Convert all "_DIR" variables to Windows paths. - source <(python3 ./ci/utilities/convert_msys_paths_to_win_paths.py --convert $(env | grep "JAXCI.*DIR")) + # Convert all "JAXCI.*DIR" variables + source <(python3 ./ci/utilities/convert_msys_paths_to_win_paths.py --convert $(env | grep "JAXCI.*DIR" | awk -F= '{print $1}')) fi \ No newline at end of file From fdb6ef9e9aec693e499555278e7c7a1310d82634 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 23:00:20 +0000 Subject: [PATCH 097/205] update msys conversion script --- .../convert_msys_paths_to_win_paths.py | 20 ++++++------------- ci/utilities/setup_build_environment.sh | 2 +- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/ci/utilities/convert_msys_paths_to_win_paths.py b/ci/utilities/convert_msys_paths_to_win_paths.py index bee5c3d8d9cc..6164e6a5e29d 100644 --- a/ci/utilities/convert_msys_paths_to_win_paths.py +++ b/ci/utilities/convert_msys_paths_to_win_paths.py @@ -16,7 +16,7 @@ Converts MSYS Linux-like paths stored in env variables to Windows paths. This is necessary on Windows, because some applications do not understand/handle -Linux-like paths MSYS uses, for example, Docker. +Linux-like paths MSYS uses, for example, Bazel. """ import argparse import os @@ -44,24 +44,19 @@ def msys_to_windows_path(msys_path): return None def should_convert(var: str, - convert: list[str] | None, - exclude: list[str] | None): - """Check the variable name against convert/exclude list""" - if exclude and var in exclude: - return False - + convert: list[str] | None): + """Check the variable name against convert list""" if var in convert: return True - - return False + else: + return False def main(parsed_args: argparse.Namespace): converted_paths = {} for var, value in os.environ.items(): if not value or not should_convert(var, - parsed_args.convert, - parsed_args.exclude): + parsed_args.convert): continue converted_path = msys_to_windows_path(value) converted_paths[var] = converted_path @@ -80,9 +75,6 @@ def main(parsed_args: argparse.Namespace): nargs='+', required=True, help='Space separated list of environment variables to convert. E.g: --convert env_var1 env_var2') - parser.add_argument('--exclude', - nargs='*', - help='Optional space separated list of environment variables to exclude') args = parser.parse_args() main(args) diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index 6a761506e392..727732536d9d 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -72,7 +72,7 @@ fi # For Windows, convert MSYS Linux-like paths to Windows paths. if [[ $(uname -s) =~ "MSYS_NT" ]]; then - echo 'Converting MSYS Linux-like paths to Windows paths (for Docker, Python, etc.)' + echo 'Converting MSYS Linux-like paths to Windows paths (for Bazel, Python, etc.)' # Convert all "JAXCI.*DIR" variables source <(python3 ./ci/utilities/convert_msys_paths_to_win_paths.py --convert $(env | grep "JAXCI.*DIR" | awk -F= '{print $1}')) fi \ No newline at end of file From c9b5bc3472348ea80d38ede9c65ed450a8df43be Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 28 Nov 2024 16:31:19 +0000 Subject: [PATCH 098/205] Change to ml build container --- .github/workflows/bazel_gpu_non_rbe.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 70970e977096..180d6eedc8f4 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -34,7 +34,7 @@ jobs: run_bazel_tests: needs: build_artifacts runs-on: "linux-x86-g2-48-l4-4gpu" - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" strategy: fail-fast: false # don't cancel all jobs on failure matrix: From d47d72d2b7b11b2415751432ea56a46cbd3e619d Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 28 Nov 2024 17:36:04 +0000 Subject: [PATCH 099/205] update docker scripts --- ci/envs/default.env | 10 +------- ci/envs/docker.env | 35 +++++++++++++++++++++------- ci/utilities/run_docker_container.sh | 27 +++++++++++++++------ 3 files changed, 47 insertions(+), 25 deletions(-) diff --git a/ci/envs/default.env b/ci/envs/default.env index 4e2f734352f2..c27065934cc5 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -66,12 +66,4 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} # JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels # on the system. By default, it is set to match the version of the hermetic # Python used by Bazel for building the wheels. -export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} - -# ############################################################################# -# Docker specific environment variables. -# ############################################################################# -# Docker specifc environment variables. Used by `run_docker_container.sh` -export JAXCI_DOCKER_WORK_DIR="/jax" -export JAXCI_DOCKER_IMAGE="" -export JAXCI_DOCKER_ARGS="" \ No newline at end of file +export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} \ No newline at end of file diff --git a/ci/envs/docker.env b/ci/envs/docker.env index 3832c095c85d..c943d9e7f479 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -15,24 +15,41 @@ # This file contains all the docker specifc envs that are needed by the # ci/utilities/run_docker_container.sh script. -# Inherit default JAXCI environment variables. -source ci/envs/default.env - os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) -# TODO: Set GPU Docker args and GPU Docker images -# Linux x86 specifc settings +# The path to the JAX git repository. +export JAXCI_JAX_GIT_DIR=$(pwd) + +export JAXCI_DOCKER_WORK_DIR="/jax" +export JAXCI_DOCKER_ARGS="" + +# Linux x86 image for building JAX artifacts, running Pytests CPU/TPU tests, and Bazel tests if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" fi -# Linux Aarch64 specifc settings +# Linux Aarch64 image for building JAX artifacts, running Pytests CPU tests, and Bazel tests if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest" fi -# Windows specific settings +# Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel tests if [[ $os =~ "msys_nt" ]]; then - export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc" -fi \ No newline at end of file + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows:latest" +fi + +# Uncomment the following lines if you want to run the GPU tests with Pytest. +# Note that GPU Pytests, as a prequisite, require that the following JAX artifacts be +# present in the $JAXCI_OUTPUT_DIR: jaxlib, jax-cuda-plugin, jax-cuda-pjrt. If you don't +# have these wheels stored there, either build them from source via ci/build_artifacts.sh or +# download them from PyPI into that folder. +# +# Linux x86 image for running Pytest GPU tests +# if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then +# # Choose one of: 12.3, 12.1 +# export JAXCI_DOCKER_CUDA_VERSION=${JAXCI_DOCKER_CUDA_VERSION:-12.3} +# export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda${JAXCI_DOCKER_CUDA_VERSION}-cudnn9.1-ubuntu20.04-manylinux2014-multipython" +# +# export JAXCI_DOCKER_ARGS="--gpus all --shm-size=16g" +# fi \ No newline at end of file diff --git a/ci/utilities/run_docker_container.sh b/ci/utilities/run_docker_container.sh index 1cc3199bd5fd..d785103312f4 100755 --- a/ci/utilities/run_docker_container.sh +++ b/ci/utilities/run_docker_container.sh @@ -13,7 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# Set up the Docker container and start it for JAX CI jobs. +# Sets up a Docker container for JAX CI. + +# This script creates and starts a Docker container named "jax" for internal +# JAX CI jobs. These jobs primarily handle building and publishing JAX artifacts +# to PyPI and/or GCS. + +# Note: GitHub Actions workflows do not utilize this script, as they leverage +# built-in containerization features to run jobs within a container. However, +# they use the same Docker image to maintain consistency. This script also helps +# ensure that local build environments mirror the behavior of CI builds. +# Usage: +# source ci/envs/docker.env && ./ci/utilities/run_docker_container.sh +# docker exec -it jax +# E.g: docker exec -it jax ./ci/build_artifacts.sh jaxlib # # -e: abort script if one command fails # -u: error if undefined variable used @@ -46,13 +59,13 @@ if ! docker container inspect jax >/dev/null 2>&1 ; then JAXCI_DOCKER_ARGS="$JAXCI_DOCKER_ARGS -v $HOME/.config/gcloud:/root/.config/gcloud" fi - # Start the container. `user_set_jaxci_envs` is read after `jax_ci_envs` to - # allow the user to override any environment variables set by JAXCI_ENV_FILE. + # Start the container. docker run $JAXCI_DOCKER_ARGS --name jax \ - -w "$JAXCI_DOCKER_WORK_DIR" -itd --rm \ - -v "$JAXCI_JAX_GIT_DIR:$JAXCI_DOCKER_WORK_DIR" \ - "$JAXCI_DOCKER_IMAGE" \ - bash + --env-file <(env | grep JAXCI_) \ + -w "$JAXCI_DOCKER_WORK_DIR" -itd --rm \ + -v "$JAXCI_JAX_GIT_DIR:$JAXCI_DOCKER_WORK_DIR" \ + "$JAXCI_DOCKER_IMAGE" \ + bash if [[ "$(uname -s)" =~ "MSYS_NT" ]]; then # Allow requests from the container. From 0ea43ca9faf253b1f48bca03dc99cc4d6113c3b0 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 28 Nov 2024 17:36:45 +0000 Subject: [PATCH 100/205] Enable workflows --- .github/workflows/bazel_gpu_non_rbe.yml | 6 +++--- .github/workflows/pytest_cpu.yml | 6 +++--- .github/workflows/pytest_gpu.yml | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 180d6eedc8f4..e72a5c66d2ba 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -1,9 +1,9 @@ name: Run Bazel GPU tests (non RBE) on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 2dd5e0ed2f21..d730c50d1008 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -1,9 +1,9 @@ name: Run Pytest CPU tests on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index 15f82fd97c8e..3e909e4d380b 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -1,9 +1,9 @@ name: Run Pytest GPU tests on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: From 1a33626e441b070006e1d3fe3fc6adb23c4549a4 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 28 Nov 2024 18:24:04 +0000 Subject: [PATCH 101/205] fix gsutil download link --- .github/workflows/bazel_gpu_non_rbe.yml | 2 +- .github/workflows/pytest_cpu.yml | 2 +- .github/workflows/pytest_gpu.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index e72a5c66d2ba..63e3aa6df24f 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -64,6 +64,6 @@ jobs: echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Download the artifacts built in the "build_artifacts" job - run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/ + run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - name: Run Bazel GPU tests locally run: ./ci/run_bazel_test_gpu_non_rbe.sh diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index d730c50d1008..1fe2cb826df0 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -72,7 +72,7 @@ jobs: - name: Download the artifacts built in the "build_artifacts" job run: >- mkdir -p $(pwd)/dist && - gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/ + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index 3e909e4d380b..c7b2f3c2b98c 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -72,7 +72,7 @@ jobs: echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Download the artifacts built in the "build_artifacts" job - run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/ + run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} From 2dc0af2500bb6d2726fd391c94341c3a489c92ed Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 28 Nov 2024 19:03:20 +0000 Subject: [PATCH 102/205] add a workflow to test building mulitple artifacts with build CLI --- .../build_artifacts_single_invocation.yml | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 .github/workflows/build_artifacts_single_invocation.yml diff --git a/.github/workflows/build_artifacts_single_invocation.yml b/.github/workflows/build_artifacts_single_invocation.yml new file mode 100644 index 000000000000..28ee0a996e6e --- /dev/null +++ b/.github/workflows/build_artifacts_single_invocation.yml @@ -0,0 +1,35 @@ +name: Build multiple artifacts with build CLI + +on: + pull_request: + branches: + - main + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + +jobs: + build_artifacts: + + runs-on: "linux-x86-n2-16" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: 3.12 + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Build all artifacts with a single invocation of build CLI + run: python build/build.py build --wheels=jax-cuda-plugin,jax-cuda-pjrt,jaxlib --bazel_options=--config=rbe_linux_x86_64_cuda --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file From b308da63cb066c873783237a4e634d673570dc01 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 28 Nov 2024 19:09:25 +0000 Subject: [PATCH 103/205] mark github repo as safe --- .github/workflows/build_artifacts_single_invocation.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/build_artifacts_single_invocation.yml b/.github/workflows/build_artifacts_single_invocation.yml index 28ee0a996e6e..563312172d30 100644 --- a/.github/workflows/build_artifacts_single_invocation.yml +++ b/.github/workflows/build_artifacts_single_invocation.yml @@ -31,5 +31,7 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Mark GitHub repo as safe + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" - name: Build all artifacts with a single invocation of build CLI run: python build/build.py build --wheels=jax-cuda-plugin,jax-cuda-pjrt,jaxlib --bazel_options=--config=rbe_linux_x86_64_cuda --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file From 3c41e5639be253d4b4adbbfdf62c8ed4f17da9d6 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 3 Dec 2024 23:42:42 +0000 Subject: [PATCH 104/205] switch to 16 core windows runner --- .github/workflows/build_artifacts.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 521ae48ca373..2ee3e848f610 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -90,7 +90,7 @@ jobs: if [[ ${is_workflow_call:-"0"} == "0" ]]; then wheels=("'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") python_versions=("'3.10'" ", '3.11'" ", '3.12'", ", '3.13'") - platforms=("'linux-x86-n2-16'" ", 'linux-arm64-t2a-48'" ", 'windows-x86-n2-64'") + platforms=("'linux-x86-n2-16'" ", 'linux-arm64-t2a-48'" ", 'windows-x86-n2-16'") else # Set the Internal Field Separator to be comma IFS=, @@ -112,7 +112,7 @@ jobs: elif [[ $platform == "linux_arm64" ]]; then platforms+="'linux-arm64-t2a-48'," elif [[ $platform == "windows_x86" ]]; then - platforms+="'windows-x86-n2-64'," + platforms+="'windows-x86-n2-16'," else echo "Incorrect platform provided. Valid options are: linux_x86, linux_arm64, windows_x86" exit 1 From 9531ed2acf151e81675ead913ecd12ba7762e121 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 4 Dec 2024 21:47:49 +0000 Subject: [PATCH 105/205] add flag to enable detailed timestamped logging --- build/build.py | 15 ++++++++++++--- build/tools/command.py | 13 +++++++------ ci/build_artifacts.sh | 2 +- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/build/build.py b/build/build.py index 79ec8d05fc9f..a6c1a7922b0e 100755 --- a/build/build.py +++ b/build/build.py @@ -123,6 +123,15 @@ def add_global_arguments(parser: argparse.ArgumentParser): help="Produce verbose output for debugging.", ) + parser.add_argument( + "--detailed_timestamped_log", + action="store_true", + help=""" + Enable detailed logging of the Bazel command with timestamps. The logs + will be stored and can be accessed as artifacts. + """, + ) + def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): """Adds all the arguments that applies to the artifact subcommands.""" @@ -399,7 +408,7 @@ async def main(): else: requirements_command.append("//build:requirements.update") - result = await executor.run(requirements_command.get_command_as_string(), args.dry_run) + result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) if result.return_code != 0: raise RuntimeError(f"Command failed with return code {result.return_code}") else: @@ -597,7 +606,7 @@ async def main(): wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") - result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run) + result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) # Exit with error if any wheel build fails. if result.return_code != 0: raise RuntimeError(f"Command failed with return code {result.return_code}") @@ -607,4 +616,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/build/tools/command.py b/build/tools/command.py index 48a9bfc1c0d6..cc95d7eea4af 100644 --- a/build/tools/command.py +++ b/build/tools/command.py @@ -75,7 +75,7 @@ def __init__(self, environment: Dict[str, str] = None): """ self.environment = environment or dict(os.environ) - async def run(self, cmd: str, dry_run: bool = False) -> CommandResult: + async def run(self, cmd: str, dry_run: bool = False, detailed_timestamped_log: bool = False) -> CommandResult: """ Executes a subprocess command. @@ -96,14 +96,15 @@ async def run(self, cmd: str, dry_run: bool = False) -> CommandResult: process = await asyncio.create_subprocess_shell( cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE if detailed_timestamped_log else None, + stderr=asyncio.subprocess.PIPE if detailed_timestamped_log else None, env=self.environment, ) - await asyncio.gather( - _process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result) - ) + if detailed_timestamped_log: + await asyncio.gather( + _process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result) + ) result.return_code = await process.wait() result.end_time = datetime.datetime.now() diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 9f8d54401691..698de38418b7 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -69,7 +69,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then fi # Build the artifact. - python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose + python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose --detailed_timestamped_log # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. From 58e6611d6d71d32fb13ad1503af099ea6a2537ec Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 4 Dec 2024 21:49:56 +0000 Subject: [PATCH 106/205] test non ci build --- .github/workflows/build_artifacts_single_invocation.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_artifacts_single_invocation.yml b/.github/workflows/build_artifacts_single_invocation.yml index 563312172d30..32f463b4eadf 100644 --- a/.github/workflows/build_artifacts_single_invocation.yml +++ b/.github/workflows/build_artifacts_single_invocation.yml @@ -18,7 +18,7 @@ on: jobs: build_artifacts: - runs-on: "linux-x86-n2-16" + runs-on: "linux-x86-n2-64" container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" env: @@ -34,4 +34,4 @@ jobs: - name: Mark GitHub repo as safe run: git config --global --add safe.directory "$GITHUB_WORKSPACE" - name: Build all artifacts with a single invocation of build CLI - run: python build/build.py build --wheels=jax-cuda-plugin,jax-cuda-pjrt,jaxlib --bazel_options=--config=rbe_linux_x86_64_cuda --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file + run: python build/build.py build --wheels=jax-cuda-plugin,jax-cuda-pjrt,jaxlib --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file From 844afaaf5655cc71455d3409c697c56cc63bc40c Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 4 Dec 2024 21:56:54 +0000 Subject: [PATCH 107/205] revert change --- .github/workflows/build_artifacts_single_invocation.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_artifacts_single_invocation.yml b/.github/workflows/build_artifacts_single_invocation.yml index 32f463b4eadf..563312172d30 100644 --- a/.github/workflows/build_artifacts_single_invocation.yml +++ b/.github/workflows/build_artifacts_single_invocation.yml @@ -18,7 +18,7 @@ on: jobs: build_artifacts: - runs-on: "linux-x86-n2-64" + runs-on: "linux-x86-n2-16" container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" env: @@ -34,4 +34,4 @@ jobs: - name: Mark GitHub repo as safe run: git config --global --add safe.directory "$GITHUB_WORKSPACE" - name: Build all artifacts with a single invocation of build CLI - run: python build/build.py build --wheels=jax-cuda-plugin,jax-cuda-pjrt,jaxlib --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file + run: python build/build.py build --wheels=jax-cuda-plugin,jax-cuda-pjrt,jaxlib --bazel_options=--config=rbe_linux_x86_64_cuda --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file From cf02fd3d0d764d29a3edd3aa91c7ad53d0bd12c8 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 4 Dec 2024 23:10:50 +0000 Subject: [PATCH 108/205] fix exclude filters --- .github/workflows/build_artifacts.yml | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 2ee3e848f610..fcb79a19b061 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -141,15 +141,18 @@ jobs: artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} python: ${{ fromJSON(needs.determine_matrix.outputs.python_matrix) }} exclude: - # jax-cuda-pjrt does not need to be built for every Python but excluding it here for - # every but one Python version causes issues when a workflow call is made to this file - # requesting a build for an exlcuded Python version (see pytest_gpu.yaml) - # - # jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows. - - artifact: "jax-cuda-plugin" - runner: "windows-x86-n2-64" + # Windows doesn't support CUDA artifacts + - runner: "windows-x86-n2-16" + artifact: "jax-cuda-pjrt" + - runner: "windows-x86-n2-16" + artifact: "jax-cuda-plugin" + # cuda-pjrt is a pure Python package, so it doesn't need to be built for all Python versions. - artifact: "jax-cuda-pjrt" - runner: "windows-x86-n2-64" + python: 3.10 + - artifact: "jax-cuda-pjrt" + python: 3.11 + - artifact: "jax-cuda-pjrt" + python: 3.12 runs-on: ${{ matrix.runner }} From 97a2b1e9b12479d1a4382cda2da9fa6bec5ff118 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 5 Dec 2024 18:14:04 +0000 Subject: [PATCH 109/205] replace t2a with c4a runners --- .github/workflows/bazel_cpu_rbe.yml | 2 +- .github/workflows/build_artifacts.yml | 4 ++-- .github/workflows/pytest_cpu.yml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 429e9fcaa987..49dc1d873c14 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -28,7 +28,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] + runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-c4a-16"] enable-x_64: [1, 0] runs-on: ${{ matrix.runner }} diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index fcb79a19b061..eeb516ece92c 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -90,7 +90,7 @@ jobs: if [[ ${is_workflow_call:-"0"} == "0" ]]; then wheels=("'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") python_versions=("'3.10'" ", '3.11'" ", '3.12'", ", '3.13'") - platforms=("'linux-x86-n2-16'" ", 'linux-arm64-t2a-48'" ", 'windows-x86-n2-16'") + platforms=("'linux-x86-n2-16'" ", 'linux-arm64-c4a-64'" ", 'windows-x86-n2-16'") else # Set the Internal Field Separator to be comma IFS=, @@ -110,7 +110,7 @@ jobs: if [[ $platform == "linux_x86" ]]; then platforms+="'linux-x86-n2-16'," elif [[ $platform == "linux_arm64" ]]; then - platforms+="'linux-arm64-t2a-48'," + platforms+="'linux-arm64-c4a-64'," elif [[ $platform == "windows_x86" ]]; then platforms+="'windows-x86-n2-16'," else diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 1fe2cb826df0..54f265223f22 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -40,7 +40,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] + runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] python: ["3.10", "3.11", "3.12", "3.13"] runs-on: ${{ matrix.runner }} From cbf4548ab841fe7170140b1c7681dc572f308e5d Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 5 Dec 2024 19:56:42 +0000 Subject: [PATCH 110/205] test GPU workflow with nosla image --- .github/workflows/bazel_gpu_non_rbe.yml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 63e3aa6df24f..f704bd291787 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -34,16 +34,20 @@ jobs: run_bazel_tests: needs: build_artifacts runs-on: "linux-x86-g2-48-l4-4gpu" - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" strategy: fail-fast: false # don't cancel all jobs on failure matrix: - enable-x_64: [1, 0] + enable-x_64: [1] #, 0] + container-list: [ + {name: "nosla", image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython"}, + {name: "ml-build", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"} + ] + container: ${{ matrix.container-list.image }} env: JAXCI_HERMETIC_PYTHON_VERSION: 3.11 JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} - name: "Bazel single accelerator and multi-accelerator GPU tests (Non RBE, linux-x86-g2-48-l4-4gpu, Python 3.11, x64=${{ matrix.enable-x_64 }})" + name: "Bazel single accelerator and multi-accelerator GPU tests (Non RBE, linux-x86-g2-48-l4-4gpu, Python 3.11, x64=${{ matrix.enable-x_64 }}, image=${{ matrix.container-list.name }})" steps: - uses: actions/checkout@v3 From 5981613f6a5b6b07c81fc5058e950b944bc4370b Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 5 Dec 2024 21:01:17 +0000 Subject: [PATCH 111/205] remove exclude filters on cuda pjrt --- .github/workflows/build_artifacts.yml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index eeb516ece92c..35868a409a0e 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -146,13 +146,6 @@ jobs: artifact: "jax-cuda-pjrt" - runner: "windows-x86-n2-16" artifact: "jax-cuda-plugin" - # cuda-pjrt is a pure Python package, so it doesn't need to be built for all Python versions. - - artifact: "jax-cuda-pjrt" - python: 3.10 - - artifact: "jax-cuda-pjrt" - python: 3.11 - - artifact: "jax-cuda-pjrt" - python: 3.12 runs-on: ${{ matrix.runner }} From d7704a06e3cc5a107ee01320970f85a3024815a3 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 5 Dec 2024 21:35:25 +0000 Subject: [PATCH 112/205] change image to one with gsutil --- .github/workflows/bazel_gpu_non_rbe.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index f704bd291787..7b397c2e5f62 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -39,7 +39,7 @@ jobs: matrix: enable-x_64: [1] #, 0] container-list: [ - {name: "nosla", image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython"}, + {name: "nosla", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, {name: "ml-build", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"} ] container: ${{ matrix.container-list.image }} From ce247f90cede91de88392233b5e4625da9ee36a6 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 6 Dec 2024 20:03:20 +0000 Subject: [PATCH 113/205] disable tpu workflow --- .github/workflows/cloud-tpu-ci-nightly.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index fe879617c8a7..505feb170ca6 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -12,13 +12,13 @@ name: CI - Cloud TPU (nightly) on: - schedule: - - cron: "0 */2 * * *" # Run every 2 hours - workflow_dispatch: # allows triggering the workflow run manually -# This should also be set to read-only in the project settings, but it's nice to -# document and enforce the permissions here. -permissions: - contents: read +# schedule: +# - cron: "0 */2 * * *" # Run every 2 hours +# workflow_dispatch: # allows triggering the workflow run manually +# # This should also be set to read-only in the project settings, but it's nice to +# # document and enforce the permissions here. +# permissions: +# contents: read jobs: cloud-tpu-test: strategy: From e23e7fb64d9a7b7d29638b36753aff8789ba02d9 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 6 Dec 2024 20:04:17 +0000 Subject: [PATCH 114/205] try adding gpus=all option to debug gpu failures --- .github/workflows/bazel_gpu_non_rbe.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 7b397c2e5f62..64ab5a8638d1 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -42,7 +42,9 @@ jobs: {name: "nosla", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, {name: "ml-build", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"} ] - container: ${{ matrix.container-list.image }} + container: + image: ${{ matrix.container-list.image }} + options: --gpus all env: JAXCI_HERMETIC_PYTHON_VERSION: 3.11 JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} From 4d44155c9eee3bb80f1a24ff61ef9f2341fd5ec9 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 6 Dec 2024 20:45:20 +0000 Subject: [PATCH 115/205] add a workflow to debug matrix in resuable workflows --- .github/workflows/pytest_cpu_matrix_debug.yml | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 .github/workflows/pytest_cpu_matrix_debug.yml diff --git a/.github/workflows/pytest_cpu_matrix_debug.yml b/.github/workflows/pytest_cpu_matrix_debug.yml new file mode 100644 index 000000000000..620c6b949c33 --- /dev/null +++ b/.github/workflows/pytest_cpu_matrix_debug.yml @@ -0,0 +1,87 @@ +name: Run Pytest CPU tests + +on: + pull_request: + branches: + - main + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + build_jaxlib_artifact: + name: "Build the jaxlib aritfact using latest XLA" + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] + python: ["3.10", "3.11", "3.12", "3.13"] + with: + clone_main_xla: 1 + upload_artifacts: true + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + + run_pytest: + name: "Run CPU tests with Pytest" + needs: build_jaxlib_artifact + defaults: + run: + # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. + shell: bash + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] + python: ["3.10", "3.11", "3.12", "3.13"] + + runs-on: ${{ matrix.runner }} + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(matrix.runner, 'windows-x86') && null) }} + + env: + JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Set Platform + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Adjust name for Windows + if [[ $os =~ "msys_nt" ]]; then + os="windows" + fi + + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + - name: Download the artifacts built in the "build_artifacts" job + run: >- + mkdir -p $(pwd)/dist && + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ + - name: Install pytest + env: + JAXCI_PYTHON: python${{ matrix.python }} + run: $JAXCI_PYTHON -m pip install pytest + - name: Install dependencies + env: + JAXCI_PYTHON: python${{ matrix.python }} + run: $JAXCI_PYTHON -m pip install -r build/requirements.in + - name: Run Pytest CPU tests + run: ./ci/run_pytest_cpu.sh From c871fd17ffedafe2d801ca9cf8935e058b81fe23 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 10 Dec 2024 06:10:06 +0000 Subject: [PATCH 116/205] add workflows to debug matrix strategy for reusable workflows --- .../build_artifacts_matrix_debug.yml | 91 +++++++++++++++++++ .github/workflows/pytest_cpu_matrix_debug.yml | 8 +- 2 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/build_artifacts_matrix_debug.yml diff --git a/.github/workflows/build_artifacts_matrix_debug.yml b/.github/workflows/build_artifacts_matrix_debug.yml new file mode 100644 index 000000000000..a5667d862287 --- /dev/null +++ b/.github/workflows/build_artifacts_matrix_debug.yml @@ -0,0 +1,91 @@ +name: Build JAX Artifacts (matrix debug) + +on: + # pull_request: + # branches: + # - main + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + workflow_call: + inputs: + clone_main_xla: + description: "Should latest XLA be used? (1 to enable, 0 to disable)" + type: string + required: false + default: "0" + upload_artifacts: + description: "Should the artifacts be uploaded to a GCS bucket?" + required: false + default: false + type: boolean + upload_destination_prefix: + description: "GCS location prefix to where the artifacts should be uploaded" + required: false + default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + +jobs: + build_artifacts: + defaults: + run: + # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. + shell: bash + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["windows-x86-n2-16", "linux-x86-n2-16", "linux-arm64-c4a-64"] + artifact: ["jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] + python: ["3.10", "3.11", "3.12", "3.13"] + exclude: + # Windows doesn't support CUDA artifacts + - runner: "windows-x86-n2-16" + artifact: "jax-cuda-pjrt" + - runner: "windows-x86-n2-16" + artifact: "jax-cuda-plugin" + + runs-on: ${{ matrix.runner }} + + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(matrix.runner, 'windows-x86') && null) }} + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" + JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" + + steps: + - uses: actions/checkout@v3 + - name: Enable RBE if building on Linux x86 or Windows x86 + if: contains(matrix.runner, 'linux-x86') || contains(matrix.runner, 'windows-x86') + run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Build ${{ matrix.artifact }} + run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" + - name: Set PLATFORM env var for use in upload destination + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Adjust name for Windows + if [[ $os =~ "msys_nt" ]]; then + os="windows" + + fi + + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + - name: Upload artifacts to GCS bucket + if: inputs.upload_artifacts + run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ + diff --git a/.github/workflows/pytest_cpu_matrix_debug.yml b/.github/workflows/pytest_cpu_matrix_debug.yml index 620c6b949c33..69de98e180a0 100644 --- a/.github/workflows/pytest_cpu_matrix_debug.yml +++ b/.github/workflows/pytest_cpu_matrix_debug.yml @@ -1,4 +1,4 @@ -name: Run Pytest CPU tests +name: Run Pytest CPU tests (matrix debug) on: pull_request: @@ -21,12 +21,12 @@ concurrency: jobs: build_jaxlib_artifact: name: "Build the jaxlib aritfact using latest XLA" - uses: ./.github/workflows/build_artifacts.yml + uses: ./.github/workflows/build_artifacts_matrix_debug.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] - python: ["3.10", "3.11", "3.12", "3.13"] + python: ["3.10", "3.11", "3.12",] with: clone_main_xla: 1 upload_artifacts: true @@ -43,7 +43,7 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] - python: ["3.10", "3.11", "3.12", "3.13"] + python: ["3.10", "3.11", "3.12",] runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || From 0d47fe159449b5001d32665110ef68651d9e8088 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 10 Dec 2024 06:28:08 +0000 Subject: [PATCH 117/205] try 2 --- .../build_artifacts_matrix_debug.yml | 28 +++++++++++++++++-- .github/workflows/pytest_cpu_matrix_debug.yml | 2 +- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts_matrix_debug.yml b/.github/workflows/build_artifacts_matrix_debug.yml index a5667d862287..b8317b88a192 100644 --- a/.github/workflows/build_artifacts_matrix_debug.yml +++ b/.github/workflows/build_artifacts_matrix_debug.yml @@ -33,7 +33,29 @@ on: type: string jobs: + determine_matrix: + runs-on: "linux-x86-n2-16" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + outputs: + artifact_matrix: ${{ steps.set-matrix.outputs.artifact_matrix }} + python_matrix: ${{ steps.set-matrix.outputs.python_matrix }} + platform_matrix: ${{ steps.set-matrix.outputs.platform_matrix }} + defaults: + run: + shell: bash + steps: + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: "Determine the matrix" + id: set-matrix + run: | + echo ${{ matrix.workflow_call_runner }} + build_artifacts: + needs: determine_matrix defaults: run: # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. @@ -41,9 +63,9 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - runner: ["windows-x86-n2-16", "linux-x86-n2-16", "linux-arm64-c4a-64"] - artifact: ["jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] - python: ["3.10", "3.11", "3.12", "3.13"] + runner: ${{ fromJSON(needs.determine_matrix.outputs.platform_matrix) }} + artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} + python: ${{ fromJSON(needs.determine_matrix.outputs.python_matrix) }} exclude: # Windows doesn't support CUDA artifacts - runner: "windows-x86-n2-16" diff --git a/.github/workflows/pytest_cpu_matrix_debug.yml b/.github/workflows/pytest_cpu_matrix_debug.yml index 69de98e180a0..e1cd2cfa2c12 100644 --- a/.github/workflows/pytest_cpu_matrix_debug.yml +++ b/.github/workflows/pytest_cpu_matrix_debug.yml @@ -25,7 +25,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] + workflow_call_runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] python: ["3.10", "3.11", "3.12",] with: clone_main_xla: 1 From 040566938ea627fcaeb9e10ef34e3a6276631dc5 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 10 Dec 2024 06:42:12 +0000 Subject: [PATCH 118/205] use inputs to define matrix --- .../build_artifacts_matrix_debug.yml | 41 ++++++++----------- .github/workflows/pytest_cpu_matrix_debug.yml | 15 +++++-- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/.github/workflows/build_artifacts_matrix_debug.yml b/.github/workflows/build_artifacts_matrix_debug.yml index b8317b88a192..0297cf1aaa29 100644 --- a/.github/workflows/build_artifacts_matrix_debug.yml +++ b/.github/workflows/build_artifacts_matrix_debug.yml @@ -16,6 +16,21 @@ on: - 'no' workflow_call: inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-n2-16" + artifact: + description: "Which JAX artifact to build?" + type: string + required: true + default: "jaxlib" + python-version: + description: "Which python version should the artifact be built for?" + type: string + required: true + default: "3.12" clone_main_xla: description: "Should latest XLA be used? (1 to enable, 0 to disable)" type: string @@ -33,29 +48,7 @@ on: type: string jobs: - determine_matrix: - runs-on: "linux-x86-n2-16" - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" - outputs: - artifact_matrix: ${{ steps.set-matrix.outputs.artifact_matrix }} - python_matrix: ${{ steps.set-matrix.outputs.python_matrix }} - platform_matrix: ${{ steps.set-matrix.outputs.platform_matrix }} - defaults: - run: - shell: bash - steps: - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: "Determine the matrix" - id: set-matrix - run: | - echo ${{ matrix.workflow_call_runner }} - build_artifacts: - needs: determine_matrix defaults: run: # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. @@ -63,7 +56,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - runner: ${{ fromJSON(needs.determine_matrix.outputs.platform_matrix) }} + runner: ${{ }} artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} python: ${{ fromJSON(needs.determine_matrix.outputs.python_matrix) }} exclude: @@ -73,7 +66,7 @@ jobs: - runner: "windows-x86-n2-16" artifact: "jax-cuda-plugin" - runs-on: ${{ matrix.runner }} + runs-on: ${{ inputs.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || diff --git a/.github/workflows/pytest_cpu_matrix_debug.yml b/.github/workflows/pytest_cpu_matrix_debug.yml index e1cd2cfa2c12..5a1839a86a92 100644 --- a/.github/workflows/pytest_cpu_matrix_debug.yml +++ b/.github/workflows/pytest_cpu_matrix_debug.yml @@ -25,9 +25,18 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - workflow_call_runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] - python: ["3.10", "3.11", "3.12",] + runner: ["linux-x86-n2-16", "linux-arm64-c4a-64", "windows-x86-n2-16"] + artifact: ["jaxlib"] + python: ["3.10", "3.11", "3.12", "3.13"] + exclude: + # Windows doesn't support CUDA artifacts + - runner: "windows-x86-n2-16" + artifact: "jax-cuda-pjrt" + - runner: "windows-x86-n2-16" + artifact: "jax-cuda-plugin" with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} clone_main_xla: 1 upload_artifacts: true upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' @@ -43,7 +52,7 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] - python: ["3.10", "3.11", "3.12",] + python: ["3.10", "3.11", "3.12", "3.13"] runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || From dd8b3d4d9c7228f86481412cf3dd61954610042e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 10 Dec 2024 06:52:53 +0000 Subject: [PATCH 119/205] use inputs to define matrix --- .../build_artifacts_matrix_debug.yml | 28 ++++++------------- .github/workflows/pytest_cpu_matrix_debug.yml | 6 ---- 2 files changed, 8 insertions(+), 26 deletions(-) diff --git a/.github/workflows/build_artifacts_matrix_debug.yml b/.github/workflows/build_artifacts_matrix_debug.yml index 0297cf1aaa29..9870ed8ae493 100644 --- a/.github/workflows/build_artifacts_matrix_debug.yml +++ b/.github/workflows/build_artifacts_matrix_debug.yml @@ -26,7 +26,7 @@ on: type: string required: true default: "jaxlib" - python-version: + python: description: "Which python version should the artifact be built for?" type: string required: true @@ -53,41 +53,29 @@ jobs: run: # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. shell: bash - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - runner: ${{ }} - artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} - python: ${{ fromJSON(needs.determine_matrix.outputs.python_matrix) }} - exclude: - # Windows doesn't support CUDA artifacts - - runner: "windows-x86-n2-16" - artifact: "jax-cuda-pjrt" - - runner: "windows-x86-n2-16" - artifact: "jax-cuda-plugin" runs-on: ${{ inputs.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || - (contains(matrix.runner, 'windows-x86') && null) }} + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(inputs.runner, 'windows-x86') && null) }} env: - JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" steps: - uses: actions/checkout@v3 - name: Enable RBE if building on Linux x86 or Windows x86 - if: contains(matrix.runner, 'linux-x86') || contains(matrix.runner, 'windows-x86') + if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build ${{ matrix.artifact }} - run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" + - name: Build ${{ inputs.artifact }} + run: ./ci/build_artifacts.sh "${{ inputs.artifact }}" - name: Set PLATFORM env var for use in upload destination run: | os=$(uname -s | awk '{print tolower($0)}') diff --git a/.github/workflows/pytest_cpu_matrix_debug.yml b/.github/workflows/pytest_cpu_matrix_debug.yml index 5a1839a86a92..c004580228a3 100644 --- a/.github/workflows/pytest_cpu_matrix_debug.yml +++ b/.github/workflows/pytest_cpu_matrix_debug.yml @@ -28,12 +28,6 @@ jobs: runner: ["linux-x86-n2-16", "linux-arm64-c4a-64", "windows-x86-n2-16"] artifact: ["jaxlib"] python: ["3.10", "3.11", "3.12", "3.13"] - exclude: - # Windows doesn't support CUDA artifacts - - runner: "windows-x86-n2-16" - artifact: "jax-cuda-pjrt" - - runner: "windows-x86-n2-16" - artifact: "jax-cuda-plugin" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} From 2bd68cbb00738bdad1f5c0838e43049a50c34ecc Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 10 Dec 2024 06:53:46 +0000 Subject: [PATCH 120/205] pass artifact as input --- .github/workflows/pytest_cpu_matrix_debug.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pytest_cpu_matrix_debug.yml b/.github/workflows/pytest_cpu_matrix_debug.yml index c004580228a3..7a49962c79bb 100644 --- a/.github/workflows/pytest_cpu_matrix_debug.yml +++ b/.github/workflows/pytest_cpu_matrix_debug.yml @@ -30,6 +30,7 @@ jobs: python: ["3.10", "3.11", "3.12", "3.13"] with: runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} python: ${{ matrix.python }} clone_main_xla: 1 upload_artifacts: true From f8e1ce6ff8044490634021b94057e2473c21adc4 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 10 Dec 2024 17:02:38 +0000 Subject: [PATCH 121/205] run only jaxlib 3.10 build --- .github/workflows/build_artifacts_matrix_debug.yml | 2 ++ .github/workflows/pytest_cpu_matrix_debug.yml | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts_matrix_debug.yml b/.github/workflows/build_artifacts_matrix_debug.yml index 9870ed8ae493..e115bd7f2766 100644 --- a/.github/workflows/build_artifacts_matrix_debug.yml +++ b/.github/workflows/build_artifacts_matrix_debug.yml @@ -64,6 +64,8 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" + name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, Clone Main XLA=${{ inputs.clone_main_xla }}) + steps: - uses: actions/checkout@v3 - name: Enable RBE if building on Linux x86 or Windows x86 diff --git a/.github/workflows/pytest_cpu_matrix_debug.yml b/.github/workflows/pytest_cpu_matrix_debug.yml index 7a49962c79bb..0cbced40549e 100644 --- a/.github/workflows/pytest_cpu_matrix_debug.yml +++ b/.github/workflows/pytest_cpu_matrix_debug.yml @@ -25,9 +25,9 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - runner: ["linux-x86-n2-16", "linux-arm64-c4a-64", "windows-x86-n2-16"] + runner: ["linux-x86-n2-16",] # "linux-arm64-c4a-64", "windows-x86-n2-16"] artifact: ["jaxlib"] - python: ["3.10", "3.11", "3.12", "3.13"] + python: ["3.10",] # "3.11", "3.12", "3.13"] with: runner: ${{ matrix.runner }} artifact: ${{ matrix.artifact }} @@ -46,8 +46,8 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] - python: ["3.10", "3.11", "3.12", "3.13"] + runner: ["linux-x86-n2-64",] # "linux-arm64-c4a-64", "windows-x86-n2-64"] + python: ["3.10",] # "3.11", "3.12", "3.13"] runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || From bad1a5142e0c928d1b7cd58de7725cfba9930a5e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 11 Dec 2024 00:00:43 +0000 Subject: [PATCH 122/205] update reusable workflows to new matrix strategy --- .github/workflows/bazel_cpu_rbe.yml | 6 +- .github/workflows/bazel_gpu_non_rbe.yml | 6 +- .github/workflows/bazel_gpu_rbe.yml | 6 +- .github/workflows/build_artifacts.yml | 172 +++++------------- .../build_artifacts_matrix_debug.yml | 96 ---------- .github/workflows/pytest_cpu.yml | 28 +-- .github/workflows/pytest_gpu.yml | 27 +-- ci/run_pytest_cpu.sh | 1 + ci/run_pytest_gpu.sh | 1 + 9 files changed, 93 insertions(+), 250 deletions(-) delete mode 100644 .github/workflows/build_artifacts_matrix_debug.yml diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 49dc1d873c14..181da30356fa 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -1,9 +1,9 @@ name: CI - Bazel CPU tests (RBE) on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 64ab5a8638d1..5984eb513e56 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -1,9 +1,9 @@ name: Run Bazel GPU tests (non RBE) on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml index b0a8148d9484..5f3e29c59476 100644 --- a/.github/workflows/bazel_gpu_rbe.yml +++ b/.github/workflows/bazel_gpu_rbe.yml @@ -1,9 +1,9 @@ name: CI - Bazel GPU tests (RBE) on: - pull_request: - branches: - - main + # pull_request: + # branches: + # - main workflow_dispatch: inputs: halt-for-connection: diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 35868a409a0e..225f84442d5e 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -6,6 +6,26 @@ on: - main workflow_dispatch: inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-n2-16" + artifact: + description: "Which JAX artifact to build?" + type: string + required: true + default: "jaxlib" + python: + description: "Which python version should the artifact be built for?" + type: string + required: true + default: "3.12" + clone_main_xla: + description: "Should latest XLA be used? (1 to enable, 0 to disable)" + type: string + required: false + default: "0" halt-for-connection: description: 'Should this workflow run wait for a remote connection?' type: choice @@ -16,21 +36,21 @@ on: - 'no' workflow_call: inputs: - wheel_list: - description: "A comma separated list of JAX wheels to build. E.g: jaxlib or jaxlib,jax-cuda-pjrt" + runner: + description: "Which runner should the workflow run on?" type: string - required: false - default: "" - python_list: - description: "A comma separated list of Python versions to build for. E.g: 3.10 or 3.11,3.12" + required: true + default: "linux-x86-n2-16" + artifact: + description: "Which JAX artifact to build?" type: string - required: false - default: "" - platform_list: - description: "A comma separated list of platforms to build for. E.g: linux_x86 or linux_x86,linux_arm64,windows_x86" + required: true + default: "jaxlib" + python: + description: "Which python version should the artifact be built for?" type: string - required: false - default: "" + required: true + default: "3.12" clone_main_xla: description: "Should latest XLA be used? (1 to enable, 0 to disable)" type: string @@ -46,135 +66,42 @@ on: required: false default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string - is_workflow_call: - description: "Metadata variable to know whether a workflow call was made" - type: string - required: false - default: "1" -jobs: - determine_matrix: - runs-on: "linux-x86-n2-16" - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" - outputs: - artifact_matrix: ${{ steps.set-matrix.outputs.artifact_matrix }} - python_matrix: ${{ steps.set-matrix.outputs.python_matrix }} - platform_matrix: ${{ steps.set-matrix.outputs.platform_matrix }} - defaults: - run: - shell: bash - steps: - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: "Determine the matrix" - id: set-matrix - run: | - # Define inputs as bash variables to be able to parse them in - # if conditions - is_workflow_call=${{ inputs.is_workflow_call }} - wheel_list=${{ inputs.wheel_list }} - python_list=${{ inputs.python_list }} - platform_list=${{ inputs.platform_list }} - - # Initialize the arrays - wheels=() - python_versions=() - platforms=() - - # Build every package for every Python version on every platform if not a workflow call - # Packages that are not supported on a platform won't be built. E.g. CUDA packages won't be - # built for Windows - if [[ ${is_workflow_call:-"0"} == "0" ]]; then - wheels=("'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") - python_versions=("'3.10'" ", '3.11'" ", '3.12'", ", '3.13'") - platforms=("'linux-x86-n2-16'" ", 'linux-arm64-c4a-64'" ", 'windows-x86-n2-16'") - else - # Set the Internal Field Separator to be comma - IFS=, - - # Wheels - for wheel in $wheel_list; do - wheels+="'$wheel'," - done - - # Python versions - for python_version in $python_list; do - python_versions+="'$python_version'," - done - - # Platforms - for platform in $platform_list; do - if [[ $platform == "linux_x86" ]]; then - platforms+="'linux-x86-n2-16'," - elif [[ $platform == "linux_arm64" ]]; then - platforms+="'linux-arm64-c4a-64'," - elif [[ $platform == "windows_x86" ]]; then - platforms+="'windows-x86-n2-16'," - else - echo "Incorrect platform provided. Valid options are: linux_x86, linux_arm64, windows_x86" - exit 1 - fi - done - fi - - echo "artifact_matrix=[${wheels[@]}]" >> $GITHUB_OUTPUT - echo "python_matrix=[${python_versions[@]}]" >> $GITHUB_OUTPUT - echo "platform_matrix=[${platforms[@]}]" >> $GITHUB_OUTPUT - - echo "Artifacts: ${wheels[@]}" - echo "Python versions:${python_versions[@]}" - echo "Platforms: ${platforms[@]}" +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true +jobs: build_artifacts: - needs: determine_matrix defaults: run: # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. shell: bash - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - runner: ${{ fromJSON(needs.determine_matrix.outputs.platform_matrix) }} - artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} - python: ${{ fromJSON(needs.determine_matrix.outputs.python_matrix) }} - exclude: - # Windows doesn't support CUDA artifacts - - runner: "windows-x86-n2-16" - artifact: "jax-cuda-pjrt" - - runner: "windows-x86-n2-16" - artifact: "jax-cuda-plugin" - runs-on: ${{ matrix.runner }} + runs-on: ${{ inputs.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || - (contains(matrix.runner, 'windows-x86') && null) }} + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(inputs.runner, 'windows-x86') && null) }} env: - JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" - steps: - - uses: actions/checkout@v3 - - name: Enable RBE on platforms where its supported - run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) + name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, Clone main XLA=${{ inputs.clone_main_xla }}) - # Enable RBE if building on Linux x86 or Windows x86 - if [[ ($os == "linux" || $os =~ "msys_nt" ) && $arch == "x86_64" ]]; then - echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV - fi + steps: + - uses: actions/checkout@v4 + - name: Enable RBE if building on Linux x86 or Windows x86 + if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') + run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build ${{ matrix.artifact }} - run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" + - name: Build ${{ inputs.artifact }} + run: ./ci/build_artifacts.sh "${{ inputs.artifact }}" - name: Set PLATFORM env var for use in upload destination run: | os=$(uname -s | awk '{print tolower($0)}') @@ -183,7 +110,6 @@ jobs: # Adjust name for Windows if [[ $os =~ "msys_nt" ]]; then os="windows" - fi echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV diff --git a/.github/workflows/build_artifacts_matrix_debug.yml b/.github/workflows/build_artifacts_matrix_debug.yml deleted file mode 100644 index e115bd7f2766..000000000000 --- a/.github/workflows/build_artifacts_matrix_debug.yml +++ /dev/null @@ -1,96 +0,0 @@ -name: Build JAX Artifacts (matrix debug) - -on: - # pull_request: - # branches: - # - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - - 'no' - workflow_call: - inputs: - runner: - description: "Which runner should the workflow run on?" - type: string - required: true - default: "linux-x86-n2-16" - artifact: - description: "Which JAX artifact to build?" - type: string - required: true - default: "jaxlib" - python: - description: "Which python version should the artifact be built for?" - type: string - required: true - default: "3.12" - clone_main_xla: - description: "Should latest XLA be used? (1 to enable, 0 to disable)" - type: string - required: false - default: "0" - upload_artifacts: - description: "Should the artifacts be uploaded to a GCS bucket?" - required: false - default: false - type: boolean - upload_destination_prefix: - description: "GCS location prefix to where the artifacts should be uploaded" - required: false - default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - type: string - -jobs: - build_artifacts: - defaults: - run: - # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. - shell: bash - - runs-on: ${{ inputs.runner }} - - container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || - (contains(inputs.runner, 'windows-x86') && null) }} - - env: - JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" - JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" - - name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, Clone Main XLA=${{ inputs.clone_main_xla }}) - - steps: - - uses: actions/checkout@v3 - - name: Enable RBE if building on Linux x86 or Windows x86 - if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') - run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build ${{ inputs.artifact }} - run: ./ci/build_artifacts.sh "${{ inputs.artifact }}" - - name: Set PLATFORM env var for use in upload destination - run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) - - # Adjust name for Windows - if [[ $os =~ "msys_nt" ]]; then - os="windows" - - fi - - echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Upload artifacts to GCS bucket - if: inputs.upload_artifacts - run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ - diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 54f265223f22..cb32304ccb91 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -21,17 +21,22 @@ concurrency: jobs: build_jaxlib_artifact: name: "Build the jaxlib aritfact using latest XLA" - uses: ./.github/workflows/build_artifacts.yml + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-n2-16", "linux-arm64-c4a-64", "windows-x86-n2-16"] + artifact: ["jaxlib"] + python: ["3.10",] # "3.11", "3.12", "3.13"] with: - wheel_list: "jaxlib" - python_list: "3.10,3.11,3.12,3.13" - platform_list: "linux_x86,linux_arm64,windows_x86" + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} clone_main_xla: 1 upload_artifacts: true upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' run_pytest: - name: "Run CPU tests with Pytest" needs: build_jaxlib_artifact defaults: run: @@ -41,16 +46,21 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] - python: ["3.10", "3.11", "3.12", "3.13"] + python: ["3.10",] # "3.11", "3.12", "3.13"] + enable-x_64: [1, 0] runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(matrix.runner, 'windows-x86') && null) }} + name: "Pytest CPU (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" + env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} + JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} + steps: - uses: actions/checkout@v3 # Halt for testing @@ -69,14 +79,10 @@ jobs: fi echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Download the artifacts built in the "build_artifacts" job + - name: Download artifacts built in the "build_artifacts" job run: >- mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - - name: Install pytest - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install pytest - name: Install dependencies env: JAXCI_PYTHON: python${{ matrix.python }} diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index c7b2f3c2b98c..88133accecd8 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -21,12 +21,18 @@ concurrency: jobs: build_artifacts: - name: "Build the jaxlib and CUDA plugins using latest XLA" - uses: ./.github/workflows/build_artifacts.yml + name: "Build the jaxlib and CUDA aritfacts with latest XLA" + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-n2-16"] + artifact: ["jaxlib", "jax-cuda-plugin", "jax-cuda-pjrt"] + python: ["3.10",] # "3.11", "3.12", "3.13"] with: - wheel_list: "jaxlib,jax-cuda-plugin,jax-cuda-pjrt" - python_list: "3.10" - platform_list: "linux_x86" + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} clone_main_xla: 1 upload_artifacts: true upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' @@ -43,14 +49,17 @@ jobs: image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, ] python: ["3.10"] + enable-x_64: [1, 0] runs-on: ${{ matrix.test_env.runner }} container: image: ${{ matrix.test_env.image }} - name: "Pytest GPU (Test on CUDA ${{ matrix.test_env.cuda_version }})" + name: "Pytest GPU (${{ matrix.test_env.runner }}, CUDA ${{ matrix.test_env.cuda_version }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" + env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} + JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} steps: - uses: actions/checkout@v3 @@ -71,12 +80,8 @@ jobs: fi echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Download the artifacts built in the "build_artifacts" job + - name: Download artifacts built in the "build_artifacts" job run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - - name: Install pytest - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install pytest - name: Install dependencies env: JAXCI_PYTHON: python${{ matrix.python }} diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh index 2b19ca5ddaa5..ef96f80aec35 100755 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -39,6 +39,7 @@ source "ci/utilities/setup_build_environment.sh" export PY_COLORS=1 export JAX_SKIP_SLOW_TESTS=true export TF_CPP_MIN_LOG_LEVEL=0 +export JAX_ENABLE_X64="$JAXCI_ENABLE_X64" # End of test environment variable setup echo "Running CPU tests..." diff --git a/ci/run_pytest_gpu.sh b/ci/run_pytest_gpu.sh index 7bc2492781b2..42cffc54947a 100755 --- a/ci/run_pytest_gpu.sh +++ b/ci/run_pytest_gpu.sh @@ -43,6 +43,7 @@ export PY_COLORS=1 export JAX_SKIP_SLOW_TESTS=true export NCCL_DEBUG=WARN export TF_CPP_MIN_LOG_LEVEL=0 +export JAX_ENABLE_X64="$JAXCI_ENABLE_X64" # Set the number of processes to run to be 4x the number of GPUs. export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) From 43f698b30bf6f6642e778c85c59785ccaeb631a3 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 11 Dec 2024 00:02:34 +0000 Subject: [PATCH 123/205] remove concurrency settings --- .github/workflows/build_artifacts.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 225f84442d5e..c9c0106be4e1 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -67,10 +67,6 @@ on: default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - jobs: build_artifacts: defaults: From 24f0d1eeda08ff3e8745966d2fc83a33383b34f7 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 11 Dec 2024 17:14:04 +0000 Subject: [PATCH 124/205] adjust upload/download logic for windows --- .github/workflows/build_artifacts.yml | 9 +++++++-- .github/workflows/pytest_cpu.yml | 11 +++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index c9c0106be4e1..9fcb1538142c 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -109,7 +109,12 @@ jobs: fi echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Upload artifacts to GCS bucket - if: inputs.upload_artifacts + - name: Upload artifacts to GCS bucket (non-Windows) + if: inputs.upload_artifacts && !contains(inputs.runner, 'windows-x86') run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ + - name: Upload artifacts to GCS bucket (Windows) + if: inputs.upload_artifacts && contains(inputs.runner, 'windows-x86') + shell: cmd + run: gsutil -m cp -r dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ + diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index cb32304ccb91..3f872092b7a1 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -79,10 +79,17 @@ jobs: fi echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Download artifacts built in the "build_artifacts" job + - name: Download artifacts built in the "build_artifacts" job (non-Windows runs) + if: ! contains(matrix.runner, 'windows-x86') run: >- - mkdir -p $(pwd)/dist && + mkdir -p $(pwd)/dist gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ + - name: Download artifacts built in the "build_artifacts" job (Windows runs) + if: contains(matrix.runner, 'windows-x86') + shell: cmd + run: >- + mkdir dist + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl dist/ - name: Install dependencies env: JAXCI_PYTHON: python${{ matrix.python }} From 0ad69cc96d14c57c5aa6deb9a7b6a2d31c87d873 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 11 Dec 2024 17:18:52 +0000 Subject: [PATCH 125/205] fix syntax issue --- .github/workflows/build_artifacts.yml | 4 ++-- .github/workflows/pytest_cpu.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 9fcb1538142c..d8e4be76b9c3 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -110,10 +110,10 @@ jobs: echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket (non-Windows) - if: inputs.upload_artifacts && !contains(inputs.runner, 'windows-x86') + if: inputs.upload_artifacts && ${{ !contains(inputs.runner, 'windows-x86') }} run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ - name: Upload artifacts to GCS bucket (Windows) - if: inputs.upload_artifacts && contains(inputs.runner, 'windows-x86') + if: inputs.upload_artifacts && ${{ contains(inputs.runner, 'windows-x86') }} shell: cmd run: gsutil -m cp -r dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 3f872092b7a1..19600e485894 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -80,12 +80,12 @@ jobs: echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Download artifacts built in the "build_artifacts" job (non-Windows runs) - if: ! contains(matrix.runner, 'windows-x86') + if: ${{ !contains(matrix.runner, 'windows-x86') }} run: >- mkdir -p $(pwd)/dist gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - name: Download artifacts built in the "build_artifacts" job (Windows runs) - if: contains(matrix.runner, 'windows-x86') + if: ${{ contains(matrix.runner, 'windows-x86') }} shell: cmd run: >- mkdir dist From a20be400de5d362cd1776033a3da9a2f2d8aff91 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 11 Dec 2024 17:36:02 +0000 Subject: [PATCH 126/205] add pipe to enable multi-line if condition --- .github/workflows/build_artifacts.yml | 6 ++++-- .github/workflows/pytest_cpu.yml | 1 - .github/workflows/pytest_gpu.yml | 1 - 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index d8e4be76b9c3..60c57d2e49cb 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -110,10 +110,12 @@ jobs: echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket (non-Windows) - if: inputs.upload_artifacts && ${{ !contains(inputs.runner, 'windows-x86') }} + if: | + inputs.upload_artifacts && ${{ !contains(inputs.runner, 'windows-x86') }} run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ - name: Upload artifacts to GCS bucket (Windows) - if: inputs.upload_artifacts && ${{ contains(inputs.runner, 'windows-x86') }} + if: | + inputs.upload_artifacts && ${{ contains(inputs.runner, 'windows-x86') }} shell: cmd run: gsutil -m cp -r dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 19600e485894..763053106d2c 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -20,7 +20,6 @@ concurrency: jobs: build_jaxlib_artifact: - name: "Build the jaxlib aritfact using latest XLA" uses: ./.github/workflows/build_artifacts.yml strategy: fail-fast: false # don't cancel all jobs on failure diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index 88133accecd8..7a37fd20cccf 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -21,7 +21,6 @@ concurrency: jobs: build_artifacts: - name: "Build the jaxlib and CUDA aritfacts with latest XLA" uses: ./.github/workflows/build_artifacts.yml strategy: fail-fast: false # don't cancel all jobs on failure From cfaf3d0a66238e5f28430241ca1b71832c438c30 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 11 Dec 2024 17:45:25 +0000 Subject: [PATCH 127/205] Fix syntax issue --- .github/workflows/build_artifacts.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 60c57d2e49cb..a5291e631060 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -111,11 +111,11 @@ jobs: echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket (non-Windows) if: | - inputs.upload_artifacts && ${{ !contains(inputs.runner, 'windows-x86') }} + ${{ inputs.upload_artifacts && !contains(inputs.runner, 'windows-x86') }} run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ - name: Upload artifacts to GCS bucket (Windows) if: | - inputs.upload_artifacts && ${{ contains(inputs.runner, 'windows-x86') }} + ${{ inputs.upload_artifacts && contains(inputs.runner, 'windows-x86') }} shell: cmd run: gsutil -m cp -r dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ From fb875753d8a14aec72b8f704f0c1d4e93c97a8c2 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 11 Dec 2024 17:58:08 +0000 Subject: [PATCH 128/205] try fix for if condition evaluation --- .github/workflows/build_artifacts.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index a5291e631060..08b6c55debc5 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -110,11 +110,11 @@ jobs: echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket (non-Windows) - if: | + if: >- ${{ inputs.upload_artifacts && !contains(inputs.runner, 'windows-x86') }} run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ - name: Upload artifacts to GCS bucket (Windows) - if: | + if: >- ${{ inputs.upload_artifacts && contains(inputs.runner, 'windows-x86') }} shell: cmd run: gsutil -m cp -r dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ From 6278a734017e8bcad2fa720d51fd36be784eb1bb Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 11 Dec 2024 18:07:47 +0000 Subject: [PATCH 129/205] combine command with and operator --- .github/workflows/pytest_cpu.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 763053106d2c..20221f3f1af1 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -81,13 +81,13 @@ jobs: - name: Download artifacts built in the "build_artifacts" job (non-Windows runs) if: ${{ !contains(matrix.runner, 'windows-x86') }} run: >- - mkdir -p $(pwd)/dist + mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - name: Download artifacts built in the "build_artifacts" job (Windows runs) if: ${{ contains(matrix.runner, 'windows-x86') }} shell: cmd run: >- - mkdir dist + mkdir dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl dist/ - name: Install dependencies env: From 97c4c70c175db4f9c298a2dad9fad1ff272e371c Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 11 Dec 2024 19:43:54 +0000 Subject: [PATCH 130/205] adjust paths when on windows --- ci/utilities/install_wheels_locally.sh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index 181256b90804..c43fa815af9b 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -26,7 +26,12 @@ fi echo "Installing the following wheels:" echo "${WHEELS[@]}" -"$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" +# On Windows, convert MSYS Linux-like paths to Windows paths. +if [[ $(uname -s) =~ "MSYS_NT" ]]; then + "$JAXCI_PYTHON" -m pip install $(cygpath -w "${WHEELS[@]}") +else + "$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" +fi echo "Installing the JAX package in editable mode at the current commit..." # Install JAX package at the current commit. From 774e5dc0824b59671a80c97a4e93fdfd687018b3 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 11 Dec 2024 19:50:30 +0000 Subject: [PATCH 131/205] merge changes from upstream --- .github/workflows/bazel_gpu_rbe.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml index 6c9526077c89..6c58c0d4b8f7 100644 --- a/.github/workflows/bazel_gpu_rbe.yml +++ b/.github/workflows/bazel_gpu_rbe.yml @@ -28,7 +28,6 @@ jobs: matrix: runner: ["linux-x86-n2-16"] enable-x_64: [1, 0] - enable-x_64: [1, 0] runs-on: ${{ matrix.runner }} container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest' From c905115445e89c54962443b4fb013e96edc2b3af Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 12 Dec 2024 20:12:41 +0000 Subject: [PATCH 132/205] enable the bazel cpu rbe workflow --- .github/workflows/bazel_cpu_rbe.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 98ed2907409c..ae1bb377a70b 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -1,9 +1,9 @@ name: CI - Bazel CPU tests (RBE) on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: From c25c78fdf3dd84a3443ec37d33841b3c81fc3170 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 12 Dec 2024 20:14:53 +0000 Subject: [PATCH 133/205] fix syntax issues --- .github/workflows/bazel_cpu_rbe.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index ae1bb377a70b..d3a4e9bca1ac 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -14,9 +14,6 @@ on: options: - 'yes' - 'no' - pull_request: - branches: - - main concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} From a7b493370e26c7f270a361dfbc7b315d4a20d127 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 16 Dec 2024 18:57:35 +0000 Subject: [PATCH 134/205] Utilize cache better by keeping the set of build options same when building jaxlib+plugin artifacts together --- build/build.py | 272 +++++++++++++++++++++++++------------------------ 1 file changed, 141 insertions(+), 131 deletions(-) diff --git a/build/build.py b/build/build.py index a6c1a7922b0e..5b1c90024f6f 100755 --- a/build/build.py +++ b/build/build.py @@ -414,6 +414,8 @@ async def main(): else: sys.exit(0) + wheel_build_command_base = copy.deepcopy(bazel_command_base) + wheel_cpus = { "darwin_arm64": "arm64", "darwin_x86_64": "x86_64", @@ -426,164 +428,172 @@ async def main(): if args.local_xla_path: logging.debug("Local XLA path: %s", args.local_xla_path) - bazel_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"") + wheel_build_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"") if args.target_cpu: logging.debug("Target CPU: %s", args.target_cpu) - bazel_command_base.append(f"--cpu={args.target_cpu}") + wheel_build_command_base.append(f"--cpu={args.target_cpu}") if args.disable_nccl: logging.debug("Disabling NCCL") - bazel_command_base.append("--config=nonccl") + wheel_build_command_base.append("--config=nonccl") git_hash = utils.get_githash() - # Wheel build command execution - for wheel in args.wheels.split(","): - # Allow CUDA/ROCm wheels without the "jax-" prefix. - if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: - wheel = "jax-" + wheel - - if wheel not in WHEEL_BUILD_TARGET_DICT.keys(): - logging.error( - "Incorrect wheel name provided, valid choices are jaxlib," - " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," - " jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt" - ) - sys.exit(1) - - wheel_build_command = copy.deepcopy(bazel_command_base) - print("\n") - logger.info( - "Building %s for %s %s...", - wheel, - os_name, - arch, + clang_path = "" + if args.use_clang: + clang_path = args.clang_path or utils.get_clang_path_or_exit() + clang_major_version = utils.get_clang_major_version(clang_path) + logging.debug( + "Using Clang as the compiler, clang path: %s, clang version: %s", + clang_path, + clang_major_version, ) - clang_path = "" - if args.use_clang: - clang_path = args.clang_path or utils.get_clang_path_or_exit() - clang_major_version = utils.get_clang_major_version(clang_path) + # Use double quotes around clang path to avoid path issues on Windows. + wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") + else: + logging.debug("Use Clang: False") + + # Do not apply --config=clang on Mac as these settings do not apply to + # Apple Clang. + if os_name != "darwin": + wheel_build_command_base.append("--config=clang") + + if not args.disable_mkl_dnn: + logging.debug("Enabling MKL DNN") + if target_cpu == "aarch64": + wheel_build_command_base.append("--config=mkl_aarch64_threadpool") + else: + wheel_build_command_base.append("--config=mkl_open_source_only") + + if args.target_cpu_features == "release": + if arch in ["x86_64", "AMD64"]: logging.debug( - "Using Clang as the compiler, clang path: %s, clang version: %s", - clang_path, - clang_major_version, + "Using release cpu features: --config=avx_%s", + "windows" if os_name == "windows" else "posix", + ) + wheel_build_command_base.append( + "--config=avx_windows" + if os_name == "windows" + else "--config=avx_posix" + ) + elif args.target_cpu_features == "native": + if os_name == "windows": + logger.warning( + "--target_cpu_features=native is not supported on Windows;" + " ignoring." ) - - # Use double quotes around clang path to avoid path issues on Windows. - wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") - wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"") - wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") else: - logging.debug("Use Clang: False") - - # Do not apply --config=clang on Mac as these settings do not apply to - # Apple Clang. - if os_name != "darwin": - wheel_build_command.append("--config=clang") - - if not args.disable_mkl_dnn: - logging.debug("Enabling MKL DNN") - wheel_build_command.append("--config=mkl_open_source_only") - - if args.target_cpu_features == "release": - if arch in ["x86_64", "AMD64"]: - logging.debug( - "Using release cpu features: --config=avx_%s", - "windows" if os_name == "windows" else "posix", - ) - wheel_build_command.append( - "--config=avx_windows" - if os_name == "windows" - else "--config=avx_posix" - ) - elif wheel_build_command == "native": - if os_name == "windows": - logger.warning( - "--target_cpu_features=native is not supported on Windows;" - " ignoring." - ) - else: - logging.debug("Using native cpu features: --config=native_arch_posix") - wheel_build_command.append("--config=native_arch_posix") + logging.debug("Using native cpu features: --config=native_arch_posix") + wheel_build_command_base.append("--config=native_arch_posix") + else: + logging.debug("Using default cpu features") + + if "cuda" in args.wheels and "rocm" in args.wheels: + logging.error("CUDA and ROCm cannot be enabled at the same time.") + sys.exit(1) + + if "cuda" in args.wheels: + wheel_build_command_base.append("--config=cuda") + wheel_build_command_base.append( + f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" + ) + if args.build_cuda_with_clang: + logging.debug("Building CUDA with Clang") + wheel_build_command_base.append("--config=build_cuda_with_clang") else: - logging.debug("Using default cpu features") - - if "cuda" in wheel: - wheel_build_command.append("--config=cuda") - wheel_build_command.append( - f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" - ) - if args.build_cuda_with_clang: - logging.debug("Building CUDA with Clang") - wheel_build_command.append("--config=build_cuda_with_clang") - else: - logging.debug("Building CUDA with NVCC") - wheel_build_command.append("--config=build_cuda_with_nvcc") - - if args.cuda_version: - logging.debug("Hermetic CUDA version: %s", args.cuda_version) - wheel_build_command.append( - f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}" - ) - if args.cudnn_version: - logging.debug("Hermetic cuDNN version: %s", args.cudnn_version) - wheel_build_command.append( - f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}" - ) - if args.cuda_compute_capabilities: - logging.debug( - "Hermetic CUDA compute capabilities: %s", - args.cuda_compute_capabilities, - ) - wheel_build_command.append( - f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}" - ) - - if "rocm" in wheel: - wheel_build_command.append("--config=rocm_base") - if args.use_clang: - wheel_build_command.append("--config=rocm") - wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") - if args.rocm_path: - logging.debug("ROCm tookit path: %s", args.rocm_path) - wheel_build_command.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") - if args.rocm_amdgpu_targets: - logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) - wheel_build_command.append( - f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" - ) + logging.debug("Building CUDA with NVCC") + wheel_build_command_base.append("--config=build_cuda_with_nvcc") - # Append additional build options at the end to override any options set in - # .bazelrc or above. - if args.bazel_options: + if args.cuda_version: + logging.debug("Hermetic CUDA version: %s", args.cuda_version) + wheel_build_command_base.append( + f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}" + ) + if args.cudnn_version: + logging.debug("Hermetic cuDNN version: %s", args.cudnn_version) + wheel_build_command_base.append( + f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}" + ) + if args.cuda_compute_capabilities: logging.debug( - "Additional Bazel build options: %s", args.bazel_options + "Hermetic CUDA compute capabilities: %s", + args.cuda_compute_capabilities, + ) + wheel_build_command_base.append( + f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}" + ) + + if "rocm" in args.wheels: + wheel_build_command_base.append("--config=rocm_base") + if args.use_clang: + wheel_build_command_base.append("--config=rocm") + wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + if args.rocm_path: + logging.debug("ROCm tookit path: %s", args.rocm_path) + wheel_build_command_base.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") + if args.rocm_amdgpu_targets: + logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) + wheel_build_command_base.append( + f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" ) - for option in args.bazel_options: - wheel_build_command.append(option) - with open(".jax_configure.bazelrc", "w") as f: - jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list()) - if not jax_configure_options: - logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") + # Append additional build options at the end to override any options set in + # .bazelrc or above. + if args.bazel_options: + logging.debug( + "Additional Bazel build options: %s", args.bazel_options + ) + for option in args.bazel_options: + wheel_build_command_base.append(option) + + with open(".jax_configure.bazelrc", "w") as f: + jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list()) + if not jax_configure_options: + logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") + sys.exit(1) + f.write(jax_configure_options) + logging.info("Bazel options written to .jax_configure.bazelrc") + + if args.configure_only: + logging.info("--configure_only is set so not running any Bazel commands.") + else: + output_path = args.output_path + logger.debug("Artifacts output directory: %s", output_path) + + # Wheel build command execution + for wheel in args.wheels.split(","): + wheel_build_command = copy.deepcopy(wheel_build_command_base) + + # Allow CUDA/ROCm wheels without the "jax-" prefix. + if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: + wheel = "jax-" + wheel + + if wheel not in WHEEL_BUILD_TARGET_DICT.keys(): + logging.error( + "Incorrect wheel name provided, valid choices are jaxlib," + " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," + " jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt" + ) sys.exit(1) - f.write(jax_configure_options) - logging.info("Bazel options written to .jax_configure.bazelrc") - if args.configure_only: - logging.info("--configure_only is set so not running any Bazel commands.") - else: + print("\n") + logger.info( + "Building %s for %s %s...", + wheel, + os_name, + arch, + ) + # Append the build target to the Bazel command. build_target = WHEEL_BUILD_TARGET_DICT[wheel] wheel_build_command.append(build_target) wheel_build_command.append("--") - output_path = args.output_path - logger.debug("Artifacts output directory: %s", output_path) - if args.editable: logger.info("Building an editable build") output_path = os.path.join(output_path, wheel) From 1615c23f5f0d6a11a2be0f8e620407f0aadddba5 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 16 Dec 2024 18:59:35 +0000 Subject: [PATCH 135/205] Fix syntax issue --- .github/workflows/pytest_cpu.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 971bcfacdca9..fb986ab0e482 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -18,10 +18,6 @@ concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - jobs: build_jaxlib_artifact: uses: ./.github/workflows/build_artifacts.yml From 16f74bb68117d49fbff736db628f66554a04f8d6 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 16 Dec 2024 19:00:41 +0000 Subject: [PATCH 136/205] remove rbe --- .github/workflows/build_artifacts_single_invocation.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_artifacts_single_invocation.yml b/.github/workflows/build_artifacts_single_invocation.yml index 563312172d30..0610e733a004 100644 --- a/.github/workflows/build_artifacts_single_invocation.yml +++ b/.github/workflows/build_artifacts_single_invocation.yml @@ -18,7 +18,7 @@ on: jobs: build_artifacts: - runs-on: "linux-x86-n2-16" + runs-on: "linux-x86-n2-64" container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" env: @@ -34,4 +34,4 @@ jobs: - name: Mark GitHub repo as safe run: git config --global --add safe.directory "$GITHUB_WORKSPACE" - name: Build all artifacts with a single invocation of build CLI - run: python build/build.py build --wheels=jax-cuda-plugin,jax-cuda-pjrt,jaxlib --bazel_options=--config=rbe_linux_x86_64_cuda --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file + run: python build/build.py build --wheels=jax-cuda-plugin,jax-cuda-pjrt,jaxlib --bazel_options=--config=ci_linux_x86_64_cuda --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file From 8fa2800238fc2af387bfdcff884c02bbc5a5817f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 16 Dec 2024 19:03:00 +0000 Subject: [PATCH 137/205] build jaxlib first --- .github/workflows/build_artifacts_single_invocation.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts_single_invocation.yml b/.github/workflows/build_artifacts_single_invocation.yml index 0610e733a004..1f16e8ca3cd9 100644 --- a/.github/workflows/build_artifacts_single_invocation.yml +++ b/.github/workflows/build_artifacts_single_invocation.yml @@ -34,4 +34,4 @@ jobs: - name: Mark GitHub repo as safe run: git config --global --add safe.directory "$GITHUB_WORKSPACE" - name: Build all artifacts with a single invocation of build CLI - run: python build/build.py build --wheels=jax-cuda-plugin,jax-cuda-pjrt,jaxlib --bazel_options=--config=ci_linux_x86_64_cuda --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file + run: python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --bazel_options=--config=ci_linux_x86_64_cuda --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file From 8d77de1cd057a99d027b2178bae96a2b7f240949 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 16 Dec 2024 19:04:04 +0000 Subject: [PATCH 138/205] update .bazelrc to upstream --- .bazelrc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.bazelrc b/.bazelrc index 8b53bd475e5b..864daf76feed 100644 --- a/.bazelrc +++ b/.bazelrc @@ -96,6 +96,11 @@ build:avx_windows --copt=/arch:AVX build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 +# Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL). +build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true +build:mkl_aarch64_threadpool --@compute_library//:openmp=false +build:mkl_aarch64_threadpool -c opt + # Disable clang extention that rejects type definitions within offsetof. # This was added in clang-16 by https://reviews.llvm.org/D133574. # Can be removed once upb is updated, since a type definition is used within From 17928f592bdfc14035a2fb4fd7fd29e2f6e1d1c7 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 16 Dec 2024 19:30:45 +0000 Subject: [PATCH 139/205] update build.py to upstream --- build/build.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/build/build.py b/build/build.py index 5b1c90024f6f..6df3e2673fbc 100755 --- a/build/build.py +++ b/build/build.py @@ -454,14 +454,14 @@ async def main(): wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") + + # Do not apply --config=clang on Mac as these settings do not apply to + # Apple Clang. + if os_name != "darwin": + wheel_build_command_base.append("--config=clang") else: logging.debug("Use Clang: False") - # Do not apply --config=clang on Mac as these settings do not apply to - # Apple Clang. - if os_name != "darwin": - wheel_build_command_base.append("--config=clang") - if not args.disable_mkl_dnn: logging.debug("Enabling MKL DNN") if target_cpu == "aarch64": @@ -498,9 +498,10 @@ async def main(): if "cuda" in args.wheels: wheel_build_command_base.append("--config=cuda") - wheel_build_command_base.append( - f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" - ) + if args.use_clang: + wheel_build_command_base.append( + f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" + ) if args.build_cuda_with_clang: logging.debug("Building CUDA with Clang") wheel_build_command_base.append("--config=build_cuda_with_clang") From ac02e1af8fd331646cdfc7e503c65c7756beec77 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 17 Dec 2024 04:43:21 +0000 Subject: [PATCH 140/205] remove ci config flag --- .github/workflows/build_artifacts_single_invocation.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts_single_invocation.yml b/.github/workflows/build_artifacts_single_invocation.yml index 1f16e8ca3cd9..7e2c4f41e6e2 100644 --- a/.github/workflows/build_artifacts_single_invocation.yml +++ b/.github/workflows/build_artifacts_single_invocation.yml @@ -34,4 +34,4 @@ jobs: - name: Mark GitHub repo as safe run: git config --global --add safe.directory "$GITHUB_WORKSPACE" - name: Build all artifacts with a single invocation of build CLI - run: python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --bazel_options=--config=ci_linux_x86_64_cuda --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file + run: python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file From b162517b9fa16363f22689909005ae1484d8eece Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 17 Dec 2024 16:54:51 +0000 Subject: [PATCH 141/205] apply --config=clang on all platforms --- build/build.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/build/build.py b/build/build.py index 6df3e2673fbc..3f9ce34ce982 100755 --- a/build/build.py +++ b/build/build.py @@ -454,11 +454,7 @@ async def main(): wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") - - # Do not apply --config=clang on Mac as these settings do not apply to - # Apple Clang. - if os_name != "darwin": - wheel_build_command_base.append("--config=clang") + wheel_build_command_base.append("--config=clang") else: logging.debug("Use Clang: False") From b9ccb7fe17c02b09e4b025fcd0190f50a0d49dcb Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 17 Dec 2024 22:31:28 +0000 Subject: [PATCH 142/205] add a matrix strtegy for building with the ml-build rbe images --- .github/workflows/bazel_gpu_non_rbe.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 5984eb513e56..c13b2b52867d 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -1,9 +1,9 @@ name: Run Bazel GPU tests (non RBE) on: - # pull_request: - # branches: - # - main + pull_request: + branches: + - main workflow_dispatch: inputs: halt-for-connection: @@ -40,7 +40,8 @@ jobs: enable-x_64: [1] #, 0] container-list: [ {name: "nosla", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, - {name: "ml-build", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"} + {name: "ml-build", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"}, + {name: "ml-build-rbe", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-rbe:latest"} ] container: image: ${{ matrix.container-list.image }} From 87cfd2d2c7be0e2ec0fbe589cebd9cf6302bcf87 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 17 Dec 2024 22:40:23 +0000 Subject: [PATCH 143/205] update workflow call --- .github/workflows/bazel_gpu_non_rbe.yml | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index c13b2b52867d..54ac8efd6810 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -21,15 +21,20 @@ concurrency: jobs: build_artifacts: - name: "Build the jaxlib and CUDA plugins using latest XLA" - uses: ./.github/workflows/build_artifacts.yml - with: - wheel_list: "jaxlib,jax-cuda-plugin,jax-cuda-pjrt" - python_list: "3.11" - platform_list: "linux_x86" - clone_main_xla: 1 - upload_artifacts: true - upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-n2-16"] + artifact: ["jaxlib", "jax-cuda-plugin", "jax-cuda-pjrt"] + python: ["3.10",] # "3.11", "3.12", "3.13"] + with: + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + clone_main_xla: 1 + upload_artifacts: true + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' run_bazel_tests: needs: build_artifacts From 501ec7ef6e0e1245c0430f0f9299b1f725498d31 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 17 Dec 2024 23:12:22 +0000 Subject: [PATCH 144/205] update python verison to match workflow call --- .github/workflows/bazel_gpu_non_rbe.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 54ac8efd6810..85bcd1b5f4a2 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -52,7 +52,7 @@ jobs: image: ${{ matrix.container-list.image }} options: --gpus all env: - JAXCI_HERMETIC_PYTHON_VERSION: 3.11 + JAXCI_HERMETIC_PYTHON_VERSION: 3.10 JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} name: "Bazel single accelerator and multi-accelerator GPU tests (Non RBE, linux-x86-g2-48-l4-4gpu, Python 3.11, x64=${{ matrix.enable-x_64 }}, image=${{ matrix.container-list.name }})" From 6123ecde21b4cd5804dcc2250c0c099c4e8171a1 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 17 Dec 2024 23:57:09 +0000 Subject: [PATCH 145/205] update python verison to match workflow call --- .github/workflows/bazel_gpu_non_rbe.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml index 85bcd1b5f4a2..b6bbc333c083 100644 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ b/.github/workflows/bazel_gpu_non_rbe.yml @@ -52,7 +52,7 @@ jobs: image: ${{ matrix.container-list.image }} options: --gpus all env: - JAXCI_HERMETIC_PYTHON_VERSION: 3.10 + JAXCI_HERMETIC_PYTHON_VERSION: "3.10" JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} name: "Bazel single accelerator and multi-accelerator GPU tests (Non RBE, linux-x86-g2-48-l4-4gpu, Python 3.11, x64=${{ matrix.enable-x_64 }}, image=${{ matrix.container-list.name }})" From 2eb10ccce7e52d3cf176cd73aa7106a22d1595a5 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 18 Dec 2024 23:07:12 +0000 Subject: [PATCH 146/205] add new workflows for running wheel tests --- .github/workflows/pytest_cpu_wt.yml | 95 +++++++++++++++++++++++++++++ .github/workflows/pytest_gpu_wt.yml | 78 +++++++++++++++++++++++ .github/workflows/wheel_tests.yml | 85 ++++++++++++++++++++++++++ 3 files changed, 258 insertions(+) create mode 100644 .github/workflows/pytest_cpu_wt.yml create mode 100644 .github/workflows/pytest_gpu_wt.yml create mode 100644 .github/workflows/wheel_tests.yml diff --git a/.github/workflows/pytest_cpu_wt.yml b/.github/workflows/pytest_cpu_wt.yml new file mode 100644 index 000000000000..c6d36d555319 --- /dev/null +++ b/.github/workflows/pytest_cpu_wt.yml @@ -0,0 +1,95 @@ +name: Run Pytest CPU tests + +on: + pull_request: + branches: + - main + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-n2-16" + python: + description: "Which python version should the artifact be built for?" + type: string + required: true + default: "3.12" + download_url_prefix: + description: "GCS location prefix from where the artifacts should be downloaded" + required: false + default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + run-tests: + defaults: + run: + # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. + shell: bash + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + enable-x_64: [1, 0] + + runs-on: ${{ inputs.runner }} + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(inputs.runner, 'windows-x86') && null) }} + + name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ matrix.enable-x_64 }})" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_PYTHON: "python${{ inputs.python }}" + JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} + + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Set PLATFORM env var for use in artifact download URL + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Adjust name for Windows + if [[ $os =~ "msys_nt" ]]; then + os="windows" + fi + + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + - name: Download artifacts built in the "build_artifacts" job (non-Windows runs) + if: ${{ !contains(inputs.runner, 'windows-x86') }} + run: >- + mkdir -p $(pwd)/dist && + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ + - name: Download artifacts built in the "build_artifacts" job (Windows runs) + if: ${{ contains(inputs.runner, 'windows-x86') }} + shell: cmd + run: >- + mkdir dist && + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl dist/ + - name: Install dependencies + run: $JAXCI_PYTHON -m pip install -r build/requirements.in + - name: Run Pytest CPU tests + run: ./ci/run_pytest_cpu.sh \ No newline at end of file diff --git a/.github/workflows/pytest_gpu_wt.yml b/.github/workflows/pytest_gpu_wt.yml new file mode 100644 index 000000000000..aa5c2dd6bd6c --- /dev/null +++ b/.github/workflows/pytest_gpu_wt.yml @@ -0,0 +1,78 @@ +name: Run Pytest GPU tests + +on: + pull_request: + branches: + - main + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-n2-16" + python: + description: "Which python version should the artifact be built for?" + type: string + required: true + default: "3.12" + download_url_prefix: + description: "GCS location prefix from where the artifacts should be downloaded" + required: false + default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + run-tests: + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + test_env: [ + {cuda_version: "12.3", + image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, + {cuda_version: "12.1", + image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, + ] + enable-x_64: [1, 0] + + runs-on: ${{ inputs.runner }} + container: ${{ matrix.test_env.image }} + + name: "Pytest GPU (${{ inputs.runner }}, CUDA ${{ matrix.test_env.cuda_version }}, Python ${{ inputs.python }}, x64=${{ matrix.enable-x_64 }})" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_PYTHON: "python${{ inputs.python }}" + JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Set PLATFORM env var for use in artifact download URL + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + - name: Download artifacts built in the "build_artifacts" job + run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ + - name: Install dependencies + run: $JAXCI_PYTHON -m pip install -r build/requirements.in + - name: Run Pytest GPU tests + run: ./ci/run_pytest_gpu.sh diff --git a/.github/workflows/wheel_tests.yml b/.github/workflows/wheel_tests.yml new file mode 100644 index 000000000000..2b8304fc9a12 --- /dev/null +++ b/.github/workflows/wheel_tests.yml @@ -0,0 +1,85 @@ +name: CI - Wheel Tests + +on: + pull_request: + branches: + - main + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + build_jaxlib_artifact: + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Runner and Python values need to match the matrix stategy in the CPU tests job + # Enable Windows after we have fixed the runner issue + runner: ["linux-x86-n2-16", "linux-arm64-c4a-64",] # "windows-x86-n2-16"] + artifact: ["jaxlib"] + python: ["3.10"] + with: + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + clone_main_xla: 1 + upload_artifacts: true + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + + build_cuda_artifacts: + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Python values need to match the matrix stategy in the GPU tests job below + runner: ["linux-x86-n2-16"] + artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] + python: ["3.10",] + with: + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + clone_main_xla: 1 + upload_artifacts: true + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + + run_pytest_cpu: + needs: build_jaxlib_artifact + uses: ./.github/workflows/pytest_cpu_wt.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Runner OS and Python values need to match the matrix stategy in the build_jaxlib_artifact job above + runner: ["linux-x86-n2-64", "linux-arm64-c4a-64",] + python: ["3.10",] + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + + + run_pytest_gpu: + needs: [build_jaxlib_artifact, build_cuda_artifacts] + uses: ./.github/workflows/pytest_gpu_wt.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Python values need to match the matrix stategy in the build_jaxlib_artifact job above + runner: ["linux-x86-g2-48-l4-4gpu",] + python: ["3.10",] + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' \ No newline at end of file From a8da0ae49f17022691caef6e4d7b13c5c87e4399 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 18 Dec 2024 23:08:18 +0000 Subject: [PATCH 147/205] fix indent --- .github/workflows/wheel_tests.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/wheel_tests.yml b/.github/workflows/wheel_tests.yml index 2b8304fc9a12..612fd8f03cf2 100644 --- a/.github/workflows/wheel_tests.yml +++ b/.github/workflows/wheel_tests.yml @@ -25,11 +25,11 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - # Runner and Python values need to match the matrix stategy in the CPU tests job - # Enable Windows after we have fixed the runner issue - runner: ["linux-x86-n2-16", "linux-arm64-c4a-64",] # "windows-x86-n2-16"] - artifact: ["jaxlib"] - python: ["3.10"] + # Runner and Python values need to match the matrix stategy in the CPU tests job + # Enable Windows after we have fixed the runner issue + runner: ["linux-x86-n2-16", "linux-arm64-c4a-64",] # "windows-x86-n2-16"] + artifact: ["jaxlib"] + python: ["3.10"] with: runner: ${{ matrix.runner }} artifact: ${{ matrix.artifact }} From bb2a90c014e439e736d2f0df5a32221b9510a1c8 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 18 Dec 2024 23:08:59 +0000 Subject: [PATCH 148/205] fix indent --- .github/workflows/wheel_tests.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/wheel_tests.yml b/.github/workflows/wheel_tests.yml index 612fd8f03cf2..ef1351c51648 100644 --- a/.github/workflows/wheel_tests.yml +++ b/.github/workflows/wheel_tests.yml @@ -31,12 +31,12 @@ jobs: artifact: ["jaxlib"] python: ["3.10"] with: - runner: ${{ matrix.runner }} - artifact: ${{ matrix.artifact }} - python: ${{ matrix.python }} - clone_main_xla: 1 - upload_artifacts: true - upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + clone_main_xla: 1 + upload_artifacts: true + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' build_cuda_artifacts: uses: ./.github/workflows/build_artifacts.yml From fc484e653eabe227cf86b0278c20ddf9436312e0 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 18 Dec 2024 23:10:09 +0000 Subject: [PATCH 149/205] add missing type to iputs --- .github/workflows/pytest_gpu_wt.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pytest_gpu_wt.yml b/.github/workflows/pytest_gpu_wt.yml index aa5c2dd6bd6c..51c8b748ce1e 100644 --- a/.github/workflows/pytest_gpu_wt.yml +++ b/.github/workflows/pytest_gpu_wt.yml @@ -30,6 +30,7 @@ on: description: "GCS location prefix from where the artifacts should be downloaded" required: false default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} From 148a6a0e3e2d1f3bafa1dccf21c90d0c93fd200b Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 19 Dec 2024 00:27:24 +0000 Subject: [PATCH 150/205] move matrix stategy to calling workflow --- .github/workflows/pytest_cpu_wt.yml | 13 ++++++------ .github/workflows/pytest_gpu_wt.yml | 31 +++++++++++++++-------------- .github/workflows/wheel_tests.yml | 6 ++++++ 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/.github/workflows/pytest_cpu_wt.yml b/.github/workflows/pytest_cpu_wt.yml index c6d36d555319..6d7e075f5c79 100644 --- a/.github/workflows/pytest_cpu_wt.yml +++ b/.github/workflows/pytest_cpu_wt.yml @@ -26,6 +26,11 @@ on: type: string required: true default: "3.12" + enable-x64: + description: "Should x64 mode be enabled?" + type: string + required: true + default: "0" download_url_prefix: description: "GCS location prefix from where the artifacts should be downloaded" required: false @@ -42,22 +47,18 @@ jobs: run: # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. shell: bash - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - enable-x_64: [1, 0] runs-on: ${{ inputs.runner }} container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(inputs.runner, 'windows-x86') && null) }} - name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ matrix.enable-x_64 }})" + name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" - JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} + JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} steps: diff --git a/.github/workflows/pytest_gpu_wt.yml b/.github/workflows/pytest_gpu_wt.yml index 51c8b748ce1e..2ddac60acaaf 100644 --- a/.github/workflows/pytest_gpu_wt.yml +++ b/.github/workflows/pytest_gpu_wt.yml @@ -22,13 +22,23 @@ on: required: true default: "linux-x86-n2-16" python: - description: "Which python version should the artifact be built for?" + description: "Which python version to test?" type: string required: true default: "3.12" + cuda: + description: "Which CUDA version to test?" + type: string + required: true + default: "12.3" + enable-x64: + description: "Should x64 mode be enabled?" + type: string + required: true + default: "0" download_url_prefix: description: "GCS location prefix from where the artifacts should be downloaded" - required: false + required: true default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string @@ -38,26 +48,17 @@ concurrency: jobs: run-tests: - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - test_env: [ - {cuda_version: "12.3", - image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, - {cuda_version: "12.1", - image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, - ] - enable-x_64: [1, 0] runs-on: ${{ inputs.runner }} - container: ${{ matrix.test_env.image }} + container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') || + (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') }} - name: "Pytest GPU (${{ inputs.runner }}, CUDA ${{ matrix.test_env.cuda_version }}, Python ${{ inputs.python }}, x64=${{ matrix.enable-x_64 }})" + name: "Pytest GPU (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" - JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} + JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/wheel_tests.yml b/.github/workflows/wheel_tests.yml index ef1351c51648..6f8939961f28 100644 --- a/.github/workflows/wheel_tests.yml +++ b/.github/workflows/wheel_tests.yml @@ -64,9 +64,11 @@ jobs: # Runner OS and Python values need to match the matrix stategy in the build_jaxlib_artifact job above runner: ["linux-x86-n2-64", "linux-arm64-c4a-64",] python: ["3.10",] + enable-x64: [1, 0] with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} + enable-x64: ${{ matrix.enable-x64 }} download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' @@ -79,7 +81,11 @@ jobs: # Python values need to match the matrix stategy in the build_jaxlib_artifact job above runner: ["linux-x86-g2-48-l4-4gpu",] python: ["3.10",] + cuda: ["12.3", "12.1"] + enable-x64: [1, 0] with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} + cuda: ${{ matrix.cuda }} + enable-x64: ${{ matrix.enable-x64 }} download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' \ No newline at end of file From 18947b3cf7a1b1fb4a34514497cae36e8f7692ac Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 19 Dec 2024 02:09:10 +0000 Subject: [PATCH 151/205] remove concurrency --- .github/workflows/pytest_cpu_wt.yml | 4 ---- .github/workflows/pytest_gpu_wt.yml | 4 ---- 2 files changed, 8 deletions(-) diff --git a/.github/workflows/pytest_cpu_wt.yml b/.github/workflows/pytest_cpu_wt.yml index 6d7e075f5c79..534ae165176c 100644 --- a/.github/workflows/pytest_cpu_wt.yml +++ b/.github/workflows/pytest_cpu_wt.yml @@ -37,10 +37,6 @@ on: default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - jobs: run-tests: defaults: diff --git a/.github/workflows/pytest_gpu_wt.yml b/.github/workflows/pytest_gpu_wt.yml index 2ddac60acaaf..c322ada351af 100644 --- a/.github/workflows/pytest_gpu_wt.yml +++ b/.github/workflows/pytest_gpu_wt.yml @@ -42,10 +42,6 @@ on: default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - jobs: run-tests: From 3c1eda75211c854fadf11a5a744f208fc75f3873 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 19 Dec 2024 13:32:21 +0000 Subject: [PATCH 152/205] add bazel gpu to wheel tests wokrflow --- .github/workflows/bazel_cuda_non_rbe_wt.yml | 68 +++++++++++++++++++ .github/workflows/pytest_cpu_wt.yml | 10 +-- .../{pytest_gpu_wt.yml => pytest_cuda_wt.yml} | 4 +- .github/workflows/wheel_tests.yml | 28 ++++++-- 4 files changed, 97 insertions(+), 13 deletions(-) create mode 100644 .github/workflows/bazel_cuda_non_rbe_wt.yml rename .github/workflows/{pytest_gpu_wt.yml => pytest_cuda_wt.yml} (96%) diff --git a/.github/workflows/bazel_cuda_non_rbe_wt.yml b/.github/workflows/bazel_cuda_non_rbe_wt.yml new file mode 100644 index 000000000000..d680e151e2aa --- /dev/null +++ b/.github/workflows/bazel_cuda_non_rbe_wt.yml @@ -0,0 +1,68 @@ +name: CI - Bazel CUDA (non-RBE) + +on: + pull_request: + branches: + - main + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-n2-16" + python: + description: "Which python version to test?" + type: string + required: true + default: "3.12" + enable-x64: + description: "Should x64 mode be enabled?" + type: string + required: true + default: "0" + download_url_prefix: + description: "GCS location prefix from where the artifacts should be downloaded" + required: true + default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + + +jobs: + run-tests: + runs-on: ${{ inputs.runner }} + + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-rbe:latest" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}" + + name: "Bazel single accelerator and multi-accelerator GPU tests (Non RBE, ${{ inputs.runner }}, Python 3.11, x64=${{ inputs.enable-x64 }})" + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Set PLATFORM env var for use in artifact download URL + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + - name: Download the wheel artifacts from GCS + run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ + - name: Run Bazel tests + run: ./ci/run_bazel_test_gpu_non_rbe.sh diff --git a/.github/workflows/pytest_cpu_wt.yml b/.github/workflows/pytest_cpu_wt.yml index 534ae165176c..c59b50f8e9c1 100644 --- a/.github/workflows/pytest_cpu_wt.yml +++ b/.github/workflows/pytest_cpu_wt.yml @@ -1,4 +1,4 @@ -name: Run Pytest CPU tests +name: CI - Pytest CPU on: pull_request: @@ -75,17 +75,17 @@ jobs: fi echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Download artifacts built in the "build_artifacts" job (non-Windows runs) + - name: Download the jaxlib wheel from GCS (non-Windows runs) if: ${{ !contains(inputs.runner, 'windows-x86') }} run: >- mkdir -p $(pwd)/dist && - gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - - name: Download artifacts built in the "build_artifacts" job (Windows runs) + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/jaxlib*.whl $(pwd)/dist/ + - name: Download the jaxlib wheel from GCS (Windows runs) if: ${{ contains(inputs.runner, 'windows-x86') }} shell: cmd run: >- mkdir dist && - gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl dist/ + gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/jaxlib*.whl dist/ - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in - name: Run Pytest CPU tests diff --git a/.github/workflows/pytest_gpu_wt.yml b/.github/workflows/pytest_cuda_wt.yml similarity index 96% rename from .github/workflows/pytest_gpu_wt.yml rename to .github/workflows/pytest_cuda_wt.yml index c322ada351af..385b1c653d39 100644 --- a/.github/workflows/pytest_gpu_wt.yml +++ b/.github/workflows/pytest_cuda_wt.yml @@ -1,4 +1,4 @@ -name: Run Pytest GPU tests +name: CI - Pytest CUDA on: pull_request: @@ -68,7 +68,7 @@ jobs: os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Download artifacts built in the "build_artifacts" job + - name: Download the wheel artifacts from GCS run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in diff --git a/.github/workflows/wheel_tests.yml b/.github/workflows/wheel_tests.yml index 6f8939961f28..0d6bce1010a8 100644 --- a/.github/workflows/wheel_tests.yml +++ b/.github/workflows/wheel_tests.yml @@ -36,7 +36,7 @@ jobs: python: ${{ matrix.python }} clone_main_xla: 1 upload_artifacts: true - upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}' build_cuda_artifacts: uses: ./.github/workflows/build_artifacts.yml @@ -53,7 +53,7 @@ jobs: python: ${{ matrix.python }} clone_main_xla: 1 upload_artifacts: true - upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}' run_pytest_cpu: needs: build_jaxlib_artifact @@ -69,16 +69,16 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}' run_pytest_gpu: needs: [build_jaxlib_artifact, build_cuda_artifacts] - uses: ./.github/workflows/pytest_gpu_wt.yml + uses: ./.github/workflows/pytest_cuda_wt.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: - # Python values need to match the matrix stategy in the build_jaxlib_artifact job above + # Python values need to match the matrix stategy in the build artifacts job above runner: ["linux-x86-g2-48-l4-4gpu",] python: ["3.10",] cuda: ["12.3", "12.1"] @@ -88,4 +88,20 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} - download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' \ No newline at end of file + download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}' + + run_bazel_test_gpu: + needs: [build_jaxlib_artifact, build_cuda_artifacts] + uses: ./.github/workflows/bazel_cuda_non_rbe_wt.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Python values need to match the matrix stategy in the build artifacts job above + runner: ["linux-x86-g2-48-l4-4gpu",] + python: ["3.10",] + enable-x64: [1, 0] + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + enable-x64: ${{ matrix.enable-x64 }} + download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}' \ No newline at end of file From 1c834657f50c3eb5670798ee7c78ae0ef6f40e59 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 19 Dec 2024 16:54:12 +0000 Subject: [PATCH 153/205] switch to t2a to debug crashes --- .github/workflows/pytest_cpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index fb986ab0e482..fd4f884c18dc 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -44,7 +44,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - runner: ["linux-x86-n2-64", "linux-arm64-c4a-64", "windows-x86-n2-64"] + runner: ["linux-x86-n2-64", "linux-x86-t2a-48-dev", "windows-x86-n2-64"] python: ["3.10",] # "3.11", "3.12", "3.13"] enable-x_64: [1, 0] From db08947782a90a5bf5cf556db9d2d4855ad9bbc6 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 19 Dec 2024 17:00:31 +0000 Subject: [PATCH 154/205] disable windows --- .github/workflows/pytest_cpu.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index fd4f884c18dc..70a084912712 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -24,7 +24,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - runner: ["linux-x86-n2-16", "linux-arm64-c4a-64", "windows-x86-n2-16"] + runner: ["linux-x86-n2-16", "linux-arm64-c4a-64",] # "windows-x86-n2-16"] artifact: ["jaxlib"] python: ["3.10",] # "3.11", "3.12", "3.13"] with: @@ -44,7 +44,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - runner: ["linux-x86-n2-64", "linux-x86-t2a-48-dev", "windows-x86-n2-64"] + runner: ["linux-x86-n2-64", "linux-x86-t2a-48-dev",] # "windows-x86-n2-64"] python: ["3.10",] # "3.11", "3.12", "3.13"] enable-x_64: [1, 0] From a6fdb206fbe58ee4b46d75b5838767b14694db91 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 19 Dec 2024 18:18:54 +0000 Subject: [PATCH 155/205] update string match --- .github/workflows/pytest_cpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 70a084912712..8586a2c0392d 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -50,7 +50,7 @@ jobs: runs-on: ${{ matrix.runner }} container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(matrix.runner, 't2a') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(matrix.runner, 'windows-x86') && null) }} name: "Pytest CPU (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" From c0c7ec8ab2711e706ac2626d6e8ddd4eb6832c03 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 19 Dec 2024 21:20:14 +0000 Subject: [PATCH 156/205] update string match --- .github/workflows/pytest_cpu.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 8586a2c0392d..35687d39cceb 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -49,8 +49,8 @@ jobs: enable-x_64: [1, 0] runs-on: ${{ matrix.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 't2a') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + container: ${{ (contains(matrix.runner, 'linux-x86-n2-64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(matrix.runner, 'linux-x86-t2a-48-dev') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(matrix.runner, 'windows-x86') && null) }} name: "Pytest CPU (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" From 9232ee358d89945e9e16f1e3747861e89b8d575f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 19 Dec 2024 21:20:43 +0000 Subject: [PATCH 157/205] update build.py to upstream --- build/build.py | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/build/build.py b/build/build.py index 35b56eadbb3e..3d9bca3fa5c0 100755 --- a/build/build.py +++ b/build/build.py @@ -277,6 +277,15 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): """, ) + compile_group.add_argument( + "--gcc_path", + type=str, + default="", + help=""" + Path to the GCC binary to use. + """, + ) + compile_group.add_argument( "--disable_mkl_dnn", action="store_true", @@ -454,12 +463,19 @@ async def main(): wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") + if clang_major_version >= 16: # Enable clang settings that are needed for the build to work with newer # versions of Clang. wheel_build_command_base.append("--config=clang") else: - logging.debug("Use Clang: False") + gcc_path = args.gcc_path or utils.get_gcc_path_or_exit() + logging.debug( + "Using GCC as the compiler, gcc path: %s", + gcc_path, + ) + wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"") + wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"") if not args.disable_mkl_dnn: logging.debug("Enabling MKL DNN") @@ -501,9 +517,12 @@ async def main(): wheel_build_command_base.append( f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" ) - if args.build_cuda_with_clang: - logging.debug("Building CUDA with Clang") - wheel_build_command_base.append("--config=build_cuda_with_clang") + if args.build_cuda_with_clang: + logging.debug("Building CUDA with Clang") + wheel_build_command_base.append("--config=build_cuda_with_clang") + else: + logging.debug("Building CUDA with NVCC") + wheel_build_command_base.append("--config=build_cuda_with_nvcc") else: logging.debug("Building CUDA with NVCC") wheel_build_command_base.append("--config=build_cuda_with_nvcc") @@ -527,20 +546,6 @@ async def main(): f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}" ) - if "rocm" in args.wheels: - wheel_build_command_base.append("--config=rocm_base") - if args.use_clang: - wheel_build_command_base.append("--config=rocm") - wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") - if args.rocm_path: - logging.debug("ROCm tookit path: %s", args.rocm_path) - wheel_build_command_base.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") - if args.rocm_amdgpu_targets: - logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) - wheel_build_command_base.append( - f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" - ) - # Append additional build options at the end to override any options set in # .bazelrc or above. if args.bazel_options: From 08e708d29f599a6e8dd24a7cadf7b7349f92dbe8 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 19 Dec 2024 21:24:40 +0000 Subject: [PATCH 158/205] update utils.py to upstream --- build/tools/utils.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/build/tools/utils.py b/build/tools/utils.py index 1a4b26fd3d4a..16450ab9eae5 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -170,19 +170,25 @@ def get_bazel_version(bazel_path): return None return tuple(int(x) for x in match.group(1).split(".")) -def get_clang_path_or_exit(): - which_clang_output = shutil.which("clang") - if which_clang_output: - # If we've found a clang on the path, need to get the fully resolved path +def get_compiler_path_or_exit(compiler_path_flag, compiler_name): + which_compiler_output = shutil.which(compiler_name) + if which_compiler_output: + # If we've found a compiler on the path, need to get the fully resolved path # to ensure that system headers are found. - return str(pathlib.Path(which_clang_output).resolve()) + return str(pathlib.Path(which_compiler_output).resolve()) else: print( - "--clang_path is unset and clang cannot be found" - " on the PATH. Please pass --clang_path directly." + f"--{compiler_path_flag} is unset and {compiler_name} cannot be found" + " on the PATH. Please pass --{compiler_path_flag} directly." ) sys.exit(-1) +def get_gcc_path_or_exit(): + return get_compiler_path_or_exit("gcc_path", "gcc") + +def get_clang_path_or_exit(): + return get_compiler_path_or_exit("clang_path", "clang") + def get_clang_major_version(clang_path): clang_version_proc = subprocess.run( [clang_path, "-E", "-P", "-"], From 71699a2d63accdcd5c87da830c0f80f077690766 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 19 Dec 2024 21:57:14 +0000 Subject: [PATCH 159/205] update to upstream --- build/build.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/build/build.py b/build/build.py index 3d9bca3fa5c0..0b640dd123e4 100755 --- a/build/build.py +++ b/build/build.py @@ -515,8 +515,8 @@ async def main(): wheel_build_command_base.append("--config=cuda") if args.use_clang: wheel_build_command_base.append( - f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" - ) + f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" + ) if args.build_cuda_with_clang: logging.debug("Building CUDA with Clang") wheel_build_command_base.append("--config=build_cuda_with_clang") @@ -546,6 +546,20 @@ async def main(): f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}" ) + if "rocm" in args.wheels: + wheel_build_command_base.append("--config=rocm_base") + if args.use_clang: + wheel_build_command_base.append("--config=rocm") + wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + if args.rocm_path: + logging.debug("ROCm tookit path: %s", args.rocm_path) + wheel_build_command_base.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") + if args.rocm_amdgpu_targets: + logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) + wheel_build_command_base.append( + f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" + ) + # Append additional build options at the end to override any options set in # .bazelrc or above. if args.bazel_options: @@ -571,8 +585,6 @@ async def main(): # Wheel build command execution for wheel in args.wheels.split(","): - wheel_build_command = copy.deepcopy(wheel_build_command_base) - # Allow CUDA/ROCm wheels without the "jax-" prefix. if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: wheel = "jax-" + wheel @@ -585,6 +597,7 @@ async def main(): ) sys.exit(1) + wheel_build_command = copy.deepcopy(wheel_build_command_base) print("\n") logger.info( "Building %s for %s %s...", From c7df7a8295ecf9d788484df1545c3a13c05da3ff Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 20 Dec 2024 22:18:34 +0000 Subject: [PATCH 160/205] switch to t2a machines, delete old workflows --- .github/workflows/bazel_gpu_non_rbe.yml | 81 ---------------- .github/workflows/build_artifacts.yml | 2 +- .github/workflows/pytest_cpu.yml | 97 ------------------- .github/workflows/pytest_cpu_matrix_debug.yml | 91 ----------------- .github/workflows/pytest_cpu_wt.yml | 6 +- .github/workflows/pytest_cuda_wt.yml | 2 +- .github/workflows/pytest_gpu.yml | 89 ----------------- .github/workflows/wheel_tests.yml | 12 +-- 8 files changed, 11 insertions(+), 369 deletions(-) delete mode 100644 .github/workflows/bazel_gpu_non_rbe.yml delete mode 100644 .github/workflows/pytest_cpu.yml delete mode 100644 .github/workflows/pytest_cpu_matrix_debug.yml delete mode 100644 .github/workflows/pytest_gpu.yml diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml deleted file mode 100644 index b6bbc333c083..000000000000 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ /dev/null @@ -1,81 +0,0 @@ -name: Run Bazel GPU tests (non RBE) - -on: - pull_request: - branches: - - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - - 'no' - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -jobs: - build_artifacts: - uses: ./.github/workflows/build_artifacts.yml - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - runner: ["linux-x86-n2-16"] - artifact: ["jaxlib", "jax-cuda-plugin", "jax-cuda-pjrt"] - python: ["3.10",] # "3.11", "3.12", "3.13"] - with: - runner: ${{ matrix.runner }} - artifact: ${{ matrix.artifact }} - python: ${{ matrix.python }} - clone_main_xla: 1 - upload_artifacts: true - upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - - run_bazel_tests: - needs: build_artifacts - runs-on: "linux-x86-g2-48-l4-4gpu" - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - enable-x_64: [1] #, 0] - container-list: [ - {name: "nosla", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, - {name: "ml-build", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"}, - {name: "ml-build-rbe", image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-rbe:latest"} - ] - container: - image: ${{ matrix.container-list.image }} - options: --gpus all - env: - JAXCI_HERMETIC_PYTHON_VERSION: "3.10" - JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} - - name: "Bazel single accelerator and multi-accelerator GPU tests (Non RBE, linux-x86-g2-48-l4-4gpu, Python 3.11, x64=${{ matrix.enable-x_64 }}, image=${{ matrix.container-list.name }})" - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Set Platform - run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) - - # Adjust name for Windows - if [[ $os =~ "msys_nt" ]]; then - os="windows" - fi - - echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Download the artifacts built in the "build_artifacts" job - run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - - name: Run Bazel GPU tests locally - run: ./ci/run_bazel_test_gpu_non_rbe.sh diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index e1bcf3e02960..20ada8d19681 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -61,7 +61,7 @@ on: upload_destination_prefix: description: "GCS location prefix to where the artifacts should be uploaded" required: false - default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + default: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string jobs: diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml deleted file mode 100644 index 35687d39cceb..000000000000 --- a/.github/workflows/pytest_cpu.yml +++ /dev/null @@ -1,97 +0,0 @@ -name: Run Pytest CPU tests - -on: - pull_request: - branches: - - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -jobs: - build_jaxlib_artifact: - uses: ./.github/workflows/build_artifacts.yml - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - runner: ["linux-x86-n2-16", "linux-arm64-c4a-64",] # "windows-x86-n2-16"] - artifact: ["jaxlib"] - python: ["3.10",] # "3.11", "3.12", "3.13"] - with: - runner: ${{ matrix.runner }} - artifact: ${{ matrix.artifact }} - python: ${{ matrix.python }} - clone_main_xla: 1 - upload_artifacts: true - upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - - run_pytest: - needs: build_jaxlib_artifact - defaults: - run: - # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. - shell: bash - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - runner: ["linux-x86-n2-64", "linux-x86-t2a-48-dev",] # "windows-x86-n2-64"] - python: ["3.10",] # "3.11", "3.12", "3.13"] - enable-x_64: [1, 0] - - runs-on: ${{ matrix.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86-n2-64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-x86-t2a-48-dev') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || - (contains(matrix.runner, 'windows-x86') && null) }} - - name: "Pytest CPU (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" - - env: - JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} - JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} - - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Set Platform - run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) - - # Adjust name for Windows - if [[ $os =~ "msys_nt" ]]; then - os="windows" - fi - - echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Download artifacts built in the "build_artifacts" job (non-Windows runs) - if: ${{ !contains(matrix.runner, 'windows-x86') }} - run: >- - mkdir -p $(pwd)/dist && - gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - - name: Download artifacts built in the "build_artifacts" job (Windows runs) - if: ${{ contains(matrix.runner, 'windows-x86') }} - shell: cmd - run: >- - mkdir dist && - gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl dist/ - - name: Install dependencies - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install -r build/requirements.in - - name: Run Pytest CPU tests - run: ./ci/run_pytest_cpu.sh \ No newline at end of file diff --git a/.github/workflows/pytest_cpu_matrix_debug.yml b/.github/workflows/pytest_cpu_matrix_debug.yml deleted file mode 100644 index 0cbced40549e..000000000000 --- a/.github/workflows/pytest_cpu_matrix_debug.yml +++ /dev/null @@ -1,91 +0,0 @@ -name: Run Pytest CPU tests (matrix debug) - -on: - pull_request: - branches: - - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -jobs: - build_jaxlib_artifact: - name: "Build the jaxlib aritfact using latest XLA" - uses: ./.github/workflows/build_artifacts_matrix_debug.yml - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - runner: ["linux-x86-n2-16",] # "linux-arm64-c4a-64", "windows-x86-n2-16"] - artifact: ["jaxlib"] - python: ["3.10",] # "3.11", "3.12", "3.13"] - with: - runner: ${{ matrix.runner }} - artifact: ${{ matrix.artifact }} - python: ${{ matrix.python }} - clone_main_xla: 1 - upload_artifacts: true - upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - - run_pytest: - name: "Run CPU tests with Pytest" - needs: build_jaxlib_artifact - defaults: - run: - # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. - shell: bash - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - runner: ["linux-x86-n2-64",] # "linux-arm64-c4a-64", "windows-x86-n2-64"] - python: ["3.10",] # "3.11", "3.12", "3.13"] - - runs-on: ${{ matrix.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || - (contains(matrix.runner, 'windows-x86') && null) }} - - env: - JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Set Platform - run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) - - # Adjust name for Windows - if [[ $os =~ "msys_nt" ]]; then - os="windows" - fi - - echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Download the artifacts built in the "build_artifacts" job - run: >- - mkdir -p $(pwd)/dist && - gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - - name: Install pytest - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install pytest - - name: Install dependencies - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install -r build/requirements.in - - name: Run Pytest CPU tests - run: ./ci/run_pytest_cpu.sh diff --git a/.github/workflows/pytest_cpu_wt.yml b/.github/workflows/pytest_cpu_wt.yml index c59b50f8e9c1..5088d885048f 100644 --- a/.github/workflows/pytest_cpu_wt.yml +++ b/.github/workflows/pytest_cpu_wt.yml @@ -34,7 +34,7 @@ on: download_url_prefix: description: "GCS location prefix from where the artifacts should be downloaded" required: false - default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + default: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string jobs: @@ -45,8 +45,8 @@ jobs: shell: bash runs-on: ${{ inputs.runner }} - container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + container: ${{ (contains(inputs.runner, 'linux-x86-n2') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(inputs.runner, 'linux-x86-t2a') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(inputs.runner, 'windows-x86') && null) }} name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" diff --git a/.github/workflows/pytest_cuda_wt.yml b/.github/workflows/pytest_cuda_wt.yml index 385b1c653d39..334b0b727c79 100644 --- a/.github/workflows/pytest_cuda_wt.yml +++ b/.github/workflows/pytest_cuda_wt.yml @@ -39,7 +39,7 @@ on: download_url_prefix: description: "GCS location prefix from where the artifacts should be downloaded" required: true - default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + default: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string jobs: diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml deleted file mode 100644 index 7a37fd20cccf..000000000000 --- a/.github/workflows/pytest_gpu.yml +++ /dev/null @@ -1,89 +0,0 @@ -name: Run Pytest GPU tests - -on: - pull_request: - branches: - - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - - 'no' - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -jobs: - build_artifacts: - uses: ./.github/workflows/build_artifacts.yml - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - runner: ["linux-x86-n2-16"] - artifact: ["jaxlib", "jax-cuda-plugin", "jax-cuda-pjrt"] - python: ["3.10",] # "3.11", "3.12", "3.13"] - with: - runner: ${{ matrix.runner }} - artifact: ${{ matrix.artifact }} - python: ${{ matrix.python }} - clone_main_xla: 1 - upload_artifacts: true - upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - - run_tests: - needs: build_artifacts - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - test_env: [ - {cuda_version: "12.3", runner: "linux-x86-g2-48-l4-4gpu", - image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, - {cuda_version: "12.1", runner: "linux-x86-g2-48-l4-4gpu", - image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, - ] - python: ["3.10"] - enable-x_64: [1, 0] - - runs-on: ${{ matrix.test_env.runner }} - container: - image: ${{ matrix.test_env.image }} - - name: "Pytest GPU (${{ matrix.test_env.runner }}, CUDA ${{ matrix.test_env.cuda_version }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" - - env: - JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} - JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Set Platform - run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) - - # Adjust name for Windows - if [[ $os =~ "msys_nt" ]]; then - os="windows" - - fi - - echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - - name: Download artifacts built in the "build_artifacts" job - run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ - - name: Install dependencies - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install -r build/requirements.in - - name: Run Pytest GPU tests - run: ./ci/run_pytest_gpu.sh diff --git a/.github/workflows/wheel_tests.yml b/.github/workflows/wheel_tests.yml index 0d6bce1010a8..00d59194acdf 100644 --- a/.github/workflows/wheel_tests.yml +++ b/.github/workflows/wheel_tests.yml @@ -27,7 +27,7 @@ jobs: matrix: # Runner and Python values need to match the matrix stategy in the CPU tests job # Enable Windows after we have fixed the runner issue - runner: ["linux-x86-n2-16", "linux-arm64-c4a-64",] # "windows-x86-n2-16"] + runner: ["linux-x86-n2-16", "linux-x86-t2a-48-dev",] # "windows-x86-n2-16"] artifact: ["jaxlib"] python: ["3.10"] with: @@ -36,7 +36,7 @@ jobs: python: ${{ matrix.python }} clone_main_xla: 1 upload_artifacts: true - upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}' + upload_destination_prefix: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}' build_cuda_artifacts: uses: ./.github/workflows/build_artifacts.yml @@ -53,7 +53,7 @@ jobs: python: ${{ matrix.python }} clone_main_xla: 1 upload_artifacts: true - upload_destination_prefix: '${{ github.workflow }}/${{ github.run_number }}' + upload_destination_prefix: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}' run_pytest_cpu: needs: build_jaxlib_artifact @@ -69,7 +69,7 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}' + download_url_prefix: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}' run_pytest_gpu: @@ -88,7 +88,7 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} - download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}' + download_url_prefix: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}' run_bazel_test_gpu: needs: [build_jaxlib_artifact, build_cuda_artifacts] @@ -104,4 +104,4 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - download_url_prefix: '${{ github.workflow }}/${{ github.run_number }}' \ No newline at end of file + download_url_prefix: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}' \ No newline at end of file From 6d23c1a991ff074b2a802e968b68a3b5d92e8178 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 20 Dec 2024 22:56:38 +0000 Subject: [PATCH 161/205] use the t2a machine for running tests --- .github/workflows/wheel_tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/wheel_tests.yml b/.github/workflows/wheel_tests.yml index 00d59194acdf..dc17869bfb49 100644 --- a/.github/workflows/wheel_tests.yml +++ b/.github/workflows/wheel_tests.yml @@ -27,7 +27,7 @@ jobs: matrix: # Runner and Python values need to match the matrix stategy in the CPU tests job # Enable Windows after we have fixed the runner issue - runner: ["linux-x86-n2-16", "linux-x86-t2a-48-dev",] # "windows-x86-n2-16"] + runner: ["linux-x86-n2-16", "linux-arm64-c4a-64",] # "windows-x86-n2-16"] artifact: ["jaxlib"] python: ["3.10"] with: @@ -62,7 +62,7 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Runner OS and Python values need to match the matrix stategy in the build_jaxlib_artifact job above - runner: ["linux-x86-n2-64", "linux-arm64-c4a-64",] + runner: ["linux-x86-n2-64", "linux-x86-t2a-48-dev",] python: ["3.10",] enable-x64: [1, 0] with: From ae0d99b92b62ced3080894936c94e78b39d0a2eb Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Sat, 21 Dec 2024 01:21:50 +0000 Subject: [PATCH 162/205] move output_path to be inside for loop --- build/build.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/build/build.py b/build/build.py index 8fce10fb8437..c0434fa80238 100755 --- a/build/build.py +++ b/build/build.py @@ -580,11 +580,11 @@ async def main(): if args.configure_only: logging.info("--configure_only is set so not running any Bazel commands.") else: - output_path = args.output_path - logger.debug("Artifacts output directory: %s", output_path) - # Wheel build command execution for wheel in args.wheels.split(","): + output_path = args.output_path + logger.debug("Artifacts output directory: %s", output_path) + # Allow CUDA/ROCm wheels without the "jax-" prefix. if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: wheel = "jax-" + wheel From 7da8aca769f48ae5d93337dd01204491c4b2a4af Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Sat, 21 Dec 2024 01:22:31 +0000 Subject: [PATCH 163/205] change to building an editable wheel --- .github/workflows/build_artifacts_single_invocation.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts_single_invocation.yml b/.github/workflows/build_artifacts_single_invocation.yml index 7e2c4f41e6e2..e62ef120f366 100644 --- a/.github/workflows/build_artifacts_single_invocation.yml +++ b/.github/workflows/build_artifacts_single_invocation.yml @@ -34,4 +34,4 @@ jobs: - name: Mark GitHub repo as safe run: git config --global --add safe.directory "$GITHUB_WORKSPACE" - name: Build all artifacts with a single invocation of build CLI - run: python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file + run: python build/build.py build --editable --use_clang --output_path=$(pwd)/dist --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose \ No newline at end of file From 740a92c38efc09276609ae337d7587ac3e1ffdfb Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Sat, 21 Dec 2024 01:23:15 +0000 Subject: [PATCH 164/205] sync to upstream --- .github/workflows/cloud-tpu-ci-nightly.yml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 31b8cf6c2a3d..8f1a01b3bf4a 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -12,15 +12,6 @@ name: CI - Cloud TPU (nightly) on: -<<<<<<< HEAD -# schedule: -# - cron: "0 */2 * * *" # Run every 2 hours -# workflow_dispatch: # allows triggering the workflow run manually -# # This should also be set to read-only in the project settings, but it's nice to -# # document and enforce the permissions here. -# permissions: -# contents: read -======= schedule: - cron: "0 2,14 * * *" # Run at 7am and 7pm PST workflow_dispatch: # allows triggering the workflow run manually @@ -28,7 +19,6 @@ on: # document and enforce the permissions here. permissions: contents: read ->>>>>>> 402814b10b8f0df268910e1628b72cbe252aa045 jobs: cloud-tpu-test: strategy: From 47c7afab9be998c209423031ce36f7bef7a3b181 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Sat, 21 Dec 2024 01:24:51 +0000 Subject: [PATCH 165/205] create a pytest cpu workflow to debug arm64 crashes --- .github/workflows/pytest_cpu_crash_debug.yml | 57 ++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 .github/workflows/pytest_cpu_crash_debug.yml diff --git a/.github/workflows/pytest_cpu_crash_debug.yml b/.github/workflows/pytest_cpu_crash_debug.yml new file mode 100644 index 000000000000..6a05131bdc0b --- /dev/null +++ b/.github/workflows/pytest_cpu_crash_debug.yml @@ -0,0 +1,57 @@ +name: CI - Pytest CPU + +on: + pull_request: + branches: + - main + +jobs: + run-tests: + defaults: + run: + # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. + shell: bash + + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-t2a-48-dev",] + python: ["3.10",] + enable-x64: [1, 0] + + runs-on: ${{ matrix.runner }} + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest" + + name: "Pytest CPU (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x64 }})" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" + JAXCI_PYTHON: "python${{ matrix.python }}" + JAXCI_ENABLE_X64: ${{ matrix.enable-x64 }} + + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + # Checkout XLA at head, if we're building jaxlib at head. + - name: Checkout XLA at head + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: openxla/xla + path: xla + # We need to mark the GitHub workspace as safe as otherwise git commands will fail. + - name: Mark GitHub workspace as safe + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + - name: Set PLATFORM env var for use in artifact download URL + run: | + $JAXCI_PYTHON build/build.py build --wheels=jaxlib \ + --bazel_options=--config=rbe_linux_x86_64 \ + --local_xla_path="$(pwd)/xla" \ + --verbose + - name: Install dependencies + run: $JAXCI_PYTHON -m pip install -r build/requirements.in + - name: Run Pytest CPU tests + run: ./ci/run_pytest_cpu.sh \ No newline at end of file From be0e54cabf14bce7842b46a5bbcd147ba427b77e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Sat, 21 Dec 2024 01:27:55 +0000 Subject: [PATCH 166/205] Change name and build config --- .github/workflows/pytest_cpu_crash_debug.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pytest_cpu_crash_debug.yml b/.github/workflows/pytest_cpu_crash_debug.yml index 6a05131bdc0b..43e5dcf065ec 100644 --- a/.github/workflows/pytest_cpu_crash_debug.yml +++ b/.github/workflows/pytest_cpu_crash_debug.yml @@ -1,4 +1,4 @@ -name: CI - Pytest CPU +name: CI - Pytest CPU (crash debug) on: pull_request: @@ -45,10 +45,10 @@ jobs: - name: Mark GitHub workspace as safe run: | git config --global --add safe.directory "$GITHUB_WORKSPACE" - - name: Set PLATFORM env var for use in artifact download URL + - name: Build jaxlib run: | $JAXCI_PYTHON build/build.py build --wheels=jaxlib \ - --bazel_options=--config=rbe_linux_x86_64 \ + --bazel_options=--config=ci_linux_aarch64 \ --local_xla_path="$(pwd)/xla" \ --verbose - name: Install dependencies From 903910a151f1c64199bc893cc261580845115f8c Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 23 Dec 2024 17:15:17 +0000 Subject: [PATCH 167/205] upload artifacts to a single location in gcs to match internal ci --- .github/workflows/bazel_cuda_non_rbe_wt.yml | 8 +++--- .github/workflows/build_artifacts.yml | 21 ++++------------ .github/workflows/pytest_cpu_wt.yml | 27 +++++++++++++-------- .github/workflows/pytest_cuda_wt.yml | 21 ++++++++++------ .github/workflows/wheel_tests.yml | 10 ++++---- 5 files changed, 45 insertions(+), 42 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe_wt.yml b/.github/workflows/bazel_cuda_non_rbe_wt.yml index d680e151e2aa..55b770307613 100644 --- a/.github/workflows/bazel_cuda_non_rbe_wt.yml +++ b/.github/workflows/bazel_cuda_non_rbe_wt.yml @@ -31,10 +31,10 @@ on: type: string required: true default: "0" - download_url_prefix: - description: "GCS location prefix from where the artifacts should be downloaded" + gcs_download_uri: + description: "GCS location URI from where the artifacts should be downloaded" required: true - default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' type: string @@ -63,6 +63,6 @@ jobs: arch=$(uname -m) echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Download the wheel artifacts from GCS - run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ + run: mkdir -p $(pwd)/dist && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ - name: Run Bazel tests run: ./ci/run_bazel_test_gpu_non_rbe.sh diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 20ada8d19681..d929ee870f9a 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -58,10 +58,10 @@ on: required: false default: false type: boolean - upload_destination_prefix: - description: "GCS location prefix to where the artifacts should be uploaded" + gcs_upload_uri: + description: "GCS location URI to where the artifacts should be uploaded" required: false - default: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + default: ' gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' type: string jobs: @@ -95,23 +95,12 @@ jobs: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Build ${{ inputs.artifact }} run: ./ci/build_artifacts.sh "${{ inputs.artifact }}" - - name: Set PLATFORM env var for use in upload destination - run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) - - # Adjust name for Windows - if [[ $os =~ "msys_nt" ]]; then - os="windows" - fi - - echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV - name: Upload artifacts to GCS bucket (non-Windows) if: >- ${{ inputs.upload_artifacts && !contains(inputs.runner, 'windows-x86') }} - run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ + run: gsutil -m cp -r $(pwd)/dist/*.whl "${{ inputs.gcs_upload_uri }}"/ - name: Upload artifacts to GCS bucket (Windows) if: >- ${{ inputs.upload_artifacts && contains(inputs.runner, 'windows-x86') }} shell: cmd - run: gsutil -m cp -r dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ \ No newline at end of file + run: gsutil -m cp -r dist/*.whl "${{ inputs.gcs_upload_uri }}"/ \ No newline at end of file diff --git a/.github/workflows/pytest_cpu_wt.yml b/.github/workflows/pytest_cpu_wt.yml index 5088d885048f..c0c08d33b1d4 100644 --- a/.github/workflows/pytest_cpu_wt.yml +++ b/.github/workflows/pytest_cpu_wt.yml @@ -31,10 +31,10 @@ on: type: string required: true default: "0" - download_url_prefix: + gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: false - default: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' type: string jobs: @@ -58,34 +58,41 @@ jobs: steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Set PLATFORM env var for use in artifact download URL + - name: Set env vars for use in artifact download URL run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) - # Adjust name for Windows - if [[ $os =~ "msys_nt" ]]; then - os="windows" + # Adjust os and arch for Windows x86 + if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then + os="win" + arch="amd64" fi + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + python_major_minor="${JAXCI_HERMETIC_PYTHON_VERSION//./}" - echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download the jaxlib wheel from GCS (non-Windows runs) if: ${{ !contains(inputs.runner, 'windows-x86') }} run: >- mkdir -p $(pwd)/dist && - gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/jaxlib*.whl $(pwd)/dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ - name: Download the jaxlib wheel from GCS (Windows runs) if: ${{ contains(inputs.runner, 'windows-x86') }} shell: cmd run: >- mkdir dist && - gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/jaxlib*.whl dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl dist/ - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in - name: Run Pytest CPU tests diff --git a/.github/workflows/pytest_cuda_wt.yml b/.github/workflows/pytest_cuda_wt.yml index 334b0b727c79..559cb418e399 100644 --- a/.github/workflows/pytest_cuda_wt.yml +++ b/.github/workflows/pytest_cuda_wt.yml @@ -36,10 +36,10 @@ on: type: string required: true default: "0" - download_url_prefix: - description: "GCS location prefix from where the artifacts should be downloaded" + gcs_download_uri: + description: "GCS location URI from where the artifacts should be downloaded" required: true - default: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' type: string jobs: @@ -57,19 +57,26 @@ jobs: JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Set PLATFORM env var for use in artifact download URL + - name: Set env vars for use in artifact download URL run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) - echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + python_major_minor="${JAXCI_HERMETIC_PYTHON_VERSION//./}" + + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download the wheel artifacts from GCS - run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.download_url_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/*.whl $(pwd)/dist/ + run: mkdir -p $(pwd)/dist && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in - name: Run Pytest GPU tests diff --git a/.github/workflows/wheel_tests.yml b/.github/workflows/wheel_tests.yml index dc17869bfb49..e9b12786aed3 100644 --- a/.github/workflows/wheel_tests.yml +++ b/.github/workflows/wheel_tests.yml @@ -36,7 +36,7 @@ jobs: python: ${{ matrix.python }} clone_main_xla: 1 upload_artifacts: true - upload_destination_prefix: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}' + gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' build_cuda_artifacts: uses: ./.github/workflows/build_artifacts.yml @@ -53,7 +53,7 @@ jobs: python: ${{ matrix.python }} clone_main_xla: 1 upload_artifacts: true - upload_destination_prefix: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}' + gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' run_pytest_cpu: needs: build_jaxlib_artifact @@ -69,7 +69,7 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - download_url_prefix: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}' + gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' run_pytest_gpu: @@ -88,7 +88,7 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} - download_url_prefix: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}' + gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' run_bazel_test_gpu: needs: [build_jaxlib_artifact, build_cuda_artifacts] @@ -104,4 +104,4 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - download_url_prefix: 'jax-fork/${{ github.workflow }}/${{ github.run_number }}' \ No newline at end of file + gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' \ No newline at end of file From 2ac88b7e980c31c2fc211afd5f922223004a4f17 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 23 Dec 2024 17:43:53 +0000 Subject: [PATCH 168/205] debug error --- .github/workflows/bazel_cuda_non_rbe_wt.yml | 10 +++++++++- .github/workflows/pytest_cuda_wt.yml | 3 ++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe_wt.yml b/.github/workflows/bazel_cuda_non_rbe_wt.yml index 55b770307613..de117ec3262e 100644 --- a/.github/workflows/bazel_cuda_non_rbe_wt.yml +++ b/.github/workflows/bazel_cuda_non_rbe_wt.yml @@ -61,7 +61,15 @@ jobs: run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) - echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + echo + python_major_minor="${JAXCI_HERMETIC_PYTHON_VERSION//./}" + + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download the wheel artifacts from GCS run: mkdir -p $(pwd)/dist && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ - name: Run Bazel tests diff --git a/.github/workflows/pytest_cuda_wt.yml b/.github/workflows/pytest_cuda_wt.yml index 559cb418e399..718828838eac 100644 --- a/.github/workflows/pytest_cuda_wt.yml +++ b/.github/workflows/pytest_cuda_wt.yml @@ -67,9 +67,10 @@ jobs: run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) - + # Get the major and minor version of Python. # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + echo $os $arch "${JAXCI_HERMETIC_PYTHON_VERSION}" python_major_minor="${JAXCI_HERMETIC_PYTHON_VERSION//./}" echo "OS=${os}" >> $GITHUB_ENV From e76aab29de13170af65945fc1dadbbcf8d4cbb14 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 23 Dec 2024 18:16:57 +0000 Subject: [PATCH 169/205] copy paste command from other workflow that works --- .github/workflows/pytest_cuda_wt.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pytest_cuda_wt.yml b/.github/workflows/pytest_cuda_wt.yml index 718828838eac..cb42c918be49 100644 --- a/.github/workflows/pytest_cuda_wt.yml +++ b/.github/workflows/pytest_cuda_wt.yml @@ -67,10 +67,15 @@ jobs: run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) + + # Adjust os and arch for Windows x86 + if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then + os="win" + arch="amd64" + fi # Get the major and minor version of Python. # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 - echo $os $arch "${JAXCI_HERMETIC_PYTHON_VERSION}" python_major_minor="${JAXCI_HERMETIC_PYTHON_VERSION//./}" echo "OS=${os}" >> $GITHUB_ENV From 4878b6e6a14f7a7efb80a73b5c09c36fa49928f8 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 23 Dec 2024 18:36:52 +0000 Subject: [PATCH 170/205] add a workflow to debug env var issue --- .github/workflows/env_var_debug.yml | 41 +++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 .github/workflows/env_var_debug.yml diff --git a/.github/workflows/env_var_debug.yml b/.github/workflows/env_var_debug.yml new file mode 100644 index 000000000000..01fb34ba8336 --- /dev/null +++ b/.github/workflows/env_var_debug.yml @@ -0,0 +1,41 @@ +name: CI - Debug + +on: + pull_request: + branches: + - main + +jobs: + run-tests: + + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Python values need to match the matrix stategy in the build artifacts job above + runner: ["linux-x86-n2-16", "linux-x86-g2-48-l4-4gpu",] + python: ["3.10",] + + runs-on: ${{ matrix.runner }} + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + + name: "Debug env var setting" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" + JAXCI_PYTHON: "python${{ matrix.python }}" + + steps: + - uses: actions/checkout@v4 + - name: Set env vars for use in artifact download URL + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + echo $os $arch $JAXCI_HERMETIC_PYTHON_VERSION + python_major_minor="${JAXCI_HERMETIC_PYTHON_VERSION//./}" + + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV From 2ed7af38394041cf3c40195b9f6f333169727bd1 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 23 Dec 2024 18:42:28 +0000 Subject: [PATCH 171/205] switch out image --- .github/workflows/env_var_debug.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/env_var_debug.yml b/.github/workflows/env_var_debug.yml index 01fb34ba8336..032318b1d604 100644 --- a/.github/workflows/env_var_debug.yml +++ b/.github/workflows/env_var_debug.yml @@ -12,11 +12,11 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Python values need to match the matrix stategy in the build artifacts job above - runner: ["linux-x86-n2-16", "linux-x86-g2-48-l4-4gpu",] + runner: ["linux-x86-g2-48-l4-4gpu",] python: ["3.10",] runs-on: ${{ matrix.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" name: "Debug env var setting" From df59b2de82d3b29c15338177e35b086c5e70f443 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 23 Dec 2024 18:44:32 +0000 Subject: [PATCH 172/205] set env var to 3.11 --- .github/workflows/env_var_debug.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/env_var_debug.yml b/.github/workflows/env_var_debug.yml index 032318b1d604..d55fd1b28b11 100644 --- a/.github/workflows/env_var_debug.yml +++ b/.github/workflows/env_var_debug.yml @@ -13,7 +13,7 @@ jobs: matrix: # Python values need to match the matrix stategy in the build artifacts job above runner: ["linux-x86-g2-48-l4-4gpu",] - python: ["3.10",] + python: ["3.11",] runs-on: ${{ matrix.runner }} container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" From 729cbb585e300f504d34d38c1455718612a72aef Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 23 Dec 2024 18:48:08 +0000 Subject: [PATCH 173/205] use tr to delete periods --- .github/workflows/env_var_debug.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/env_var_debug.yml b/.github/workflows/env_var_debug.yml index d55fd1b28b11..4108cb67e0bd 100644 --- a/.github/workflows/env_var_debug.yml +++ b/.github/workflows/env_var_debug.yml @@ -34,7 +34,12 @@ jobs: # Get the major and minor version of Python. # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 echo $os $arch $JAXCI_HERMETIC_PYTHON_VERSION - python_major_minor="${JAXCI_HERMETIC_PYTHON_VERSION//./}" + + python_major_minor_1=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') + echo $os $arch $JAXCI_HERMETIC_PYTHON_VERSION $python_major_minor_1 + + python_major_minor_2="${JAXCI_HERMETIC_PYTHON_VERSION//./}" + echo $os $arch $JAXCI_HERMETIC_PYTHON_VERSION $python_major_minor_1 $python_major_minor_2 echo "OS=${os}" >> $GITHUB_ENV echo "ARCH=${arch}" >> $GITHUB_ENV From dfb12e30bc41c352a5473c7cdb0e54c001614a93 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 23 Dec 2024 18:56:26 +0000 Subject: [PATCH 174/205] use tr to get python major minor versions --- .github/workflows/bazel_cuda_non_rbe_wt.yml | 3 +-- .github/workflows/pytest_cpu_wt.yml | 4 ++-- .github/workflows/pytest_cuda_wt.yml | 24 ++++++++------------- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe_wt.yml b/.github/workflows/bazel_cuda_non_rbe_wt.yml index de117ec3262e..7d0529702e7a 100644 --- a/.github/workflows/bazel_cuda_non_rbe_wt.yml +++ b/.github/workflows/bazel_cuda_non_rbe_wt.yml @@ -64,8 +64,7 @@ jobs: # Get the major and minor version of Python. # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 - echo - python_major_minor="${JAXCI_HERMETIC_PYTHON_VERSION//./}" + python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') echo "OS=${os}" >> $GITHUB_ENV echo "ARCH=${arch}" >> $GITHUB_ENV diff --git a/.github/workflows/pytest_cpu_wt.yml b/.github/workflows/pytest_cpu_wt.yml index c0c08d33b1d4..92b6d726ee1f 100644 --- a/.github/workflows/pytest_cpu_wt.yml +++ b/.github/workflows/pytest_cpu_wt.yml @@ -32,7 +32,7 @@ on: required: true default: "0" gcs_download_uri: - description: "GCS location prefix from where the artifacts should be downloaded" + description: "GCS location URI from where the artifacts should be downloaded" required: false default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' type: string @@ -77,7 +77,7 @@ jobs: # Get the major and minor version of Python. # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 - python_major_minor="${JAXCI_HERMETIC_PYTHON_VERSION//./}" + python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') echo "OS=${os}" >> $GITHUB_ENV echo "ARCH=${arch}" >> $GITHUB_ENV diff --git a/.github/workflows/pytest_cuda_wt.yml b/.github/workflows/pytest_cuda_wt.yml index cb42c918be49..d779dd86042c 100644 --- a/.github/workflows/pytest_cuda_wt.yml +++ b/.github/workflows/pytest_cuda_wt.yml @@ -65,22 +65,16 @@ jobs: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Set env vars for use in artifact download URL run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + python_major_minor="$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')" - # Adjust os and arch for Windows x86 - if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then - os="win" - arch="amd64" - fi - - # Get the major and minor version of Python. - # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 - python_major_minor="${JAXCI_HERMETIC_PYTHON_VERSION//./}" - - echo "OS=${os}" >> $GITHUB_ENV - echo "ARCH=${arch}" >> $GITHUB_ENV - echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download the wheel artifacts from GCS run: mkdir -p $(pwd)/dist && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ - name: Install dependencies From 266746d29545ad0338ec48c3f5e36c9b62a5cccf Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 23 Dec 2024 19:18:11 +0000 Subject: [PATCH 175/205] modify download wheel commands --- .github/workflows/bazel_cuda_non_rbe_wt.yml | 6 +++++- .github/workflows/pytest_cuda_wt.yml | 8 ++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe_wt.yml b/.github/workflows/bazel_cuda_non_rbe_wt.yml index 7d0529702e7a..32c94e05cbe5 100644 --- a/.github/workflows/bazel_cuda_non_rbe_wt.yml +++ b/.github/workflows/bazel_cuda_non_rbe_wt.yml @@ -70,6 +70,10 @@ jobs: echo "ARCH=${arch}" >> $GITHUB_ENV echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download the wheel artifacts from GCS - run: mkdir -p $(pwd)/dist && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ + run: >- + mkdir -p $(pwd)/dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*pjrt*${OS}*${ARCH}*.whl $(pwd)/dist/ - name: Run Bazel tests run: ./ci/run_bazel_test_gpu_non_rbe.sh diff --git a/.github/workflows/pytest_cuda_wt.yml b/.github/workflows/pytest_cuda_wt.yml index d779dd86042c..1ccb8d358fd9 100644 --- a/.github/workflows/pytest_cuda_wt.yml +++ b/.github/workflows/pytest_cuda_wt.yml @@ -70,13 +70,17 @@ jobs: # Get the major and minor version of Python. # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 - python_major_minor="$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')" + python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') echo "OS=${os}" >> $GITHUB_ENV echo "ARCH=${arch}" >> $GITHUB_ENV echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download the wheel artifacts from GCS - run: mkdir -p $(pwd)/dist && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ + run: >- + mkdir -p $(pwd)/dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*pjrt*${OS}*${ARCH}*.whl $(pwd)/dist/ - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in - name: Run Pytest GPU tests From c89b47b1ab2bf419c29940a5e942f1b8ea64b7e7 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 2 Jan 2025 20:36:51 +0000 Subject: [PATCH 176/205] add an entrypoint workflow for nightly/release testing --- ...l_tests.yml => wheel_tests_continuous.yml} | 2 +- .../workflows/wheel_tests_nightly_release.yml | 51 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) rename .github/workflows/{wheel_tests.yml => wheel_tests_continuous.yml} (99%) create mode 100644 .github/workflows/wheel_tests_nightly_release.yml diff --git a/.github/workflows/wheel_tests.yml b/.github/workflows/wheel_tests_continuous.yml similarity index 99% rename from .github/workflows/wheel_tests.yml rename to .github/workflows/wheel_tests_continuous.yml index e9b12786aed3..4d839233cf32 100644 --- a/.github/workflows/wheel_tests.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -1,4 +1,4 @@ -name: CI - Wheel Tests +name: CI - Wheel Tests (Continuous) on: pull_request: diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml new file mode 100644 index 000000000000..0ae827d281f4 --- /dev/null +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -0,0 +1,51 @@ +name: CI - Wheel Tests (Nightly/Release) + +on: + workflow_call: + inputs: + gcs_download_uri: + description: "GCS location URI from where the artifacts should be downloaded" + required: true + default: 'gs://jax-nightly-release-transient/nightly/latest' + type: string + + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + run_pytest_cpu: + uses: ./.github/workflows/pytest_cpu_wt.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Runner OS and Python values need to match the matrix stategy in the build_jaxlib_artifact job above + runner: ["linux-x86-n2-64", "linux-x86-t2a-48-dev", "windows-x86-n2-64"] + python: ["3.10","3.11", "3.12", "3.13"] + enable-x64: [0] + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + enable-x64: ${{ matrix.enable-x64 }} + gcs_download_uri: ${{inputs.gcs_download_uri}} + + + run_pytest_gpu: + uses: ./.github/workflows/pytest_cuda_wt.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Python values need to match the matrix stategy in the build artifacts job above + # TODO(srnitin): Add the h100 runner when we switch to using the jax-ml/jax + # repo. + runner: ["linux-x86-g2-48-l4-4gpu",] + python: ["3.10","3.11", "3.12", "3.13"] + cuda: ["12.3", "12.1"] + enable-x64: [0] + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + cuda: ${{ matrix.cuda }} + enable-x64: ${{ matrix.enable-x64 }} + gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file From ada4e63ddc73536f3dbde0c8c5edc77ff1d1262c Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 2 Jan 2025 21:32:06 +0000 Subject: [PATCH 177/205] change the name of pytest and bazel non-rbe workflows --- .../{bazel_cuda_non_rbe_wt.yml => bazel_cuda_non_rbe.yml} | 0 .github/workflows/{pytest_cpu_wt.yml => pytest_cpu.yml} | 0 .github/workflows/{pytest_cuda_wt.yml => pytest_cuda.yml} | 0 .github/workflows/wheel_tests_continuous.yml | 6 +++--- .github/workflows/wheel_tests_nightly_release.yml | 4 ++-- 5 files changed, 5 insertions(+), 5 deletions(-) rename .github/workflows/{bazel_cuda_non_rbe_wt.yml => bazel_cuda_non_rbe.yml} (100%) rename .github/workflows/{pytest_cpu_wt.yml => pytest_cpu.yml} (100%) rename .github/workflows/{pytest_cuda_wt.yml => pytest_cuda.yml} (100%) diff --git a/.github/workflows/bazel_cuda_non_rbe_wt.yml b/.github/workflows/bazel_cuda_non_rbe.yml similarity index 100% rename from .github/workflows/bazel_cuda_non_rbe_wt.yml rename to .github/workflows/bazel_cuda_non_rbe.yml diff --git a/.github/workflows/pytest_cpu_wt.yml b/.github/workflows/pytest_cpu.yml similarity index 100% rename from .github/workflows/pytest_cpu_wt.yml rename to .github/workflows/pytest_cpu.yml diff --git a/.github/workflows/pytest_cuda_wt.yml b/.github/workflows/pytest_cuda.yml similarity index 100% rename from .github/workflows/pytest_cuda_wt.yml rename to .github/workflows/pytest_cuda.yml diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 4d839233cf32..7625c79caa5d 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -57,7 +57,7 @@ jobs: run_pytest_cpu: needs: build_jaxlib_artifact - uses: ./.github/workflows/pytest_cpu_wt.yml + uses: ./.github/workflows/pytest_cpu.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: @@ -74,7 +74,7 @@ jobs: run_pytest_gpu: needs: [build_jaxlib_artifact, build_cuda_artifacts] - uses: ./.github/workflows/pytest_cuda_wt.yml + uses: ./.github/workflows/pytest_cuda.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: @@ -92,7 +92,7 @@ jobs: run_bazel_test_gpu: needs: [build_jaxlib_artifact, build_cuda_artifacts] - uses: ./.github/workflows/bazel_cuda_non_rbe_wt.yml + uses: ./.github/workflows/bazel_cuda_non_rbe.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 0ae827d281f4..c0ea51ca233b 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -16,7 +16,7 @@ concurrency: jobs: run_pytest_cpu: - uses: ./.github/workflows/pytest_cpu_wt.yml + uses: ./.github/workflows/pytest_cpu.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: @@ -32,7 +32,7 @@ jobs: run_pytest_gpu: - uses: ./.github/workflows/pytest_cuda_wt.yml + uses: ./.github/workflows/pytest_cuda.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: From 2f52a39a7f9f724f7f9c3e24362cba029dade941 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 3 Jan 2025 04:13:25 +0000 Subject: [PATCH 178/205] Add env var to decide if to install jax at current commit --- .github/workflows/pytest_cpu.yml | 21 ++++++++++++++++++++- .github/workflows/pytest_cuda.yml | 14 ++++++++++++++ ci/envs/default.env | 3 +++ ci/utilities/install_wheels_locally.sh | 8 +++++--- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 92b6d726ee1f..75e5b9270bb6 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -31,6 +31,14 @@ on: type: string required: true default: "0" + # If enabled, install "jax" at the current commit in editable mode (used in + # continuous test jobs). If disabled, installs "jax" from a GCS bucket (used in + # nightly/release test jobs) + install_jax_current_commit: + description: "Should the 'jax' package be installed at the current commit?" + type: string + required: true + default: "1" gcs_download_uri: description: "GCS location URI from where the artifacts should be downloaded" required: false @@ -55,6 +63,7 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} + JAXCI_INSTALL_JAX_CURRENT_COMMIT: ${{ inputs.install_jax_current_commit }} steps: @@ -82,17 +91,27 @@ jobs: echo "OS=${os}" >> $GITHUB_ENV echo "ARCH=${arch}" >> $GITHUB_ENV echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - - name: Download the jaxlib wheel from GCS (non-Windows runs) + - name: Download wheel artifacts from GCS (non-Windows runs) if: ${{ !contains(inputs.runner, 'windows-x86') }} run: >- mkdir -p $(pwd)/dist && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ + + # Download the "jax" wheel from GCS if inputs.install_latest_jax is not set to 1 + if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/ + fi - name: Download the jaxlib wheel from GCS (Windows runs) if: ${{ contains(inputs.runner, 'windows-x86') }} shell: cmd run: >- mkdir dist && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl dist/ + + # Download the "jax" wheel from GCS if inputs.install_latest_jax is not set to 1 + if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/ + fi - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in - name: Run Pytest CPU tests diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 1ccb8d358fd9..327d0aa9d001 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -36,6 +36,14 @@ on: type: string required: true default: "0" + # If enabled, install "jax" at the current commit in editable mode (used in + # continuous test jobs). If disabled, installs "jax" from a GCS bucket (used in + # nightly/release test jobs) + install_jax_current_commit: + description: "Should the 'jax' package be installed at the current commit?" + type: string + required: true + default: "1" gcs_download_uri: description: "GCS location URI from where the artifacts should be downloaded" required: true @@ -55,6 +63,7 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} + JAXCI_INSTALL_JAX_CURRENT_COMMIT: ${{ inputs.install_jax_current_commit }} steps: - uses: actions/checkout@v4 @@ -81,6 +90,11 @@ jobs: gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*pjrt*${OS}*${ARCH}*.whl $(pwd)/dist/ + + # Download the "jax" wheel from GCS if + if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/ + fi - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in - name: Run Pytest GPU tests diff --git a/ci/envs/default.env b/ci/envs/default.env index c27065934cc5..d647e7439749 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -59,6 +59,9 @@ export JAXCI_MAX_TESTS_PER_GPU=${JAXCI_MAX_TESTS_PER_GPU:-12} export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} # Pytest specific environment variables below. Used in run_pytest_*.sh scripts. +# Installing the JAX package in editable mode at the current commit +export JAXCI_INSTALL_JAX_CURRENT_COMMIT=${JAXCI_INSTALL_JAX_CURRENT_COMMIT:-} + # Sets the number of TPU cores for the TPU machine type. These values are # defined in the TPU GitHub Actions workflow. export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index c43fa815af9b..78c85400f194 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -33,6 +33,8 @@ else "$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" fi -echo "Installing the JAX package in editable mode at the current commit..." -# Install JAX package at the current commit. -"$JAXCI_PYTHON" -m pip install -U -e . +if [[ $JAXCI_INSTALL_JAX_CURRENT_COMMIT == "1" ]]; then + echo "Installing the JAX package in editable mode at the current commit..." + # Install JAX package at the current commit. + "$JAXCI_PYTHON" -m pip install -U -e . +fi \ No newline at end of file From c2a4cd02e6f1fab65a0fb8165bcdacef5953e8f0 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 3 Jan 2025 04:20:24 +0000 Subject: [PATCH 179/205] Pass the install_jax_current_commit env var as an input to workflow calls --- .github/workflows/wheel_tests_continuous.yml | 2 ++ .github/workflows/wheel_tests_nightly_release.yml | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 7625c79caa5d..c80e7a19ff5e 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -70,6 +70,7 @@ jobs: python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' + install_jax_current_commit: "1" run_pytest_gpu: @@ -89,6 +90,7 @@ jobs: cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' + install_jax_current_commit: "1" run_bazel_test_gpu: needs: [build_jaxlib_artifact, build_cuda_artifacts] diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index c0ea51ca233b..fc8b4b52ff7d 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -29,6 +29,7 @@ jobs: python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} gcs_download_uri: ${{inputs.gcs_download_uri}} + install_jax_current_commit: "0" run_pytest_gpu: @@ -48,4 +49,5 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} - gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file + gcs_download_uri: ${{inputs.gcs_download_uri}} + install_jax_current_commit: "0" \ No newline at end of file From 806a161cde98280b07393297893a59dcf9ed01c3 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 3 Jan 2025 04:35:14 +0000 Subject: [PATCH 180/205] Fix syntax error --- .github/workflows/pytest_cpu.yml | 4 ++-- .github/workflows/pytest_cuda.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 813341356639..f0ca79c4c061 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -90,13 +90,13 @@ jobs: echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download wheel artifacts from GCS (non-Windows runs) if: ${{ !contains(inputs.runner, 'windows-x86') }} - run: >- + run: | mkdir -p $(pwd)/dist && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ # Download the "jax" wheel from GCS if inputs.install_latest_jax is not set to 1 if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any*.whl $(pwd)/dist/ fi - name: Download the jaxlib wheel from GCS (Windows runs) if: ${{ contains(inputs.runner, 'windows-x86') }} diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 977a8369a88e..1039d527847d 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -82,7 +82,7 @@ jobs: echo "ARCH=${arch}" >> $GITHUB_ENV echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download the wheel artifacts from GCS - run: >- + run: | mkdir -p $(pwd)/dist && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && @@ -90,7 +90,7 @@ jobs: # Download the "jax" wheel from GCS if if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any*.whl $(pwd)/dist/ fi - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in From 50ef46ead21c39cdfd0c1e2da8630b616d2d7ffb Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 10 Jan 2025 17:34:17 +0000 Subject: [PATCH 181/205] Remove windows specific download steps --- .github/workflows/bazel_cuda_non_rbe.yml | 8 +++--- .github/workflows/build_artifacts.yml | 19 +++++++------- .github/workflows/pytest_cpu.yml | 26 ++++++-------------- .github/workflows/pytest_cuda.yml | 10 ++++---- .github/workflows/wheel_tests_continuous.yml | 15 +++++------ 5 files changed, 33 insertions(+), 45 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 6be0b002b273..7dc5d668af55 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -31,7 +31,7 @@ on: gcs_download_uri: description: "GCS location URI from where the artifacts should be downloaded" required: true - default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' + default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string @@ -69,8 +69,8 @@ jobs: - name: Download the wheel artifacts from GCS run: >- mkdir -p $(pwd)/dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*pjrt*${OS}*${ARCH}*.whl $(pwd)/dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.wh"l" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ - name: Run Bazel tests run: ./ci/run_bazel_test_gpu_non_rbe.sh diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index d929ee870f9a..05c1a9b2c1ed 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -53,7 +53,7 @@ on: type: string required: false default: "0" - upload_artifacts: + upload_artifacts_to_gcs: description: "Should the artifacts be uploaded to a GCS bucket?" required: false default: false @@ -61,8 +61,12 @@ on: gcs_upload_uri: description: "GCS location URI to where the artifacts should be uploaded" required: false - default: ' gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' + default: ' gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string + outputs: + gcs_upload_uri: + description: "GCS location prefix to where the artifacts were uploaded" + value: ${{ inputs.gcs_upload_uri }} jobs: build_artifacts: @@ -95,12 +99,7 @@ jobs: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Build ${{ inputs.artifact }} run: ./ci/build_artifacts.sh "${{ inputs.artifact }}" - - name: Upload artifacts to GCS bucket (non-Windows) - if: >- - ${{ inputs.upload_artifacts && !contains(inputs.runner, 'windows-x86') }} - run: gsutil -m cp -r $(pwd)/dist/*.whl "${{ inputs.gcs_upload_uri }}"/ - - name: Upload artifacts to GCS bucket (Windows) + - name: Upload artifacts to GCS bucket if: >- - ${{ inputs.upload_artifacts && contains(inputs.runner, 'windows-x86') }} - shell: cmd - run: gsutil -m cp -r dist/*.whl "${{ inputs.gcs_upload_uri }}"/ \ No newline at end of file + ${{ inputs.upload_artifacts_to_gcs }} + run: gsutil -m cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ \ No newline at end of file diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index f0ca79c4c061..2fb3019b950f 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -39,7 +39,7 @@ on: gcs_download_uri: description: "GCS location URI from where the artifacts should be downloaded" required: false - default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' + default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string jobs: @@ -50,8 +50,8 @@ jobs: shell: bash runs-on: ${{ inputs.runner }} - container: ${{ (contains(inputs.runner, 'linux-x86-n2') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(inputs.runner, 'linux-x86-t2a') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || (contains(inputs.runner, 'windows-x86') && null) }} name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" @@ -88,26 +88,14 @@ jobs: echo "OS=${os}" >> $GITHUB_ENV echo "ARCH=${arch}" >> $GITHUB_ENV echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - - name: Download wheel artifacts from GCS (non-Windows runs) - if: ${{ !contains(inputs.runner, 'windows-x86') }} + - name: Download wheel artifacts from GCS run: | - mkdir -p $(pwd)/dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ + mkdir -p dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" dist/ # Download the "jax" wheel from GCS if inputs.install_latest_jax is not set to 1 if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any*.whl $(pwd)/dist/ - fi - - name: Download the jaxlib wheel from GCS (Windows runs) - if: ${{ contains(inputs.runner, 'windows-x86') }} - shell: cmd - run: >- - mkdir dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl dist/ - - # Download the "jax" wheel from GCS if inputs.install_latest_jax is not set to 1 - if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*py3*none*any*.whl" dist/ fi - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 1039d527847d..ed86b20d134e 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -44,7 +44,7 @@ on: gcs_download_uri: description: "GCS location URI from where the artifacts should be downloaded" required: true - default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' + default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string jobs: @@ -84,13 +84,13 @@ jobs: - name: Download the wheel artifacts from GCS run: | mkdir -p $(pwd)/dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*pjrt*${OS}*${ARCH}*.whl $(pwd)/dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ # Download the "jax" wheel from GCS if if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any*.whl $(pwd)/dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*py3*none*any*.whl" $(pwd)/dist/ fi - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index b020c5880c92..b89fb7b5e5a0 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -35,8 +35,8 @@ jobs: artifact: ${{ matrix.artifact }} python: ${{ matrix.python }} clone_main_xla: 1 - upload_artifacts: true - gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' + upload_artifacts_to_gcs: true + gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' build_cuda_artifacts: uses: ./.github/workflows/build_artifacts.yml @@ -52,8 +52,8 @@ jobs: artifact: ${{ matrix.artifact }} python: ${{ matrix.python }} clone_main_xla: 1 - upload_artifacts: true - gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' + upload_artifacts_to_gcs: true + gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' run_pytest_cpu: needs: build_jaxlib_artifact @@ -69,7 +69,7 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' + gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }} install_jax_current_commit: "1" @@ -89,7 +89,8 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} - gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' + # GCS upload URI is the same for both artifact build jobs + gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }} install_jax_current_commit: "1" run_bazel_test_gpu: @@ -106,4 +107,4 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' \ No newline at end of file + gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }} \ No newline at end of file From 34774fca6268488b0603ae43fdc8602fb412a7ca Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 13 Jan 2025 22:44:01 +0000 Subject: [PATCH 182/205] Enable remote cache support for platforms where rbe is not supported --- .bazelrc | 3 +++ .github/workflows/build_artifacts.yml | 3 +++ ci/build_artifacts.sh | 17 +++++++++++++++-- ci/envs/default.env | 6 ++++++ 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/.bazelrc b/.bazelrc index 864daf76feed..12b8a1e33c52 100644 --- a/.bazelrc +++ b/.bazelrc @@ -181,6 +181,9 @@ build:macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-maco # Cache pushes are limited to JAX's CI system. build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials +build:public_cache --remote_cache="https://storage.googleapis.com/jax-bazel-cache/" --remote_upload_local_results=false +build:public_cache_push --config=public_cache --remote_upload_local_results=true --google_default_credentials + # ############################################################################# # CI Build config options below. # JAX uses these configs in CI builds for building artifacts and when running diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 05c1a9b2c1ed..ee99c2e84a19 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -92,6 +92,9 @@ jobs: - name: Enable RBE if building on Linux x86 or Windows x86 if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV + - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 + if: contains(inputs.runner, 'linux-arm64') + run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 698de38418b7..e739522d5468 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -56,11 +56,21 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then # flags in the .bazelrc depending upon the platform we are building for. bazelrc_config="${os}_${arch}" - # TODO(b/379903748): Add remote cache options for Linux and Windows. + # Set remote_cache_flag to be empty by default to avoid unbound variable errors + remote_cache_flag="" + if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then bazelrc_config="rbe_${bazelrc_config}" else bazelrc_config="ci_${bazelrc_config}" + + # Bazel remote cache can be used on platforms with no RBE support. Pushes to + # the cache bucket is limited to JAX's CI system. + if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1; ]]; then + remote_cache_flag="--bazel_options=--config=public_cache_push" + else + remote_cache_flag="--bazel_options=--config=public_cache" + fi fi # Use the "_cuda" configs when building the CUDA artifacts. @@ -69,7 +79,10 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then fi # Build the artifact. - python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose --detailed_timestamped_log + python build/build.py build --wheels="$artifact" \ + --bazel_options=--config="$bazelrc_config" $remote_cache_flag \ + --bazel_options=--config --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ + --verbose --detailed_timestamped_log # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. diff --git a/ci/envs/default.env b/ci/envs/default.env index d647e7439749..a01461ad9962 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -44,6 +44,12 @@ export JAXCI_OUTPUT_DIR="$(pwd)/dist" # for CI builds where RBE is supported. export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} +# On platforms where RBE is not supported, we use Bazel remote cache to speed up +# builds. When this flag is enabled, the build will also try to push new cache entries +# to the bucket. Since writes to the bucket require authentication, this flag is enabled +# only for CI builds. +export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0} + # ############################################################################# # Test script specific environment variables. # ############################################################################# From 937f95179f149b0aa8d3fd3e269440bb6e8cb1d7 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 13 Jan 2025 22:48:11 +0000 Subject: [PATCH 183/205] fix syntax error --- ci/build_artifacts.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index e739522d5468..8677c7b3d20b 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -66,7 +66,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then # Bazel remote cache can be used on platforms with no RBE support. Pushes to # the cache bucket is limited to JAX's CI system. - if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1; ]]; then + if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then remote_cache_flag="--bazel_options=--config=public_cache_push" else remote_cache_flag="--bazel_options=--config=public_cache" From a5c8f2c19b646303a22443f1c2bb0c694a3656ec Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 13 Jan 2025 22:49:27 +0000 Subject: [PATCH 184/205] fix syntax error --- ci/build_artifacts.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 8677c7b3d20b..a48edd1c815f 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -81,7 +81,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then # Build the artifact. python build/build.py build --wheels="$artifact" \ --bazel_options=--config="$bazelrc_config" $remote_cache_flag \ - --bazel_options=--config --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ + --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we From 4c37d90b4b063191fe5701a9feeb25d67b649392 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 13 Jan 2025 23:25:18 +0000 Subject: [PATCH 185/205] dummy change --- .github/workflows/build_artifacts.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index ee99c2e84a19..7485817a3ae9 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -95,7 +95,7 @@ jobs: - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 if: contains(inputs.runner, 'linux-arm64') run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV - # Halt for testing + # Halt for testing, - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main with: From a07fed476bdb954fe2b38db8dfb717c5a54f186d Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 14 Jan 2025 00:25:59 +0000 Subject: [PATCH 186/205] update checkout version --- .github/workflows/build_artifacts.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 7485817a3ae9..f96c9b07eb5c 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -88,7 +88,7 @@ jobs: name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, Clone main XLA=${{ inputs.clone_main_xla }}) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Enable RBE if building on Linux x86 or Windows x86 if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV From 3fddb21364cc5b2f94d755e62c67eed513c9528e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 14 Jan 2025 06:14:41 +0000 Subject: [PATCH 187/205] remove spaces in upload/download uri --- .github/workflows/bazel_cuda_non_rbe.yml | 12 ++++++++---- .github/workflows/build_artifacts.yml | 13 +++++++++++-- .github/workflows/pytest_cpu.yml | 13 +++++++++++-- .github/workflows/pytest_cuda.yml | 12 ++++++++---- 4 files changed, 38 insertions(+), 12 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 7dc5d668af55..d75c794c4bd3 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -67,10 +67,14 @@ jobs: echo "ARCH=${arch}" >> $GITHUB_ENV echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download the wheel artifacts from GCS - run: >- + run: | + download_uri="${{ inputs.gcs_download_uri }}" + # Replace spaces with underscore + download_uri=${download_uri// /_} + mkdir -p $(pwd)/dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.wh"l" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + gsutil -m cp -r "${download_uri}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${download_uri}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.wh"l" $(pwd)/dist/ && + gsutil -m cp -r "${download_uri}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ - name: Run Bazel tests run: ./ci/run_bazel_test_gpu_non_rbe.sh diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index f96c9b07eb5c..581ad48ecc84 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -61,7 +61,7 @@ on: gcs_upload_uri: description: "GCS location URI to where the artifacts should be uploaded" required: false - default: ' gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string outputs: gcs_upload_uri: @@ -105,4 +105,13 @@ jobs: - name: Upload artifacts to GCS bucket if: >- ${{ inputs.upload_artifacts_to_gcs }} - run: gsutil -m cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ \ No newline at end of file + run: | + # Remove after Dockerfile is updated + if [[ $(uname -s) =~ "MSYS_NT" ]]; then + alias gsutil="/c/Program\ Files/google-cloud-sdk/bin/gsutil.cmd" + fi + upload_uri="${{ inputs.gcs_upload_uri }}/" + # Replace spaces with underscore + upload_uri=${upload_uri// /_} + + gsutil -m cp -r "dist/*.whl" "${upload_uri}" \ No newline at end of file diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 2fb3019b950f..e334e17a035f 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -90,12 +90,21 @@ jobs: echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download wheel artifacts from GCS run: | + # Remove after Dockerfile is updated + if [[ $(uname -s) =~ "MSYS_NT" ]]; then + alias gsutil="/c/Program\ Files/google-cloud-sdk/bin/gsutil.cmd" + fi + + download_uri="${{ inputs.gcs_download_uri }}" + # Replace spaces with underscore + download_uri=${download_uri// /_} + mkdir -p dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" dist/ + gsutil -m cp -r "${download_uri}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" dist/ # Download the "jax" wheel from GCS if inputs.install_latest_jax is not set to 1 if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*py3*none*any*.whl" dist/ + gsutil -m cp -r "${download_uri}/jax*py3*none*any*.whl" dist/ fi - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index ed86b20d134e..dbf5161cf302 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -83,14 +83,18 @@ jobs: echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download the wheel artifacts from GCS run: | + download_uri="${{ inputs.gcs_download_uri }}" + # Replace spaces with underscore + download_uri=${download_uri// /_} + mkdir -p $(pwd)/dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + gsutil -m cp -r "${download_uri}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${download_uri}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${download_uri}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ # Download the "jax" wheel from GCS if if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*py3*none*any*.whl" $(pwd)/dist/ + gsutil -m cp -r "${download_uri}/jax*py3*none*any*.whl" $(pwd)/dist/ fi - name: Install dependencies run: $JAXCI_PYTHON -m pip install -r build/requirements.in From 2b362ccaa9a2547ee8ba49b3ed2ceb3a0145b095 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 14 Jan 2025 06:42:48 +0000 Subject: [PATCH 188/205] update alias --- .github/workflows/build_artifacts.yml | 2 +- .github/workflows/pytest_cpu.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 581ad48ecc84..cdeb66354b37 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -108,7 +108,7 @@ jobs: run: | # Remove after Dockerfile is updated if [[ $(uname -s) =~ "MSYS_NT" ]]; then - alias gsutil="/c/Program\ Files/google-cloud-sdk/bin/gsutil.cmd" + alias gsutil='"C:\Program Files\google-cloud-sdk\bin\gsutil.cmd"' fi upload_uri="${{ inputs.gcs_upload_uri }}/" # Replace spaces with underscore diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index e334e17a035f..87eb949cab6d 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -92,7 +92,7 @@ jobs: run: | # Remove after Dockerfile is updated if [[ $(uname -s) =~ "MSYS_NT" ]]; then - alias gsutil="/c/Program\ Files/google-cloud-sdk/bin/gsutil.cmd" + alias gsutil='"C:\Program Files\google-cloud-sdk\bin\gsutil.cmd"' fi download_uri="${{ inputs.gcs_download_uri }}" From 82e8d9d507724fe9d27f3b72c5bfd75412094335 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 14 Jan 2025 06:56:01 +0000 Subject: [PATCH 189/205] update shell --- .github/workflows/build_artifacts.yml | 2 +- .github/workflows/pytest_cpu.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index cdeb66354b37..e5f557f7b875 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -73,7 +73,7 @@ jobs: defaults: run: # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. - shell: bash + shell: bash -e -o pipefail {0} runs-on: ${{ inputs.runner }} diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 87eb949cab6d..9749330a5ec9 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -47,7 +47,7 @@ jobs: defaults: run: # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. - shell: bash + shell: bash -e -o pipefail {0} runs-on: ${{ inputs.runner }} container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || From 1eb8871db6e6c4141433abe43f3519e4ab84c914 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 14 Jan 2025 07:07:07 +0000 Subject: [PATCH 190/205] update alias --- .github/workflows/build_artifacts.yml | 4 +++- .github/workflows/pytest_cpu.yml | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index e5f557f7b875..616248094c5f 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -108,7 +108,9 @@ jobs: run: | # Remove after Dockerfile is updated if [[ $(uname -s) =~ "MSYS_NT" ]]; then - alias gsutil='"C:\Program Files\google-cloud-sdk\bin\gsutil.cmd"' + alias gcloud=gcloud.cmd + alias gsutil=gsutil.cmd + alias bq=bq.cmd fi upload_uri="${{ inputs.gcs_upload_uri }}/" # Replace spaces with underscore diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 9749330a5ec9..e3fbcbbe8110 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -92,7 +92,9 @@ jobs: run: | # Remove after Dockerfile is updated if [[ $(uname -s) =~ "MSYS_NT" ]]; then - alias gsutil='"C:\Program Files\google-cloud-sdk\bin\gsutil.cmd"' + alias gcloud=gcloud.cmd + alias gsutil=gsutil.cmd + alias bq=bq.cmd fi download_uri="${{ inputs.gcs_download_uri }}" From 157a2937b737dfcc89d0ca44728bdaa0fd67aeb3 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 14 Jan 2025 16:21:07 +0000 Subject: [PATCH 191/205] remove gsutil aliases now that dockerfile is updated --- .github/workflows/build_artifacts.yml | 6 ------ .github/workflows/pytest_cpu.yml | 7 ------- 2 files changed, 13 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 616248094c5f..9a95aea95e0f 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -106,12 +106,6 @@ jobs: if: >- ${{ inputs.upload_artifacts_to_gcs }} run: | - # Remove after Dockerfile is updated - if [[ $(uname -s) =~ "MSYS_NT" ]]; then - alias gcloud=gcloud.cmd - alias gsutil=gsutil.cmd - alias bq=bq.cmd - fi upload_uri="${{ inputs.gcs_upload_uri }}/" # Replace spaces with underscore upload_uri=${upload_uri// /_} diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index e3fbcbbe8110..ca61c40ef6ee 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -90,13 +90,6 @@ jobs: echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download wheel artifacts from GCS run: | - # Remove after Dockerfile is updated - if [[ $(uname -s) =~ "MSYS_NT" ]]; then - alias gcloud=gcloud.cmd - alias gsutil=gsutil.cmd - alias bq=bq.cmd - fi - download_uri="${{ inputs.gcs_download_uri }}" # Replace spaces with underscore download_uri=${download_uri// /_} From c943c254fe54b6f86297276754236175f89555eb Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 31 Jan 2025 16:21:55 +0000 Subject: [PATCH 192/205] update to new ml build gpu containers --- .github/workflows/bazel_cuda_non_rbe.yml | 2 +- .github/workflows/pytest_cuda.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index d75c794c4bd3..e31d02c4d196 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -39,7 +39,7 @@ jobs: run-tests: runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-rbe:latest" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cudnn:720686788" env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index dbf5161cf302..474b40affbea 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -51,7 +51,7 @@ jobs: run-tests: runs-on: ${{ inputs.runner }} - container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') || + container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cudnn:720686788') || (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') }} name: "Pytest GPU (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" From 9fa26dd9c05078d5720e8a1df510fb2d4b8a98aa Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 31 Jan 2025 16:34:51 +0000 Subject: [PATCH 193/205] make continuous wokrkflow run on presubmits --- .github/workflows/wheel_tests_continuous.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index e9074efe1162..017f306db1a5 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -20,6 +20,9 @@ name: CI - Wheel Tests (Continuous) on: schedule: - cron: "0 */2 * * *" # Run once every 2 hours + pull_request: + branches: + - main concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -112,8 +115,4 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} -<<<<<<< HEAD gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }} -======= - gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' ->>>>>>> 14246134e7f2ec1ae4eaba8a70315a0fc2e73cd4 From 2014eb03c9a52da7dafe09a3b0baf9a081a4e16e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 31 Jan 2025 16:42:03 +0000 Subject: [PATCH 194/205] fix syntax issue --- .github/workflows/wheel_tests_continuous.yml | 57 +++++++++++--------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 017f306db1a5..372b7510e3e7 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -22,14 +22,14 @@ on: - cron: "0 */2 * * *" # Run once every 2 hours pull_request: branches: - - main + - main concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} jobs: - build_jaxlib_artifact: + build-jaxlib-artifact: uses: ./.github/workflows/build_artifacts.yml strategy: fail-fast: false # don't cancel all jobs on failure @@ -46,12 +46,12 @@ jobs: upload_artifacts_to_gcs: true gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - build_cuda_artifacts: - uses: ./.github/workflows/build_artifacts.yml + build-cuda-artifacts: + uses: ./.github/workflows/build_artifacts.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: - # Python values need to match the matrix stategy in the GPU tests job below + # Python values need to match the matrix stategy in the CUDA tests job below runner: ["linux-x86-n2-16"] artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] python: ["3.10",] @@ -62,14 +62,15 @@ jobs: clone_main_xla: 1 upload_artifacts_to_gcs: true gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - - run_pytest_cpu: - needs: build_jaxlib_artifact - uses: ./.github/workflows/pytest_cpu.yml + + run-pytest-cpu: + needs: build-jaxlib-artifact + uses: ./.github/workflows/pytest_cpu.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: - # Runner OS and Python values need to match the matrix stategy in the build_jaxlib_artifact job above + # Runner OS and Python values need to match the matrix stategy in the + # build_jaxlib_artifact job above runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] python: ["3.10",] enable-x64: [1, 0] @@ -77,33 +78,38 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }} - install_jax_current_commit: "1" - + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} - run_pytest_gpu: - needs: [build_jaxlib_artifact, build_cuda_artifacts] - uses: ./.github/workflows/pytest_cuda.yml + run-pytest-cuda: + needs: [build-jaxlib-artifact, build-cuda-artifacts] + uses: ./.github/workflows/pytest_cuda.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: - # Python values need to match the matrix stategy in the build artifacts job above - runner: ["linux-x86-g2-48-l4-4gpu",] + # Python values need to match the matrix stategy in the artifact build jobs above + runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"] python: ["3.10",] cuda: ["12.3", "12.1"] enable-x64: [1, 0] + exclude: + # Run only a single configuration on H100 to save resources + - runner: "linux-x86-a3-8g-h100-8gpu" + python: "3.10" + cuda: "12.1" + - runner: "linux-x86-a3-8g-h100-8gpu" + python: "3.10" + enable-x64: 0 with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} - # GCS upload URI is the same for both artifact build jobs - gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }} - install_jax_current_commit: "1" + # GCS upload URI is the same for both artifact build jobs + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} - run_bazel_test_gpu: - needs: [build_jaxlib_artifact, build_cuda_artifacts] - uses: ./.github/workflows/bazel_cuda_non_rbe.yml + run-bazel-test-cuda: + needs: [build-jaxlib-artifact, build-cuda-artifacts] + uses: ./.github/workflows/bazel_cuda_non_rbe.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: @@ -115,4 +121,5 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }} + # GCS upload URI is the same for both artifact build jobs + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} \ No newline at end of file From 3fd003d23903c1a4aa406050ac89f491c04ce82a Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 31 Jan 2025 16:43:34 +0000 Subject: [PATCH 195/205] fix workflow files --- .github/workflows/bazel_cuda_non_rbe.yml | 27 ++++++------------------ .github/workflows/pytest_cuda.yml | 2 +- 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index af51b106f15e..77f21b8a5451 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -34,17 +34,16 @@ on: required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string - + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: boolean + required: false + default: false jobs: run-tests: runs-on: ${{ inputs.runner }} -<<<<<<< HEAD - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cudnn:720686788" -======= - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" ->>>>>>> 14246134e7f2ec1ae4eaba8a70315a0fc2e73cd4 env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} @@ -69,19 +68,6 @@ jobs: echo "ARCH=${arch}" >> $GITHUB_ENV echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - name: Download the wheel artifacts from GCS -<<<<<<< HEAD - run: | - download_uri="${{ inputs.gcs_download_uri }}" - # Replace spaces with underscore - download_uri=${download_uri// /_} - - mkdir -p $(pwd)/dist && - gsutil -m cp -r "${download_uri}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${download_uri}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.wh"l" $(pwd)/dist/ && - gsutil -m cp -r "${download_uri}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ - - name: Run Bazel tests - run: ./ci/run_bazel_test_gpu_non_rbe.sh -======= run: >- mkdir -p $(pwd)/dist && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && @@ -94,5 +80,4 @@ jobs: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Bazel CUDA tests (Non-RBE) timeout-minutes: 60 - run: ./ci/run_bazel_test_cuda_non_rbe.sh ->>>>>>> 14246134e7f2ec1ae4eaba8a70315a0fc2e73cd4 + run: ./ci/run_bazel_test_cuda_non_rbe.sh \ No newline at end of file diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 0595df15855c..0a0c9c832c3f 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -44,7 +44,7 @@ jobs: run-tests: runs-on: ${{ inputs.runner }} - container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') || + container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cudnn:720686788') || (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') }} name: "Pytest GPU (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" From 60b0a1143cc1ebc7929e5ab0180e969d69ea6be6 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 31 Jan 2025 16:47:04 +0000 Subject: [PATCH 196/205] fix workflow file --- .github/workflows/pytest_cpu.yml | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 70b82cfa2a99..e43d236936c9 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -29,19 +29,16 @@ on: type: string required: true default: "0" - # If enabled, install "jax" at the current commit in editable mode (used in - # continuous test jobs). If disabled, installs "jax" from a GCS bucket (used in - # nightly/release test jobs) - install_jax_current_commit: - description: "Should the 'jax' package be installed at the current commit?" - type: string - required: true - default: "1" gcs_download_uri: - description: "GCS location URI from where the artifacts should be downloaded" - required: false + description: "GCS location prefix from where the artifacts should be downloaded" + required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: boolean + required: false + default: false jobs: run-tests: @@ -101,4 +98,4 @@ jobs: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest CPU tests timeout-minutes: 60 - run: ./ci/run_pytest_cpu.sh + run: ./ci/run_pytest_cpu.sh \ No newline at end of file From be3fca80e3f3d29cbbab19ce2b2b2b786683a5f1 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 31 Jan 2025 16:54:45 +0000 Subject: [PATCH 197/205] fix merge --- ci/envs/default.env | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/ci/envs/default.env b/ci/envs/default.env index 4f56d3104cfb..e2bcfc26b31c 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -45,15 +45,9 @@ export JAXCI_OUTPUT_DIR="$(pwd)/dist" export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} # On platforms where RBE is not supported, we use Bazel remote cache to speed up -<<<<<<< HEAD -# builds. When this flag is enabled, the build will also try to push new cache entries -# to the bucket. Since writes to the bucket require authentication, this flag is enabled -# only for CI builds. -======= # builds. When this flag is enabled, Bazel will also try to push new cache # entries to the bucket. Since writes to the bucket require authentication, this # flag is enabled only for CI builds. ->>>>>>> 14246134e7f2ec1ae4eaba8a70315a0fc2e73cd4 export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0} # ############################################################################# @@ -71,9 +65,6 @@ export JAXCI_MAX_TESTS_PER_GPU=${JAXCI_MAX_TESTS_PER_GPU:-12} export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} # Pytest specific environment variables below. Used in run_pytest_*.sh scripts. -# Installing the JAX package in editable mode at the current commit -export JAXCI_INSTALL_JAX_CURRENT_COMMIT=${JAXCI_INSTALL_JAX_CURRENT_COMMIT:-} - # Sets the number of TPU cores for the TPU machine type. These values are # defined in the TPU GitHub Actions workflow. export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} @@ -86,4 +77,4 @@ export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} # Installs the JAX package in editable mode at the current commit. Enabled by # default. Nightly/Release builds disable this flag in the Github action # workflow files. -export JAXCI_INSTALL_JAX_CURRENT_COMMIT=${JAXCI_INSTALL_JAX_CURRENT_COMMIT:-"1"} +export JAXCI_INSTALL_JAX_CURRENT_COMMIT=${JAXCI_INSTALL_JAX_CURRENT_COMMIT:-"1"} \ No newline at end of file From 492b10453b2425c17cf01687ec0a0c4facc13514 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 31 Jan 2025 16:57:48 +0000 Subject: [PATCH 198/205] fix merge --- ci/build_artifacts.sh | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 260cdfc18430..3cc1fa0c5e10 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -56,35 +56,21 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then # flags in the .bazelrc depending upon the platform we are building for. bazelrc_config="${os}_${arch}" -<<<<<<< HEAD - # Set remote_cache_flag to be empty by default to avoid unbound variable errors - remote_cache_flag="" -======= # On platforms with no RBE support, we can use the Bazel remote cache. Set # it to be empty by default to avoid unbound variable errors. bazel_remote_cache="" ->>>>>>> 14246134e7f2ec1ae4eaba8a70315a0fc2e73cd4 if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then bazelrc_config="rbe_${bazelrc_config}" else bazelrc_config="ci_${bazelrc_config}" -<<<<<<< HEAD - # Bazel remote cache can be used on platforms with no RBE support. Pushes to - # the cache bucket is limited to JAX's CI system. - if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then - remote_cache_flag="--bazel_options=--config=public_cache_push" - else - remote_cache_flag="--bazel_options=--config=public_cache" -======= # Set remote cache flags. Pushes to the cache bucket is limited to JAX's # CI system. if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then bazel_remote_cache="--bazel_options=--config=public_cache_push" else bazel_remote_cache="--bazel_options=--config=public_cache" ->>>>>>> 14246134e7f2ec1ae4eaba8a70315a0fc2e73cd4 fi fi @@ -95,15 +81,9 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then # Build the artifact. python build/build.py build --wheels="$artifact" \ -<<<<<<< HEAD - --bazel_options=--config="$bazelrc_config" $remote_cache_flag \ - --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ - --verbose --detailed_timestamped_log -======= --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log ->>>>>>> 14246134e7f2ec1ae4eaba8a70315a0fc2e73cd4 # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. From 45479bd0fb714665d7e3b5208a73ae091c64b475 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 31 Jan 2025 17:07:51 +0000 Subject: [PATCH 199/205] fix merge --- ci/run_pytest_cuda.sh | 2 +- ci/utilities/install_wheels_locally.sh | 10 +--------- ci/utilities/setup_build_environment.sh | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/ci/run_pytest_cuda.sh b/ci/run_pytest_cuda.sh index eb815de144f0..70f855df059c 100755 --- a/ci/run_pytest_cuda.sh +++ b/ci/run_pytest_cuda.sh @@ -58,4 +58,4 @@ echo "Running CUDA tests..." tests examples \ --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ --deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \ ---deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric +--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric \ No newline at end of file diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index 8fa7c2f462c5..fbeafe22db01 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -34,16 +34,8 @@ else "$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" fi -<<<<<<< HEAD -if [[ $JAXCI_INSTALL_JAX_CURRENT_COMMIT == "1" ]]; then - echo "Installing the JAX package in editable mode at the current commit..." - # Install JAX package at the current commit. - "$JAXCI_PYTHON" -m pip install -U -e . -fi -======= if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then echo "Installing the JAX package in editable mode at the current commit..." # Install JAX package at the current commit. "$JAXCI_PYTHON" -m pip install -U -e . -fi ->>>>>>> 14246134e7f2ec1ae4eaba8a70315a0fc2e73cd4 +fi \ No newline at end of file diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index 8e403941402b..d6948249cc8b 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -97,4 +97,4 @@ function retry { } # Retry "bazel --version" 3 times to avoid flakiness when downloading bazel. -retry "bazel --version" +retry "bazel --version" \ No newline at end of file From 78195892abee3ef584965eb1da3707fed78e4913 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 3 Feb 2025 21:29:17 +0000 Subject: [PATCH 200/205] sync to main --- .github/workflows/asan.yaml | 6 +- .github/workflows/bazel_cpu_rbe.yml | 11 ++- .github/workflows/bazel_cuda_non_rbe.yml | 4 +- .github/workflows/bazel_cuda_rbe.yml | 9 +-- .github/workflows/build_artifacts.yml | 25 ++++--- .github/workflows/pytest_cpu.yml | 2 +- .github/workflows/pytest_cuda.yml | 18 +++-- .github/workflows/tsan-suppressions.txt | 5 ++ .github/workflows/tsan.yaml | 74 ++++++++++++++++++- .github/workflows/wheel_tests_continuous.yml | 3 - .../workflows/wheel_tests_nightly_release.yml | 4 +- build/build.py | 2 +- build/tools/utils.py | 2 +- ci/README.md | 2 +- ci/envs/docker.env | 20 +---- ci/run_bazel_test_cpu_rbe.sh | 7 -- ci/run_bazel_test_cuda_rbe.sh | 2 +- ci/run_pytest_cuda.sh | 2 +- .../convert_msys_paths_to_win_paths.py | 2 +- ci/utilities/setup_build_environment.sh | 2 +- 20 files changed, 131 insertions(+), 71 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 838128f82f39..facc473473e0 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -1,5 +1,9 @@ name: CI - Address Sanitizer (nightly) +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + on: schedule: - cron: "0 12 * * *" # Daily at 12:00 UTC @@ -16,7 +20,7 @@ jobs: if: github.repository == 'jax-ml/jax' runs-on: linux-x86-n2-64 container: - image: us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest # ratchet:ubuntu:24.04 + image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 strategy: fail-fast: false defaults: diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index 24b23832b38e..e45df4e0dbbd 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -1,9 +1,6 @@ name: CI - Bazel CPU tests (RBE) on: - pull_request: - branches: - - main workflow_dispatch: inputs: halt-for-connection: @@ -14,6 +11,9 @@ on: options: - 'yes' - 'no' + pull_request: + branches: + - main concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -30,9 +30,8 @@ jobs: JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} # Begin Presubmit Naming Check - name modification requires internal check to be updated strategy: - fail-fast: false # don't cancel all jobs on failure matrix: - runner: ["windows-x86-n2-16", "linux-x86-n2-16", "linux-arm64-t2a-16"] + runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"] enable-x_64: [1, 0] name: "Bazel CPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})" # End Presubmit Naming Check github-cpu-presubmits @@ -43,4 +42,4 @@ jobs: with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Bazel CPU Tests with RBE - run: ./ci/run_bazel_test_cpu_rbe.sh + run: ./ci/run_bazel_test_cpu_rbe.sh \ No newline at end of file diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 77f21b8a5451..5f846099eefd 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -43,7 +43,7 @@ on: jobs: run-tests: runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cudnn:720686788" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} @@ -80,4 +80,4 @@ jobs: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Bazel CUDA tests (Non-RBE) timeout-minutes: 60 - run: ./ci/run_bazel_test_cuda_non_rbe.sh \ No newline at end of file + run: ./ci/run_bazel_test_cuda_non_rbe.sh diff --git a/.github/workflows/bazel_cuda_rbe.yml b/.github/workflows/bazel_cuda_rbe.yml index 1a806bcc6c98..cbeeecb69f48 100644 --- a/.github/workflows/bazel_cuda_rbe.yml +++ b/.github/workflows/bazel_cuda_rbe.yml @@ -1,9 +1,6 @@ name: CI - Bazel CUDA tests (RBE) on: - pull_request: - branches: - - main workflow_dispatch: inputs: halt-for-connection: @@ -14,6 +11,9 @@ on: options: - 'yes' - 'no' + pull_request: + branches: + - main concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -21,7 +21,6 @@ concurrency: jobs: run_tests: - if: github.event.repository.fork == false runs-on: ${{ matrix.runner }} container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest' env: @@ -41,4 +40,4 @@ jobs: with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Bazel CUDA Tests with RBE - run: ./ci/run_bazel_test_cuda_rbe.sh + run: ./ci/run_bazel_test_cuda_rbe.sh \ No newline at end of file diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 54cf52fe65a7..73f121725cbf 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -69,34 +69,37 @@ on: python: description: "Which python version should the artifact be built for?" type: string - required: true + required: false default: "3.12" clone_main_xla: - description: "Should latest XLA be used? (1 to enable, 0 to disable)" + description: "Should latest XLA be used?" type: string required: false default: "0" upload_artifacts_to_gcs: description: "Should the artifacts be uploaded to a GCS bucket?" - required: false - default: false + required: true + default: true type: boolean gcs_upload_uri: - description: "GCS location URI to where the artifacts should be uploaded" - required: false + description: "GCS location prefix to where the artifacts should be uploaded" + required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string outputs: gcs_upload_uri: description: "GCS location prefix to where the artifacts were uploaded" - value: ${{ inputs.gcs_upload_uri }} + value: ${{ jobs.build-artifacts.outputs.gcs_upload_uri }} + +permissions: + contents: read jobs: - build_artifacts: + build-artifacts: defaults: run: - # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. - shell: bash -e -o pipefail {0} + # Explicitly set the shell to bash to override Windows's default (cmd) + shell: bash runs-on: ${{ inputs.runner }} @@ -143,4 +146,4 @@ jobs: - name: Store the GCS upload URI as an output id: store-gcs-upload-uri if: ${{ inputs.upload_artifacts_to_gcs }} - run: echo "gcs_upload_uri=${{ inputs.gcs_upload_uri }}" >> "$GITHUB_OUTPUT" + run: echo "gcs_upload_uri=${{ inputs.gcs_upload_uri }}" >> "$GITHUB_OUTPUT" \ No newline at end of file diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index e43d236936c9..ef3e57f51d2e 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -98,4 +98,4 @@ jobs: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest CPU tests timeout-minutes: 60 - run: ./ci/run_pytest_cpu.sh \ No newline at end of file + run: ./ci/run_pytest_cpu.sh diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 0a0c9c832c3f..6114383b5422 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -35,19 +35,23 @@ on: required: true default: "0" gcs_download_uri: - description: "GCS location URI from where the artifacts should be downloaded" + description: "GCS location prefix from where the artifacts should be downloaded" required: true - default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' + default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: boolean + required: false + default: false jobs: run-tests: - runs-on: ${{ inputs.runner }} - container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cudnn:720686788') || + # TODO: Update to the generic ML ecosystem test containers when they are ready. + container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') || (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') }} - - name: "Pytest GPU (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" + name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" @@ -83,4 +87,4 @@ jobs: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest CUDA tests timeout-minutes: 60 - run: ./ci/run_pytest_cuda.sh + run: ./ci/run_pytest_cuda.sh \ No newline at end of file diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions.txt index 8ff7996de5b5..836a990c2aad 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions.txt @@ -34,6 +34,11 @@ race:py_digest_by_name # https://github.com/python/cpython/issues/128714 race:func_get_annotations +race:type_get_annotations + +# https://github.com/python/cpython/issues/129533 +race:PyGC_Disable +race:PyGC_Enable # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index b56fa68ad501..f580e3fe748e 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -44,6 +44,11 @@ jobs: repository: python/cpython path: cpython ref: "3.13" + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: numpy/numpy + path: numpy + submodules: true - name: Restore cached CPython with TSAN id: cache-cpython-tsan-restore @@ -67,7 +72,7 @@ jobs: # Create archive to be used with bazel as hermetic python: cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan - - name: Save CPython with TSAN + - name: Save TSAN CPython id: cache-cpython-tsan-save if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 @@ -76,6 +81,73 @@ jobs: ./python-tsan.tgz key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} + - name: Get year & week number + id: get-date + run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT + shell: bash -l {0} + + - name: Restore cached TSAN Numpy + id: cache-numpy-tsan-restore + uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + - name: Build TSAN Numpy wheel + if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' + run: | + cd numpy + + # If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz + if [ ! -d ${GITHUB_WORKSPACE}/cpython-tsan/bin/ ]; then + echo "Extract cpython from python-tsan.tgz" + pushd . + ls ${GITHUB_WORKSPACE}/python-tsan.tgz + cd ${GITHUB_WORKSPACE} && tar -xzf python-tsan.tgz + ls ${GITHUB_WORKSPACE}/cpython-tsan/bin/ + popd + fi + + export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH + + python3 -m pip install -r requirements/build_requirements.txt + # Make sure to install a compatible Cython version (master branch is best for now) + python3 -m pip install -U git+https://github.com/cython/cython + + CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized + + # Create simple index and copy the wheel + mkdir -p ${GITHUB_WORKSPACE}/wheelhouse/numpy + + numpy_whl_name=($(cd dist && ls numpy*.whl)) + if [ -z "${numpy_whl_name}" ]; then exit 1; fi + + echo "Built TSAN Numpy wheel: ${numpy_whl_name}" + + cp dist/${numpy_whl_name} ${GITHUB_WORKSPACE}/wheelhouse/numpy + + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/index.html + + numpy>
+ + EOF + + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/numpy/index.html + + ${numpy_whl_name}
+ + EOF + + - name: Save TSAN Numpy wheel + id: cache-numpy-tsan-save + if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' + uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + - name: Build Jax and run tests timeout-minutes: 120 env: diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 372b7510e3e7..c0939112a2f3 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -20,9 +20,6 @@ name: CI - Wheel Tests (Continuous) on: schedule: - cron: "0 */2 * * *" # Run once every 2 hours - pull_request: - branches: - - main concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 3ddff54ff7a7..7aecd25bd8bd 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -43,7 +43,6 @@ jobs: python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} gcs_download_uri: ${{inputs.gcs_download_uri}} - install_jax_current_commit: "0" run-pytest-cuda: uses: ./.github/workflows/pytest_cuda.yml @@ -61,5 +60,4 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} - gcs_download_uri: ${{inputs.gcs_download_uri}} - install_jax_current_commit: "0" \ No newline at end of file + gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file diff --git a/build/build.py b/build/build.py index ef9e10b19d32..c2933416d510 100755 --- a/build/build.py +++ b/build/build.py @@ -664,4 +664,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/build/tools/utils.py b/build/tools/utils.py index 52303e53e210..03a762ac3940 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -251,4 +251,4 @@ def _parse_string_as_bool(s): elif lower == "false": return False else: - raise ValueError(f"Expected either 'true' or 'false'; got {s}") \ No newline at end of file + raise ValueError(f"Expected either 'true' or 'false'; got {s}") diff --git a/ci/README.md b/ci/README.md index 48362070855a..ea867df52f97 100644 --- a/ci/README.md +++ b/ci/README.md @@ -7,4 +7,4 @@ > directory are stable and appropriate documentation around its usage is in > place. -******************************************************************************** +******************************************************************************** \ No newline at end of file diff --git a/ci/envs/docker.env b/ci/envs/docker.env index 402bced11d75..a0f558520d45 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -38,22 +38,8 @@ if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest" fi -# Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel tests +# Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel +# tests if [[ $os =~ "msys_nt" ]]; then export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows:latest" -fi - -# Uncomment the following lines if you want to run the GPU tests with Pytest. -# Note that GPU Pytests, as a prequisite, require that the following JAX artifacts be -# present in the $JAXCI_OUTPUT_DIR: jaxlib, jax-cuda-plugin, jax-cuda-pjrt. If you don't -# have these wheels stored there, either build them from source via ci/build_artifacts.sh or -# download them from PyPI into that folder. -# -# Linux x86 image for running Pytest GPU tests -# if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then -# # Choose one of: 12.3, 12.1 -# export JAXCI_DOCKER_CUDA_VERSION=${JAXCI_DOCKER_CUDA_VERSION:-12.3} -# export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda${JAXCI_DOCKER_CUDA_VERSION}-cudnn9.1-ubuntu20.04-manylinux2014-multipython" -# -# export JAXCI_DOCKER_ARGS="--gpus all --shm-size=16g" -# fi +fi \ No newline at end of file diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index ad83ffd53d1d..248111e0247a 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -37,13 +37,6 @@ source "ci/utilities/setup_build_environment.sh" os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) -# Adjust the values when running on Windows x86 to match the config in -# .bazelrc -if [[ $os =~ "msys_nt" ]] && [[ $arch == "x86_64" ]]; then - os="windows" - arch="amd64" -fi - # When running on Mac or Linux Aarch64, we only build the test targets and # not run them. These platforms do not have native RBE support so we # RBE cross-compile them on remote Linux x86 machines. As the tests still diff --git a/ci/run_bazel_test_cuda_rbe.sh b/ci/run_bazel_test_cuda_rbe.sh index 5ce69f0dbf05..17bd8d9db4f8 100755 --- a/ci/run_bazel_test_cuda_rbe.sh +++ b/ci/run_bazel_test_cuda_rbe.sh @@ -48,4 +48,4 @@ bazel test --config=rbe_linux_x86_64_cuda \ --test_env=JAX_SKIP_SLOW_TESTS=true \ --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --color=yes \ - //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests + //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file diff --git a/ci/run_pytest_cuda.sh b/ci/run_pytest_cuda.sh index 70f855df059c..eb815de144f0 100755 --- a/ci/run_pytest_cuda.sh +++ b/ci/run_pytest_cuda.sh @@ -58,4 +58,4 @@ echo "Running CUDA tests..." tests examples \ --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ --deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \ ---deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric \ No newline at end of file +--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric diff --git a/ci/utilities/convert_msys_paths_to_win_paths.py b/ci/utilities/convert_msys_paths_to_win_paths.py index b11bb724bd31..6164e6a5e29d 100644 --- a/ci/utilities/convert_msys_paths_to_win_paths.py +++ b/ci/utilities/convert_msys_paths_to_win_paths.py @@ -77,4 +77,4 @@ def main(parsed_args: argparse.Namespace): help='Space separated list of environment variables to convert. E.g: --convert env_var1 env_var2') args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index d6948249cc8b..8e403941402b 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -97,4 +97,4 @@ function retry { } # Retry "bazel --version" 3 times to avoid flakiness when downloading bazel. -retry "bazel --version" \ No newline at end of file +retry "bazel --version" From 88582f16d5636dc13ee478847e66ba76565e3c8e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 13 Feb 2025 05:19:54 +0000 Subject: [PATCH 201/205] use new mlbuild cuda images --- .github/workflows/bazel_cuda_non_rbe.yml | 2 +- .github/workflows/pytest_cuda.yml | 4 ++-- .github/workflows/wheel_tests_continuous.yml | 13 ++++--------- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 5f846099eefd..89906c338d1b 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -43,7 +43,7 @@ on: jobs: run-tests: runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cudnn9.1-cuda12.3:720686788" env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 6114383b5422..179a66a36d19 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -49,8 +49,8 @@ jobs: run-tests: runs-on: ${{ inputs.runner }} # TODO: Update to the generic ML ecosystem test containers when they are ready. - container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') || - (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') }} + container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cudnn9.1-cuda12.3:720686788') || + (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cudnn9.1-cuda12.1:720686788') }} name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" env: diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index c0939112a2f3..503c574445ce 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -20,6 +20,9 @@ name: CI - Wheel Tests (Continuous) on: schedule: - cron: "0 */2 * * *" # Run once every 2 hours + pull_request: + branches: + - main concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -84,18 +87,10 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Python values need to match the matrix stategy in the artifact build jobs above - runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"] + runner: ["linux-x86-g2-48-l4-4gpu"] python: ["3.10",] cuda: ["12.3", "12.1"] enable-x64: [1, 0] - exclude: - # Run only a single configuration on H100 to save resources - - runner: "linux-x86-a3-8g-h100-8gpu" - python: "3.10" - cuda: "12.1" - - runner: "linux-x86-a3-8g-h100-8gpu" - python: "3.10" - enable-x64: 0 with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} From 1b20f8c6617912e386b2e6571d56291da792f025 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 28 Feb 2025 17:40:45 +0000 Subject: [PATCH 202/205] Get XLA via actions/checkout --- .github/workflows/build_artifacts.yml | 10 ++++++++-- .github/workflows/wheel_tests_continuous.yml | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 73f121725cbf..dfd6e3f95b10 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -109,7 +109,6 @@ jobs: env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" - JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }}) @@ -118,7 +117,14 @@ jobs: gcs_upload_uri: ${{ steps.store-gcs-upload-uri.outputs.gcs_upload_uri }} steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Checkout JAX repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Checkout XLA repository + if: contains("${{ inputs.clone_main_xla }}", '1') + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: openxla/xla + path: jax/xla - name: Enable RBE if building on Linux x86 or Windows x86 if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 503c574445ce..f6a3fad422d5 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -42,7 +42,7 @@ jobs: runner: ${{ matrix.runner }} artifact: ${{ matrix.artifact }} python: ${{ matrix.python }} - clone_main_xla: 1 + clone_main_xla: 0 upload_artifacts_to_gcs: true gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' @@ -59,7 +59,7 @@ jobs: runner: ${{ matrix.runner }} artifact: ${{ matrix.artifact }} python: ${{ matrix.python }} - clone_main_xla: 1 + clone_main_xla: 0 upload_artifacts_to_gcs: true gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' From 80628d63ce073c1b1e62f0bd5036ad0b9017ecf4 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 10 Apr 2025 14:35:25 +0000 Subject: [PATCH 203/205] remove obsolete files --- gsutil.cmd | 146 -------------------------------------------- gsutil.py | 175 ----------------------------------------------------- t.sh | 7 --- test.py | 62 ------------------- 4 files changed, 390 deletions(-) delete mode 100644 gsutil.cmd delete mode 100644 gsutil.py delete mode 100644 t.sh delete mode 100644 test.py diff --git a/gsutil.cmd b/gsutil.cmd deleted file mode 100644 index 0b9136d4ca0f..000000000000 --- a/gsutil.cmd +++ /dev/null @@ -1,146 +0,0 @@ -@echo on -rem Copyright 2013 Google Inc. All Rights Reserved. - -SETLOCAL - -SET CLOUDSDK_PYTHON_SITEPACKAGES=1 - -rem -rem -rem CLOUDSDK_ROOT_DIR (a) installation root dir -rem CLOUDSDK_PYTHON (u) python interpreter path -rem CLOUDSDK_GSUTIL_PYTHON (u) python interpreter path for gsutil -rem CLOUDSDK_PYTHON_ARGS (u) python interpreter arguments -rem CLOUDSDK_PYTHON_SITEPACKAGES (u) use python site packages -rem CLOUDSDK_ENCODING (u) python io encoding for gcloud -rem -rem (a) always defined by the preamble -rem (u) user definition overrides preamble - -rem This command lives in google-cloud-sdk\bin or google-cloud-sdk\ so it or its -rem parent directory is the root. Don't enable DelayedExpansion yet or we -rem destroy PATHs that have exclamation marks. Quotes needed to support -rem ampersands. -IF "%CLOUDSDK_ROOT_DIR%"=="" ( - rem If sourced in install.bat ROOT_DIR is already defined. This case handles - rem setting the ROOT_DIR for every other usecase. - SET "CLOUDSDK_ROOT_DIR=%~dp0.." -) -SET "PATH=%CLOUDSDK_ROOT_DIR%\bin\sdk;%PATH%" - -rem %PYTHONHOME% can interfere with gcloud. Users should use -rem CLOUDSDK_PYTHON to configure which python gcloud uses. -SET PYTHONHOME= - - -SETLOCAL EnableDelayedExpansion - -IF "!CLOUDSDK_PYTHON!"=="" ( - SET BUNDLED_PYTHON=!CLOUDSDK_ROOT_DIR!\platform\bundledpython\python.exe - IF EXIST "!BUNDLED_PYTHON!" ( - SET CLOUDSDK_PYTHON=!BUNDLED_PYTHON! - ) -) - -for %%X in (where.exe) do (set WHERE_FOUND=%%~$PATH:X) - -IF defined WHERE_FOUND ( - IF "!CLOUDSDK_PYTHON!"=="" ( - where /q python - IF NOT ERRORLEVEL 1 ( - FOR /F "tokens=* USEBACKQ" %%F IN (`where python`) DO ( - SET PYTHON_CANDIDATE_PATH=%%F - "!PYTHON_CANDIDATE_PATH!" -c "import sys; print(sys.version)" > tmpfile - set PYTHON_CANDIDATE_VERSION= - set /p PYTHON_CANDIDATE_VERSION= < tmpfile - del tmpfile - IF "!PYTHON_CANDIDATE_VERSION:~0,1!"=="3" ( - SET CLOUDSDK_PYTHON=%%F - SET CLOUDSDK_PYTHON_VERSION="!PYTHON_CANDIDATE_VERSION!" - ) - ) - ) - ) - - IF "!CLOUDSDK_PYTHON!"=="" ( - where /q python3 - IF NOT ERRORLEVEL 1 ( - FOR /F "tokens=* USEBACKQ" %%F IN (`where python3`) DO ( - SET PYTHON_CANDIDATE_PATH=%%F - "!PYTHON_CANDIDATE_PATH!" -c "import sys; print(sys.version)" > tmpfile - set PYTHON_CANDIDATE_VERSION= - set /p PYTHON_CANDIDATE_VERSION= < tmpfile - del tmpfile - IF "!PYTHON_CANDIDATE_VERSION:~0,1!"=="3" ( - SET CLOUDSDK_PYTHON=%%F - SET CLOUDSDK_PYTHON_VERSION="!PYTHON_CANDIDATE_VERSION!" - ) - ) - ) - ) - - IF "!CLOUDSDK_PYTHON!"=="" ( - where /q python - IF NOT ERRORLEVEL 1 ( - FOR /F "tokens=* USEBACKQ" %%F IN (`where python`) DO ( - SET PYTHON_CANDIDATE_PATH=%%F - "!PYTHON_CANDIDATE_PATH!" -c "import sys; print(sys.version)" > tmpfile - set PYTHON_CANDIDATE_VERSION= - set /p PYTHON_CANDIDATE_VERSION= < tmpfile - del tmpfile - IF "!PYTHON_CANDIDATE_VERSION:~0,1!"=="2" ( - SET CLOUDSDK_PYTHON=%%F - SET CLOUDSDK_PYTHON_VERSION="!PYTHON_CANDIDATE_VERSION!" - ) - ) - ) - ) -) - -IF "!CLOUDSDK_PYTHON!"=="" ( - SET CLOUDSDK_PYTHON="python.exe" -) - - -SET NO_WORKING_PYTHON_FOUND="false" -rem We run sys.version to ensure it's not the Windows Store python.exe -"!CLOUDSDK_PYTHON!" -c "import sys; print(sys.version)" >nul 2>&1 -IF NOT %ERRORLEVEL%==0 ( - SET NO_WORKING_PYTHON_FOUND="true" -) - -IF "%CLOUDSDK_PYTHON_SITEPACKAGES%" == "" ( - IF "!VIRTUAL_ENV!" == "" ( - SET CLOUDSDK_PYTHON_SITEPACKAGES= - ) ELSE ( - SET CLOUDSDK_PYTHON_SITEPACKAGES=1 - ) -) -SET CLOUDSDK_PYTHON_ARGS_NO_S=!CLOUDSDK_PYTHON_ARGS:-S=! -IF "%CLOUDSDK_PYTHON_SITEPACKAGES%" == "" ( - IF "!CLOUDSDK_PYTHON_ARGS!" == "" ( - SET CLOUDSDK_PYTHON_ARGS=-S - ) ELSE ( - SET CLOUDSDK_PYTHON_ARGS=!CLOUDSDK_PYTHON_ARGS_NO_S! -S - ) -) ELSE IF "!CLOUDSDK_PYTHON_ARGS!" == "" ( - SET CLOUDSDK_PYTHON_ARGS= -) ELSE ( - SET CLOUDSDK_PYTHON_ARGS=!CLOUDSDK_PYTHON_ARGS_NO_S! -) - -IF "%CLOUDSDK_GSUTIL_PYTHON%" == "" ( - SET CLOUDSDK_GSUTIL_PYTHON=!CLOUDSDK_PYTHON! -) - -IF NOT "%CLOUDSDK_ENCODING%" == "" ( - SET PYTHONIOENCODING=!CLOUDSDK_ENCODING! -) - -SETLOCAL DisableDelayedExpansion - -rem - -"%CLOUDSDK_GSUTIL_PYTHON%" %CLOUDSDK_PYTHON_ARGS% "%CLOUDSDK_ROOT_DIR%\bin\bootstrapping\gsutil.py" "%*" - -"%COMSPEC%" /C exit %ERRORLEVEL% \ No newline at end of file diff --git a/gsutil.py b/gsutil.py deleted file mode 100644 index 31fa97e078e7..000000000000 --- a/gsutil.py +++ /dev/null @@ -1,175 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2013 Google Inc. All Rights Reserved. -# - -"""A convenience wrapper for starting gsutil.""" - -from __future__ import absolute_import -from __future__ import unicode_literals - -import json -import os - - -import bootstrapping -from googlecloudsdk.calliope import exceptions -from googlecloudsdk.core import config -from googlecloudsdk.core import context_aware -from googlecloudsdk.core import log -from googlecloudsdk.core import metrics -from googlecloudsdk.core import properties -from googlecloudsdk.core.credentials import gce as c_gce -from googlecloudsdk.core.util import encoding -from googlecloudsdk.core.util import files - - -def _MaybeAddBotoOption(args, section, name, value): - if value is None: - return - args.append('-o') - args.append('{section}:{name}={value}'.format( - section=section, name=name, value=value)) - - -def _GetCertProviderCommand(context_config): - """Returns the cert provider command from the context config.""" - # TODO(b/190102217) - Cleanup code that handles both version of context_config - if hasattr(context_config, 'cert_provider_command'): - return context_config.cert_provider_command - - try: - contents = files.ReadFileContents(context_config.config_path) - json_out = json.loads(contents) - if 'cert_provider_command' in json_out: - return json_out['cert_provider_command'] - except files.Error as e: - log.debug('context aware settings discovery file %s - %s', - context_config.config_path, e) - - -def _AddContextAwareOptions(args): - """Adds device certificate settings for mTLS.""" - context_config = context_aware.Config() - # Enterprise certificate is not yet supported for gsutil. - if ( - context_config - and context_config.config_type - == context_aware.ConfigType.ENTERPRISE_CERTIFICATE - ): - return - - # TODO(b/190102217) - Cleanup code that handles both version of context_config - use_client_certificate = ( - context_config and - getattr(context_config, 'use_client_certificate', True)) - _MaybeAddBotoOption(args, 'Credentials', 'use_client_certificate', - use_client_certificate) - if context_config: - cert_provider_command = _GetCertProviderCommand(context_config) - if isinstance(cert_provider_command, list): - # e.g. cert_provider_command = ['*/apihelper', '--print_certificate'] - cert_provider_command = ' '.join(cert_provider_command) - # Don't need to pass mTLS data if gsutil shouldn't be using it. - _MaybeAddBotoOption(args, 'Credentials', 'cert_provider_command', - cert_provider_command) - - -def main(): - """Launches gsutil.""" - - args = [] - - project, account = bootstrapping.GetActiveProjectAndAccount() - pass_credentials = ( - properties.VALUES.core.pass_credentials_to_gsutil.GetBool() and - not properties.VALUES.auth.disable_credentials.GetBool()) - - _MaybeAddBotoOption(args, 'GSUtil', 'default_project_id', project) - - if pass_credentials: - # Allow gsutil to only check for the '1' string value, as is done - # with regard to the 'CLOUDSDK_WRAPPER' environment variable. - encoding.SetEncodedValue( - os.environ, 'CLOUDSDK_CORE_PASS_CREDENTIALS_TO_GSUTIL', '1') - - if account in c_gce.Metadata().Accounts(): - # Tell gsutil that it should obtain credentials from the GCE metadata - # server for the instance's configured service account. - _MaybeAddBotoOption(args, 'GoogleCompute', 'service_account', 'default') - # For auth'n debugging purposes, allow gsutil to reason about whether the - # configured service account was set in a boto file or passed from here. - encoding.SetEncodedValue( - os.environ, 'CLOUDSDK_PASSED_GCE_SERVICE_ACCOUNT_TO_GSUTIL', '1') - else: - legacy_config_path = config.Paths().LegacyCredentialsGSUtilPath(account) - # We construct a BOTO_PATH that tacks the config containing our - # credentials options onto the end of the list of config paths. We ensure - # the other credential options are loaded first so that ours will take - # precedence and overwrite them. - boto_config = encoding.GetEncodedValue(os.environ, 'BOTO_CONFIG', '') - boto_path = encoding.GetEncodedValue(os.environ, 'BOTO_PATH', '') - if boto_config: - boto_path = os.pathsep.join([boto_config, legacy_config_path]) - elif boto_path: - boto_path = os.pathsep.join([boto_path, legacy_config_path]) - else: - path_parts = ['/etc/boto.cfg', - os.path.expanduser(os.path.join('~', '.boto')), - legacy_config_path] - boto_path = os.pathsep.join(path_parts) - - encoding.SetEncodedValue(os.environ, 'BOTO_CONFIG', None) - encoding.SetEncodedValue(os.environ, 'BOTO_PATH', boto_path) - - # Tell gsutil whether gcloud analytics collection is enabled. - encoding.SetEncodedValue( - os.environ, 'GA_CID', metrics.GetCIDIfMetricsEnabled()) - - # Set proxy settings. Note that if these proxy settings are configured in a - # boto config file, the options here will be loaded afterward, overriding - # them. - proxy_params = properties.VALUES.proxy - proxy_address = proxy_params.address.Get() - if proxy_address: - _MaybeAddBotoOption(args, 'Boto', 'proxy', proxy_address) - _MaybeAddBotoOption(args, 'Boto', 'proxy_port', proxy_params.port.Get()) - _MaybeAddBotoOption(args, 'Boto', 'proxy_rdns', proxy_params.rdns.GetBool()) - _MaybeAddBotoOption(args, 'Boto', 'proxy_user', proxy_params.username.Get()) - _MaybeAddBotoOption(args, 'Boto', 'proxy_pass', proxy_params.password.Get()) - - # Set SSL-related settings. - disable_ssl = properties.VALUES.auth.disable_ssl_validation.GetBool() - _MaybeAddBotoOption(args, 'Boto', 'https_validate_certificates', - None if disable_ssl is None else not disable_ssl) - _MaybeAddBotoOption(args, 'Boto', 'ca_certificates_file', - properties.VALUES.core.custom_ca_certs_file.Get()) - - # Sync device certificate settings for mTLS. - _AddContextAwareOptions(args) - - # Note that the original args to gsutil will be appended after the args we've - # supplied here. - bootstrapping.ExecutePythonTool('platform/gsutil', 'gsutil', *args) - - -if __name__ == '__main__': - print("\n\n\nI'm here\n\n\n") - try: - version = bootstrapping.ReadFileContents('platform/gsutil', 'VERSION') - bootstrapping.CommandStart('gsutil', version=version) - - blocked_commands = { - 'update': 'To update, run: gcloud components update', - } - - argv = bootstrapping.GetDecodedArgv() - bootstrapping.WarnAndExitOnBlockedCommand(argv, blocked_commands) - - # Don't call bootstrapping.PreRunChecks because anonymous access is - # supported for some endpoints. gsutil will output the appropriate - # error message upon receiving an authentication error. - bootstrapping.CheckUpdates('gsutil') - main() - except Exception as e: # pylint: disable=broad-except - exceptions.HandleError(e, 'gsutil') diff --git a/t.sh b/t.sh deleted file mode 100644 index 93c4bc5c048a..000000000000 --- a/t.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -#if [[ -f $(pwd)/dist/jaxlib-0.4.36.dev20241207-cp310-cp310-manylinux2014_x86_64.whl ]]; then -file=$(find dist/ -type f -name "jaxlib*.whl" -print -quit 2>/dev/null) -if [[ -f $file ]]; then - echo "file found" -fi diff --git a/test.py b/test.py deleted file mode 100644 index 4f4da480fd27..000000000000 --- a/test.py +++ /dev/null @@ -1,62 +0,0 @@ -import collections -import datetime -import pathlib -import re -import time -from typing import Dict, List, Optional, Sequence, Tuple, Union - -from absl import app -from absl import flags -import immutabledict - -from google3.pyglib import gfile - -ROOT_BIGSTORE_RELEASES_PATH = pathlib.Path('/bigstore/jax-releases') - -def find_wheels( - patterns: Sequence[str], is_nightly: bool, bucket: str -) -> List[str]: - """Finds the wheels that match any of the input patterns.""" - wheels = [] - if is_nightly: - wheel_version_keyword = datetime.date.today().strftime('%Y%m%d') - else: - wheel_version_keyword = f'-{jax.jaxlib.version._version}-' # pylint: disable=protected-access - - bucket_path = ROOT_BIGSTORE_RELEASES_PATH / bucket - files = gfile.ListDir(bucket_path) - print("files:", files) -# for file in files: -# if wheel_version_keyword not in file: -# continue -# for pattern in patterns: -# if re.fullmatch(pattern, file): -# wheels.append(str(bucket_path / file)) -# break - - return wheels - -def _move_cuda_plugin_wheels_to_release_bucket(): - """Move cuda plugin wheels from temp_wheels bucket to release bucket.""" - wheels = find_wheels( - patterns=['.*jax_cuda.*[pjrt|plugin].*'], is_nightly=False, - bucket='temp_wheels' - ) - print("wheels:", wheels) - # for wheel_path in wheels: - # # TODO(jieying): make the bucket renaming here more flexible for future - # # cuda versions. - # if 'cuda12' in wheel_path: - # release_bucket = 'cuda12_plugin' - # release_file_path = wheel_path.replace('temp_wheels', release_bucket) - # gfile.Rename(wheel_path, release_file_path, overwrite=True) - # else: - # print(f'Unexpected wheel found: {wheel_path}') - - -def main(argv: Sequence[str]) -> None: - _move_cuda_plugin_wheels_to_release_bucket() - - -if __name__ == '__main__': - app.run(main) From f03298ee39bd746c7d1dc6e2b11485694ceecf14 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 10 Apr 2025 15:02:51 +0000 Subject: [PATCH 204/205] Switch the CUDA workflow to use a non-cuda docker image to test jax install breakage --- .github/workflows/pytest_cuda.yml | 3 ++- .github/workflows/wheel_tests_continuous.yml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 28e11be8b437..57a4a1674b51 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -58,7 +58,8 @@ jobs: runs-on: ${{ inputs.runner }} # Test the oldest and newest supported CUDA versions. container: ${{ (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.8:latest') || - (contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }} + (contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') || + (contains(inputs.cuda, 'CUDA pip') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') }} name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 654436aac25a..ec6a9212af9e 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -116,7 +116,7 @@ jobs: # See exlusions for what is fully tested runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] python: ["3.10",] - cuda: ["12.1", "12.8"] + cuda: ["CUDA pip",] #"12.1", "12.8"] enable-x64: [1, 0] exclude: # H100 runs only a single config, CUDA 12.8 Enable x64 1 From 919155fa33fbd76c090a1cf7dc97eaf2a656780d Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 11 Apr 2025 04:23:36 +0000 Subject: [PATCH 205/205] Revert cuda changes done to test wheel breakage --- .github/workflows/pytest_cuda.yml | 3 +-- .github/workflows/wheel_tests_continuous.yml | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index 57a4a1674b51..f885e579f6b9 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -58,8 +58,7 @@ jobs: runs-on: ${{ inputs.runner }} # Test the oldest and newest supported CUDA versions. container: ${{ (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.8:latest') || - (contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') || - (contains(inputs.cuda, 'CUDA pip') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') }} + (contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest')}} name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 62a366822563..0df096e1e019 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -117,7 +117,7 @@ jobs: # Comment out H100 and B200 on jax-fork to save on resources runner: ["linux-x86-g2-48-l4-4gpu"] #, "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] python: ["3.10",] - cuda: ["CUDA pip",] #"12.1", "12.8"] + cuda: ["12.1", "12.8"] enable-x64: [1, 0] # exclude: # # H100 runs only a single config, CUDA 12.8 Enable x64 1