diff --git a/.bazelrc b/.bazelrc deleted file mode 100644 index 458ce69fae8b..000000000000 --- a/.bazelrc +++ /dev/null @@ -1,328 +0,0 @@ -############################################################################ -# All default build options below. - -# Required by OpenXLA -# https://github.com/openxla/xla/issues/1323 -build --nocheck_visibility - -# Sets the default Apple platform to macOS. -build --apple_platform_type=macos -build --macos_minimum_os=10.14 - -# Make Bazel print out all options from rc files. -build --announce_rc - -build --define open_source_build=true - -build --spawn_strategy=standalone - -build --enable_platform_specific_config - -build --experimental_cc_shared_library - -# Disable enabled-by-default TensorFlow features that we don't care about. -build --define=no_aws_support=true -build --define=no_gcp_support=true -build --define=no_hdfs_support=true -build --define=no_kafka_support=true -build --define=no_ignite_support=true - -build --define=grpc_no_ares=true - -build --define=tsl_link_protobuf=true - -build -c opt - -build --config=short_logs - -build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. - -########################################################################### - -build:posix --copt=-fvisibility=hidden -build:posix --copt=-Wno-sign-compare -build:posix --cxxopt=-std=c++17 -build:posix --host_cxxopt=-std=c++17 - -build:avx_posix --copt=-mavx -build:avx_posix --host_copt=-mavx - -build:avx_windows --copt=/arch=AVX - -build:avx_linux --copt=-mavx -build:avx_linux --host_copt=-mavx - -build:native_arch_posix --copt=-march=native -build:native_arch_posix --host_copt=-march=native - -build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 - -build:clang --action_env=CC="/usr/lib/llvm-18/bin/clang" -# Disable clang extention that rejects type definitions within offsetof. -# This was added in clang-16 by https://reviews.llvm.org/D133574. -# Can be removed once upb is updated, since a type definition is used within -# offset of in the current version of ubp. -# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. -build:clang --copt=-Wno-gnu-offsetof-extensions -# Disable clang extention that rejects unknown arguments. -build:clang --copt=-Qunused-arguments - -build:cuda --repo_env TF_NEED_CUDA=1 -build:cuda --repo_env TF_NCCL_USE_STUB=1 -# "sm" means we emit only cubin, which is forward compatible within a GPU generation. -# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. -build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" -build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain -build:cuda --@local_config_cuda//:enable_cuda -build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true -# Default hermetic CUDA and CUDNN versions. -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" -# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, -# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to -# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA -# packages. -# This has pros and cons: -# * pro: we'll ignore other CUDA installations, which has frequently confused -# users in the past. By setting RPATH, we'll always use the NVIDIA pip -# packages if they are installed. -# * con: the user cannot override the CUDA installation location -# via LD_LIBRARY_PATH, if the nvidia-... pip packages are installed. This is -# acceptable, because the workaround is "remove the nvidia-..." pip packages. -# The list of CUDA pip packages that JAX depends on are present in setup.py. -build:cuda --linkopt=-Wl,--disable-new-dtags -build:cuda --@local_config_cuda//:cuda_compiler=clang -build:cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" - -# This flag is needed to include CUDA libraries for bazel tests. -test:cuda --@local_config_cuda//cuda:include_cuda_libs=true - -# Build with NVCC for CUDA -build:cuda_nvcc --config=cuda -build:cuda_nvcc --config=clang -build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc -build:cuda_nvcc --action_env=TF_NVCC_CLANG="1" -build:cuda_nvcc --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" - -build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain -build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true -build:rocm --repo_env TF_NEED_ROCM=1 -build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030" - -build:nonccl --define=no_nccl_support=true - -# Requires MSVC and LLVM to be installed -build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl -build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl -build:win_clang --compiler=clang-cl - -# Windows has a relatively short command line limit, which JAX has begun to hit. -# See https://docs.bazel.build/versions/main/windows.html -build:windows --features=compiler_param_file -build:windows --features=archive_param_file - -# Tensorflow uses M_* math constants that only get defined by MSVC headers if -# _USE_MATH_DEFINES is defined. -build:windows --copt=/D_USE_MATH_DEFINES -build:windows --host_copt=/D_USE_MATH_DEFINES -# Make sure to include as little of windows.h as possible -build:windows --copt=-DWIN32_LEAN_AND_MEAN -build:windows --host_copt=-DWIN32_LEAN_AND_MEAN -build:windows --copt=-DNOGDI -build:windows --host_copt=-DNOGDI -# https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/ -# otherwise, there will be some compiling error due to preprocessing. -build:windows --copt=/Zc:preprocessor -build:windows --cxxopt=/std:c++17 -build:windows --host_cxxopt=/std:c++17 -# Generate PDB files, to generate useful PDBs, in opt compilation_mode -# --copt /Z7 is needed. -build:windows --linkopt=/DEBUG -build:windows --host_linkopt=/DEBUG -build:windows --linkopt=/OPT:REF -build:windows --host_linkopt=/OPT:REF -build:windows --linkopt=/OPT:ICF -build:windows --host_linkopt=/OPT:ICF -build:windows --incompatible_strict_action_env=true - -build:linux --config=posix -build:linux --copt=-Wno-unknown-warning-option -# Workaround for gcc 10+ warnings related to upb. -# See https://github.com/tensorflow/tensorflow/issues/39467 -build:linux --copt=-Wno-stringop-truncation -build:linux --copt=-Wno-array-parameter - -build:macos --config=posix - -# Public cache for macOS builds. The "oct2023" in the URL is just the -# date when the bucket was created and can be disregarded. It still contains the -# latest cache that is being used. -build:macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false -# Cache pushes are limited to Jax's CI system. -build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials - -# Suppress all warning messages. -build:short_logs --output_filter=DONT_MATCH_ANYTHING - -######################################################################### -# RBE config options below. -# Flag to enable remote config -common --experimental_repo_remote_exec - -build:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 -build:rbe --google_default_credentials -build:rbe --bes_backend=buildeventservice.googleapis.com -build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" -build:rbe --bes_timeout=600s -build:rbe --define=EXECUTOR=remote -build:rbe --flaky_test_attempts=3 -build:rbe --jobs=200 -build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com -build:rbe --remote_timeout=3600 -build:rbe --spawn_strategy=remote,worker,standalone,local -test:rbe --test_env=USER=anon -# Attempt to minimize the amount of data transfer between bazel and the remote -# workers: -build:rbe --remote_download_toplevel - -build:rbe_linux --config=rbe -build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" -build:rbe_linux --host_javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 -build:rbe_linux --javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 -build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 -build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 - -# Non-rbe settings we should include because we do not run configure -build:rbe_linux --config=avx_linux -build:rbe_linux --linkopt=-lrt -build:rbe_linux --host_linkopt=-lrt -build:rbe_linux --linkopt=-lm -build:rbe_linux --host_linkopt=-lm - -# Use the GPU toolchain until the CPU one is ready. -# https://github.com/bazelbuild/bazel/issues/13623 -build:rbe_cpu_linux_base --config=rbe_linux -build:rbe_cpu_linux_base --config=clang -build:rbe_cpu_linux_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain" -build:rbe_cpu_linux_base --crosstool_top="@local_config_cuda//crosstool:toolchain" -build:rbe_cpu_linux_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_cpu_linux_base --repo_env=TF_SYSROOT="/dt9" -build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" - -build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base -build:rbe_cpu_linux_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" -build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base -build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" -build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base -build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" -build:rbe_cpu_linux_py3.13 --config=rbe_cpu_linux_base -build:rbe_cpu_linux_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13" - -build:rbe_linux_cuda_base --config=rbe_linux -build:rbe_linux_cuda_base --config=cuda -build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 - -build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda12.3_nvcc_base --config=cuda_nvcc -build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" -build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@local_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_SYSROOT="/dt9" -build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base -build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" -build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base -build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" -build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base -build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" -build:rbe_linux_cuda12.3_nvcc_py3.13 --config=rbe_linux_cuda12.3_nvcc_base -build:rbe_linux_cuda12.3_nvcc_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13" - -# These you may need to change for your own GCP project. -build:tensorflow_testing_rbe --project_id=tensorflow-testing -common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance -build:tensorflow_testing_rbe_linux --config=tensorflow_testing_rbe - -# START CROSS-COMPILE CONFIGS - -# Set execution platform to Linux x86 -# Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" -# flags seem to be actually used to specify the execution platform details. It -# seems it is this way because these flags are old and predate the distinction -# between host and execution platform. -build:cross_compile_base --host_cpu=k8 -build:cross_compile_base --host_crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -build:cross_compile_base --extra_execution_platforms=@xla//tools/toolchains/cross_compile/config:linux_x86_64 - -# START LINUX AARCH64 CROSS-COMPILE CONFIGS -build:cross_compile_linux_arm64 --config=cross_compile_base - -# Set the target CPU to Aarch64 -build:cross_compile_linux_arm64 --platforms=@xla//tools/toolchains/cross_compile/config:linux_aarch64 -build:cross_compile_linux_arm64 --cpu=aarch64 -build:cross_compile_linux_arm64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite - -build:rbe_cross_compile_base --config=rbe - -# RBE cross-compile configs for Linux Aarch64 -build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 -build:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base -# END LINUX AARCH64 CROSS-COMPILE CONFIGS - -# START MACOS CROSS-COMPILE CONFIGS -build:cross_compile_macos_x86 --config=cross_compile_base -build:cross_compile_macos_x86 --config=nonccl -# Target Catalina (10.15) as the minimum supported OS -build:cross_compile_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 - -# Set the target CPU to Darwin x86 -build:cross_compile_macos_x86 --platforms=@xla//tools/toolchains/cross_compile/config:darwin_x86_64 -build:cross_compile_macos_x86 --cpu=darwin -build:cross_compile_macos_x86 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -# When RBE cross-compiling for macOS, we need to explicitly register the -# toolchain. Otherwise, oddly, RBE complains that a "docker container must be -# specified". -build:cross_compile_macos_x86 --extra_toolchains=@xla//tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain -# Map --platforms=darwin_x86_64 to --cpu=darwin and vice-versa to make selects() -# and transistions that use these flags work. The flag --platform_mappings needs -# to be set to a file that exists relative to the package path roots. -build:cross_compile_macos_x86 --platform_mappings=platform_mappings - -# RBE cross-compile configs for Darwin x86 -build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 -build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base -# END MACOS CROSS-COMPILE CONFIGS - -# END CROSS-COMPILE CONFIGS - -############################################################################# - -############################################################################# -# Some configs to make getting some forms of debug builds. In general, the -# codebase is only regularly built with optimizations. Use 'debug_symbols' to -# just get symbols for the parts of XLA/PJRT that jaxlib uses. -# Or try 'debug' to get a build with assertions enabled and minimal -# optimizations. -# Include these in a local .bazelrc.user file as: -# build --config=debug_symbols -# Or: -# build --config=debug -# -# Additional files can be opted in for debug symbols by adding patterns -# to a per_file_copt similar to below. -############################################################################# - -build:debug_symbols --strip=never --per_file_copt="xla/pjrt|xla/python@-g3" -build:debug --config debug_symbols -c fastbuild - -# Load `.jax_configure.bazelrc` file written by build.py -try-import %workspace%/.jax_configure.bazelrc - -# Load rc file with user-specific options. -try-import %workspace%/.bazelrc.user diff --git a/.github/workflows/bazel_cpu.yml b/.github/workflows/bazel_cpu.yml new file mode 100644 index 000000000000..9614fbd6e81a --- /dev/null +++ b/.github/workflows/bazel_cpu.yml @@ -0,0 +1,31 @@ +name: Run Bazel CPU tests with RBE + +on: + pull_request: + branches: + - main + +jobs: + build: + continue-on-error: true + strategy: + matrix: + runner: ["linux-x86-n2-64", "linux-arm64-t2a-48"] + + runs-on: ${{ matrix.runner }} + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') }} + + env: + # Disable running `run_docker_container.sh`. JAX self-hosted runners already run in a Docker + # container. + JAXCI_RUN_DOCKER_CONTAINER: 0 + JAXCI_HERMETIC_PYTHON_VERSION: "3.12" + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: ./actions/ci_connection/ + - name: Run Bazel CPU Tests + run: ./ci/run_bazel_test.sh "ci/envs/run_tests/bazel_cpu" diff --git a/.github/workflows/bazel_gpu_local.yml b/.github/workflows/bazel_gpu_local.yml new file mode 100644 index 000000000000..c93c6203ee3a --- /dev/null +++ b/.github/workflows/bazel_gpu_local.yml @@ -0,0 +1,41 @@ +name: Run Bazel GPU tests locally + +on: + pull_request: + branches: + - main + +jobs: + build: + strategy: + matrix: + runner: ["mike-x86-g2-48-l4-4gpu"] + + runs-on: ${{ matrix.runner }} + container: + image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + + env: + # GitHub actions run in Docker by defaut. Disable running the `setup_docker.sh` script. + JAXCI_RUN_DOCKER_CONTAINER: 0 + # Use RBE to build the artifacts. + JAXCI_BUILD_ARTIFACT_WITH_RBE: 1 + # Setup the test environment (disable x64 mode and clone XLA at HEAD) + JAXCI_SETUP_TEST_ENVIRONMENT: 1 + JAXCI_HERMETIC_PYTHON_VERSION: 3.11 + + steps: + - uses: actions/checkout@v3 + - name: Build jaxlib + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib" + - name: Build jax-cuda-plugin + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-plugin" + - name: Build jax-cuda-pjrt + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-pjrt" + # Halt for testing + - name: Wait For Connection + uses: ./actions/ci_connection/ + - name: Run Bazel GPU tests locally + run: ./ci/run_bazel_test.sh "ci/envs/run_tests/bazel_gpu_local" + - name: Test if step runs + run: echo "This step ran!" diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml new file mode 100644 index 000000000000..ada95af1b58e --- /dev/null +++ b/.github/workflows/bazel_gpu_rbe.yml @@ -0,0 +1,29 @@ +name: Run Bazel GPU tests using RBE + +on: + pull_request: + branches: + - main + +jobs: + build: + strategy: + matrix: + runner: ["linux-x86-g2-48-l4-4gpu"] + + runs-on: ${{ matrix.runner }} + container: + image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + + env: + # Do not run Docker container for Linux runners. Linux runners already run in a Docker container. + JAXCI_RUN_DOCKER_CONTAINER: 0 + JAXCI_HERMETIC_PYTHON_VERSION: 3.12 + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: ./actions/ci_connection/ + - name: Run Bazel GPU tests using RBE + run: ./ci/run_bazel_test.sh "ci/envs/run_tests/bazel_gpu_rbe" diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml new file mode 100644 index 000000000000..f19b96c1c07b --- /dev/null +++ b/.github/workflows/build_artifacts.yml @@ -0,0 +1,66 @@ +name: Build JAX Artifacts + +# on: +# pull_request: +# branches: +# - main +# workflow_dispatch: + +jobs: + build: + defaults: + run: + # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. + shell: bash + strategy: + matrix: + runner: ["windows-x86-n2-64", "linux-x86-n2-64", "linux-arm64-t2a-48"] + artifact: ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] + python: ["3.10", "3.11", "3.12"] + # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each + # Python version. + exclude: + # Pure Python packages do not need to be built for each Python version. + - artifact: "jax-cuda-pjrt" + python: "3.10" + - artifact: "jax-cuda-pjrt" + python: "3.11" + - artifact: "jax" + python: "3.10" + - artifact: "jax" + python: "3.11" + # jax is a pure Python package so it does not need to be built on multiple platforms. + - artifact: "jax" + runner: "windows-x86-n2-64" + - artifact: "jax" + runner: "linux-arm64-t2a-48" + # jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows. + - artifact: "jax-cuda-plugin" + runner: "windows-x86-n2-64" + - artifact: "jax-cuda-pjrt" + runner: "windows-x86-n2-64" + + runs-on: ${{ matrix.runner }} + + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.9') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:jax-') || + (contains(matrix.runner, 'windows-x86') && null) }} + + env: + # Do not run Docker container for Linux runners. Linux runners already run in a Docker container. + JAXCI_RUN_DOCKER_CONTAINER: 0 + # Use RBE to build the artifacts where possibl (Linux x86 and Windows). + JAXCI_BUILD_ARTIFACT_WITH_RBE: 1 + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: ./actions/ci_connection/ + - name: Build ${{ matrix.artifact }} + # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. + shell: bash + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" + ENV_FILE: "ci/envs/build_artifacts/${{ matrix.artifact }}" + run: ./ci/build_artifacts.sh \ No newline at end of file diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml new file mode 100644 index 000000000000..0b32f58e356c --- /dev/null +++ b/.github/workflows/pytest_cpu.yml @@ -0,0 +1,51 @@ +name: Run Pytest CPU tests + +on: + pull_request: + branches: + - main + +jobs: + build: + continue-on-error: true + defaults: + run: + # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. + shell: bash + strategy: + matrix: + runner: ["windows-x86-n2-64", "linux-x86-n2-64", "linux-arm64-t2a-48"] + python: ["3.10"] + + runs-on: ${{ matrix.runner }} + container: ${{ (contains(matrix.runner, 'linux-x86') && 'gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') || + (contains(matrix.runner, 'windows-x86') && null) }} + + env: + # Disable running `run_docker_container.sh`. JAX self-hosted runners already run in a Docker + # container. + JAXCI_RUN_DOCKER_CONTAINER: 0 + # Use RBE to build the artifacts where possible (Linux x86 and Windows). + JAXCI_BUILD_ARTIFACT_WITH_RBE: 1 + # Setup the test environment (disable x64 mode and clone XLA at HEAD) + JAXCI_SETUP_TEST_ENVIRONMENT: 1 + JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: ./actions/ci_connection/ + - name: Build jaxlib + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib" + - name: Install pytest + env: + JAXCI_PYTHON: python${{ matrix.python }} + run: $JAXCI_PYTHON -m pip install pytest + - name: Install dependencies + env: + JAXCI_PYTHON: python${{ matrix.python }} + run: $JAXCI_PYTHON -m pip install --upgrade numpy=="2.0.0" scipy=="1.13.1" wheel build pytest-xdist absl-py opt-einsum colorama portpicker matplotlib 'importlib_metadata>=4.6' hypothesis flatbuffers filelock + - name: Run Pytest CPU tests + run: ./ci/run_pytest.sh "ci/envs/run_tests/pytest_cpu" diff --git a/.github/workflows/pytest_gpu.yml b/.github/workflows/pytest_gpu.yml new file mode 100644 index 000000000000..58def8af275c --- /dev/null +++ b/.github/workflows/pytest_gpu.yml @@ -0,0 +1,48 @@ +name: Run Pytest GPU tests + +on: + pull_request: + branches: + - main + +jobs: + build: + strategy: + matrix: + runner: ["mike-x86-g2-48-l4-4gpu"] + python: ["3.10"] + + runs-on: ${{ matrix.runner }} + container: + image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + + env: + # GitHub actions run in Docker by defaut. Disable running the `setup_docker.sh` script. + JAXCI_RUN_DOCKER_CONTAINER: 0 + # Use RBE to build the artifacts. + JAXCI_BUILD_ARTIFACT_WITH_RBE: 1 + # Setup the test environment (disable x64 mode and clone XLA at HEAD) + JAXCI_SETUP_TEST_ENVIRONMENT: 1 + JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: ./actions/ci_connection/ + - name: Build jaxlib + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib" + - name: Build jax-cuda-plugin + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-plugin" + - name: Build jax-cuda-pjrt + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-pjrt" + - name: Install pytest + env: + JAXCI_PYTHON: python${{ matrix.python }} + run: $JAXCI_PYTHON -m pip install pytest + - name: Install dependencies + env: + JAXCI_PYTHON: python${{ matrix.python }} + run: $JAXCI_PYTHON -m pip install -r build/test-requirements.txt + - name: Run Pytest GPU tests + run: ./ci/run_pytest.sh "ci/envs/run_tests/pytest_gpu" diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml new file mode 100644 index 000000000000..94b0806b108b --- /dev/null +++ b/.github/workflows/pytest_tpu.yml @@ -0,0 +1,50 @@ +name: Run Pytest TPU tests + +on: + pull_request: + branches: + - main + +jobs: + build: + strategy: + matrix: + runner: ["linux-x86-ct5lp-224-8tpu"] + python: ["3.10"] + + runs-on: ${{ matrix.runner }} + container: + image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest" + + env: + # GitHub actions run in Docker by defaut. Disable running the `setup_docker.sh` script. + JAXCI_RUN_DOCKER_CONTAINER: 0 + # Use RBE to build the artifacts. + JAXCI_BUILD_ARTIFACT_WITH_RBE: 1 + # Setup the test environment (disable x64 mode and clone XLA at HEAD) + JAXCI_SETUP_TEST_ENVIRONMENT: 1 + JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: ./actions/ci_connection/ + - name: Build jaxlib + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib" + - name: Install pytest + env: + JAXCI_PYTHON: python${{ matrix.python }} + run: $JAXCI_PYTHON -m pip install pytest + - name: Install Test requirements + env: + JAXCI_PYTHON: python${{ matrix.python }} + run: | + $JAXCI_PYTHON -m pip install -r build/test-requirements.txt + $JAXCI_PYTHON -m pip install -r build/collect-profile-requirements.txt + - name: Install Libtpu + env: + JAXCI_PYTHON: python${{ matrix.python }} + run: $JAXCI_PYTHON -m pip install --pre libtpu-nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html" + - name: Run Pytest TPU tests + run: ./ci/run_pytest.sh "ci/envs/run_tests/pytest_tpu" diff --git a/actions/ci_connection/notify_connection.py b/actions/ci_connection/notify_connection.py index 52c7b58f43cf..db851341cc8e 100644 --- a/actions/ci_connection/notify_connection.py +++ b/actions/ci_connection/notify_connection.py @@ -14,6 +14,7 @@ import time import threading +import os import subprocess from multiprocessing.connection import Client @@ -42,6 +43,20 @@ def timer(conn): timer_thread.start() print("Entering interactive bash session") + + # Hard-coded for now for demo purposes. + next_command = "bash ci/build_artifacts.sh" + # Print the "next" commands to be run + # TODO: actually get this data from workflow files + print(f"The next command that would have run is:\n\n{next_command}") + + # Set the hardcoded envs for testing purposes + # TODO: sync env vars + sub_env = os.environ.copy() + sub_env["ENV_FILE"] = "ci/envs/build_artifacts/jaxlib" + sub_env["JAXCI_USE_DOCKER"] = "0" + sub_env["JAXCI_USE_RBE"] = "1" + # Enter interactive bash session subprocess.run(["bash", "-i"]) diff --git a/ci/.bazelrc b/ci/.bazelrc new file mode 100644 index 000000000000..ac7691dc1249 --- /dev/null +++ b/ci/.bazelrc @@ -0,0 +1,472 @@ +# ############################################################################# +# All default build options below. These apply to all build commands. +# ############################################################################# +# Make Bazel print out all options from rc files. +build --announce_rc + +# Required by OpenXLA +# https://github.com/openxla/xla/issues/1323 +build --nocheck_visibility + +# By default, execute all actions locally. +build --spawn_strategy=local + +# Enable host OS specific configs. For instance, "build:linux" will be used +# automatically when building on Linux. +build --enable_platform_specific_config + +build --experimental_cc_shared_library + +# Disable enabled-by-default TensorFlow features that we don't care about. +build --define=no_gcp_support=true + +# Do not use C-Ares when building gRPC. +build --define=grpc_no_ares=true + +build --define=tsl_link_protobuf=true + +# Enable optimization. +build -c opt + +# Suppress all warning messages. +build --output_filter=DONT_MATCH_ANYTHING + +build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. + +build --verbose_failures=true + +# ############################################################################# +# Platform Specific configs below. These are automatically picked up by Bazel +# depending on the platform that is running the build. If you would like to +# disable this behavior, pass in `--noenable_platform_specific_config` +# ############################################################################# +build:linux --config=posix +build:linux --copt=-Wno-unknown-warning-option + +# Workaround for gcc 10+ warnings related to upb. +# See https://github.com/tensorflow/tensorflow/issues/39467 +build:linux --copt=-Wno-stringop-truncation +build:linux --copt=-Wno-array-parameter + +build:macos --config=posix +build:macos --apple_platform_type=macos + +# Windows has a relatively short command line limit, which JAX has begun to hit. +# See https://docs.bazel.build/versions/main/windows.html +build:windows --features=compiler_param_file +build:windows --features=archive_param_file + +# Tensorflow uses M_* math constants that only get defined by MSVC headers if +# _USE_MATH_DEFINES is defined. +build:windows --copt=/D_USE_MATH_DEFINES +build:windows --host_copt=/D_USE_MATH_DEFINES +# Make sure to include as little of windows.h as possible +build:windows --copt=-DWIN32_LEAN_AND_MEAN +build:windows --host_copt=-DWIN32_LEAN_AND_MEAN +build:windows --copt=-DNOGDI +build:windows --host_copt=-DNOGDI +# https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/ +# otherwise, there will be some compiling error due to preprocessing. +build:windows --copt=/Zc:preprocessor +build:windows --cxxopt=/std:c++17 +build:windows --host_cxxopt=/std:c++17 +# Generate PDB files, to generate useful PDBs, in opt compilation_mode +# --copt /Z7 is needed. +build:windows --linkopt=/DEBUG +build:windows --host_linkopt=/DEBUG +build:windows --linkopt=/OPT:REF +build:windows --host_linkopt=/OPT:REF +build:windows --linkopt=/OPT:ICF +build:windows --host_linkopt=/OPT:ICF +build:windows --incompatible_strict_action_env=true + +# ############################################################################# +# Feature-specific configurations. These are used by the Local and CI configs +# below depending on the type of build. E.g. `local_linux_x86_64` inherits the +# Linux x86 configs such as `avx_linux` and `mkl_open_source_only`, +# `local_cuda_base` inherits `cuda` and `build_cuda_with_nvcc`, etc. +# ############################################################################# +build:nonccl --define=no_nccl_support=true + +build:posix --copt=-fvisibility=hidden +build:posix --copt=-Wno-sign-compare +build:posix --cxxopt=-std=c++17 +build:posix --host_cxxopt=-std=c++17 + +build:avx_posix --copt=-mavx +build:avx_posix --host_copt=-mavx + +build:native_arch_posix --copt=-march=native +build:native_arch_posix --host_copt=-march=native + +build:avx_linux --copt=-mavx +build:avx_linux --host_copt=-mavx + +build:avx_windows --copt=/arch:AVX + +build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 + +# Disable clang extention that rejects type definitions within offsetof. +# This was added in clang-16 by https://reviews.llvm.org/D133574. +# Can be removed once upb is updated, since a type definition is used within +# offset of in the current version of ubp. +# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. +build:clang --copt=-Wno-gnu-offsetof-extensions +# Disable clang extention that rejects unknown arguments. +build:clang --copt=-Qunused-arguments + +# Configs for CUDA +build:cuda --repo_env TF_NEED_CUDA=1 +build:cuda --repo_env TF_NCCL_USE_STUB=1 +# "sm" means we emit only cubin, which is forward compatible within a GPU generation. +# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" +build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain +build:cuda --@local_config_cuda//:enable_cuda +build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true + +# Default hermetic CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + +# This flag is needed to include CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_cuda_libs=true + +# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, +# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to +# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA +# packages. +# This has pros and cons: +# * pro: we'll ignore other CUDA installations, which has frequently confused +# users in the past. By setting RPATH, we'll always use the NVIDIA pip +# packages if they are installed. +# * con: the user cannot override the CUDA installation location +# via LD_LIBRARY_PATH, if the nvidia-... pip packages are installed. This is +# acceptable, because the workaround is "remove the nvidia-..." pip packages. +# The list of CUDA pip packages that JAX depends on are present in setup.py. +build:cuda --linkopt=-Wl,--disable-new-dtags + +# Build CUDA and other C++ targets with Clang +build:build_cuda_with_clang --@local_config_cuda//:cuda_compiler=clang + +# Build CUDA with NVCC and other C++ targets with Clang +build:build_cuda_with_nvcc --action_env=TF_NVCC_CLANG="1" +build:build_cuda_with_nvcc --@local_config_cuda//:cuda_compiler=nvcc + +# Requires MSVC and LLVM to be installed +build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl +build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl +build:win_clang --compiler=clang-cl + +# Configs for building ROCM +build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain +build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true +build:rocm --repo_env TF_NEED_ROCM=1 +build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030" + +# ############################################################################# +# Cache options below. +# ############################################################################# +# Public read-only cache for macOS builds. The "oct2023" in the URL is just the +# date when the bucket was created and can be disregarded. It still contains the +# latest cache that is being used. +build:macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false +# Cache pushes are limited to Jax's CI system. +build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials + +# ############################################################################# +# Local Build config options below. Use these configs to build JAX locally. +# ############################################################################# +# Set base CUDA configs. These are inherited by the Linux x86 and Linux Aarch64 +# CUDA configs. +build:local_cuda_base --config=cuda + +# JAX uses NVCC to build CUDA targets. If you would like to build CUDA targets +# with Clang, change this to `--config=build_cuda_with_clang` +build:local_cuda_base --config=build_cuda_with_nvcc + +# Linux x86 Local configs +build:local_linux_x86_64 --config=avx_linux +build:local_linux_x86_64 --config=avx_posix +build:local_linux_x86_64 --config=mkl_open_source_only + +build:local_linux_x86_64_cuda --config=local_linux_x86_64 +build:local_linux_x86_64_cuda --config=local_cuda_base + +# Linux Aarch64 Local configs +# No custom config for Linux Aarch64. If building for CPU, run +# `bazel build|test //path/to:target`. If building for CUDA, run +# `bazel build|test --config=local_cuda_base //path/to:target`. +build:local_linux_aarch64_cuda --config=local_cuda_base + +# Mac x86 Local configs +# For Mac x86, we target compatibility with macOS 10.14. +build:local_darwin_x86_64 --macos_minimum_os=10.14 +# Read-only cache to boost build times. +build:local_darwin_x86_64 --config=macos_cache + +# Mac Arm64 CI configs +# For Mac Arm64, we target compatibility with macOS 12. +build:local_darwin_arm64 --macos_minimum_os=12.0 +# Read-only cache to boost build times. +build:local_darwin_arm64 --config=macos_cache_push + +# Windows x86 Local configs +build:local_windows_amd64 --config=avx_windows + +# ############################################################################# +# CI Build config options below. +# JAX uses these configs in CI builds for building artifacts and when running +# Bazel tests. +# +# These configs are pretty much the same as the local build configs above. The +# difference is that, in CI, we build with Clang for and pass in a custom +# non-hermetic toolchain to ensure manylinux compliance for Linux builds and +# for using RBE on Windows. Because the toolchain is non-hermetic, it requires +# specific versions of the compiler and other tools to be present on the system +# in specific locations, which is why the Linux and Windows builds are run in a +# Docker container. +# ############################################################################# + +# Linux x86 CI configs +# Inherit the local Linux x86 configs. +build:ci_linux_x86_64 --config=local_linux_x86_64 + +# CI builds use Clang as the default compiler so we inherit Clang +# specific configs +build:ci_linux_x86_64 --config=clang + +# TODO(b/356695103): We do not have a CPU only toolchain so we use the CUDA +# toolchain for both CPU and GPU builds. +build:ci_linux_x86_64 --host_crosstool_top="@local_config_cuda//crosstool:toolchain" +build:ci_linux_x86_64 --crosstool_top="@local_config_cuda//crosstool:toolchain" +build:ci_linux_x86_64 --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" +build:ci_linux_x86_64 --repo_env=TF_SYSROOT="/dt9" + +# Clang path needs to be set for remote toolchain to be configured correctly. +build:ci_linux_x86_64 --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" + +# The toolchain in `--config=cuda` needs to be read before the toolchain in +# `--config=ci_linux_x86_64`. Otherwise, we run into issues with manylinux +# compliance. +build:ci_linux_x86_64_cuda --config=local_cuda_base +build:ci_linux_x86_64_cuda --config=ci_linux_x86_64 + +# Linux Aarch64 CI configs +build:ci_linux_aarch64_base --config=clang +build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10" + +build:ci_linux_aarch64 --config=ci_linux_aarch64_base +build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" +build:ci_linux_aarch64 --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" + +# CUDA configs for Linux Aarch64 do not pass in the crosstool top flag from +# above because the Aarch64 toolchain rule does not support building with NVCC. +# Instead, we use `@local_config_cuda//crosstool:toolchain` from --config=cuda +# and set `CLANG_CUDA_COMPILER_PATH` to define the toolchain so that we can +# use Clang for the C++ targets and NVCC to build CUDA targets. +build:ci_linux_aarch64_cuda --config=ci_linux_aarch64_base +build:ci_linux_aarch64_cuda --config=local_cuda_base +build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" + +# Mac x86 CI configs +build:ci_darwin_x86_64 --config=local_darwin_x86_64 +# Mac CI builds read and push cache to/from GCS bucket. +build:ci_darwin_x86_64 --config=macos_cache_push + +# Mac Arm64 CI configs +build:ci_darwin_arm64 --config=local_darwin_arm64 +# CI builds read and push cache to/from GCS bucket. +build:ci_darwin_arm64 --config=macos_cache_push + +# Windows x86 CI configs +build:ci_windows_amd64 --config=local_windows_amd64 +build:ci_windows_amd64 --config=clang +# Set the toolchains +build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" +build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" +build:ci_windows_amd64 --compiler=clang-cl +build:ci_windows_amd64 --linkopt=/FORCE:MULTIPLE +build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE + +# ############################################################################# +# RBE config options below. These inherit the CI configs above and set the +# remote execution backend and authentication options required to run builds +# with RBE. Linux x86 and Windows builds use RBE. +# ############################################################################# +# Flag to enable remote config +common --experimental_repo_remote_exec + +# Allow creation of resultstore URLs for any bazel invocation +build:resultstore --google_default_credentials +build:resultstore --bes_backend=buildeventservice.googleapis.com +build:resultstore --bes_instance_name="tensorflow-testing" +build:resultstore --bes_results_url="https://source.cloud.google.com/results/invocations" +build:resultstore --bes_timeout=600s + +build:rbe --config=resultstore +build:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 +build:rbe --define=EXECUTOR=remote +build:rbe --flaky_test_attempts=3 +build:rbe --jobs=200 +build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:rbe --remote_timeout=3600 +build:rbe --spawn_strategy=remote,worker,standalone,local +# Attempt to minimize the amount of data transfer between bazel and the remote +# workers: +build:rbe --remote_download_toplevel +test:rbe --test_env=USER=anon + +# RBE configs for Linux x86 +# Set the remote worker pool +common:rbe_linux_x86_64_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance + +build:rbe_linux_x86_64_base --config=rbe +build:rbe_linux_x86_64_base --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +build:rbe_linux_x86_64_base --linkopt=-lrt +build:rbe_linux_x86_64_base --host_linkopt=-lrt +build:rbe_linux_x86_64_base --linkopt=-lm +build:rbe_linux_x86_64_base --host_linkopt=-lm + +# Set the host, execution, and target platform +build:rbe_linux_x86_64_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +build:rbe_linux_x86_64_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +build:rbe_linux_x86_64_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" + +# Python config is the same across all containers because the binary is the same +build:rbe_linux_x86_64_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python" + +build:rbe_linux_x86_64 --config=rbe_linux_x86_64_base +build:rbe_linux_x86_64 --config=ci_linux_x86_64 + +build:rbe_linux_x86_64_cuda --config=rbe_linux_x86_64_base +build:rbe_linux_x86_64_cuda --config=ci_linux_x86_64_cuda +build:rbe_linux_x86_64_cuda --repo_env=REMOTE_GPU_TESTING=1 + +# RBE configs for Windows +# Set the remote worker pool +common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/instances/windows + +build:rbe_windows_amd64 --config=rbe + +# Set the host, execution, and target platform +build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" + +build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe +build:rbe_windows_amd64 --enable_runfiles +build:rbe_windows_amd64 --define=override_eigen_strong_inline=true + +# Don't build the python zip archive in the RBE build. +build:rbe_windows_amd64 --nobuild_python_zip + +build:rbe_windows_amd64 --config=ci_windows_amd64 + +# ############################################################################# +# Cross-compile config options below. Native RBE support does not exist for +# Linux Aarch64 and Mac x86. So, we use the cross-compile toolchain to build +# targets for Linux Aarch64 and Mac x86 on the Linux x86 RBE pool. +# ############################################################################# +# Set execution platform to Linux x86 +# Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" +# flags seem to be actually used to specify the execution platform details. It +# seems it is this way because these flags are old and predate the distinction +# between host and execution platform. +build:cross_compile_base --host_cpu=k8 +build:cross_compile_base --host_crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_base --extra_execution_platforms=@xla//tools/toolchains/cross_compile/config:linux_x86_64 + +# Linux Aarch64 +build:cross_compile_linux_aarch64 --config=cross_compile_base + +# Set the target CPU to Aarch64 +build:cross_compile_linux_aarch64 --platforms=@xla//tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_aarch64 --cpu=aarch64 +build:cross_compile_linux_aarch64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite + +build:rbe_cross_compile_base --config=rbe +build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance + +# RBE cross-compile configs for Linux Aarch64 +build:rbe_cross_compile_linux_aarch64 --config=cross_compile_linux_aarch64 +build:rbe_cross_compile_linux_aarch64 --config=rbe_cross_compile_base + +# Mac x86 +build:cross_compile_darwin_x86_64 --config=cross_compile_base +build:cross_compile_darwin_x86_64 --config=nonccl +# Target Catalina (10.15) as the minimum supported OS +build:cross_compile_darwin_x86_64 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + +# Set the target CPU to Darwin x86 +build:cross_compile_darwin_x86_64 --platforms=@xla//tools/toolchains/cross_compile/config:darwin_x86_64 +build:cross_compile_darwin_x86_64 --cpu=darwin +build:cross_compile_darwin_x86_64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +# When RBE cross-compiling for macOS, we need to explicitly register the +# toolchain. Otherwise, oddly, RBE complains that a "docker container must be +# specified". +build:cross_compile_darwin_x86_64 --extra_toolchains=@xla//tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain +# Map --platforms=darwin_x86_64 to --cpu=darwin and vice-versa to make selects() +# and transistions that use these flags work. The flag --platform_mappings needs +# to be set to a file that exists relative to the package path roots. +build:cross_compile_darwin_x86_64 --platform_mappings=platform_mappings + +# RBE cross-compile configs for Darwin x86 +build:rbe_cross_compile_darwin_x86_64 --config=cross_compile_darwin_x86_64 +build:rbe_cross_compile_darwin_x86_64 --config=rbe_cross_compile_base + +# ############################################################################# +# Test specific config options below. These are used when `bazel test` is run. +# ############################################################################# +test --test_output=errors + +# Common configs for for running GPU tests. +test:gpu --test_env=TF_CPP_MIN_LOG_LEVEL=0 --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform + +# Non-multiaccelerator tests with one GPU apiece. These tests are run on RBE +# and locally. +test:non_multiaccelerator --config=gpu +test:non_multiaccelerator --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow +test:non_multiaccelerator --test_tag_filters=-multiaccelerator + +# Configs for running non-multiaccelerator tests locally +test:non_multiaccelerator_local --config=non_multiaccelerator +# Disable building jaxlib. Instead we depend on the local wheel. +test:non_multiaccelerator_local --//jax:build_jaxlib=false + +# `JAX_ACCELERATOR_COUNT` needs to match the number of GPUs in the VM. +test:non_multiaccelerator_local --test_env=JAX_TESTS_PER_ACCELERATOR=12 --test_env=JAX_ACCELERATOR_COUNT=4 + +# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR` +# should match the VM's CPU core count (set in `--local_test_jobs`). +test:non_multiaccelerator_local --local_test_jobs=48 + +# Multiaccelerator tests with all GPUs. These tests are only run locally +# Disable building jaxlib. Instead we depend on the local wheel. +test:multiaccelerator_local --config=gpu +test:multiaccelerator_local --//jax:build_jaxlib=false +test:multiaccelerator_local --jobs=8 --test_tag_filters=multiaccelerator + +############################################################################# +# Some configs to make getting some forms of debug builds. In general, the +# codebase is only regularly built with optimizations. Use 'debug_symbols' to +# just get symbols for the parts of XLA/PJRT that jaxlib uses. +# Or try 'debug' to get a build with assertions enabled and minimal +# optimizations. +# Include these in a local .bazelrc.user file as: +# build --config=debug_symbols +# Or: +# build --config=debug +# +# Additional files can be opted in for debug symbols by adding patterns +# to a per_file_copt similar to below. +############################################################################# + +build:debug_symbols --strip=never --per_file_copt="xla/pjrt|xla/python@-g3" +build:debug --config debug_symbols -c fastbuild + +# Load `.jax_configure.bazelrc` file written by build.py +try-import %workspace%/.jax_configure.bazelrc + +# Load rc file with user-specific options. +try-import %workspace%/.bazelrc.user \ No newline at end of file diff --git a/ci/README.md b/ci/README.md new file mode 100644 index 000000000000..914b1e2a8283 --- /dev/null +++ b/ci/README.md @@ -0,0 +1,9 @@ +# JAX continuous integration. + +> **Warning** This folder is still under construction. It is part of an ongoing +> effort to improve the structure of CI and build related files within the +> JAX repo. This warning will be removed when the contents of this +> directory are stable and appropriate documentation around its usage is in +> place. + +******************************************************************************** diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh new file mode 100755 index 000000000000..13e991518e1e --- /dev/null +++ b/ci/build_artifacts.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Source JAXCI environment variables. +source "ci/utilities/setup_envs.sh" "$1" +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Build the jax artifact +if [[ "$JAXCI_BUILD_JAX" == 1 ]]; then + check_if_to_run_in_docker python3 -m build --outdir $JAXCI_OUTPUT_DIR +fi + +# Build the jaxlib CPU artifact +if [[ "$JAXCI_BUILD_JAXLIB" == 1 ]]; then + check_if_to_run_in_docker python3 ci/cli/build.py jaxlib --mode=$JAXCI_CLI_BUILD_MODE --python_version=$JAXCI_HERMETIC_PYTHON_VERSION +fi + +# Build the jax-cuda-plugin artifact +if [[ "$JAXCI_BUILD_PLUGIN" == 1 ]]; then + check_if_to_run_in_docker python3 ci/cli/build.py jax-cuda-plugin --mode=$JAXCI_CLI_BUILD_MODE --python_version=$JAXCI_HERMETIC_PYTHON_VERSION +fi + +# Build the jax-cuda-pjrt artifact + +if [[ "$JAXCI_BUILD_PJRT" == 1 ]]; then + check_if_to_run_in_docker python3 ci/cli/build.py jax-cuda-pjrt --mode=$JAXCI_CLI_BUILD_MODE +fi + +# After building `jaxlib`, `jaxcuda-plugin`, and `jax-cuda-pjrt`, we run +# `auditwheel show` to ensure manylinux compliance. +if [[ "$JAXCI_RUN_AUDITWHEEL" == 1 ]]; then + check_if_to_run_in_docker ./ci/utilities/run_auditwheel.sh +fi diff --git a/ci/cli/build.py b/ci/cli/build.py new file mode 100644 index 000000000000..7aac33c29099 --- /dev/null +++ b/ci/cli/build.py @@ -0,0 +1,709 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# CLI for building JAX artifacts. +import argparse +import asyncio +import logging +import os +import platform +import collections +import sys +import subprocess +from helpers import command, tools + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +BANNER = r""" + _ _ __ __ + | | / \ \ \/ / + _ | |/ _ \ \ / +| |_| / ___ \/ \ + \___/_/ \/_/\_\ + +""" + +EPILOG = """ +From the root directory of the JAX repository, run + python ci/cli/build.py [jaxlib | jax-cuda-plugin | jax-cuda-pjrt | jax-rocm-plugin | jax-rocm-pjrt] +or + python3 ci/cli/build.py [jaxlib | jax-cuda-plugin | jax-cuda-pjrt | jax-rocm-plugin | jax-rocm-pjrt] + +to build one of: jaxlib, jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, or jax-rocm-pjrt. +""" + +ArtifactBuildSpec = collections.namedtuple( + "ArtifactBuildSpec", + ["bazel_build_target", "wheel_binary"], +) + +# Define the build target and resulting wheel binary for each artifact. +ARTIFACT_BUILD_TARGET_DICT = { + "jaxlib": ArtifactBuildSpec("//jaxlib/tools:build_wheel", "bazel-bin/jaxlib/tools/build_wheel"), + "jax-cuda-plugin": ArtifactBuildSpec("//jaxlib/tools:build_gpu_kernels_wheel", "bazel-bin/jaxlib/tools/build_gpu_kernels_wheel"), + "jax-cuda-pjrt": ArtifactBuildSpec("//jaxlib/tools:build_gpu_plugin_wheel", "bazel-bin/jaxlib/tools/build_gpu_plugin_wheel"), + "jax-rocm-plugin": ArtifactBuildSpec("//jaxlib/tools:build_gpu_kernels_wheel", "bazel-bin/jaxlib/tools/build_gpu_kernels_wheel"), + "jax-rocm-pjrt": ArtifactBuildSpec("//jaxlib/tools:build_gpu_plugin_wheel", "bazel-bin/jaxlib/tools/build_gpu_plugin_wheel"), +} + +def get_bazelrc_config(os_name: str, arch: str, artifact: str, mode:str, use_rbe: bool): + """ + Returns the bazelrc config for the given architecture, OS, and build mode. + Args: + os_name: The name of the OS. + arch: The architecture of the host system. + artifact: The artifact to build. + mode: CLI build mode. + use_rbe: Whether to use RBE. + """ + + # When building ROCm packages, we only inherit `--config=rocm` from .bazelrc + if "rocm" in artifact: + logger.debug("Building ROCm package. Using --config=rocm.") + return "rocm" + + bazelrc_config = f"{os_name}_{arch}" + + # When the CLI is run by invoking ci/build_artifacts.sh, the CLI runs in CI + # mode and will use one of the "ci_" configs in the .bazelrc. We want to run + # certain CI builds with RBE and we also want to allow users the flexibility + # to build JAX artifacts either by running the CLI or by running + # ci/build_artifacts.sh. Because RBE requires permissions, we cannot enable it + # by default in ci/build_artifacts.sh. Instead, we have the CI builds set + # JAXCI_BUILD_ARTIFACT_WITH_RBE to 1 to enable RBE. + if os.environ.get("JAXCI_BUILD_ARTIFACT_WITH_RBE", "0") == "1": + use_rbe = True + + # In CI builds, we want to use RBE where possible. At the moment, RBE is only + # supported on Linux x86 and Windows. If an user is requesting RBE, the CLI + # will use RBE if the host system supports it, otherwise it will use the + # local config. + if use_rbe and ((os_name == "linux" and arch == "x86_64") \ + or (os_name == "windows" and arch == "amd64")): + bazelrc_config = "rbe_" + bazelrc_config + elif mode == "local": + # Show warning if RBE is requested on an unsupported platform. + if use_rbe: + logger.warning("RBE is not supported on %s_%s. Using Local config instead.", os_name, arch) + + # If building `jaxlib` on Linux Aarch64, we use the default configs. No + # custom local config is present in JAX's .bazelrc. + if os_name == "linux" and arch == "aarch64" and artifact == "jaxlib": + logger.debug("Linux Aarch64 CPU builds do not have custom local config in JAX's root .bazelrc. Running with default configs.") + bazelrc_config = "" + return bazelrc_config + + bazelrc_config = "local_" + bazelrc_config + else: + # Show warning if RBE is requested on an unsupported platform. + if use_rbe: + logger.warning("RBE is not supported on %s_%s. Using CI config instead.", os_name, arch) + + # Let user know that RBE is available for this platform. + if (os_name == "linux" and arch == "x86_64")or (os_name == "windows" and arch == "amd64"): + logger.info("RBE support is available for this platform. If you want to use RBE and have the required permissions, run the CLI with `--use_rbe` or set `JAXCI_BUILD_ARTIFACT_WITH_RBE=1`") + + bazelrc_config = "ci_" + bazelrc_config + + # When building jax-cuda-plugin or jax-cuda-pjrt, append "_cuda" to the + # bazelrc config to use the CUDA specific configs. + if artifact == "jax-cuda-plugin" or artifact == "jax-cuda-pjrt": + bazelrc_config = bazelrc_config + "_cuda" + + return bazelrc_config + +def get_jaxlib_git_hash(): + """Returns the git hash of the current repository.""" + res = subprocess.run( + ["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True + ) + return res.stdout + +# Set Clang as the C++ compiler if requested. CI builds use Clang by default +# via the toolchain used by the "ci_" configs in the .bazelrc. For Local builds, +# Bazel uses the default C++ compiler on the system which is GCC for Linux and +# MSVC for Windows. +def set_clang_as_compiler(bazel_command: command.CommandBuilder, clang_path: str): + """ + Sets Clang as the C++ compiler in the Bazel command. + Args: + bazel_command: An instance of command.CommandBuilder. + clang_path: The path to Clang. + """ + # Find the path to Clang. + absolute_clang_path = tools.get_clang_path(clang_path) + if absolute_clang_path: + logger.debug("Adding Clang as the C++ compiler to Bazel...") + bazel_command.append(f"--action_env CLANG_COMPILER_PATH='{absolute_clang_path}'") + bazel_command.append(f"--repo_env CC='{absolute_clang_path}'") + bazel_command.append(f"--repo_env BAZEL_COMPILER='{absolute_clang_path}'") + # Inherit Clang specific settings from the .bazelrc + bazel_command.append("--config=clang") + else: + logger.debug("Could not find path to Clang. Continuing without Clang.") + +def adjust_paths_for_windows(output_dir: str, arch: str) -> tuple[str, str, str]: + """ + Adjusts the paths to be compatible with Windows. + Args: + output_dir: The output directory for the wheel. + arch: The architecture of the host system. + Returns: + A tuple of the adjusted paths. + """ + logger.debug("Adjusting paths for Windows...") + output_dir = output_dir.replace("/", "\\") + + # Change to upper case to match the case in + # "jax/tools/build_utils.py" for Windows. + arch = arch.upper() + + return (output_dir, arch) + +def parse_and_append_bazel_options(bazel_command: command.CommandBuilder, bazel_options: str): + """ + Parses the bazel options and appends them to the bazel command. + Args: + bazel_command: An instance of command.CommandBuilder. + bazel_options: The bazel options to parse and append. + """ + for option in bazel_options.split(" "): + bazel_command.append(option) + +def construct_requirements_update_command(bazel_command: command.CommandBuilder, additional_build_options: str, python_version: str, update_nightly: bool): + """ + Constructs the Bazel command to run the requirements update. + Args: + bazel_command: An instance of command.CommandBuilder. + additional_build_options: Additional build options to pass to Bazel. + python_version: Hermetic Python version to use. + update_nightly: Whether to update the nightly requirements file. + """ + bazel_command.append("run") + + if python_version: + logging.debug("Setting Hermetic Python version to %s", python_version) + bazel_command.append(f"--repo_env=HERMETIC_PYTHON_VERSION={python_version}") + + if additional_build_options: + logging.debug("Using additional build options: %s", additional_build_options) + parse_and_append_bazel_options(bazel_command, additional_build_options) + + if update_nightly: + bazel_command.append("//build:requirements_nightly.update") + else: + bazel_command.append("//build:requirements.update") + +def add_python_version_argument(parser: argparse.ArgumentParser): + """ + Add Python version argument to the parser. + Args: + parser: An instance of argparse.ArgumentParser. + """ + parser.add_argument( + "--python_version", + type=str, + choices=["3.10", "3.11", "3.12"], + default="3.12", + help="Python version to use", + ) + +def add_cuda_version_argument(parser: argparse.ArgumentParser): + """ + Add CUDA version argument to the parser. + Args: + parser: An instance of argparse.ArgumentParser. + """ + parser.add_argument( + "--cuda_version", + type=str, + default="12.3.2", + help="CUDA version to use", + ) + +def add_cudnn_version_argument(parser: argparse.ArgumentParser): + """ + Add cuDNN version argument to the parser. + Args: + parser: An instance of argparse.ArgumentParser. + """ + parser.add_argument( + "--cudnn_version", + type=str, + default="9.1.1", + help="cuDNN version to use", + ) + +def add_disable_nccl_argument(parser: argparse.ArgumentParser): + """ + Add an argument to allow disabling NCCL for CUDA/ROCM builds. + Args: + parser: An instance of argparse.ArgumentParser. + """ + parser.add_argument( + "--disable_nccl", + action="store_true", + help="Whether to disable NCCL for CUDA/ROCM builds.", + ) + +def add_cuda_compute_capabilities_argument(parser: argparse.ArgumentParser): + """ + Add an argument to set the CUDA compute capabilities. + Args: + parser: An instance of argparse.ArgumentParser. + """ + parser.add_argument( + "--cuda_compute_capabilities", + type=str, + default=None, + help="A comma-separated list of CUDA compute capabilities to support.", + ) + +def add_rocm_version_argument(parser: argparse.ArgumentParser): + """ + Add ROCm version argument to the parser. + Args: + parser: An instance of argparse.ArgumentParser. + """ + parser.add_argument( + "--rocm_version", + type=str, + default="60", + help="ROCm version to use", + ) + + +def add_rocm_amdgpu_targets_argument(parser: argparse.ArgumentParser): + """ + Add an argument to set the ROCm amdgpu targets. + Args: + parser: An instance of argparse.ArgumentParser. + """ + parser.add_argument( + "--rocm_amdgpu_targets", + type=str, + default="gfx900,gfx906,gfx908,gfx90a,gfx1030", + help="A comma-separated list of ROCm amdgpu targets to support.", + ) + +def add_rocm_path_argument(parser: argparse.ArgumentParser): + """ + Add an argument to set the ROCm toolkit path. + Args: + parser: An instance of argparse.ArgumentParser. + """ + parser.add_argument( + "--rocm_path", + type=str, + default="", + help="Path to the ROCm toolkit.", + ) + +def add_global_arguments(parser: argparse.ArgumentParser): + """ + Add global arguments to the parser. + Args: + parser: An instance of argparse.ArgumentParser. + """ + # Set the build mode. This is used to determine the Bazelrc config to use. + # Local selects the "local_" config and CI selects the "ci_" config. CI + # configs inherit local configs and set a custom C++ toolchain that needs to + # be present on the system. + parser.add_argument( + "--mode", + type=str, + choices=["ci", "local"], + default="local", + help=""" + Sets the build mode to use. + If set to "ci", the CLI will assume the build is being run in CI or CI + like environment and will use the "ci_" configs in the .bazelrc. + If set to "local", the CLI will use the "local_" configs in the + .bazelrc. + CI configs inherit the local configs and set a custom C++ toolchain to + use Clang and specific versioned standard libraries. As a result, CI + configs require the toolchain to be present on the system. + When set to local, Bazel will use the default C++ compiler on the + system which is GCC for Linux and MSVC for Windows. If you want to use + Clang for local builds, use the `--use_clang` flag. + """, + ) + + # If set, the build will create an 'editable' build instead of a wheel. + parser.add_argument( + "--editable", + action="store_true", + help= + "Create an 'editable' build instead of a wheel.", + ) + + # Set Path to Bazel binary + parser.add_argument( + "--bazel_path", + type=str, + default="", + help= + """ + Path to the Bazel binary to use. The default is to find bazel via the + PATH; if none is found, downloads a fresh copy of Bazelisk from GitHub. + """, + ) + + # Use Clang as the C++ compiler. CI builds use Clang by default via the + # toolchain used by the "ci_" configs in the .bazelrc. + parser.add_argument( + "--use_clang", + action="store_true", + help= + """ + If set, the build will use Clang as the C++ compiler. Requires Clang to + be present on the PATH or a path is given with --clang_path. CI builds use + Clang by default. + """, + ) + + # Set the path to Clang. If not set, the build will attempt to find Clang on + # the PATH. + parser.add_argument( + "--clang_path", + type=str, + default="", + help= + """ + Path to the Clang binary to use. If not set and --use_clang is set, the + build will attempt to find Clang on the PATH. + """, + ) + + # Use RBE if available. Only available for Linux x86 and Windows and requires + # permissions. + parser.add_argument( + "--use_rbe", + action="store_true", + help= + """ + If set, the build will use RBE where possible. Currently, only Linux x86 + and Windows builds can use RBE. On other platforms, setting this flag will + be a no-op. RBE requires permissions to JAX's remote worker pool. Only + Googlers and CI builds can use RBE. + """, + ) + + # Set the path to local XLA repository. If not set, the build will use the + # XLA at the pinned version in workspace.bzl. CI builds set this via the + # JAXCI_XLA_GIT_DIR environment variable. + parser.add_argument( + "--local_xla_path", + type=str, + default=os.environ.get("JAXCI_XLA_GIT_DIR", ""), + help= + """ + Path to local XLA repository to use. If not set, Bazel uses the XLA + at the pinned version in workspace.bzl. + """, + ) + + # Enabling native arch features will add --config=native_arch_posix to the + # Bazel command. This enables -march=native, which generates code targeted to + # use all features of the current machine. Not supported on Windows. + parser.add_argument( + "--enable_native_arch_features", + action="store_true", + help="Enables `-march=native` which generates code targeted to use all" + "features of the current machine. (not supported on Windows)", + ) + + # Enabling MKL DNN will add --config=mkl_open_source_only to the Bazel + # command. + parser.add_argument( + "--enable_mkl_dnn", + action="store_true", + help="Enables MKL-DNN.", + ) + + # Additional startup options to pass to Bazel. + parser.add_argument( + "--bazel_startup_options", + type=str, + default="", + help="Space separated list of additional startup options to pass to Bazel." + "E.g. --bazel_startup_options='--nobatch --noclient_debug'" + ) + + # Additional build options to pass to Bazel. + parser.add_argument( + "--bazel_build_options", + type=str, + default="", + help="Space separated list of additional build options to pass to Bazel." + "E.g. --bazel_build_options='--local_resources=HOST_CPUS --nosandbox_debug'" + ) + + # Directory in which artifacts should be stored. + parser.add_argument( + "--output_dir", + type=str, + default=os.environ.get("JAXCI_OUTPUT_DIR", os.path.join(os.getcwd(), "dist")), + help="Directory in which artifacts should be stored." + ) + + parser.add_argument( + "--requirements_update", + action="store_true", + help="If true, writes a .bazelrc and updates requirements_lock.txt for a" + "corresponding version of Python but does not build any artifacts." + ) + + parser.add_argument( + "--requirements_nightly_update", + action="store_true", + help="Same as update_requirements, but will consider dev, nightly and" + "pre-release versions of packages." + ) + + # Use to invoke a dry run of the build. This will print the Bazel command that + # will be invoked but will not execute it. + parser.add_argument( + "--dry_run", + action="store_true", + help="Prints the Bazel command that is going will be invoked.", + ) + + # Use to enable verbose logging. + parser.add_argument( + "--verbose", + action="store_true", + help="Prodcue verbose output for debugging.", + ) + +async def main(): + parser = argparse.ArgumentParser( + description=( + "CLI for building one of the following packages from source: jaxlib, " + "jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, jax-rocm-pjrt." + ), + epilog=EPILOG, + ) + + # Create subparsers for jax, jaxlib, plugin, pjrt + subparsers = parser.add_subparsers( + dest="command", required=True, help="Artifact to build" + ) + + # jaxlib subcommand + jaxlib_parser = subparsers.add_parser("jaxlib", help="Builds the jaxlib package.") + add_global_arguments(jaxlib_parser) + add_python_version_argument(jaxlib_parser) + + # jax-cuda-plugin subcommand + cuda_plugin_parser = subparsers.add_parser("jax-cuda-plugin", help="Builds the jax-cuda-plugin package.") + add_global_arguments(cuda_plugin_parser) + add_python_version_argument(cuda_plugin_parser) + add_cuda_version_argument(cuda_plugin_parser) + add_cudnn_version_argument(cuda_plugin_parser) + add_cuda_compute_capabilities_argument(cuda_plugin_parser) + add_disable_nccl_argument(cuda_plugin_parser) + + # jax-cuda-pjrt subcommand + cuda_pjrt_parser = subparsers.add_parser("jax-cuda-pjrt", help="Builds the jax-cuda-pjrt package.") + add_global_arguments(cuda_pjrt_parser) + add_cuda_version_argument(cuda_pjrt_parser) + add_cudnn_version_argument(cuda_pjrt_parser) + add_cuda_compute_capabilities_argument(cuda_pjrt_parser) + add_disable_nccl_argument(cuda_pjrt_parser) + + # jax-rocm-plugin subcommand + rocm_plugin_parser = subparsers.add_parser("jax-rocm-plugin", help="Builds the jax-rocm-plugin package.") + add_global_arguments(rocm_plugin_parser) + add_python_version_argument(rocm_plugin_parser) + add_rocm_version_argument(rocm_plugin_parser) + add_rocm_amdgpu_targets_argument(rocm_plugin_parser) + add_rocm_path_argument(rocm_plugin_parser) + add_disable_nccl_argument(rocm_plugin_parser) + + # jax-rocm-pjrt subcommand + rocm_pjrt_parser = subparsers.add_parser("jax-rocm-pjrt", help="Builds the jax-rocm-pjrt package.") + add_global_arguments(rocm_pjrt_parser) + add_rocm_version_argument(rocm_pjrt_parser) + add_rocm_amdgpu_targets_argument(rocm_pjrt_parser) + add_rocm_path_argument(rocm_pjrt_parser) + add_disable_nccl_argument(rocm_pjrt_parser) + + # Get the host systems architecture + arch = platform.machine().lower() + # Get the host system OS + os_name = platform.system().lower() + + args = parser.parse_args() + + logger.info("%s", BANNER) + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + logger.info("Verbose logging enabled.") + + logger.info( + "Building %s for %s %s...", + args.command, + os_name, + arch, + ) + + # Find the path to Bazel + bazel_path = tools.get_bazel_path(args.bazel_path) + + executor = command.SubprocessExecutor() + + # Start constructing the Bazel command + bazel_command = command.CommandBuilder(bazel_path) + + if args.bazel_startup_options: + logging.debug("Using additional Bazel startup options: %s", args.bazel_startup_options) + parse_and_append_bazel_options(bazel_command, args.bazel_startup_options) + + # Temporary; when we make the new scripts as the default we can remove this. + bazel_command.append("--bazelrc=ci/.bazelrc") + + # If the user requested a requirements update, construct the command and + # execute it. Exit without building any artifacts. + if args.requirements_update or args.requirements_nightly_update: + python_version = args.python_version if hasattr(args, "python_version") else "" + construct_requirements_update_command(bazel_command, args.bazel_build_options, python_version, args.requirements_nightly_update) + await executor.run(bazel_command.command, args.dry_run) + sys.exit(0) + + bazel_command.append("run") + + if args.enable_native_arch_features: + logging.debug("Enabling native target CPU features.") + bazel_command.append("--config=native_arch_posix") + + if args.enable_mkl_dnn: + logging.debug("Enabling MKL DNN.") + bazel_command.append("--config=mkl_open_source_only") + + if hasattr(args, "disable_nccl") and args.disable_nccl: + logging.debug("Disabling NCCL.") + bazel_command.append("--config=nonccl") + + # Set Clang as the C++ compiler if requested. If Clang cannot be found, the + # build will continue without Clang and instead use the system default. + if args.use_clang or args.clang_path: + set_clang_as_compiler(bazel_command, args.clang_path) + + if args.mode == "ci": + logging.debug("Running in CI mode. Run the CLI with --help for more details on what this means.") + + # JAX's .bazelrc has custom configs for each build type, architecture, and + # OS. Fetch the appropriate config and pass it to Bazel. A special case is + # when building for Linux Aarch64, which does not have a custom local config + # in JAX's .bazelrc. In this case, we build with the default configs. + # When building ROCm packages, we only use `--config=rocm` from .bazelrc. + bazelrc_config = get_bazelrc_config(os_name, arch, args.command, args.mode, args.use_rbe) + if bazelrc_config: + logging.debug("Using --config=%s from .bazelrc", bazelrc_config) + bazel_command.append(f"--config={bazelrc_config}") + + # Check if a local XLA path is set. + # When building artifacts for running tests, we use clone XLA at HEAD into + # JAXCI_XLA_GIT_DIR and use that for building the artifacts. + if args.local_xla_path: + logging.debug("Setting local XLA path to %s", args.local_xla_path) + bazel_command.append(f"--override_repository=xla={args.local_xla_path}") + + # Set the Hermetic Python version. + if hasattr(args, "python_version"): + logging.debug("Setting Hermetic Python version to %s", args.python_version) + bazel_command.append(f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}") + else: + # While pjrt packages do not use the Python version, we set the default + # as 3.12 because Heremtic Python uses the system default if not Python + # version is set. On the Linux Arm64 Docker image, the system default is + # Python 3.9 which is not supported by JAX. + # TODO(srnitin): Update the Docker images so that we can remove this. + bazel_command.append("--repo_env=HERMETIC_PYTHON_VERSION=3.11") + + # Set the CUDA and cuDNN versions if they are not the default. Default values + # are set in the .bazelrc. + if "cuda" in args.command: + if args.cuda_version != "12.3.2": + logging.debug("Setting Hermetic CUDA version to %s", args.cuda_version) + bazel_command.append(f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}") + if args.cudnn_version != "9.1.1": + logging.debug("Setting Hermetic cuDNN version to %s", args.cudnn_version) + bazel_command.append(f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}") + if args.cuda_compute_capabilities: + logging.debug("Setting CUDA compute capabilities to %s", args.cuda_compute_capabilities) + bazel_command.append(f"--repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}") + + # If building ROCM packages, set the ROCm path and ROCm AMD GPU targets. + if "rocm" in args.command: + if args.rocm_path: + logging.debug("Setting ROCm path to %s", args.rocm_path) + bazel_command.append(f"--action_env ROCM_PATH='{args.rocm_path}'") + if args.rocm_amdgpu_targets: + logging.debug("Setting ROCm AMD GPU targets to %s", args.rocm_amdgpu_targets) + bazel_command.append(f"--action_env TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}") + + # Append any user specified Bazel build options. + if args.bazel_build_options: + logging.debug("Using additional Bazel build options: %s", args.bazel_build_options) + parse_and_append_bazel_options(bazel_command, args.bazel_build_options) + + # Append the build target to the Bazel command. + build_target, wheel_binary = ARTIFACT_BUILD_TARGET_DICT[args.command] + bazel_command.append(build_target) + + # Read output directory. Default is store the artifacts in the "dist/" + # directory in JAX's GitHub repository root. + output_dir = args.output_dir + + # If running on Windows, adjust the paths for compatibility. + if os_name == "windows": + output_dir, arch = adjust_paths_for_windows( + output_dir, arch + ) + + logger.debug("Storing artifacts in %s", output_dir) + + bazel_command.append("--") + + if args.editable: + logger.debug("Building an editable build.") + output_dir = os.path.join(output_dir, args.command) + bazel_command.append("--editable") + + bazel_command.append(f"--output_path={output_dir}") + bazel_command.append(f"--cpu={arch}") + + if "cuda" in args.command: + bazel_command.append("--enable-cuda=True") + major_cuda_version = args.cuda_version.split(".")[0] + bazel_command.append(f"--platform_version={major_cuda_version}") + + if "rocm" in args.command: + bazel_command.append("--enable-rocm=True") + bazel_command.append(f"--platform_version={args.rocm_version}") + + jaxlib_git_hash = get_jaxlib_git_hash() + bazel_command.append(f"--jaxlib_git_hash={jaxlib_git_hash}") + + # Execute the wheel build command. + await executor.run(bazel_command.command, args.dry_run) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ci/cli/helpers/command.py b/ci/cli/helpers/command.py new file mode 100644 index 000000000000..e04153161ee6 --- /dev/null +++ b/ci/cli/helpers/command.py @@ -0,0 +1,102 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Helper script for running subprocess commands. +import asyncio +import dataclasses +import datetime +import os +import logging +from typing import Dict, Optional + +logger = logging.getLogger(__name__) + +class CommandBuilder: + def __init__(self, base_command: str): + self.command = base_command + + def append(self, parameter: str): + self.command += " {}".format(parameter) + return self + +@dataclasses.dataclass +class CommandResult: + """ + Represents the result of executing a subprocess command. + """ + + command: str + return_code: int = 2 # Defaults to not successful + logs: str = "" + start_time: datetime.datetime = dataclasses.field( + default_factory=datetime.datetime.now + ) + end_time: Optional[datetime.datetime] = None + +class SubprocessExecutor: + """ + Manages execution of subprocess commands with reusable environment and logging. + """ + + def __init__(self, environment: Dict[str, str] = None): + """ + + Args: + environment: + """ + self.environment = environment or dict(os.environ) + + async def run(self, cmd: str, dry_run: bool = False) -> CommandResult: + """ + Executes a subprocess command. + + Args: + cmd: The command to execute. + dry_run: If True, prints the command instead of executing it. + + Returns: + A CommandResult instance. + """ + result = CommandResult(command=cmd) + if dry_run: + logger.info("[DRY RUN] %s", cmd) + result.return_code = 0 # Dry run is a success + return result + + logger.info("[EXECUTING] %s", cmd) + + process = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=self.environment, + ) + + async def log_stream(stream, result: CommandResult): + while True: + line_bytes = await stream.readline() + if not line_bytes: + break + line = line_bytes.decode().rstrip() + result.logs += line + logger.info("%s", line) + + await asyncio.gather( + log_stream(process.stdout, result), log_stream(process.stderr, result) + ) + + result.return_code = await process.wait() + result.end_time = datetime.datetime.now() + logger.debug("Command finished with return code %s", result.return_code) + return result diff --git a/ci/cli/helpers/tools.py b/ci/cli/helpers/tools.py new file mode 100644 index 000000000000..f29c69a9e810 --- /dev/null +++ b/ci/cli/helpers/tools.py @@ -0,0 +1,160 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Helper script for setting up the tools used by the CLI. +import collections +import hashlib +import logging +import os +import platform +import shutil +import subprocess +import urllib.request + +logger = logging.getLogger(__name__) + +BAZELISK_BASE_URI = ( + "https://github.com/bazelbuild/bazelisk/releases/download/v1.21.0/" +) + +BazeliskPackage = collections.namedtuple("BazeliskPackage", ["file", "sha256"]) + +BAZELISK_PACKAGES = { + ("Linux", "x86_64"): BazeliskPackage( + file="bazelisk-linux-amd64", + sha256=( + "655a5c675dacf3b7ef4970688b6a54598aa30cbaa0b9e717cd1412c1ef9ec5a7" + ), + ), + ("Linux", "aarch64"): BazeliskPackage( + file="bazelisk-linux-arm64", + sha256=( + "ff793b461968e30d9f954c080f4acaa557edbdeab1ce276c02e4929b767ead66" + ), + ), + ("Darwin", "x86_64"): BazeliskPackage( + file="bazelisk-darwin", + sha256=( + "07ba3d6b90c28984237a6273f6b7de2fd714a1e3a65d1e78f9b342675ecb75e4" + ), + ), + ("Darwin", "arm64"): BazeliskPackage( + file="bazelisk-darwin-arm64", + sha256=( + "17529faeed52219ee170d59bd820c401f1645a95f95ee4ac3ebd06972edfb6ff" + ), + ), + ("Windows", "AMD64"): BazeliskPackage( + file="bazelisk-windows-amd64.exe", + sha256=( + "3da1895614f460692635f8baa0cab6bb35754fc87d9badbd2b3b2ba55873cf89" + ), + ), +} + +def guess_clang_paths(clang_path_flag): + """ + Yields a sequence of guesses about Clang path. Some of sequence elements + can be None. The resulting iterator is lazy and potentially has a side + effects. + """ + + yield clang_path_flag + yield shutil.which("clang") + +def get_clang_path(clang_path_flag): + for clang_path in guess_clang_paths(clang_path_flag): + if clang_path: + absolute_clang_path = os.path.realpath(clang_path) + logger.debug("Found path to Clang: %s.", absolute_clang_path) + return absolute_clang_path + +def get_jax_supported_bazel_version(filename: str = ".bazelversion"): + """Reads the contents of .bazelversion into a string. + + Args: + filename: The path to ".bazelversion". + + Returns: + The Bazel version as a string, or None if the file doesn't exist. + """ + try: + with open(filename, 'r') as file: + content = file.read() + return content.strip() + except FileNotFoundError: + print(f"Error: File '{filename}' not found.") + return None + +def get_bazel_path(bazel_path_flag): + for bazel_path in guess_bazel_paths(bazel_path_flag): + if bazel_path and verify_bazel_version(bazel_path): + logger.debug("Found a compatible Bazel installation.") + return bazel_path + logger.debug("Unable not find a compatible Bazel installation. Downloading Bazelisk...") + return download_and_verify_bazelisk() + +def verify_bazel_version(bazel_path): + """ + Verifies if the version of Bazel is compatible with JAX's required Bazel + version. + """ + system_bazel_version = subprocess.check_output([bazel_path, "--version"]).strip().decode('UTF-8') + # `bazel --version` returns the version as "bazel a.b.c" so we split the + # result to get only the version numbers. + system_bazel_version = system_bazel_version.split(" ")[1] + expected_bazel_version = get_jax_supported_bazel_version() + if expected_bazel_version != system_bazel_version: + logger.debug("Bazel version mismatch. JAX requires %s but got %s when `%s --version` was run", expected_bazel_version, system_bazel_version, bazel_path) + return False + return True + +def guess_bazel_paths(bazel_path_flag): + """ + Yields a sequence of guesses about bazel path. Some of sequence elements + can be None. The resulting iterator is lazy and potentially has a side + effects. + """ + yield bazel_path_flag + # For when Bazelisk was downloaded and is present on the root JAX directory + yield shutil.which("./bazel") + yield shutil.which("bazel") + +def download_and_verify_bazelisk(): + """Downloads and verifies Bazelisk.""" + system = platform.system() + machine = platform.machine() + downloaded_filename = "bazel" + expected_sha256 = BAZELISK_PACKAGES[system, machine].sha256 + + # Download Bazelisk and store it as "bazel". + logger.debug("Downloading Bazelisk...") + _, _ = urllib.request.urlretrieve(BAZELISK_BASE_URI + BAZELISK_PACKAGES[system, machine].file, downloaded_filename) + + with open(downloaded_filename, "rb") as downloaded_file: + contents = downloaded_file.read() + + calculated_sha256 = hashlib.sha256(contents).hexdigest() + + # Verify checksum + logger.debug("Verifying the checksum...") + if calculated_sha256 != expected_sha256: + raise ValueError("SHA256 checksum mismatch. Download may be corrupted.") + logger.debug("Checksum verified!") + + logger.debug("Setting the Bazelisk binary to executable mode...") + subprocess.run(["chmod", "+x", downloaded_filename], check=True) + + return os.path.realpath(downloaded_filename) + diff --git a/ci/envs/build_artifacts/jax b/ci/envs/build_artifacts/jax new file mode 100644 index 000000000000..2ac4a6e834ae --- /dev/null +++ b/ci/envs/build_artifacts/jax @@ -0,0 +1,23 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default + +# Build JAX artifact. +export JAXCI_BUILD_JAX="1" + +# Note Python version of the container does not matter as `jax` is a pure +# Python package. +export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12" \ No newline at end of file diff --git a/ci/envs/build_artifacts/jax-cuda-pjrt b/ci/envs/build_artifacts/jax-cuda-pjrt new file mode 100644 index 000000000000..a5c91f2060a0 --- /dev/null +++ b/ci/envs/build_artifacts/jax-cuda-pjrt @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default + +# Enable jax-cuda-pjrt build. +export JAXCI_BUILD_PJRT="1" + +# Enable wheel audit to check for manylinux compliance. +export JAXCI_RUN_AUDITWHEEL="1" + +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# Linux x86 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then + # Note Python version of the container does not matter for Bazel builds and + # Bazel tests. JAX supports hermetic Python and thus the actual Python version + # of the artifact is controlled by the value set in `HERMETIC_PYTHON_VERSION`. + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12" +fi + +# Linux Aarch64 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:tf-2-18-multi-python" +fi + +# When building artifacts for running tests, we need to disable x64 mode and +# clone XLA at HEAD. +if [[ $JAXCI_SETUP_TEST_ENVIRONMENT == 1 ]]; then + # Disable x64 mode + export JAX_ENABLE_X64=0 + # Clone XLA at HEAD. + export JAXCI_CLONE_MAIN_XLA=1 +fi diff --git a/ci/envs/build_artifacts/jax-cuda-plugin b/ci/envs/build_artifacts/jax-cuda-plugin new file mode 100644 index 000000000000..49e11dffd71f --- /dev/null +++ b/ci/envs/build_artifacts/jax-cuda-plugin @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default + +# Enable jax-cuda-plugin build +export JAXCI_BUILD_PLUGIN="1" + +# Enable wheel audit to check for manylinux compliance. +export JAXCI_RUN_AUDITWHEEL="1" + +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# Linux x86 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then + # Note Python version of the container does not matter for Bazel builds and + # Bazel tests. JAX supports hermetic Python and thus the actual Python version + # of the artifact is controlled by the value set in `HERMETIC_PYTHON_VERSION`. + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12" +fi + +# Linux Aarch64 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:tf-2-18-multi-python" +fi + +# When building artifacts for running tests, we need to disable x64 mode and +# clone XLA at HEAD. +if [[ $JAXCI_SETUP_TEST_ENVIRONMENT == 1 ]]; then + # Disable x64 mode + export JAX_ENABLE_X64=0 + # Clone XLA at HEAD. + export JAXCI_CLONE_MAIN_XLA=1 +fi diff --git a/ci/envs/build_artifacts/jaxlib b/ci/envs/build_artifacts/jaxlib new file mode 100644 index 000000000000..cfb99c7f79a8 --- /dev/null +++ b/ci/envs/build_artifacts/jaxlib @@ -0,0 +1,60 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default + +# Enable jaxlib build. +export JAXCI_BUILD_JAXLIB="1" + +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# Linux x86 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then + # Enable wheel audit to check for manylinux compliance. + export JAXCI_RUN_AUDITWHEEL=1 + + # Note Python version of the container does not matter for Bazel builds and + # Bazel tests. JAX supports hermetic Python and thus the actual Python version + # of the artifact is controlled by the value set in `HERMETIC_PYTHON_VERSION`. + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12" +fi + +# Linux Aarch64 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then + # Enable wheel audit to check for manylinux compliance. + export JAXCI_RUN_AUDITWHEEL=1 + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:jax-" +fi + +# Windows specific settings +if [[ $os =~ "msys_nt" ]]; then + export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc" +fi + +# Mac specific settings +if [[ $os == "macos" ]]; then + # Mac builds do not run in Docker. + export JAXCI_RUN_DOCKER_CONTAINER=0 +fi + +# When building artifacts for running tests, we need to disable x64 mode and +# clone XLA at HEAD. +if [[ $JAXCI_SETUP_TEST_ENVIRONMENT == 1 ]]; then + # Disable x64 mode + export JAX_ENABLE_X64=0 + # Clone XLA at HEAD. + export JAXCI_CLONE_MAIN_XLA=1 +fi \ No newline at end of file diff --git a/ci/envs/default b/ci/envs/default new file mode 100644 index 000000000000..c4be6a36fe75 --- /dev/null +++ b/ci/envs/default @@ -0,0 +1,115 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file contains all the default values for the environment variables +# used in the JAX CI scripts. +# +# The default values are set here. Other build specifc envs such as those in the +# "build_artifacts" and "run_tests" directory source this file and override the +# default values depening on the build type. + +# This is expected to be the root of the JAX git repository. +export JAXCI_JAX_GIT_DIR=$(pwd) + +# Controls the version of Hermetic Python to use. Use system default if not +# set. +export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')} + +# Controls the location where the artifacts are stored. +export JAXCI_OUTPUT_DIR="$(pwd)/dist" + +# Release tag to use for the build. +export JAXCI_RELEASE_TAG="${JAXCI_RELEASE_TAG:-}" + +# ############################################################################# +# Artifact build specific environment variables. +# ############################################################################# + +# The build CLI can be run in either "ci" or "local" mode. This is used to +# determine which .bazelrc configs to pass to Bazel. CI mode uses JAX's custom +# toolchain and local mode uses the default Bazel toolchain. +export JAXCI_CLI_BUILD_MODE=ci + +# If set to 1, the build CLI will use RBE to build the artifacts. Available for +# Linux x86 and Windows. RBE requires authentication to JAX's GCP project so +# only CI builds and Googlers can use RBE. +export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} + +# Environment variables that control which artifact to build. Used by +# `build_artifacts.sh` +export JAXCI_BUILD_JAX=0 +export JAXCI_BUILD_JAXLIB=0 +export JAXCI_BUILD_PLUGIN=0 +export JAXCI_BUILD_PJRT=0 +export JAXCI_RUN_AUDITWHEEL=0 + +# ############################################################################# +# Docker specific environment variables. +# ############################################################################# + +# Docker specifc environment variables. Used by `run_docker_container.sh` +export JAXCI_RUN_DOCKER_CONTAINER=${JAXCI_RUN_DOCKER_CONTAINER:-1} +export JAXCI_DOCKER_WORK_DIR="/jax" +export JAXCI_DOCKER_IMAGE="" +export JAXCI_DOCKER_ARGS="" + +# ############################################################################# +# Test specific environment variables. +# ############################################################################# + +# Used by envs inside ci/build_artifacts. When set to 1, we disable x64 mode +# and clone XLA at HEAD. +export JAXCI_SETUP_TEST_ENVIRONMENT=${JAXCI_SETUP_TEST_ENVIRONMENT:-0} + +# Set when running tests locally where we need the wheels to be installed on +# the system. +export JAXCI_INSTALL_WHEELS_LOCALLY=0 + +# JAXCI_PYTHON is used to install the wheels locally. It needs to match the +# version of the hermetic Python used by Bazel. +export JAXCI_PYTHON=python${JAXCI_HERMETIC_PYTHON_VERSION} + +# Bazel test environment variables. +export JAXCI_RUN_BAZEL_TEST_CPU=0 +export JAXCI_RUN_BAZEL_TEST_GPU_LOCAL=0 +export JAXCI_RUN_BAZEL_TEST_GPU_RBE=0 + +# Pytest environment variables. +export JAXCI_RUN_PYTEST_CPU=0 +export JAXCI_RUN_PYTEST_GPU=0 +export JAXCI_RUN_PYTEST_TPU=0 +export JAXCI_TPU_CORES="" + +# If set to 1, the script will clone the main XLA repository at HEAD, set its +# path in JAXCI_XLA_GIT_DIR and use it to build the artifacts or run the tests. +export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0} + +# Enable this globally across all builds. +export JAX_SKIP_SLOW_TESTS=true + +# ############################################################################# +# Variables that can be overridden by the user. +# ############################################################################# +# Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository if you want to +# use a local copy of XLA instead of the pinned version in the WORKSPACE. +export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} + +# Set JAXCI_XLA_COMMIT to the commit to use for the XLA repository. Requires +# the path to the local copy of XLA to be set in JAXCI_XLA_GIT_DIR. +export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} + +# When running tests locally, JAX artifacts are built with CUDA 12.3 and then +# tested with a Docker image with CUDA 12.3 and CUDA 12.1. By default, we set +# the CUDA version of the Docker image to 12.3. +export JAXCI_DOCKER_CUDA_VERSION=${JAXCI_DOCKER_CUDA_VERSION:-12.3} \ No newline at end of file diff --git a/ci/envs/run_tests/bazel_cpu b/ci/envs/run_tests/bazel_cpu new file mode 100644 index 000000000000..80c531696dbf --- /dev/null +++ b/ci/envs/run_tests/bazel_cpu @@ -0,0 +1,38 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default + +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# Enable Bazel CPU tests. +export JAXCI_RUN_BAZEL_TEST_CPU=1 + +# Linux x86 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then + export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython" +fi + +# Linux Aarch64 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:tf-2-18-multi-python" +fi + +# Disable x64 mode +export JAX_ENABLE_X64=0 + +# Clone XLA at HEAD. +export JAXCI_CLONE_MAIN_XLA=1 \ No newline at end of file diff --git a/ci/envs/run_tests/bazel_gpu_local b/ci/envs/run_tests/bazel_gpu_local new file mode 100644 index 000000000000..cf8922773eb4 --- /dev/null +++ b/ci/envs/run_tests/bazel_gpu_local @@ -0,0 +1,34 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default + +# Enable local Bazel GPU tests +export JAXCI_RUN_BAZEL_TEST_GPU_LOCAL=1 + +# Only Linux x86 runs local GPU tests at the moment. +export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda${JAXCI_DOCKER_CUDA_VERSION}-cudnn9.1-ubuntu20.04-manylinux2014-multipython" +export JAXCI_DOCKER_ARGS="--shm-size=16g --gpus all" + +export NCCL_DEBUG=WARN + +# Disable x64 mode +export JAX_ENABLE_X64=0 + +# Clone XLA at HEAD. +export JAXCI_CLONE_MAIN_XLA=1 + +# Set per device memory limit to 20.5 GB +export TF_PER_DEVICE_MEMORY_LIMIT_MB=20480 diff --git a/ci/envs/run_tests/bazel_gpu_rbe b/ci/envs/run_tests/bazel_gpu_rbe new file mode 100644 index 000000000000..ff7c9983a8d6 --- /dev/null +++ b/ci/envs/run_tests/bazel_gpu_rbe @@ -0,0 +1,33 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default + +# Enable Bazel GPU tests with RBE +export JAXCI_RUN_BAZEL_TEST_GPU_RBE=1 + +# Only Linux x86 runs local GPU tests at the moment. +export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython" +export JAXCI_DOCKER_ARGS="--gpus all" + +# TODO(srnitin): Figure out where this gets used +export JAX_CUDA_VERSION=12 +export JAX_CUDNN_VERSION=9.1 + +# Disable x64 mode +export JAX_ENABLE_X64=0 + +# Clone XLA at HEAD. +export JAXCI_CLONE_MAIN_XLA=1 \ No newline at end of file diff --git a/ci/envs/run_tests/pytest_cpu b/ci/envs/run_tests/pytest_cpu new file mode 100644 index 000000000000..a49e68330ea5 --- /dev/null +++ b/ci/envs/run_tests/pytest_cpu @@ -0,0 +1,30 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default + +# Enable CPU Pytests +export JAXCI_RUN_PYTEST_CPU=1 + +# Install jaxlib wheel locally. +export JAXCI_INSTALL_WHEELS_LOCALLY=1 + +# Disable x64 mode +export JAX_ENABLE_X64=0 + +# Clone XLA at HEAD. +export JAXCI_CLONE_MAIN_XLA=1 + +export TF_CPP_MIN_LOG_LEVEL=0 diff --git a/ci/envs/run_tests/pytest_gpu b/ci/envs/run_tests/pytest_gpu new file mode 100644 index 000000000000..1e9c4f2cff06 --- /dev/null +++ b/ci/envs/run_tests/pytest_gpu @@ -0,0 +1,39 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default + +# Enable GPU Pytests +export JAXCI_RUN_PYTEST_GPU=1 + +# Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels locally. +export JAXCI_INSTALL_WHEELS_LOCALLY=1 + +# Only Linux x86 runs local GPU tests at the moment. +export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda${JAXCI_DOCKER_CUDA_VERSION}-cudnn9.1-ubuntu20.04-manylinux2014-multipython" +export JAXCI_DOCKER_ARGS="--shm-size=16g --gpus all" + +# TODO(srnitin): Figure out where this gets used +export JAX_CUDA_VERSION=12 +export JAX_CUDA_FULL_VERSION=12.3 +export JAX_DOCKER_CUDA_FULL_VERSION=12.1 +export JAX_CUDNN_VERSION=9.1 +export JAX_CUDA_PLUGIN='True' + +# Disable x64 mode +export JAX_ENABLE_X64=0 + +# Clone XLA at HEAD. +export JAXCI_CLONE_MAIN_XLA=1 \ No newline at end of file diff --git a/ci/envs/run_tests/pytest_tpu b/ci/envs/run_tests/pytest_tpu new file mode 100644 index 000000000000..b5b23fa7f075 --- /dev/null +++ b/ci/envs/run_tests/pytest_tpu @@ -0,0 +1,28 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default + +export JAXCI_INSTALL_WHEELS_LOCALLY=1 + +export JAXCI_RUN_PYTEST_TPU=1 + +export JAXCI_TPU_CORES=8 + +# Disable x64 mode +export JAX_ENABLE_X64=0 + +# Clone XLA at HEAD. +export JAXCI_CLONE_MAIN_XLA=1 diff --git a/ci/run_bazel_test.sh b/ci/run_bazel_test.sh new file mode 100755 index 000000000000..b03c07a69fb5 --- /dev/null +++ b/ci/run_bazel_test.sh @@ -0,0 +1,85 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +## +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Source JAXCI environment variables. +source "ci/utilities/setup_envs.sh" "$1" +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +if [[ $JAXCI_RUN_BAZEL_TEST_CPU == 1 ]]; then + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # If running on Mac or Linux Aarch64, we only build the test targets and + # not run them. These platforms do not have native RBE support so we + # cross-compile them on the Linux x86 RBE pool. As the tests still need + # to be run on the host machine and because running the tests on a single + # machine can take a long time, we skip running them on these platforms. + if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then + echo "Building RBE CPU tests..." + check_if_to_run_in_docker bazel --bazelrc=ci/.bazelrc build --config=rbe_cross_compile_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + //tests:cpu_tests //tests:backend_independent_tests + else + echo "Running RBE CPU tests..." + check_if_to_run_in_docker bazel --bazelrc=ci/.bazelrc test --config=rbe_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + //tests:cpu_tests //tests:backend_independent_tests + fi +fi + +# Run Bazel GPU tests locally. +if [[ $JAXCI_RUN_BAZEL_TEST_GPU_LOCAL == 1 ]]; then + check_if_to_run_in_docker nvidia-smi + echo "Running local GPU tests..." + + #check_if_to_run_in_docker "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + + # Only Linux x86 builds run GPU tests + # Runs non-multiaccelerator tests with one GPU apiece. + # It appears --run_under needs an absolute path. + check_if_to_run_in_docker bazel --bazelrc=ci/.bazelrc test --config=ci_linux_x86_64_cuda \ + --config=non_multiaccelerator_local \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --run_under "${JAXCI_JAX_GIT_DIR}/build/parallel_accelerator_execute.sh" --test_timeout=3000 \ + --test_env=JAX_PLATFORM_NAME="gpu" \ + //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests || true + echo "Finished running non-multiaccelerator tests..." + + # Runs multiaccelerator tests with all GPUs. + check_if_to_run_in_docker bazel --bazelrc=ci/.bazelrc test --config=ci_linux_x86_64_cuda \ + --config=multiaccelerator_local \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" --test_timeout=3000 \ + //tests:gpu_tests //tests/pallas:gpu_tests || true + echo "Finished running multiaccelerator tests..." +fi + +# Run Bazel GPU tests with RBE. +if [[ $JAXCI_RUN_BAZEL_TEST_GPU_RBE == 1 ]]; then + check_if_to_run_in_docker nvidia-smi + echo "Running RBE GPU tests..." + + # Only Linux x86 builds run GPU tests + # Runs non-multiaccelerator tests with one GPU apiece. + check_if_to_run_in_docker bazel --bazelrc=ci/.bazelrc test --config=rbe_linux_x86_64_cuda \ + --config=non_multiaccelerator \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests //docs/... +fi diff --git a/ci/run_pytest.sh b/ci/run_pytest.sh new file mode 100755 index 000000000000..feb8399e6de7 --- /dev/null +++ b/ci/run_pytest.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Source JAXCI environment variables. +source "ci/utilities/setup_envs.sh" "$1" +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +check_if_to_run_in_docker "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + +if [[ $JAXCI_RUN_PYTEST_CPU == 1 ]]; then + echo "Running CPU tests..." + check_if_to_run_in_docker "$JAXCI_PYTHON" -m pytest -n auto --tb=short --maxfail=20 tests examples +fi + +if [[ $JAXCI_RUN_PYTEST_GPU == 1 ]]; then + echo "Running GPU tests..." + export XLA_PYTHON_CLIENT_ALLOCATOR=platform + export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 + check_if_to_run_in_docker "$JAXCI_PYTHON" -m pytest -n 4 --tb=short --maxfail=20 \ + tests examples \ + --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ + --deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \ + --deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \ + --deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric +fi + +if [[ $JAXCI_RUN_PYTEST_TPU == 1 ]]; then + echo "Running TPU tests..." + # Run single-accelerator tests in parallel + export JAX_ENABLE_TPU_XDIST=true + + check_if_to_run_in_docker "$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)' + check_if_to_run_in_docker "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ + --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ + --maxfail=20 -m "not multiaccelerator" tests examples + + # Run Pallas printing tests, which need to run with I/O capturing disabled. + export TPU_STDERR_LOG_LEVEL=0 + check_if_to_run_in_docker "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest + + # Run multi-accelerator across all chips + check_if_to_run_in_docker "$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests +fi diff --git a/ci/utilities/convert_msys_paths_to_win_paths.py b/ci/utilities/convert_msys_paths_to_win_paths.py new file mode 100644 index 000000000000..c2b4a6bdc58f --- /dev/null +++ b/ci/utilities/convert_msys_paths_to_win_paths.py @@ -0,0 +1,74 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Converts MSYS Linux-like paths stored in env variables to Windows paths. + +This is necessary on Windows, because some applications do not understand/handle +Linux-like paths MSYS uses, for example, Docker. +""" +import argparse +import os +import subprocess + +def msys_to_windows_path(msys_path): + """Converts an MSYS path to a Windows path using cygpath. + + Args: + msys_path: The MSYS path to convert. + + Returns: + The corresponding Windows path. + """ + try: + # Use cygpath with the -w flag to convert to Windows format + process = subprocess.run(['cygpath', '-w', msys_path], capture_output=True, text=True, check=True) + windows_path = process.stdout.strip() + return windows_path + except FileNotFoundError: + print("Error: cygpath not found. Make sure it's in your PATH.") + return None + except subprocess.CalledProcessError as e: + print(f"Error converting path: {e}") + return None + +def main(parsed_args: argparse.Namespace): + converted_paths = {} + + for var, value in os.environ.items(): + if parsed_args.blacklist and var in parsed_args.blacklist: + continue + if "_DIR" in var or (args.whitelist and var in parsed_args.whitelist): + converted_path = msys_to_windows_path(value) + converted_paths[var] = converted_path + + var_str = '\n'.join(f'export {k}="{v}"' + for k, v in converted_paths.items()) + # The string can then be piped into `source`, to re-set the + # 'converted' variables. + print(var_str) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description=( + 'Convert MSYS paths in environment variables to Windows paths.')) + parser.add_argument('--blacklist', + nargs='*', + help='List of variables to ignore') + parser.add_argument('--whitelist', + nargs='*', + help='List of variables to include') + args = parser.parse_args() + + main(args) diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh new file mode 100644 index 000000000000..de2e0fcc3874 --- /dev/null +++ b/ci/utilities/install_wheels_locally.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Install wheels stored in `JAXCI_OUTPUT_DIR` locally using the Python binary +# set in JAXCI_PYTHON. Use the absolute path to the `find` utility to avoid +# using the Windows version of `find` on Windows. +WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) ) + +if [[ -z "$WHEELS" ]]; then + echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" + exit 1 +fi + +echo "Installing the following wheels:" +echo "${WHEELS[@]}" +"$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" + +echo "Installing the JAX package in editable mode at the current commit..." +# Install JAX package at the current commit. +"$JAXCI_PYTHON" -m pip install -U -e . diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh new file mode 100755 index 000000000000..f22caf5b9b02 --- /dev/null +++ b/ci/utilities/run_auditwheel.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Runs auditwheel to ensure manylinux compatibility. + +# Get a list of all the wheels in the output directory. Only look for wheels +# that need to be verified for manylinux compliance. +WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) + +if [[ -z "$WHEELS" ]]; then + echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" + exit 1 +fi + +for wheel in $WHEELS; do + printf "\nRunning auditwheel on the following wheel:" + ls $wheel + OUTPUT_FULL=$(python3 -m auditwheel show $wheel) + # Remove the wheel name from the output to avoid false positives. + wheel_name=$(basename $wheel) + OUTPUT=${OUTPUT_FULL//${wheel_name}/} + + # If a wheel is manylinux2014 compliant, `auditwheel show` will return the + # platform tag as manylinux_2_17. manylinux2014 is an alias for + # manylinux_2_17. + if echo "$OUTPUT" | grep -q "manylinux_2_17"; then + printf "\n$wheel_name is manylinux2014 compliant.\n" + else + echo "$OUTPUT_FULL" + printf "\n$wheel_name is NOT manylinux2014 compliant.\n" + exit 1 + fi +done \ No newline at end of file diff --git a/ci/utilities/run_docker_container.sh b/ci/utilities/run_docker_container.sh new file mode 100644 index 000000000000..7377d8e31c9d --- /dev/null +++ b/ci/utilities/run_docker_container.sh @@ -0,0 +1,95 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Set up the Docker container and start it for JAX CI jobs. + +# Keep the existing "jax" container if it's already present. +if ! docker container inspect jax >/dev/null 2>&1 ; then + # Simple retry logic for docker-pull errors. Sleeps if a pull fails. + # Pulling an already-pulled container image will finish instantly, so + # repeating the command costs nothing. + docker pull "$JAXCI_DOCKER_IMAGE" || sleep 15 + docker pull "$JAXCI_DOCKER_IMAGE" + + if [[ "$(uname -s)" =~ "MSYS_NT" ]]; then + # Docker on Windows doesn't support the `host` networking mode, and so + # port-forwarding is required for the container to detect it's running on GCE. + export IP_ADDR=$(powershell -command "(Get-NetIPAddress -AddressFamily IPv4 -InterfaceAlias 'vEthernet (nat)').IPAddress") + netsh interface portproxy add v4tov4 listenaddress=$IP_ADDR listenport=80 connectaddress=169.254.169.254 connectport=80 + JAXCI_DOCKER_ARGS="$JAXCI_DOCKER_ARGS -e GCE_METADATA_HOST=$IP_ADDR" + else + # The volume mapping flag below shares the user's gcloud credentials, if any, + # with the container, in case the user has credentials stored there. + # This would allow Bazel to authenticate for RBE. + # Note: JAX's CI does not have any credentials stored there. + JAXCI_DOCKER_ARGS="$JAXCI_DOCKER_ARGS -v $HOME/.config/gcloud:/root/.config/gcloud" + fi + + # If XLA repository on the local system is to be used, map it to the container + # and set the JAXCI_XLA_GIT_DIR environment variable to the container path. + if [[ -n $JAXCI_XLA_GIT_DIR ]]; then + JAXCI_DOCKER_ARGS="$JAXCI_DOCKER_ARGS -v $JAXCI_XLA_GIT_DIR:$JAXCI_DOCKER_WORK_DIR/xla -e JAXCI_XLA_GIT_DIR=$JAXCI_DOCKER_WORK_DIR/xla" + fi + + # Set the output directory to the container path. + export JAXCI_OUTPUT_DIR=$JAXCI_DOCKER_WORK_DIR/dist + + # Capture the environment variables that get set by JAXCI_ENV_FILE and store + # them in a file. This is needed so that we know which envs to set when + # setting up the Docker container in `setup_docker.sh`. An easier solution + # would be to just grep for "JAXCI_" variables but unfortunately, this is not + # robust as there are some variables such as `JAX_ENABLE_X64`, `NCCL_DEBUG`, + # etc that are used by JAX but do not have the `JAXCI_` prefix. + envs_after=$(mktemp) + env > "$envs_after" + + jax_ci_envs=$(mktemp) + + # Only get the new environment variables set by JAXCI_ENV_FILE. Use + # "env_before" that gets set in setup.sh for the initial environment + # variables. diff exits with a return code. This can end the build abrupty so + # we use "|| true" to ignore the return code and continue. + diff <(sort "$envs_before") <(sort "$envs_after") | grep "^> " | sed 's/^> //' | grep -v "^BASH_FUNC" > "$jax_ci_envs" || true + + # Start the container. `user_set_jaxci_envs` is read after `jax_ci_envs` to + # allow the user to override any environment variables set by JAXCI_ENV_FILE. + docker run --env-file $jax_ci_envs --env-file "$user_set_jaxci_envs" $JAXCI_DOCKER_ARGS --name jax \ + -w $JAXCI_DOCKER_WORK_DIR -itd --rm \ + -v "$JAXCI_JAX_GIT_DIR:$JAXCI_DOCKER_WORK_DIR" \ + "$JAXCI_DOCKER_IMAGE" \ + bash + + if [[ "$(uname -s)" =~ "MSYS_NT" ]]; then + # Allow requests from the container. + CONTAINER_IP_ADDR=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' jax) + netsh advfirewall firewall add rule name="Allow Metadata Proxy" dir=in action=allow protocol=TCP localport=80 remoteip="$CONTAINER_IP_ADDR" + fi +fi + +# Update `check_if_to_run_in_docker` to execute the commands inside the Docker +# container. +check_if_to_run_in_docker() { docker exec jax "$@"; } + +# Update `JAXCI_OUTPUT_DIR`, `JAXCI_JAX_GIT_DIR` and `JAXCI_XLA_GIT_DIR` with +# the new Docker path on the host shell environment. This is needed because when +# running in Docker with `docker exec`, the commands are run on the host shell +# environment and as such the following variables need to be updated with The +# Docker paths. +export JAXCI_OUTPUT_DIR=$JAXCI_DOCKER_WORK_DIR/dist +export JAXCI_JAX_GIT_DIR=$JAXCI_DOCKER_WORK_DIR +export JAXCI_XLA_GIT_DIR=$JAXCI_DOCKER_WORK_DIR/xla + +check_if_to_run_in_docker git config --global --add safe.directory $JAXCI_DOCKER_WORK_DIR \ No newline at end of file diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh new file mode 100644 index 000000000000..34ddc86d3900 --- /dev/null +++ b/ci/utilities/setup_build_environment.sh @@ -0,0 +1,104 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Set up the build environment for JAX CI jobs. This script depends on the +# environment variables set in `setup_envs.sh`. +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exuo pipefail -o history -o allexport + +# Pre-emptively mark the git directory as safe. This is necessary for JAX CI +# jobs running on Linux runners in GitHub Actions. Without this, git complains +# that the directory has dubious ownership and refuses to run any commands. +# Avoid running on Windows runners as git runs into issues with not being able +# to lock the config file. Other git commands seem to work the on Windows +# runners so we can skip this step. +if [[ ! $(uname -s) =~ "MSYS_NT" ]]; then + git config --global --add safe.directory $JAXCI_JAX_GIT_DIR +fi + +# When building release artifacts, check out the release tag. JAX CI jobs build +# from the main branch by default. +if [[ -n "$JAXCI_RELEASE_TAG" ]]; then + git checkout tags/"$JAXCI_RELEASE_TAG" +fi + +# When running tests, we need to check out XLA at HEAD. +if [[ -z ${JAXCI_XLA_GIT_DIR} ]] && [[ "$JAXCI_CLONE_MAIN_XLA" == 1 ]]; then + if [[ ! -d $(pwd)/xla ]]; then + echo "Cloning XLA at HEAD to $(pwd)/xla" + git clone --depth=1 https://github.com/openxla/xla.git $(pwd)/xla + fi + export JAXCI_XLA_GIT_DIR=$(pwd)/xla +fi + +# If a path to XLA is provided, use that to build JAX or run tests. +if [[ ! -z ${JAXCI_XLA_GIT_DIR} ]]; then + echo "Overriding XLA to be read from $JAXCI_XLA_GIT_DIR instead of the pinned" + echo "version in the WORKSPACE." + echo "If you would like to revert this behavior, unset JAXCI_XLA_GIT_DIR and" + echo "JAXCI_CLONE_MAIN_XLA in your environment." + + # If a XLA commit is provided, check out XLA at that commit. + if [[ ! -z "$JAXCI_XLA_COMMIT" ]]; then + pushd "$JAXCI_XLA_GIT_DIR" + + git fetch --depth=1 origin "$JAXCI_XLA_COMMIT" + echo "JAXCI_XLA_COMMIT is set. Checking out XLA at $JAXCI_XLA_COMMIT" + git checkout "$JAXCI_XLA_COMMIT" + + popd + fi +fi + +# Setup check_if_to_run_in_docker, a helper function for executing steps that +# can either be run locally or run under Docker. +# run_docker_container.sh, below, redefines it as "docker exec". +# Important: "check_if_to_run_in_docker foo | bar" is +# "( check_if_to_run_in_docker foo ) | bar", and +# not "check_if_to_run_in_docker (foo | bar)". +# Therefore, "check_if_to_run_in_docker" commands cannot include pipes -- which +# is probably for the better. If a pipe is necessary for something, it is +# probably complex. Write a well-documented script under utilities/ to +# encapsulate the functionality instead. +check_if_to_run_in_docker() { "$@"; } + +# For Windows, convert MSYS Linux-like paths to Windows paths. +if [[ $(uname -s) =~ "MSYS_NT" ]]; then + echo 'Converting MSYS Linux-like paths to Windows paths (for Docker, Python, etc.)' + # Convert all "_DIR" variables to Windows paths. + source <(python3 ./ci/utilities/convert_msys_paths_to_win_paths.py) +fi + +# Set up and and run the Docker container if needed. +# Jobs running on GitHub actions do not invoke this script. They define the +# Docker image via the `container` field in the workflow file. +if [[ "$JAXCI_RUN_DOCKER_CONTAINER" == 1 ]]; then + echo "Setting up the Docker container..." + source ./ci/utilities/run_docker_container.sh +fi + +# When running Pytests, we need to install the wheels locally. +if [[ "$JAXCI_INSTALL_WHEELS_LOCALLY" == 1 ]]; then + echo "Installing wheels locally..." + source ./ci/utilities/install_wheels_locally.sh +fi + +# TODO: cleanup steps \ No newline at end of file diff --git a/ci/utilities/setup_envs.sh b/ci/utilities/setup_envs.sh new file mode 100644 index 000000000000..8b460267cab8 --- /dev/null +++ b/ci/utilities/setup_envs.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Source JAXCI environment variables. + +# If the user has not passed in an JAXCI_ENV_FILE, exit. +if [[ -z "$1" ]]; then + echo "ERROR: No argument passed." + echo "setup_envs.sh requires that a path to a JAX CI env file be passed as" + echo "an argument when invoking the build scripts." + echo "If you are looking to build JAX artifacts, please pass in a" + echo "corresponding env file from the ci/envs/build_artifacts directory." + echo "If you are looking to run JAX tests, please pass in a" + echo "corresponding env file from the ci/envs/run_tests directory." + exit 1 +fi + +# Get the current environment variables and any user set JAXCI_ environment +# variables. We store these in a file and pass them to the Docker container +# when setting up the container in `run_docker_container.sh`. +# Store the current environment variables. +envs_before=$(mktemp) +env > "$envs_before" + +# Read any JAXCI_ environment variables set by the user. +user_set_jaxci_envs=$(mktemp) +env | grep ^JAXCI_ > "$user_set_jaxci_envs" + +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exuo pipefail -o history -o allexport +source "$1" \ No newline at end of file