diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml deleted file mode 100644 index bae03def7d53..000000000000 --- a/.github/workflows/bazel_cpu_rbe.yml +++ /dev/null @@ -1,46 +0,0 @@ -name: Run Bazel CPU tests (RBE) - -on: - # pull_request: - # branches: - # - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - - 'no' - -jobs: - run_bazel_rbe_cpu_tests: - continue-on-error: true - defaults: - run: - # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. - shell: bash - strategy: - matrix: - runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] - - runs-on: ${{ matrix.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') || - (contains(matrix.runner, 'windows-x86') && null) }} - - env: - JAXCI_CLONE_MAIN_XLA: 1 - JAXCI_HERMETIC_PYTHON_VERSION: "3.12" - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Run Bazel CPU Tests - run: ./ci/run_bazel_test_cpu_rbe.sh diff --git a/.github/workflows/bazel_gpu_non_rbe.yml b/.github/workflows/bazel_gpu_non_rbe.yml deleted file mode 100644 index ba11ac486001..000000000000 --- a/.github/workflows/bazel_gpu_non_rbe.yml +++ /dev/null @@ -1,51 +0,0 @@ -name: Run Bazel GPU tests (non RBE) - -on: - # pull_request: - # branches: - # - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - - 'no' - -jobs: - build: - strategy: - matrix: - runner: ["linux-x86-g2-48-l4-4gpu"] - - runs-on: ${{ matrix.runner }} - container: - image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" - - env: - JAXCI_HERMETIC_PYTHON_VERSION: 3.11 - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build jaxlib - env: - JAXCI_CLONE_MAIN_XLA: 1 - run: ./ci/build_artifacts.sh "jaxlib" - - name: Build jax-cuda-plugin - env: - JAXCI_CLONE_MAIN_XLA: 1 - run: ./ci/build_artifacts.sh "jax-cuda-plugin" - - name: Build jax-cuda-pjrt - env: - JAXCI_CLONE_MAIN_XLA: 1 - run: ./ci/build_artifacts.sh "jax-cuda-pjrt" - - name: Run Bazel GPU tests locally - run: ./ci/run_bazel_test_gpu_non_rbe.sh diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml deleted file mode 100644 index b21d09fed91d..000000000000 --- a/.github/workflows/bazel_gpu_rbe.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Run Bazel GPU tests (RBE) - -on: - # pull_request: - # branches: - # - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - - 'no' - -jobs: - build: - strategy: - matrix: - runner: ["linux-x86-n2-16"] - - runs-on: ${{ matrix.runner }} - container: - image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" - - env: - JAXCI_CLONE_MAIN_XLA: 1 - JAXCI_HERMETIC_PYTHON_VERSION: 3.12 - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Run Bazel GPU tests using RBE - run: ./ci/run_bazel_test_gpu_rbe.sh diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml deleted file mode 100644 index fc4e2df2bda4..000000000000 --- a/.github/workflows/build_artifacts.yml +++ /dev/null @@ -1,74 +0,0 @@ -name: Build JAX Artifacts - -on: - # pull_request: - # branches: - # - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - - 'no' - workflow_call: - -jobs: - build: - continue-on-error: true - defaults: - run: - # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. - shell: bash - strategy: - matrix: - runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] - artifact: ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] - python: ["3.10", "3.11", "3.12"] - # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each - # Python version. - exclude: - # Pure Python packages do not need to be built for each Python version. - - artifact: "jax-cuda-pjrt" - python: "3.10" - - artifact: "jax-cuda-pjrt" - python: "3.11" - - artifact: "jax" - python: "3.10" - - artifact: "jax" - python: "3.11" - # jax is a pure Python package so it does not need to be built on multiple platforms. - - artifact: "jax" - runner: "windows-x86-n2-64" - - artifact: "jax" - runner: "linux-arm64-t2a-16" - # jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows. - - artifact: "jax-cuda-plugin" - runner: "windows-x86-n2-64" - - artifact: "jax-cuda-pjrt" - runner: "windows-x86-n2-64" - - runs-on: ${{ matrix.runner }} - - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') || - (contains(matrix.runner, 'windows-x86') && null) }} - - env: - # Do not run Docker container for Linux runners. Linux runners already run in a Docker container. - JAXCI_RUN_DOCKER_CONTAINER: 0 - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build ${{ matrix.artifact }} - env: - JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" - run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml deleted file mode 100644 index db1477ac38b1..000000000000 --- a/.github/workflows/ci-build.yaml +++ /dev/null @@ -1,261 +0,0 @@ -name: CI - -# We test all supported Python versions as follows: -# - 3.10 : Documentation build -# - 3.10 : Part of Matrix with NumPy dispatch -# - 3.10 : Part of Matrix -# - 3.11 : Part of Matrix - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - pull_request: - branches: - - main - -permissions: - contents: read # to fetch code - actions: write # to cancel previous workflows - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -jobs: - lint_and_typecheck: - runs-on: ubuntu-latest - timeout-minutes: 5 - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python 3.11 - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: 3.11 - - run: python -m pip install pre-commit - - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 - with: - path: ~/.cache/pre-commit - key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} - - run: pre-commit run --show-diff-on-failure --color=always --all-files - - build: - name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" - runs-on: linux-x86-n2-32 - container: - image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 - timeout-minutes: 60 - strategy: - matrix: - # Test the oldest and newest supported Python versions here. - include: - - name-prefix: "with 3.10" - python-version: "3.10" - enable-x64: 1 - prng-upgrade: 1 - num_generated_cases: 1 - - name-prefix: "with 3.13" - python-version: "3.13" - enable-x64: 0 - prng-upgrade: 0 - num_generated_cases: 1 - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Image Setup - run: | - apt update - apt install -y libssl-dev - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Get pip cache dir - id: pip-cache - run: | - python -m pip install --upgrade pip wheel - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - name: pip cache - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} - - name: Install dependencies - run: | - pip install .[minimum-jaxlib] -r build/test-requirements.txt - - - name: Run tests - env: - JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }} - JAX_ENABLE_X64: ${{ matrix.enable-x64 }} - JAX_ENABLE_CUSTOM_PRNG: ${{ matrix.prng-upgrade }} - JAX_THREEFRY_PARTITIONABLE: ${{ matrix.prng-upgrade }} - JAX_ENABLE_CHECKS: true - JAX_SKIP_SLOW_TESTS: true - PY_COLORS: 1 - run: | - pip install -e . - echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" - echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" - echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG" - echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE" - echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" - echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" - pytest -n auto --tb=short --maxfail=20 tests examples - - - documentation: - name: Documentation - test code snippets - runs-on: ubuntu-latest - timeout-minutes: 10 - strategy: - matrix: - python-version: ['3.10'] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Get pip cache dir - id: pip-cache - run: | - python -m pip install --upgrade pip wheel - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - name: pip cache - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} - - name: Install dependencies - run: | - pip install -r docs/requirements.txt - - name: Test documentation - env: - XLA_FLAGS: "--xla_force_host_platform_device_count=8" - JAX_TRACEBACK_FILTERING: "off" - JAX_ARRAY: 1 - PY_COLORS: 1 - run: | - pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst - pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/array_api --ignore=jax/lib/xla_extension.py - - - documentation_render: - name: Documentation - render documentation - runs-on: ubuntu-latest - timeout-minutes: 10 - strategy: - matrix: - python-version: ['3.10'] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Get pip cache dir - id: pip-cache - run: | - python -m pip install --upgrade pip wheel - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - name: pip cache - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} - - name: Install dependencies - run: | - pip install -r docs/requirements.txt - - name: Render documentation - run: | - sphinx-build --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html - - - jax2tf_test: - name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})" - runs-on: ${{ matrix.os }} - timeout-minutes: 30 - strategy: - matrix: - # Test the oldest supported Python version here. - include: - - python-version: "3.10" - os: ubuntu-latest - enable-x64: 0 - num_generated_cases: 10 - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Get pip cache dir - id: pip-cache - run: | - python -m pip install --upgrade pip wheel - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - name: pip cache - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} - - name: Install dependencies - run: | - pip install .[minimum-jaxlib] tensorflow -r build/test-requirements.txt - - - name: Run tests - env: - JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }} - JAX_ENABLE_X64: ${{ matrix.enable-x64 }} - JAX_ENABLE_CHECKS: true - JAX_SKIP_SLOW_TESTS: true - PY_COLORS: 1 - run: | - pip install -e . - echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" - echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" - echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" - echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" - pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py - - ffi: - name: FFI example - runs-on: linux-x86-g2-16-l4-1gpu - container: - image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12 - timeout-minutes: 30 - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: 3.12 - - name: Get pip cache dir - id: pip-cache - run: | - python -m pip install --upgrade pip wheel - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - name: pip cache - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }} - - name: Install JAX - run: pip install .[cuda12] - - name: Build and install example project - run: python -m pip install -v ./examples/ffi[test] - env: - # We test building using GCC instead of clang. All other JAX builds use - # clang, but it is useful to make sure that FFI users can compile using - # a different toolchain. GCC is the default compiler on the - # 'ubuntu-latest' runner, but we still set this explicitly just to be - # clear. - CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON - - name: Run CPU tests - run: python -m pytest examples/ffi/tests - env: - JAX_PLATFORM_NAME: cpu - - name: Run GPU tests - run: python -m pytest examples/ffi/tests diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml deleted file mode 100644 index a5fac5ebdbc3..000000000000 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ /dev/null @@ -1,103 +0,0 @@ -# Cloud TPU CI -# -# This job currently runs once per day. We use self-hosted TPU runners, so we'd -# have to add more runners to run on every commit. -# -# This job's build matrix runs over several TPU architectures using both the -# latest released jaxlib on PyPi ("pypi_latest") and the latest nightly -# jaxlib.("nightly"). It also installs a matching libtpu, either the one pinned -# to the release for "pypi_latest", or the latest nightly.for "nightly". It -# always locally installs jax from github head (already checked out by the -# Github Actions environment). - -name: CI - Cloud TPU (nightly) -on: - schedule: - - cron: "0 14 * * *" # daily at 7am PST - workflow_dispatch: # allows triggering the workflow run manually -# This should also be set to read-only in the project settings, but it's nice to -# document and enforce the permissions here. -permissions: - contents: read -jobs: - cloud-tpu-test: - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] - tpu: [ - {type: "v3-8", cores: "4"}, - {type: "v4-8", cores: "4"}, - {type: "v5e-8", cores: "8"} - ] - name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" - env: - LIBTPU_OLDEST_VERSION_DATE: 20240722 - ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }} - runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"] - timeout-minutes: 120 - defaults: - run: - shell: bash -ex {0} - steps: - # https://opensource.google/documentation/reference/github/services#actions - # mandates using a specific commit for non-Google actions. We use - # https://github.com/sethvargo/ratchet to pin specific versions. - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install JAX test requirements - run: | - pip install -U -r build/test-requirements.txt - pip install -U -r build/collect-profile-requirements.txt - - name: Install JAX - run: | - pip uninstall -y jax jaxlib libtpu - if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then - pip install .[tpu] \ - -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - - elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then - pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - pip install --pre libtpu \ - -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install requests - - elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then - pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release. - pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ - -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install requests - else - echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}" - exit 1 - fi - - python3 -c 'import sys; print("python version:", sys.version)' - python3 -c 'import jax; print("jax version:", jax.__version__)' - python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' - strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on' - python3 -c 'import jax; print("libtpu version:", - jax.lib.xla_bridge.get_backend().platform_version)' - - name: Run tests - env: - JAX_PLATFORMS: tpu,cpu - PY_COLORS: 1 - run: | - # Run single-accelerator tests in parallel - JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \ - --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ - --maxfail=20 -m "not multiaccelerator" tests examples - # Run Pallas printing tests, which need to run with I/O capturing disabled. - TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \ - tests/pallas/tpu_pallas_test.py::PallasCallPrintTest - # Run multi-accelerator across all chips - python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests - - name: Send chat on failure - # Don't notify when testing the workflow from a branch. - if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }} - run: | - curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \ - --header 'Content-Type: application/json' \ - --data-raw "{ - 'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu.type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID' - }" diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml deleted file mode 100644 index ec2dbfa7686b..000000000000 --- a/.github/workflows/pytest_cpu.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: Run Pytest CPU tests - -on: - # pull_request: - # branches: - # - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - - 'no' - -jobs: - build: - continue-on-error: true - defaults: - run: - # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. - shell: bash - strategy: - matrix: - runner: ["windows-x86-n2-64", "linux-x86-n2-64", "linux-arm64-t2a-48"] - python: ["3.10"] - - runs-on: ${{ matrix.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') || - (contains(matrix.runner, 'windows-x86') && null) }} - - env: - JAXCI_CLONE_MAIN_XLA: 1 - JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build jaxlib - run: ./ci/build_artifacts.sh "jaxlib" - - name: Install pytest - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install pytest - - name: Install dependencies - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install -r build/requirements.in - - name: Run Pytest CPU tests - run: ./ci/run_pytest_cpu.sh diff --git a/.github/workflows/pytest_cpu_reuse.yml b/.github/workflows/pytest_cpu_reuse.yml deleted file mode 100644 index f09d53d0a96a..000000000000 --- a/.github/workflows/pytest_cpu_reuse.yml +++ /dev/null @@ -1,58 +0,0 @@ -name: Run Pytest CPU tests (resuable workflow) - -on: - # pull_request: - # branches: - # - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - -jobs: - build_jaxlib_artifacts: - uses: ./.github/workflows/build_artifacts.yml - - run_pytest: - needs: build_jaxlib_artifacts - continue-on-error: true - defaults: - run: - # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. - shell: bash - strategy: - matrix: - runner: ["windows-x86-n2-64", "linux-x86-n2-64", "linux-arm64-t2a-48"] - python: ["3.10"] - - runs-on: ${{ matrix.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') || - (contains(matrix.runner, 'windows-x86') && null) }} - - env: - JAXCI_CLONE_MAIN_XLA: 1 - JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Install pytest - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install pytest - - name: Install dependencies - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install -r build/requirements.in - - name: Run Pytest CPU tests - run: ./ci/run_pytest.sh "ci/envs/run_tests/pytest_cpu" diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml index ac6d9e39e168..e2241bf4deb7 100644 --- a/.github/workflows/pytest_gpu.yml +++ b/.github/workflows/pytest_gpu.yml @@ -1,9 +1,7 @@ name: Run Pytest GPU tests on: - # pull_request: - # branches: - # - main + pull_request: workflow_dispatch: inputs: halt-for-connection: @@ -21,27 +19,23 @@ jobs: matrix: python: ["3.10"] - runs-on: "linux-x86-g2-48-l4-4gpu" - container: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + runs-on: "mike-x86-g2-48-l4-4gpu" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" name: "Pytest GPU (Build wheels on CUDA 12.3)" env: - JAXCI_CLONE_MAIN_XLA: 1 + JAXCI_CLONE_MAIN_XLA: 0 JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} steps: - uses: actions/checkout@v3 # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Build jaxlib - run: ./ci/build_artifacts.sh "jaxlib" + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib.env" - name: Build jax-cuda-plugin - run: ./ci/build_artifacts.sh "jax-cuda-plugin" + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-plugin.env" - name: Build jax-cuda-pjrt - run: ./ci/build_artifacts.sh "jax-cuda-pjrt" + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-pjrt.env" - name: Install pytest env: JAXCI_PYTHON: python${{ matrix.python }} @@ -50,43 +44,9 @@ jobs: env: JAXCI_PYTHON: python${{ matrix.python }} run: $JAXCI_PYTHON -m pip install -r build/requirements.in + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest GPU tests run: ./ci/run_pytest_gpu.sh - - # run_tests: - # needs: build_artifacts - # strategy: - # matrix: - # test_env: [ - # {cuda_version: "12.3", runner: "linux-x86-g2-48-l4-4gpu", - # image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, - # {cuda_version: "12.1", runner: "linux-x86-g2-48-l4-4gpu", - # image: "gcr.io/tensorflow-testing/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"}, - # ] - # python: ["3.10"] - - # runs-on: ${{ matrix.test_env.runner }} - # container: - # image: ${{ matrix.test_env.image }} - - # name: "Pytest GPU (Test on CUDA ${{ matrix.test_env.cuda_version }})" - # env: - # JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} - - # steps: - # - uses: actions/checkout@v3 - # # Halt for testing - # - name: Wait For Connection - # uses: google-ml-infra/actions/ci_connection@main - # with: - # halt-dispatch-input: ${{ inputs.halt-for-connection }} - # - name: Install pytest - # env: - # JAXCI_PYTHON: python${{ matrix.python }} - # run: $JAXCI_PYTHON -m pip install pytest - # - name: Install dependencies - # env: - # JAXCI_PYTHON: python${{ matrix.python }} - # run: $JAXCI_PYTHON -m pip install -r build/requirements.in - # - name: Run Pytest GPU tests - # run: ./ci/run_pytest_gpu.sh diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml deleted file mode 100644 index 2f97aec4196d..000000000000 --- a/.github/workflows/pytest_tpu.yml +++ /dev/null @@ -1,60 +0,0 @@ -name: Run Pytest TPU tests - -on: - # pull_request: - # branches: - # - main - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - - 'no' - -jobs: - run_tests: - strategy: - matrix: - runner: ["linux-x86-ct5lp-224-8tpu"] - tpu_cores: ["8"] - python: ["3.10"] - - runs-on: ${{ matrix.runner }} - container: - image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" - - env: - JAXCI_CLONE_MAIN_XLA: 1 - JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} - - steps: - - uses: actions/checkout@v3 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Build jaxlib - run: ./ci/build_artifacts.sh "jaxlib" - - name: Install pytest - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install pytest - - name: Install Test requirements - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: | - $JAXCI_PYTHON -m pip install -r build/test-requirements.txt - $JAXCI_PYTHON -m pip install -r build/collect-profile-requirements.txt - - name: Install Libtpu - env: - JAXCI_PYTHON: python${{ matrix.python }} - run: $JAXCI_PYTHON -m pip install --pre libtpu-nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html" - - name: Run Pytest TPU tests - env: - JAXCI_TPU_CORES: ${{ matrix.tpu_cores }} - run: ./ci/run_pytest_tpu.sh diff --git a/actions/ci_connection/notify_connection.py b/actions/ci_connection/notify_connection.py index 52c7b58f43cf..db851341cc8e 100644 --- a/actions/ci_connection/notify_connection.py +++ b/actions/ci_connection/notify_connection.py @@ -14,6 +14,7 @@ import time import threading +import os import subprocess from multiprocessing.connection import Client @@ -42,6 +43,20 @@ def timer(conn): timer_thread.start() print("Entering interactive bash session") + + # Hard-coded for now for demo purposes. + next_command = "bash ci/build_artifacts.sh" + # Print the "next" commands to be run + # TODO: actually get this data from workflow files + print(f"The next command that would have run is:\n\n{next_command}") + + # Set the hardcoded envs for testing purposes + # TODO: sync env vars + sub_env = os.environ.copy() + sub_env["ENV_FILE"] = "ci/envs/build_artifacts/jaxlib" + sub_env["JAXCI_USE_DOCKER"] = "0" + sub_env["JAXCI_USE_RBE"] = "1" + # Enter interactive bash session subprocess.run(["bash", "-i"]) diff --git a/build/build.py b/build/build.py old mode 100755 new mode 100644 index dd53332613ef..048eac3e393e --- a/build/build.py +++ b/build/build.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# CLI for building JAX wheel packages from source and for updating the -# requirements_lock.txt files +# CLI for building jaxlib, jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, +# jax-rocm-pjrt and for updating the requirements_lock.txt files. import argparse import asyncio @@ -44,14 +44,11 @@ EPILOG = """ From the root directory of the JAX repository, run - `python build/build.py build --wheels=` to build JAX - artifacts. + python build/build.py [jaxlib | jax-cuda-plugin | jax-cuda-pjrt | jax-rocm-plugin | jax-rocm-pjrt] - Multiple wheels can be built with a single invocation of the CLI. - E.g. python build/build.py build --wheels=jaxlib,jax-cuda-plugin - - To update the requirements_lock.txt files, run - `python build/build.py requirements_update` + to build one of: jaxlib, jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, jax-rocm-pjrt +or + python build/build.py requirements_update to update the requirements_lock.txt """ # Define the build target for each artifact. @@ -113,13 +110,13 @@ def add_global_arguments(parser: argparse.ArgumentParser): ) bazel_group.add_argument( - "--bazel_options", + "--bazel_build_options", action="append", default=[], help=""" Additional build options to pass to Bazel, can be specified multiple times to pass multiple options. - E.g. --bazel_options='--local_resources=HOST_CPUS' + E.g. --bazel_build_options='--local_resources=HOST_CPUS' """, ) @@ -139,13 +136,13 @@ def add_global_arguments(parser: argparse.ArgumentParser): def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): """Adds all the global arguments that applies to the artifact subcommands.""" parser.add_argument( - "--wheels", + "--wheel_list", type=str, default="jaxlib", help= f""" - A comma separated list of JAX artifacts to build. E.g: --wheels="jaxlib", - --wheels="jaxlib,jax-cuda-plugin", etc. + A comma seprated list of JAX artifacts to build. E.g: --wheel_list="jaxlib", + --wheel_list="jaxlib,jax-cuda-plugin", etc. Valid options are: {','.join(ARTIFACT_BUILD_TARGET_DICT.keys())} """, ) @@ -177,26 +174,22 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): cuda_group.add_argument( "--cuda_version", type=str, - # LINT.IfChange(cuda_version) - default="12.3.2", - # LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc) + default=None, help= """ Hermetic CUDA version to use. Default is to use the version specified - in the .bazelrc (12.3.2). + in the .bazelrc. """, ) cuda_group.add_argument( "--cudnn_version", type=str, - # LINT.IfChange(cudnn_version) - default="9.1.1", - # LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc) + default=None, help= """ Hermetic cuDNN version to use. Default is to use the version specified - in the .bazelrc (9.1.1). + in the .bazelrc. """, ) @@ -222,7 +215,8 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): action="store_true", help=""" Should CUDA code be compiled using Clang? The default behavior is to - compile CUDA with NVCC. + compile CUDA with NVCC. Ignored if --use_ci_bazelrc_flags is set, CI + builds always build CUDA with NVCC in CI builds. """, ) @@ -251,13 +245,24 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): # Compile Options compile_group = parser.add_argument_group('Compile Options') + compile_group.add_argument( + "--use_ci_bazelrc_flags", + action="store_true", + help=""" + When set, the CLI will assume the build is being run in CI or CI like + environment and will use the "rbe_/ci_" configs in the .bazelrc. These + configs apply release features and set a custom C++ Clang toolchain. + Only supported for jaxlib and CUDA builds. + """, + ) compile_group.add_argument( "--clang_path", type=str, default="", help=""" - Path to the Clang binary to use. + Path to the Clang binary to use. Ignored if --use_ci_bazelrc_flags, CI + bazelrc flags set a custom Clang toolchain. """, ) @@ -265,7 +270,8 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): "--disable_mkl_dnn", action="store_true", help=""" - Disables MKL-DNN. + Disables MKL-DNN. Ignored if --use_ci_bazelrc_flags is set, CI bazelrc + flags enable MKL-DNN as default. """, ) @@ -279,7 +285,8 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): enables AVX. Native enables -march=native, which generates code targeted to use all features of the current machine. Default means don't opt-in to any architectural features and use whatever the C compiler generates - by default. + by default. Ignored if --use_ci_bazelrc_flags is set, CI bazelrc flags + enable release CPU features as default. """, ) @@ -299,17 +306,69 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): """, ) +def apply_compile_flags_non_ci(bazel_command: command.CommandBuilder, wheel: str, clang_path: str, disable_mkl_dnn: bool, build_cuda_with_clang: bool,\ + target_cpu_features: str, os_name: str, arch: str): + clang_path = clang_path or utils.get_clang_path_or_exit() + logging.debug("Using Clang as the compiler, clang path: %s", clang_path) + # Use double quotes around clang path to avoid path issues on Windows. + bazel_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + bazel_command.append(f"--repo_env=CC=\"{clang_path}\"") + bazel_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") + # Do not apply --config=clang on Mac as these settings do not apply to + # Apple Clang. + if os_name != "darwin": + bazel_command.append("--config=clang") + + if not disable_mkl_dnn: + logging.debug("Enabling MKL DNN") + bazel_command.append("--config=mkl_open_source_only") + + if "cuda" in wheel: + bazel_command.append("--config=cuda") + bazel_command.append( + f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" + ) + if build_cuda_with_clang: + logging.debug("Building CUDA with Clang") + bazel_command.append("--config=build_cuda_with_clang") + else: + logging.debug("Building CUDA with NVCC") + bazel_command.append("--config=build_cuda_with_nvcc") + + if target_cpu_features == "release": + logging.debug( + "Using release cpu features: --config=avx_%s", + "windows" if os_name == "windows" else "posix", + ) + if arch in ["x86_64", "AMD64"]: + bazel_command.append( + "--config=avx_windows" + if os_name == "windows" + else "--config=avx_posix" + ) + elif target_cpu_features == "native": + if os_name == "windows": + logger.warning( + "--target_cpu_features=native is not supported on Windows;" + " ignoring." + ) + else: + logging.debug("Using native cpu features: --config=native_arch_posix") + bazel_command.append("--config=native_arch_posix") + else: + logging.debug("Using default cpu features") + async def main(): parser = argparse.ArgumentParser( description=r""" - CLI for building JAX wheel packages from source and for updating the - requirements_lock.txt files + CLI for building JAX wheel packages from source: jaxlib, + jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, jax-rocm-pjrt and for + updating the requirements_lock.txt files """, epilog=EPILOG, - formatter_class=argparse.RawDescriptionHelpFormatter ) - # Create subparsers for build_artifacts and requirements_update + # Create subparsers for jax, jaxlib, plugin, pjrt and requirements_update subparsers = parser.add_subparsers(dest="command", required=True) # requirements_update subcommand @@ -319,14 +378,16 @@ async def main(): add_requirements_nightly_update_argument(requirements_update_parser) add_global_arguments(requirements_update_parser) - # Artifact build subcommand + # Build Artifact subcommand build_artifact_parser = subparsers.add_parser( - "build", help="Builds the jaxlib, plugin, and pjrt artifact" + "build_artifacts", help="Builds the jaxlib, plugin, PJRT artifact" ) add_artifact_subcommand_global_arguments(build_artifact_parser) add_global_arguments(build_artifact_parser) arch = platform.machine() + # Switch to lower case to match the case for the "ci_"/"rbe_" configs in the + # .bazelrc. os_name = platform.system().lower() args = parser.parse_args() @@ -369,11 +430,11 @@ async def main(): # Requirements update subcommand execution if args.command == "requirements_update": requirements_command = copy.deepcopy(bazel_command_base) - if args.bazel_options: + if args.bazel_build_options: logging.debug( - "Using additional build options: %s", args.bazel_options + "Using additional build options: %s", args.bazel_build_options ) - for option in args.bazel_options: + for option in args.bazel_build_options: requirements_command.append(option) if args.nightly_update: @@ -402,6 +463,13 @@ async def main(): logging.debug("Local XLA path: %s", args.local_xla_path) bazel_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"") + if args.bazel_build_options: + logging.debug( + "Additional Bazel build options: %s", args.bazel_build_options + ) + for option in args.bazel_build_options: + bazel_command_base.append(option) + if args.target_cpu: logging.debug("Target CPU: %s", args.target_cpu) bazel_command_base.append(f"--cpu={args.target_cpu}") @@ -410,14 +478,11 @@ async def main(): logging.debug("Disabling NCCL") bazel_command_base.append("--config=nonccl") - git_hash = utils.get_githash() - # Wheel build command execution - for wheel in args.wheels.split(","): + for wheel in args.wheel_list.split(","): if wheel not in ARTIFACT_BUILD_TARGET_DICT.keys(): logging.error("Incorrect wheel name provided: %s, valid choices are: %s", wheel, ",".join(ARTIFACT_BUILD_TARGET_DICT.keys())) - sys.exit(1) - + continue wheel_build_command = copy.deepcopy(bazel_command_base) print("\n") logger.info( @@ -425,60 +490,29 @@ async def main(): wheel, os_name, arch, - ) - - clang_path = args.clang_path or utils.get_clang_path_or_exit() - logging.debug("Using Clang as the compiler, clang path: %s", clang_path) - - # Use double quotes around clang path to avoid path issues on Windows. - wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") - wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"") - wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") - - # Do not apply --config=clang on Mac as these settings do not apply to - # Apple Clang. - if os_name != "darwin": - wheel_build_command.append("--config=clang") - - if not args.disable_mkl_dnn: - logging.debug("Enabling MKL DNN") - wheel_build_command.append("--config=mkl_open_source_only") - - if args.target_cpu_features == "release": - logging.debug( - "Using release cpu features: --config=avx_%s", - "windows" if os_name == "windows" else "posix", - ) - if arch in ["x86_64", "AMD64"]: - wheel_build_command.append( - "--config=avx_windows" - if os_name == "windows" - else "--config=avx_posix" - ) - elif wheel_build_command == "native": - if os_name == "windows": - logger.warning( - "--target_cpu_features=native is not supported on Windows;" - " ignoring." - ) - else: - logging.debug("Using native cpu features: --config=native_arch_posix") - wheel_build_command.append("--config=native_arch_posix") + ) + # If running in CI, we use the "ci_"/"rbe_" configs in the .bazelrc. + # These set a custom C++ Clang toolchain and the CUDA compiler to NVCC + # When not running in CI, we detect the path to Clang binary and pass it + # to Bazel to use as the C++ compiler. NVCC is used as the CUDA compiler + # unless the user explicitly sets --config=build_cuda_with_clang. + if args.use_ci_bazelrc_flags and "rocm" not in wheel: + bazelrc_config = utils.get_ci_bazelrc_config(os_name, arch.lower(), wheel) + logging.info("--use_ci_bazelrc_flags is set, using --config=%s from .bazelrc", bazelrc_config) + wheel_build_command.append(f"--config={bazelrc_config}") else: - logging.debug("Using default cpu features") + apply_compile_flags_non_ci( + wheel_build_command, + wheel, + args.clang_path, + args.disable_mkl_dnn, + args.build_cuda_with_clang, + args.target_cpu_features, + os_name, + arch, + ) if "cuda" in wheel: - wheel_build_command.append("--config=cuda") - wheel_build_command.append( - f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" - ) - if args.build_cuda_with_clang: - logging.debug("Building CUDA with Clang") - wheel_build_command.append("--config=build_cuda_with_clang") - else: - logging.debug("Building CUDA with NVCC") - wheel_build_command.append("--config=build_cuda_with_nvcc") - if args.cuda_version: logging.debug("Hermetic CUDA version: %s", args.cuda_version) wheel_build_command.append( @@ -509,15 +543,6 @@ async def main(): f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" ) - # Append additional build options at the end to override any options set in - # .bazelrc or above. - if args.bazel_options: - logging.debug( - "Additional Bazel build options: %s", args.bazel_options - ) - for option in args.bazel_options: - wheel_build_command.append(option) - if args.configure_only: with open(".jax_configure.bazelrc", "w") as f: jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list()) @@ -547,13 +572,17 @@ async def main(): if "cuda" in wheel: wheel_build_command.append("--enable-cuda=True") - cuda_major_version = args.cuda_version.split(".")[0] + if args.cuda_version: + cuda_major_version = args.cuda_version.split(".")[0] + else: + cuda_major_version = utils.get_cuda_major_version() wheel_build_command.append(f"--platform_version={cuda_major_version}") if "rocm" in wheel: wheel_build_command.append("--enable-rocm=True") wheel_build_command.append(f"--platform_version={args.rocm_version}") + git_hash = utils.get_githash() wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") await executor.run(wheel_build_command.get_command_as_string(), args.dry_run) diff --git a/build/tools/utils.py b/build/tools/utils.py index 8fa29e8d5c7c..aeb0e8bae437 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -189,6 +189,43 @@ def get_clang_path_or_exit(): sys.exit(-1) +def get_cuda_major_version(): + """Extract the CUDA major version from the .bazelrc""" + with open(".bazelrc", "r") as f: + for line in f: + match = re.search(r'HERMETIC_CUDA_VERSION="([^"]+)"', line) + if match: + cuda_version=match.group(1) + return cuda_version.split(".")[0] + return None + + +def get_ci_bazelrc_config(os_name: str, arch: str, artifact: str): + """Returns the bazelrc config for the given architecture and OS. + + Used in CI builds to retrieve either the "ci_"/"rbe_" configs from the + .bazelrc + """ + + bazelrc_config = f"{os_name}_{arch}" + + # If building on Linux x86 or Windows, use the "rbe_" flags otherwise use + # the "ci_" (non-rbe) flags + if (os_name == "linux" and arch == "x86_64") or ( + os_name == "windows" and arch == "amd64" + ): + bazelrc_config = "rbe_" + bazelrc_config + else: + bazelrc_config = "ci_" + bazelrc_config + + # When building jax-cuda-plugin or jax-cuda-pjrt, append "_cuda" to the + # bazelrc config to use the CUDA specific configs. + if "cuda" in artifact: + bazelrc_config = bazelrc_config + "_cuda" + + return bazelrc_config + + def get_jax_configure_bazel_options(bazel_command: list[str]): """Returns the bazel options to be written to .jax_configure.bazelrc.""" # Get the index of the "run" parameter. Build options will come after "run" so @@ -219,3 +256,4 @@ def get_githash(): ).stdout.strip() except OSError: return "" + diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index a0d3a63d1338..bc043d43cc48 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -13,10 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# Build JAX artifacts. -# Usage: ./ci/build_artifacts.sh "" -# Supported artifact values are: jax, jaxlib, jax-cuda-plugin, jax-cuda-pjrt -# E.g: ./ci/build_artifacts.sh "jax" or ./ci/build_artifacts.sh "jaxlib" +# Build JAX artifacts. Requires an env file from the ci/envs/build_artifacts to +# be passed as an argument # # -e: abort script if one command fails # -u: error if undefined variable used @@ -25,65 +23,43 @@ # -o allexport: export all functions and variables to be available to subscripts set -exu -o history -o allexport -artifact="$1" - -# Source default JAXCI environment variables. -source ci/envs/default.env +# If a JAX CI env file has not been passed, exit. +if [[ -z "$1" ]]; then + echo "ERROR: No JAX CI env file passed." + echo "build_artifacts.sh requires that a path to a JAX CI env file to be" + echo "passed as an argument when invoking the build scripts." + echo "Pass in a corresponding env file from the ci/envs/build_artifacts" + echo "directory to continue." + exit 1 +fi +# Source JAXCI environment variables. +source "$1" # Set up the build environment. source "ci/utilities/setup_build_environment.sh" -allowed_artifacts=("jax" "jaxlib" "jax-cuda-plugin" "jax-cuda-pjrt") - -os=$(uname -s | awk '{print tolower($0)}') -arch=$(uname -m) - -# Adjust the values when running on Windows x86 to match the config in -# .bazelrc -if [[ $os =~ "msys_nt" ]] && [[ $arch == "x86_64" ]]; then - os="windows" - arch="amd64" +# Build the jax artifact +if [[ "$JAXCI_BUILD_JAX" == 1 ]]; then + python -m build --outdir $JAXCI_OUTPUT_DIR fi -if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then - - # Build the jax artifact - if [[ "$artifact" == "jax" ]]; then - python -m build --outdir $JAXCI_OUTPUT_DIR - else - - # For bazel builds, use the "rbe_" config for Linux x86/Windows and "ci_" for other platforms - bazelrc_config="${os}_${arch}" - if ( [[ "$os" == "linux" ]] && [[ "$arch" == "x86_64" ]] ) || [[ "$os" == "windows" ]]; then - bazelrc_config="rbe_$bazelrc_config" - else - bazelrc_config="ci_$bazelrc_config" - fi - - # Build the jaxlib CPU artifact - if [[ "$artifact" == "jaxlib" ]]; then - python build/build.py build --wheels="jaxlib" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose - fi - - # Build the jax-cuda-plugin artifact - if [[ "$artifact" == "jax-cuda-plugin" ]]; then - python build/build.py build --wheels="jax-cuda-plugin" --bazel_options=--config="${bazelrc_config}_cuda" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose - fi - - # Build the jax-cuda-pjrt artifact - if [[ "$artifact" == "jax-cuda-pjrt" ]]; then - python build/build.py build --wheels="jax-cuda-pjrt" --bazel_options=--config="${bazelrc_config}_cuda" --verbose - fi +# Build the jaxlib CPU artifact +if [[ "$JAXCI_BUILD_JAXLIB" == 1 ]]; then + python build/build.py build_artifacts --wheel_list="jaxlib" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose +fi - # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we - # run `auditwheel show` to verify manylinux compliance. - if [[ "$os" == "linux" ]]; then - ./ci/utilities/run_auditwheel.sh - fi +# Build the jax-cuda-plugin artifact +if [[ "$JAXCI_BUILD_PLUGIN" == 1 ]]; then + python build/build.py build_artifacts --wheel_list="jax-cuda-plugin" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose +fi - fi +# Build the jax-cuda-pjrt artifact +if [[ "$JAXCI_BUILD_PJRT" == 1 ]]; then + python build/build.py build_artifacts --wheel_list="jax-cuda-pjrt" --use_ci_bazelrc_flags --verbose +fi -else - echo "Error: Invalid artifact: $artifact. Allowed values are: ${allowed_artifacts[@]}" - exit 1 -fi \ No newline at end of file +# After building `jaxlib`, `jaxcuda-plugin`, and `jax-cuda-pjrt`, we run +# `auditwheel show` to ensure manylinux compliance. +if [[ "$JAXCI_RUN_AUDITWHEEL" == 1 ]]; then + ./ci/utilities/run_auditwheel.sh +fi diff --git a/ci/envs/build_artifacts/jax-cuda-pjrt.env b/ci/envs/build_artifacts/jax-cuda-pjrt.env new file mode 100644 index 000000000000..e515e802e071 --- /dev/null +++ b/ci/envs/build_artifacts/jax-cuda-pjrt.env @@ -0,0 +1,22 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default.env + +# Enable jax-cuda-pjrt build. +export JAXCI_BUILD_PJRT="1" + +# Enable wheel audit to check for manylinux compliance. +export JAXCI_RUN_AUDITWHEEL="1" \ No newline at end of file diff --git a/ci/envs/build_artifacts/jax-cuda-plugin.env b/ci/envs/build_artifacts/jax-cuda-plugin.env new file mode 100644 index 000000000000..46a606fbdf65 --- /dev/null +++ b/ci/envs/build_artifacts/jax-cuda-plugin.env @@ -0,0 +1,22 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default.env + +# Enable jax-cuda-plugin build +export JAXCI_BUILD_PLUGIN="1" + +# Enable wheel audit to check for manylinux compliance. +export JAXCI_RUN_AUDITWHEEL="1" \ No newline at end of file diff --git a/ci/envs/build_artifacts/jax.env b/ci/envs/build_artifacts/jax.env new file mode 100644 index 000000000000..ff24fff01094 --- /dev/null +++ b/ci/envs/build_artifacts/jax.env @@ -0,0 +1,19 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default.env + +# Build JAX artifact. +export JAXCI_BUILD_JAX="1" \ No newline at end of file diff --git a/ci/envs/build_artifacts/jaxlib.env b/ci/envs/build_artifacts/jaxlib.env new file mode 100644 index 000000000000..c53d86b39d72 --- /dev/null +++ b/ci/envs/build_artifacts/jaxlib.env @@ -0,0 +1,27 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default.env + +# Enable jaxlib build. +export JAXCI_BUILD_JAXLIB="1" + +os=$(uname -s | awk '{print tolower($0)}') + +# Enable wheel audit for Linux builds to check for manylinux compliance. +if [[ $os == "linux" ]]; then + export JAXCI_RUN_AUDITWHEEL="1" +fi + diff --git a/ci/envs/default.env b/ci/envs/default.env index d6514a132c0e..07d803181553 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -36,11 +36,22 @@ export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} # Controls the location where the artifacts are stored. export JAXCI_OUTPUT_DIR="$(pwd)/dist" +# ############################################################################# +# Artifact build specific environment variables. Used in build_artifacts.sh +# ############################################################################# +# Environment variables that control which artifact to build. +export JAXCI_BUILD_JAX=0 +export JAXCI_BUILD_JAXLIB=0 +export JAXCI_BUILD_PLUGIN=0 +export JAXCI_BUILD_PJRT=0 +export JAXCI_RUN_AUDITWHEEL=0 + # ############################################################################# # Docker specific environment variables. # ############################################################################# # Docker specifc environment variables. Used by `run_docker_container.sh` +export JAXCI_RUN_DOCKER_CONTAINER=${JAXCI_RUN_DOCKER_CONTAINER:-1} export JAXCI_DOCKER_WORK_DIR="/jax" export JAXCI_DOCKER_IMAGE="" export JAXCI_DOCKER_ARGS="" diff --git a/ci/envs/docker.env b/ci/envs/docker.env index 77376e4c6578..3d2449d619c4 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -21,6 +21,9 @@ source ci/envs/default.env os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) +export JAXCI_DOCKER_ARGS="" +export JAXCI_DOCKER_IMAGE="" + # TODO: Set GPU Docker args and GPU Docker images # Linux x86 specifc settings if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then diff --git a/ci/run_bazel_test_gpu_rbe.sh b/ci/run_bazel_test_gpu_rbe.sh index 6b51598146bb..4e7844fa11be 100755 --- a/ci/run_bazel_test_gpu_rbe.sh +++ b/ci/run_bazel_test_gpu_rbe.sh @@ -34,10 +34,12 @@ fi # Set up the build environment. source "ci/utilities/setup_build_environment.sh" -# Run Bazel GPU tests with RBE (single accelerator tests with one GPU apiece.) +# Run Bazel GPU tests with RBE (single accelerator tests) +nvidia-smi echo "Running RBE GPU tests..." # Only Linux x86 builds run GPU tests +# Runs single accelerator tests with one GPU apiece. bazel test --config=rbe_linux_x86_64_cuda \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh index 2173271ad806..5f80e3e1093f 100755 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -26,8 +26,8 @@ set -exu -o history -o allexport # Inherit default JAXCI environment variables. source ci/envs/default.env -echo "Installing wheels locally..." -source ./ci/utilities/install_wheels_locally.sh +# Install jaxlib wheel on the system. +export JAXCI_INSTALL_WHEELS_LOCALLY=1 # Set up the build environment. source "ci/utilities/setup_build_environment.sh" diff --git a/ci/run_pytest_gpu.sh b/ci/run_pytest_gpu.sh index 0e2bb5f55b84..78182af130d6 100755 --- a/ci/run_pytest_gpu.sh +++ b/ci/run_pytest_gpu.sh @@ -27,12 +27,33 @@ set -exu -o history -o allexport source ci/envs/default.env # Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels on the system. + +export JAXCI_INSTALL_WHEELS_LOCALLY=1 + echo "Installing wheels locally..." source ./ci/utilities/install_wheels_locally.sh # Set up the build environment. source "ci/utilities/setup_build_environment.sh" + +# # Install cuda deps into the current python +# "$JAXCI_PYTHON" -m pip install "nvidia-cublas-cu12>=12.1.3.1" +# "$JAXCI_PYTHON" -m pip install "nvidia-cuda-cupti-cu12>=12.1.105" +# "$JAXCI_PYTHON" -m pip install "nvidia-cuda-nvcc-cu12>=12.1.105" +# "$JAXCI_PYTHON" -m pip install "nvidia-cuda-runtime-cu12>=12.1.105" +# "$JAXCI_PYTHON" -m pip install "nvidia-cudnn-cu12>=9.1,<10.0" +# "$JAXCI_PYTHON" -m pip install "nvidia-cufft-cu12>=11.0.2.54" +# "$JAXCI_PYTHON" -m pip install "nvidia-cusolver-cu12>=11.4.5.107" +# "$JAXCI_PYTHON" -m pip install "nvidia-cusparse-cu12>=12.1.0.106" +# "$JAXCI_PYTHON" -m pip install "nvidia-nccl-cu12>=2.18.1" +# "$JAXCI_PYTHON" -m pip install "nvidia-nvjitlink-cu12>=12.1.105" +# "$JAXCI_PYTHON" -m pip install "tensorrt" + +# "$JAXCI_PYTHON" -m pip install "tensorrt-lean" +# "$JAXCI_PYTHON" -m pip install "tensorrt-dispatch" + + export PY_COLORS=1 export JAX_SKIP_SLOW_TESTS=true @@ -45,9 +66,10 @@ export TF_CPP_MIN_LOG_LEVEL=0 echo "Running GPU tests..." export XLA_PYTHON_CLIENT_ALLOCATOR=platform export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 -"$JAXCI_PYTHON" -m pytest -n 8 --tb=short --maxfail=20 \ +"$JAXCI_PYTHON" -m pytest -n 8 --tb=short --maxfail=200 \ tests examples \ --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ --deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \ --deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \ ---deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric \ No newline at end of file +--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric \ +--deselect=tests/tests/sparse_nm_test.py::SpmmTest \ No newline at end of file diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index 92a4baa0fbb1..93a5cdac5360 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -28,8 +28,7 @@ source ci/envs/default.env # Install jaxlib wheel on the system. Requires a jaxlib wheel to be present # inside $JAXCI_OUTPUT_DIR (../dist) -echo "Installing wheels locally..." -source ./ci/utilities/install_wheels_locally.sh +export JAXCI_INSTALL_WHEELS_LOCALLY=1 # Set up the build environment. source "ci/utilities/setup_build_environment.sh" diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh index 30b6a3b51865..4e8d007885ca 100755 --- a/ci/utilities/run_auditwheel.sh +++ b/ci/utilities/run_auditwheel.sh @@ -14,11 +14,11 @@ # limitations under the License. # ============================================================================== # -# Runs auditwheel to verify manylinux compatibility. +# Runs auditwheel to ensure manylinux compatibility. # Get a list of all the wheels in the output directory. Only look for wheels # that need to be verified for manylinux compliance. -WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*whl" -o -name "*jax*cuda*pjrt*whl" -o -name "*jax*cuda*plugin*whl" \)) +WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) if [[ -z "$WHEELS" ]]; then echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" diff --git a/ci/utilities/run_docker_container.sh b/ci/utilities/run_docker_container.sh index 1cc3199bd5fd..c749294c1d77 100755 --- a/ci/utilities/run_docker_container.sh +++ b/ci/utilities/run_docker_container.sh @@ -49,8 +49,8 @@ if ! docker container inspect jax >/dev/null 2>&1 ; then # Start the container. `user_set_jaxci_envs` is read after `jax_ci_envs` to # allow the user to override any environment variables set by JAXCI_ENV_FILE. docker run $JAXCI_DOCKER_ARGS --name jax \ - -w "$JAXCI_DOCKER_WORK_DIR" -itd --rm \ - -v "$JAXCI_JAX_GIT_DIR:$JAXCI_DOCKER_WORK_DIR" \ + -w /jax -itd --rm \ + -v "$JAXCI_JAX_GIT_DIR:/jax" \ "$JAXCI_DOCKER_IMAGE" \ bash diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index 93e123f223a9..70b238431e4e 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -72,4 +72,10 @@ if [[ $(uname -s) =~ "MSYS_NT" ]]; then echo 'Converting MSYS Linux-like paths to Windows paths (for Docker, Python, etc.)' # Convert all "_DIR" variables to Windows paths. source <(python3 ./ci/utilities/convert_msys_paths_to_win_paths.py) +fi + +# When running Pytests, we need to install the wheels locally. +if [[ "$JAXCI_INSTALL_WHEELS_LOCALLY" == 1 ]]; then + echo "Installing wheels locally..." + source ./ci/utilities/install_wheels_locally.sh fi \ No newline at end of file diff --git a/xla b/xla new file mode 160000 index 000000000000..22004eb92ca2 --- /dev/null +++ b/xla @@ -0,0 +1 @@ +Subproject commit 22004eb92ca2d4e6132749e351fa22b87dcaae5a