diff --git a/.github/workflows/build_jaxlib.yml b/.github/workflows/build_jaxlib.yml new file mode 100644 index 000000000000..80978bbbcfef --- /dev/null +++ b/.github/workflows/build_jaxlib.yml @@ -0,0 +1,30 @@ +name: Build Jaxlib + +# dummy; a small change so that "wait for connection" detects the github labels applied on the PR +on: + pull_request: + branches: + - main + +jobs: + build: + strategy: + matrix: + runner: ["arc-linux-x86-n2-64"] + + runs-on: ${{ matrix.runner }} + container: + image: "index.docker.io/tensorflow/build@sha256:e042f39c18a01012349ad42ae69f6c9fb6e8cc912ddf362d47f6caf1f97f54c3" + + env: + ENV_FILE: ci/envs/build_artifacts/jaxlib + JAXCI_USE_DOCKER: 0 + JAXCI_USE_RBE: 1 + + steps: + - uses: actions/checkout@v3 + # Halt for testing + - name: Wait For Connection + uses: ./actions/ci_connection/ + - name: Run build script + run: ./ci/build_artifacts.sh diff --git a/actions/ci_connection/notify_connection.py b/actions/ci_connection/notify_connection.py index 30e7f7905473..2d7fb4394056 100644 --- a/actions/ci_connection/notify_connection.py +++ b/actions/ci_connection/notify_connection.py @@ -14,6 +14,7 @@ import time import threading +import os import subprocess from multiprocessing.connection import Client @@ -42,8 +43,22 @@ def timer(conn): timer_thread.start() print("Entering interactive bash session") + + # Hard-coded for now for demo purposes. + next_command = "bash ci/build_artifacts.sh" + # Print the "next" commands to be run + # TODO: actually get this data from workflow files + print(f"The next command that would have run is:\n\n{next_command}") + + # Set the hardcoded envs for testing purposes + # TODO: sync env vars + sub_env = os.environ.copy() + sub_env["ENV_FILE"] = "ci/envs/build_artifacts/jaxlib" + sub_env["JAXCI_USE_DOCKER"] = "0" + sub_env["JAXCI_USE_RBE"] = "1" + # Enter interactive bash session - subprocess.run(["/bin/bash", "-i"]) + subprocess.run(["/bin/bash", "-i"], env=sub_env) print("Exiting interactive bash session") with lock: diff --git a/actions/ci_connection/wait_for_connection.py b/actions/ci_connection/wait_for_connection.py index e752d0afb747..24a3057d74b4 100644 --- a/actions/ci_connection/wait_for_connection.py +++ b/actions/ci_connection/wait_for_connection.py @@ -20,9 +20,9 @@ import sys last_time = time.time() -timeout = 600 # 10 minutes for initial connection +timeout = 1800 # 30 minutes for initial connection keep_alive_timeout = ( - 900 # 30 minutes for keep-alive if no closed message (allow for reconnects) + 3600 # 30 minutes for keep-alive if no closed message (allow for reconnects) ) @@ -102,7 +102,7 @@ def timer(): print("Googler connection only\nSee go/ for details") print( - f"Connection string: ml-actions-connect --runner={host} --ns={ns} --loc={location} --cluster={cluster} --halt_directory={actions_path}" + f"Connection string: ml-actions-connect --runner={host} --ns={ns} --loc={location} --cluster={cluster} --halt_directory={actions_path} --project=ml-velocity-actions-testing" ) # Thread is running as a daemon so it will quit when the diff --git a/ci/.bazelrc b/ci/.bazelrc new file mode 100644 index 000000000000..8843e7e978a3 --- /dev/null +++ b/ci/.bazelrc @@ -0,0 +1,469 @@ +# ############################################################################# +# All default build options below. These apply to all build commands. +# ############################################################################# +# Make Bazel print out all options from rc files. +build --announce_rc + +# Required by OpenXLA +# https://github.com/openxla/xla/issues/1323 +build --nocheck_visibility + +# By default, execute all actions locally. +build --spawn_strategy=local + +# Enable host OS specific configs. For instance, "build:linux" will be used +# automatically when building on Linux. +build --enable_platform_specific_config + +build --experimental_cc_shared_library + +# Disable enabled-by-default TensorFlow features that we don't care about. +build --define=no_gcp_support=true + +# Do not use C-Ares when building gRPC. +build --define=grpc_no_ares=true + +build --define=tsl_link_protobuf=true + +# Enable optimization. +build -c opt + +# Suppress all warning messages. +build --output_filter=DONT_MATCH_ANYTHING + +build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. + +build --verbose_failures=true + +# ############################################################################# +# Platform Specific configs below. These are automatically picked up by Bazel +# depending on the platform that is running the build. If you would like to +# disable this behavior, pass in `--noenable_platform_specific_config` +# ############################################################################# +build:linux --config=posix +build:linux --copt=-Wno-unknown-warning-option + +# Workaround for gcc 10+ warnings related to upb. +# See https://github.com/tensorflow/tensorflow/issues/39467 +build:linux --copt=-Wno-stringop-truncation +build:linux --copt=-Wno-array-parameter + +build:macos --config=posix +build:macos --apple_platform_type=macos + +# Windows has a relatively short command line limit, which JAX has begun to hit. +# See https://docs.bazel.build/versions/main/windows.html +build:windows --features=compiler_param_file +build:windows --features=archive_param_file + +# Tensorflow uses M_* math constants that only get defined by MSVC headers if +# _USE_MATH_DEFINES is defined. +build:windows --copt=/D_USE_MATH_DEFINES +build:windows --host_copt=/D_USE_MATH_DEFINES +# Make sure to include as little of windows.h as possible +build:windows --copt=-DWIN32_LEAN_AND_MEAN +build:windows --host_copt=-DWIN32_LEAN_AND_MEAN +build:windows --copt=-DNOGDI +build:windows --host_copt=-DNOGDI +# https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/ +# otherwise, there will be some compiling error due to preprocessing. +build:windows --copt=/Zc:preprocessor +build:windows --cxxopt=/std:c++17 +build:windows --host_cxxopt=/std:c++17 +# Generate PDB files, to generate useful PDBs, in opt compilation_mode +# --copt /Z7 is needed. +build:windows --linkopt=/DEBUG +build:windows --host_linkopt=/DEBUG +build:windows --linkopt=/OPT:REF +build:windows --host_linkopt=/OPT:REF +build:windows --linkopt=/OPT:ICF +build:windows --host_linkopt=/OPT:ICF +build:windows --incompatible_strict_action_env=true + +# ############################################################################# +# Feature-specific configurations. These are used by the Local and CI configs +# below depending on the type of build. E.g. `local_linux_x86_64` inherits the +# Linux x86 configs such as `avx_linux` and `mkl_open_source_only`, +# `local_cuda_base` inherits `cuda` and `build_cuda_with_nvcc`, etc. +# ############################################################################# +build:nonccl --define=no_nccl_support=true + +build:posix --copt=-fvisibility=hidden +build:posix --copt=-Wno-sign-compare +build:posix --cxxopt=-std=c++17 +build:posix --host_cxxopt=-std=c++17 + +build:avx_posix --copt=-mavx +build:avx_posix --host_copt=-mavx + +build:native_arch_posix --copt=-march=native +build:native_arch_posix --host_copt=-march=native + +build:avx_linux --copt=-mavx +build:avx_linux --host_copt=-mavx + +build:avx_windows --copt=/arch:AVX + +build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 + +# Disable clang extention that rejects type definitions within offsetof. +# This was added in clang-16 by https://reviews.llvm.org/D133574. +# Can be removed once upb is updated, since a type definition is used within +# offset of in the current version of ubp. +# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. +build:clang --copt=-Wno-gnu-offsetof-extensions +# Disable clang extention that rejects unknown arguments. +build:clang --copt=-Qunused-arguments + +# Configs for CUDA +build:cuda --repo_env TF_NEED_CUDA=1 +build:cuda --repo_env TF_NCCL_USE_STUB=1 +# "sm" means we emit only cubin, which is forward compatible within a GPU generation. +# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" +build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain +build:cuda --@local_config_cuda//:enable_cuda +build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true + +# Default hermetic CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + +# This flag is needed to include CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_cuda_libs=true + +# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, +# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to +# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA +# packages. +# This has pros and cons: +# * pro: we'll ignore other CUDA installations, which has frequently confused +# users in the past. By setting RPATH, we'll always use the NVIDIA pip +# packages if they are installed. +# * con: the user cannot override the CUDA installation location +# via LD_LIBRARY_PATH, if the nvidia-... pip packages are installed. This is +# acceptable, because the workaround is "remove the nvidia-..." pip packages. +# The list of CUDA pip packages that JAX depends on are present in setup.py. +build:cuda --linkopt=-Wl,--disable-new-dtags + +# Build CUDA and other C++ targets with Clang +build:build_cuda_with_clang --@local_config_cuda//:cuda_compiler=clang + +# Build CUDA with NVCC and other C++ targets with Clang +build:build_cuda_with_nvcc --action_env=TF_NVCC_CLANG="1" +build:build_cuda_with_nvcc --@local_config_cuda//:cuda_compiler=nvcc + +# Requires MSVC and LLVM to be installed +build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl +build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl +build:win_clang --compiler=clang-cl + +# Configs for building ROCM +build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain +build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true +build:rocm --repo_env TF_NEED_ROCM=1 +build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030" + +# ############################################################################# +# Cache options below. +# ############################################################################# +# Public read-only cache for macOS builds. The "oct2023" in the URL is just the +# date when the bucket was created and can be disregarded. It still contains the +# latest cache that is being used. +build:macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false +# Cache pushes are limited to Jax's CI system. +build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials + +# ############################################################################# +# Local Build config options below. Use these configs to build JAX locally. +# ############################################################################# +# Set base CUDA configs. These are inherited by the Linux x86 and Linux Aarch64 +# CUDA configs. +build:local_cuda_base --config=cuda + +# JAX uses NVCC to build CUDA targets. If you would like to build CUDA targets +# with Clang, change this to `--config=build_cuda_with_clang` +build:local_cuda_base --config=build_cuda_with_nvcc + +# Linux x86 Local configs +build:local_linux_x86_64 --config=avx_linux +build:local_linux_x86_64 --config=avx_posix +build:local_linux_x86_64 --config=mkl_open_source_only + +build:local_linux_x86_64_cuda --config=local_linux_x86_64 +build:local_linux_x86_64_cuda --config=local_cuda_base + +# Linux Aarch64 Local configs +# No custom config for Linux Aarch64. If building for CPU, run +# `bazel build|test //path/to:target`. If building for CUDA, run +# `bazel build|test --config=local_cuda_base //path/to:target`. +build:local_linux_aarch64_cuda --config=local_cuda_base + +# Mac x86 Local configs +# For Mac x86, we target compatibility with macOS 10.14. +build:local_darwin_x86_64 --macos_minimum_os=10.14 +# Read-only cache to boost build times. +build:local_darwin_x86_64 --config=macos_cache + +# Mac Arm64 CI configs +# For Mac Arm64, we target compatibility with macOS 12. +build:local_darwin_arm64 --macos_minimum_os=12.0 +# Read-only cache to boost build times. +build:local_darwin_arm64 --config=macos_cache_push + +# Windows x86 Local configs +build:local_windows_x86_64 --config=avx_windows + +# ############################################################################# +# CI Build config options below. # JAX uses these configs in CI builds for +# building artifacts and when running Bazel tests. +# +# These configs are pretty much the same as the local build configs above. The +# difference is that, in CI, we build with Clang for and pass in a custom +# non-hermetic toolchain to ensure manylinux compliance for Linux builds and +# for using RBE on Windows. Because the toolchain is non-hermetic, it requires +# specific versions of the compiler and other tools to be present on the system +# in specific locations, which is why the Linux and Windows builds are run in a +# Docker container. +# ############################################################################# + +# Linux x86 CI configs +# Inherit the local Linux x86 configs. +build:ci_linux_x86_64 --config=local_linux_x86_64 + +# CI builds use Clang as the default compiler so we inherit Clang +# specific configs +build:ci_linux_x86_64 --config=clang + +# TODO(b/356695103): We do not have a CPU only toolchain so we use the CUDA +# toolchain for both CPU and GPU builds. +build:ci_linux_x86_64 --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" +build:ci_linux_x86_64 --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" +build:ci_linux_x86_64 --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64" + +# The toolchain in `--config=cuda` needs to be read before the toolchain in +# `--config=ci_linux_x86_64`. Otherwise, we run into issues with manylinux +# compliance. +build:ci_linux_x86_64_cuda --config=local_cuda_base +build:ci_linux_x86_64_cuda --config=ci_linux_x86_64 +build:ci_linux_x86_64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" + +# Linux Aarch64 CI configs +build:ci_linux_aarch64_base --config=clang +build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10" + +build:ci_linux_aarch64 --config=ci_linux_aarch64_base +build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" +build:ci_linux_aarch64 --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" + +# CUDA configs for Linux Aarch64 do not pass in the crosstool top flag from +# above because the Aarch64 toolchain rule does not support building with NVCC. +# Instead, we use `@local_config_cuda//crosstool:toolchain` from --config=cuda +# and set `CLANG_CUDA_COMPILER_PATH` to define the toolchain so that we can +# use Clang for the C++ targets and NVCC to build CUDA targets. +build:ci_linux_aarch64_cuda --config=ci_linux_aarch64_base +build:ci_linux_aarch64_cuda --config=local_cuda_base +build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" + +# Mac x86 CI configs +build:ci_darwin_x86_64 --config=local_darwin_arm64 +# Mac CI builds read and push cache to/from GCS bucket. +build:ci_darwin_x86_64 --config=macos_cache_push + +# Mac Arm64 CI configs +build:ci_darwin_arm64 --config=local_darwin_arm64 +# CI builds read and push cache to/from GCS bucket. +build:ci_darwin_arm64 --config=macos_cache_push + +# Windows x86 CI configs +build:ci_windows_x86_64 --config=local_windows_x86_64 +build:ci_windows_x86_64 --config=clang +# Set the toolchains +build:ci_windows_x86_64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" +build:ci_windows_x86_64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" +build:ci_windows_x86_64 --compiler=clang-cl +build:ci_windows_x86_64 --linkopt=/FORCE:MULTIPLE +build:ci_windows_x86_64 --host_linkopt=/FORCE:MULTIPLE + +# ############################################################################# +# RBE config options below. These inherit the CI configs above and set the +# remote execution backend and authentication options required to run builds +# with RBE. Linux x86 and Windows builds use RBE. +# ############################################################################# +# Flag to enable remote config +common --experimental_repo_remote_exec + +# Allow creation of resultstore URLs for any bazel invocation +build:resultstore --google_default_credentials +build:resultstore --bes_backend=buildeventservice.googleapis.com +build:resultstore --bes_instance_name="tensorflow-testing" +build:resultstore --bes_results_url="https://source.cloud.google.com/results/invocations" +build:resultstore --bes_timeout=600s + +build:rbe --config=resultstore +build:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 +build:rbe --define=EXECUTOR=remote +build:rbe --flaky_test_attempts=3 +build:rbe --jobs=200 +build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:rbe --remote_timeout=3600 +build:rbe --spawn_strategy=remote,worker,standalone,local +# Attempt to minimize the amount of data transfer between bazel and the remote +# workers: +build:rbe --remote_download_toplevel +test:rbe --test_env=USER=anon + +# RBE configs for Linux x86 +# Set the remote worker pool +common:rbe_linux_x86_64_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance + +build:rbe_linux_x86_64_base --config=rbe +build:rbe_linux_x86_64_base --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +build:rbe_linux_x86_64_base --linkopt=-lrt +build:rbe_linux_x86_64_base --host_linkopt=-lrt +build:rbe_linux_x86_64_base --linkopt=-lm +build:rbe_linux_x86_64_base --host_linkopt=-lm + +# Set the host, execution, and target platform +build:rbe_linux_x86_64_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +build:rbe_linux_x86_64_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +build:rbe_linux_x86_64_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" + +# Python config is the same across all containers because the binary is the same +build:rbe_linux_x86_64_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python" + +build:rbe_linux_x86_64 --config=rbe_linux_x86_64_base +build:rbe_linux_x86_64 --config=ci_linux_x86_64 + +build:rbe_linux_x86_64_cuda --config=rbe_linux_x86_64_base +build:rbe_linux_x86_64_cuda --config=ci_linux_x86_64_cuda +build:rbe_linux_x86_64_cuda --repo_env=REMOTE_GPU_TESTING=1 + +# RBE configs for Windows +# Set the remote worker pool +common:rbe_windows_x86_64 --remote_instance_name=projects/tensorflow-testing/instances/windows + +build:rbe_windows_x86_64 --config=rbe + +# Set the host, execution, and target platform +build:rbe_windows_x86_64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_windows_x86_64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_windows_x86_64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" + +build:rbe_windows_x86_64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe +build:rbe_windows_x86_64 --enable_runfiles +build:rbe_windows_x86_64 --define=override_eigen_strong_inline=true + +# Don't build the python zip archive in the RBE build. +build:rbe_windows_x86_64 --nobuild_python_zip + +build:rbe_windows_x86_64 --config=ci_windows_x86_64 + +# ############################################################################# +# Cross-compile config options below. Native RBE support does not exist for +# Linux Aarch64 and Mac x86. So, we use the cross-compile toolchain to build +# targets for Linux Aarch64 and Mac x86 on the Linux x86 RBE pool. +# ############################################################################# +# Set execution platform to Linux x86 +# Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" +# flags seem to be actually used to specify the execution platform details. It +# seems it is this way because these flags are old and predate the distinction +# between host and execution platform. +build:cross_compile_base --host_cpu=k8 +build:cross_compile_base --host_crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_base --extra_execution_platforms=@xla//tools/toolchains/cross_compile/config:linux_x86_64 + +# Linux Aarch64 +build:cross_compile_linux_arm64 --config=cross_compile_base + +# Set the target CPU to Aarch64 +build:cross_compile_linux_arm64 --platforms=@xla//tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_arm64 --cpu=aarch64 +build:cross_compile_linux_arm64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite + +build:rbe_cross_compile_base --config=rbe + +# RBE cross-compile configs for Linux Aarch64 +build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 +build:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base + +# Mac x86 +build:cross_compile_macos_x86 --config=cross_compile_base +build:cross_compile_macos_x86 --config=nonccl +# Target Catalina (10.15) as the minimum supported OS +build:cross_compile_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + +# Set the target CPU to Darwin x86 +build:cross_compile_macos_x86 --platforms=@xla//tools/toolchains/cross_compile/config:darwin_x86_64 +build:cross_compile_macos_x86 --cpu=darwin +build:cross_compile_macos_x86 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +# When RBE cross-compiling for macOS, we need to explicitly register the +# toolchain. Otherwise, oddly, RBE complains that a "docker container must be +# specified". +build:cross_compile_macos_x86 --extra_toolchains=@xla//tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain +# Map --platforms=darwin_x86_64 to --cpu=darwin and vice-versa to make selects() +# and transistions that use these flags work. The flag --platform_mappings needs +# to be set to a file that exists relative to the package path roots. +build:cross_compile_macos_x86 --platform_mappings=platform_mappings + +# RBE cross-compile configs for Darwin x86 +build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 +build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base + +# ############################################################################# +# Test specific config options below. These are used when `bazel test` is run. +# ############################################################################# +test --test_output=errors + +# Configs for for running GPU tests. +test:gpu --test_env=TF_CPP_MIN_LOG_LEVEL=0 +test:gpu --test_env=JAX_SKIP_SLOW_TESTS=1 --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform + +# Configs for local_gpu +test:local_gpu --config=gpu +# Disable building jaxlib. Instead we depend on the local wheel. +test:local_gpu --//jax:build_jaxlib=false + +# Non-multiaccelerator tests with one GPU apiece. Non-multiaccelerator tests +# are run on RBE and locally. +test:non_multiaccelerator --config=gpu +test:non_multiaccelerator --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow +test:non_multiaccelerator --test_tag_filters=-multiaccelerator + +test:non_multiaccelerator_local --config=non_multiaccelerator +test:non_multiaccelerator_local --config=local_gpu +# `JAX_ACCELERATOR_COUNT` needs to match the number of GPUs in the VM. +test:non_multiaccelerator_local --test_env=JAX_TESTS_PER_ACCELERATOR=12 --test_env=JAX_ACCELERATOR_COUNT=4 + +# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR` +# should match the VM's CPU core count (set in `--local_test_jobs`). +test:non_multiaccelerator_local --local_test_jobs=48 + +# Runs multiaccelerator tests with all GPUs. Multiaccelerator GPU tests are +# only run locally +test:multiaccelerator_local --config=local_gpu +test:multiaccelerator_local --jobs=8 --test_tag_filters=multiaccelerator + +############################################################################# +# Some configs to make getting some forms of debug builds. In general, the +# codebase is only regularly built with optimizations. Use 'debug_symbols' to +# just get symbols for the parts of XLA/PJRT that jaxlib uses. +# Or try 'debug' to get a build with assertions enabled and minimal +# optimizations. +# Include these in a local .bazelrc.user file as: +# build --config=debug_symbols +# Or: +# build --config=debug +# +# Additional files can be opted in for debug symbols by adding patterns +# to a per_file_copt similar to below. +############################################################################# + +build:debug_symbols --strip=never --per_file_copt="xla/pjrt|xla/python@-g3" +build:debug --config debug_symbols -c fastbuild + +# Load `.jax_configure.bazelrc` file written by build.py +try-import %workspace%/.jax_configure.bazelrc + +# Load rc file with user-specific options. +try-import %workspace%/.bazelrc.user \ No newline at end of file diff --git a/ci/README.md b/ci/README.md new file mode 100644 index 000000000000..c44e677b898e --- /dev/null +++ b/ci/README.md @@ -0,0 +1,134 @@ +# JAX continuous integration + +> **Warning** This folder is still under construction. It is part of an ongoing +> effort to improve the structure of CI and build related files within the +> JAX repo. This warning will be removed when the contents of this +> directory are stable and appropriate documentation around its usage is in +> place. + +Maintainer: ML Velocity team @ Google + +******************************************************************************** + +The CI folder contains the configuration files and scripts used to build, test, +and deploy JAX. This folder is typically used by continuous integration +(CI) tools to build and test JAX whenever there is a change to the +code. + +## JAX's Official CI and Build/Test Scripts + +JAX's official CI jobs run the scripts in this folder. The CI scripts require +an env file to be set in `ENV_FILE` that sets various configuration settings. +These "env" files are structured by their build type. For e.g., +`ci/envs/build_artifacts/jaxlib` contains the configs for building the `jaxlib` +package. The scripts are intended to be used across different platforms and +architectures and currently supports the following systems: Linux x86, +Linux Arm64, Mac x86, Mac Arm64, Windows x86. + + +If you would like to test these scripts, follow the instructions below. + +### Choose how you would like to build: +
+ Shell Script + +The artifact building script (`ci/build_artifacts.sh`) invokes the build CLI, +`ci/cli/build.py` which in turn invokes the bazel command that builds the +requested JAX artifact. Follow the instructions below to invoke the CI script +to build a JAX artifact of your choice. These scripts can build the `jax`, +`jaxlib`, `jax-cuda-plugin`, and the `jax-cuda-pjrt` artifacts. Note that all +commands are meant to be run from the root of this repository. + +**Docker (soft prerequisite)** + +The CI scripts are recommended to be run in Docker where possible. This ensures +the right build environment is set up before we can build the artifact. If you +would like to disable Docker, run: + +``` +export JAXCI_USE_DOCKER=0 +export JAXCI_CLI_BUILD_MODE=local +``` + +**Changing Python version** + +By default, the build will use Python 3.12. If you would like to change this, +set `JAXCI_HERMETIC_PYTHON_VERSION`. E.g.`export JAXCI_HERMETIC_PYTHON_VERSION=3.11` + +**RBE support** + +If you are running this on a Linux x86 or a Windows machine, you have the option +to use RBE to speed up the build. Please note this requires permissions to JAX's +remote worker pool and RBE configs. To enable RBE, run `export JAXCI_USE_RBE=1`. + +**How to run the script** + +``` +1. Set ENV_FILE to one of the envs inside ci/build_artifacts based the artifact +you want to build and your sytem. +E.g. export ENV_FILE=ci/envs/build_artifacts/jaxlib +2. Run: bash ci/build_artifacts.sh +``` + +**Known Bugs** + +1. Building `jax` fails due to Python missing the `build` dependency. +2. Auditwheel script fails on Linux Arm64's Docker image due to Python missing +the `auditwheel` dependency +3. If RBE is used to build the target for Windows, building the wheel fails +due to a permission denied error. + +
+ +
+ Build CLI + +Follow the instructions below to invoke the build CLI to build a JAX artifact +of your choice. The CLI can build the `jaxlib`, `jax-cuda-plugin`, and the +`jax-cuda-pjrt` artifacts. Note that all commands are meant to be run from the +root of this repository. + +By default, the CLI runs in local mode and will pick the "local_" configs in +the `ci/.bazelrc` file. On Linux systems, Bazel defaults to using GCC +as the default compiler. To change this, add `--use_clang` to your command. This +requires Clang to be present on the system and in the path. If your Clang binary +is not on the path, set its path using `--clang_path`. + +**Build Modes** + +If you want to run with the configs that the CI builds use, switch the mode by +setting `--mode=ci`. Please note CI mode has a dependency on a custom toolchain +that JAX uses. The build expects this toolchain to be present on the system. As +such, CI mode is usually run from within a Docker container. See `JAXCI_DOCKER_IMAGE` +inside `ci/build_artfacts` to know which image we use for each platform. + +**RBE support** + +If you are running this on a Linux x86 or a Windows machine, you have the option +to use RBE to speed up the build. Please note this requires permissions to JAX's +remote worker pool and RBE configs. To enable RBE, set `--use_rbe` to you command. + +**Changing Python version** + +If you would like to change the Python version of the artifact, add +`--python_version=` to your command. E.g. `--python_version=3.11`. +By default, the CLI uses Python 3.12. + +**Local XLA dependency** + +JAX artifacts built by the CLI depend on XLA version pinned in JAX's +`workspace.bzl`. If would like to depend on the XLA from your local system, +set `--local_xla_path` to its path. + +**Dry Run** + +If you would like to just invoke a dry run, add `--dry_run` to your command. +This will print the `bazel` command that the CLI would have ended up invoking. + +**Some example invocations** + +1. For building `jaxlib`, run `python ci/cli/build.py jaxlib` +2. For building `jax-cuda-plugin` for Python 3.11, run `python ci/cli/build.py jax-cuda-pjrt --python_version=3.11` +3. For building `jax-cuda-pjrt` for Python 3.10 with RBE, run `python ci/cli/build.py jax-cuda-pjrt --use_rbe --python_version=3.10` + +
\ No newline at end of file diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh new file mode 100755 index 000000000000..491fe9e19306 --- /dev/null +++ b/ci/build_artifacts.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright 2024 JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +source "ci/utilities/setup.sh" + +# Build the jax artifact +if [[ "$JAXCI_BUILD_JAX_ENABLE" == 1 ]]; then + jaxrun python3 -m build --outdir $JAXCI_OUTPUT_DIR +fi + +# Build the jaxlib CPU artifact +if [[ "$JAXCI_BUILD_JAXLIB_ENABLE" == 1 ]]; then + jaxrun python3 ci/cli/build.py jaxlib --mode=$JAXCI_CLI_BUILD_MODE --python_version=$JAXCI_HERMETIC_PYTHON_VERSION +fi + +# Build the jax-cuda-plugin artifact +if [[ "$JAXCI_BUILD_PLUGIN_ENABLE" == 1 ]]; then + jaxrun python3 ci/cli/build.py jax-cuda-plugin --mode=$JAXCI_CLI_BUILD_MODE --python_version=$JAXCI_HERMETIC_PYTHON_VERSION +fi + +# Build the jax-cuda-pjrt artifact +if [[ "$JAXCI_BUILD_PJRT_ENABLE" == 1 ]]; then + jaxrun python3 ci/cli/build.py jax-cuda-pjrt --mode=$JAXCI_CLI_BUILD_MODE +fi + +# After building `jaxlib`, `jaxcuda-plugin`, and `jax-cuda-pjrt`, we run +# `auditwheel show` to ensure manylinux compliance. +if [[ "$JAXCI_WHEEL_AUDIT_ENABLE" == 1 ]]; then + jaxrun ./ci/utilities/run_auditwheel.sh +fi diff --git a/ci/cli/build.py b/ci/cli/build.py new file mode 100644 index 000000000000..7bf2df07a4ac --- /dev/null +++ b/ci/cli/build.py @@ -0,0 +1,352 @@ +#!/usr/bin/python +import argparse +import asyncio +import logging +import os +import platform +import collections +import sys +import subprocess +from helpers import command, tools + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +ArtifactBuildSpec = collections.namedtuple( + "ArtifactBuildSpec", + ["bazel_build_target", "wheel_binary"], +) + +ARTIFACT_BUILD_TARGET_DICT = { + "jaxlib": ArtifactBuildSpec("//jaxlib/tools:build_wheel", "bazel-bin/jaxlib/tools/build_wheel"), + "jax-cuda-plugin": ArtifactBuildSpec("//jaxlib/tools:build_gpu_kernels_wheel", "bazel-bin/jaxlib/tools/build_gpu_kernels_wheel"), + "jax-cuda-pjrt": ArtifactBuildSpec("//jaxlib/tools:build_gpu_plugin_wheel", "bazel-bin/jaxlib/tools/build_gpu_plugin_wheel"), +} + +def add_python_argument(parser: argparse.ArgumentParser): + """Add Python version argument to the parser.""" + parser.add_argument( + "--python_version", + type=str, + choices=["3.10", "3.11", "3.12"], + default="3.12", + help="Python version to use", + ) + +# Target system is assumed to be the host sytem (auto-detected) unless +# specified otherwise, e.g. for cross-compile builds +# allow override to pass in custom flags for certain builds like the RBE +# jobs +def add_system_argument(parser: argparse.ArgumentParser): + """Add Target System argument to the parser.""" + parser.add_argument( + "--target_system", + type=str, + default="", + choices=["linux_x86_64", "linux_aarch64", "darwin_x86_64", "darwin_arm64", "windows_x86_64"], + help="Target system to build for", + ) + +def add_cuda_argument(parser: argparse.ArgumentParser): + """Add CUDA version argument to the parser.""" + parser.add_argument( + "--cuda_version", + type=str, + default="12.3.2", + help="CUDA version to use", + ) + +def add_cudnn_argument(parser: argparse.ArgumentParser): + """Add cuDNN version argument to the parser.""" + parser.add_argument( + "--cudnn_version", + type=str, + default="9.1.1", + help="cuDNN version to use", + ) + +def get_bazelrc_config(os_name: str, arch: str, artifact: str, mode:str, use_rbe: bool): + """Returns the bazelrc config for the given architecture, OS, and build type.""" + bazelrc_config = f"{os_name}_{arch}" + + # When the CLI is run by invoking ci/build_artifacts.sh, the CLI runs in CI + # mode by default and will use one of the "ci_" configs in the .bazelrc. We + # want to run certain CI builds with RBE and we also want to allow users the + # flexibility to build JAX artifacts either by running the CLI or by running + # ci/build_artifacts.sh. Because RBE requires permissions, we cannot enable it + # by default in ci/build_artifacts.sh. Instead, we do not set `--use_rbe` in + # build_artifacts.sh and have the CI builds set JAXCI_USE_RBE to 1 to enable + # RBE. + if os.environ.get("JAXCI_USE_RBE", "0") == "1": + use_rbe = True + + # In CI, we want to use RBE where possible. At the moment, RBE is only + # supported on Linux x86 and Windows. If an user is requesting RBE, the CLI + # will use RBE if the host system supports it, otherwise it will use the + # local config. + if use_rbe and (os_name == "linux" or os_name == "windows") and arch == "x86_64": + bazelrc_config = "rbe_" + bazelrc_config + elif mode == "local": + if use_rbe: + logger.warning("RBE is not supported on %s_%s. Using Local config instead.", os_name, arch) + if os_name == "linux" and arch == "aarch64" and artifact == "jaxlib": + logger.info("Linux Aarch64 CPU builds do not have custom local config in JAX's root .bazelrc. Running with default configs.") + bazelrc_config = "" + return bazelrc_config + bazelrc_config = "local_" + bazelrc_config + else: + if use_rbe: + logger.warning("RBE is not supported on %s_%s. Using CI config instead.", os_name, arch) + elif (os_name == "linux" or os_name == "windows") and arch == "x86_64": + logger.info("RBE support is available for this platform. If you want to use RBE and have the required permissions, run the CLI with `--use_rbe` or set `JAXCI_USE_RBE=1`") + bazelrc_config = "ci_" + bazelrc_config + + if artifact == "jax-cuda-plugin" or artifact == "jax-cuda-pjrt": + bazelrc_config = bazelrc_config + "_cuda" + + return bazelrc_config + +def get_jaxlib_git_hash(): + """Returns the git hash of the current repository.""" + res = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True) + return res.stdout + +def check_whether_running_tests(): + """ + Returns True if running tests, False otherwise. When running tests, JAX + artifacts are built with `JAX_ENABLE_X64=0` and the XLA repository is checked + out at HEAD instead of the pinned version. + """ + return os.environ.get("JAXCI_RUN_TESTS", "0") == "1" + +async def main(): + parser = argparse.ArgumentParser( + description=( + "JAX CLI for building/testing jaxlib, jaxl-cuda-plugin, and jax-cuda-pjrt." + ), + ) + + parser.add_argument( + "--mode", + type=str, + choices=["ci", "local"], + default="local", + help= + """ + Flags as requesting a CI or CI like build. Setting this flag to CI + will assume multiple settings expected in CI builds. These are set by + the CI options in .bazelrc. To see best how this flag resolves you can + run the artifact of choice with "--mode=[ci|local] --dry-run" to get the + commands issued to Bazel for that artifact. + """, + ) + + parser.add_argument( + "--build_target_only", + action="store_true", + help="If set, the tool will only build the target and not the wheel.", + ) + + parser.add_argument( + "--bazel_path", + type=str, + default="", + help= + """ + Path to the Bazel binary to use. The default is to find bazel via the + PATH; if none is found, downloads a fresh copy of Bazelisk from + GitHub. + """, + ) + + parser.add_argument( + "--use_rbe", + action="store_true", + help= + """ + If set, the build will use RBE where possible. Currently, only Linux x86 + and Windows builds can use RBE. On other platforms, setting this flag will + be a no-op. RBE requires permissions to JAX's remote worker pool. Only + Googlers and CI builds can use RBE. + """, + ) + + parser.add_argument( + "--use_clang", + action="store_true", + help= + """ + If set, the build will use Clang as the C++ compiler. Requires Clang to + be present on the PATH or a path is given with --clang_path. CI builds use + Clang by default. + """, + ) + + parser.add_argument( + "--clang_path", + type=str, + default="", + help= + """ + Path to the Clang binary to use. If not set and --use_clang is set, the + build will attempt to find Clang on the PATH. + """, + ) + + parser.add_argument( + "--local_xla_path", + type=str, + default=os.environ.get("JAXCI_XLA_GIT_DIR", ""), + help= + """ + Path to local XLA repository to use. If not set, Bazel uses the XLA + at the pinned version in workspace.bzl. + """, + ) + + parser.add_argument( + "--dry_run", + action="store_true", + help="Prints the Bazel command that is going will be invoked.", + ) + parser.add_argument("--verbose", action="store_true", help="Verbose output") + + global_args, remaining_args = parser.parse_known_args() + + # Create subparsers for jax, jaxlib, plugin, pjrt + subparsers = parser.add_subparsers( + dest="command", required=True, help="Artifact to build" + ) + + # Jaxlib subcommand + jaxlib_parser = subparsers.add_parser("jaxlib", help="Builds the jaxlib package.") + add_python_argument(jaxlib_parser) + add_system_argument(jaxlib_parser) + + # jax-cuda-plugin subcommand + plugin_parser = subparsers.add_parser("jax-cuda-plugin", help="Builds the jax-cuda-plugin package.") + add_python_argument(plugin_parser) + add_cuda_argument(plugin_parser) + add_cudnn_argument(plugin_parser) + add_system_argument(plugin_parser) + + # jax-cuda-pjrt subcommand + pjrt_parser = subparsers.add_parser("jax-cuda-pjrt", help="Builds the jax-cuda-pjrt package.") + add_cuda_argument(pjrt_parser) + add_cudnn_argument(pjrt_parser) + add_system_argument(pjrt_parser) + + # Get the host systems architecture + arch = platform.machine() + # On Windows, this returns "amd64" instead of "x86_64. However, they both + # are essentially the same. + if arch.lower() == "amd64": + arch = "x86_64" + + # Get the host system OS + os_name = platform.system().lower() + + args = parser.parse_args(remaining_args) + + for key, value in vars(global_args).items(): + setattr(args, key, value) + + logger.info( + "Building %s for %s %s...", + args.command, + os_name, + arch, + ) + + # Only jaxlib and jax-cuda-plugin are built for a specific python version + if args.command == "jaxlib" or args.command == "jax-cuda-plugin": + logger.info("Using Python version %s", args.python_version) + + if args.command == "jax-cuda-plugin" or args.command == "jax-cuda-pjrt": + logger.info("Using CUDA version %s", args.cuda_version) + logger.info("Using cuDNN version %s", args.cudnn_version) + + # Find the path to Bazel + bazel_path = tools.get_bazel_path(args.bazel_path) + + executor = command.SubprocessExecutor() + + bazel_command = command.CommandBuilder(bazel_path) + # Temporary; when we make the new scripts as the default we can remove this. + bazel_command.append("--bazelrc=ci/.bazelrc") + + bazel_command.append("build") + + if args.use_clang or args.clang_path: + # Find the path to Clang + clang_path = tools.get_clang_path(args.clang_path) + if clang_path: + bazel_command.append(f"--action_env CLANG_COMPILER_PATH='{clang_path}'") + bazel_command.append(f"--repo_env CC='{clang_path}'") + bazel_command.append(f"--repo_env BAZEL_COMPILER='{clang_path}'") + bazel_command.append("--config=clang") + + if args.mode == "ci": + logging.info("Running in CI mode. Run the CLI with --help for more details on what this means.") + + # JAX's .bazelrc has custom configs for each build type, architecture, and + # OS. Fetch the appropriate config and pass it to Bazel. A special case is + # when building for Linux Aarch64, which does not have a custom local config + # in JAX's .bazelrc. In this case, we build with the default configs. + bazelrc_config = get_bazelrc_config(os_name, arch, args.command, args.mode, args.use_rbe) + if bazelrc_config: + bazel_command.append(f"--config={bazelrc_config}") + + # Check if we are running tests or if a local XLA path is set. + # When running tests, JAX arifacts and tests are run with XLA at head. + if check_whether_running_tests() or args.local_xla_path: + bazel_command.append(f"--override_repository=xla='{args.local_xla_path}'") + + if hasattr(args, "python_version"): + bazel_command.append(f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}") + + # Set the CUDA and cuDNN versions if they are not the default. + if hasattr(args, "cuda_version") and args.cuda_version != "12.3.2": + bazel_command.append(f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}") + + if hasattr(args, "cudnn_version") and args.cudnn_version != "9.1.1": + bazel_command.append(f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}") + + build_target, wheel_binary = ARTIFACT_BUILD_TARGET_DICT[args.command] + bazel_command.append(build_target) + + logger.info("Bazel build command:\n%s\n", bazel_command.command) + + if args.dry_run: + logger.info("CLI is in dry run mode. Exiting without invoking Bazel.") + sys.exit(0) + + await executor.run(bazel_command.command) + + if not args.build_target_only: + logger.info("Building wheel...") + run_wheel_binary = command.CommandBuilder(wheel_binary) + + # Read output directory from environment variable. If not set, set it to + # dist/ in the current working directory. + output_dir = os.getenv("JAXCI_OUTPUT_DIR", os.path.join(os.getcwd(), "dist")) + run_wheel_binary.append(f"--output_path={output_dir}") + + run_wheel_binary.append(f"--cpu={arch}") + + if args.command == "jax-cuda-plugin" or args.command == "jax-cuda-pjrt": + run_wheel_binary.append("--enable-cuda=True") + major_cuda_version = args.cuda_version.split(".")[0] + run_wheel_binary.append(f"--platform_version={major_cuda_version}") + + jaxlib_git_hash = get_jaxlib_git_hash() + run_wheel_binary.append(f"--jaxlib_git_hash={jaxlib_git_hash}") + + logger.info("Wheel build command:\n%s\n", run_wheel_binary.command) + await executor.run(run_wheel_binary.command) + +if __name__ == "__main__": + asyncio.run(main()) + \ No newline at end of file diff --git a/ci/cli/helpers/command.py b/ci/cli/helpers/command.py new file mode 100644 index 000000000000..695822f4497c --- /dev/null +++ b/ci/cli/helpers/command.py @@ -0,0 +1,106 @@ +import asyncio +import dataclasses +import datetime +import os +import logging +from typing import Dict, Optional + +logger = logging.getLogger() + +class CommandBuilder: + def __init__(self, base_command: str): + self.command = base_command + + def append(self, parameter: str): + self.command += " {}".format(parameter) + return self + + +@dataclasses.dataclass +class CommandResult: + """ + Represents the result of executing a subprocess command. + """ + + command: str + return_code: int = 2 # Defaults to not successful + logs: str = "" + start_time: datetime.datetime = dataclasses.field( + default_factory=datetime.datetime.now + ) + end_time: Optional[datetime.datetime] = None + + # def logger.info(self): + # """ + # Prints a summary of the command execution. + # """ + # duration = ( + # (self.end_time - self.start_time).total_seconds() if self.end_time else None + # ) + # logger.info(f"Command: {self.get_command()}") + # logger.info(f"Return code: {self.return_code}") + # logger.info(f"Duration: {duration:.3f} seconds" if duration else "Command still running") + # if self.logs: + # logger.info("Logs:") + # logger.info(self.logs) + + +class SubprocessExecutor: + """ + Manages execution of subprocess commands with reusable environment and logging. + """ + + def __init__(self, environment: Dict[str, str] = dict(os.environ)): + self.environment = environment + + def set_verbose(self, verbose: bool): + """Enables or disables verbose logging.""" + self._verbose = verbose + + def update_environment(self, new_env: Dict[str, str]): + """Updates the environment with new key-value pairs.""" + self.environment.update(new_env) + + async def run(self, cmd: str, dry_run: bool = False) -> CommandResult: + """ + Executes a subprocess command. + + Args: + cmd: The command to execute. + dry_run: If True, prints the command instead of executing it. + + Returns: + A CommandResult instance. + """ + result = CommandResult(command=cmd) + if dry_run: + logger.info("[DRY RUN] %s", cmd) + result.return_code = 0 # Dry run is a success + return result + + logger.debug("Executing: %s", cmd) + + process = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=self.environment, + ) + + async def log_stream(stream, result: CommandResult): + while True: + line_bytes = await stream.readline() + if not line_bytes: + break + line = line_bytes.decode().rstrip() + result.logs += line + logger.info("%s", line) + + await asyncio.gather( + log_stream(process.stdout, result), log_stream(process.stderr, result) + ) + + result.return_code = await process.wait() + result.end_time = datetime.datetime.now() + logger.debug("Command finished with return code %s", result.return_code) + return result diff --git a/ci/cli/helpers/tools.py b/ci/cli/helpers/tools.py new file mode 100644 index 000000000000..85567ac6bb1b --- /dev/null +++ b/ci/cli/helpers/tools.py @@ -0,0 +1,146 @@ +import collections +import hashlib +import logging +import os +import platform +import shutil +import subprocess +import urllib.request + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +BAZELISK_BASE_URI = ( + "https://github.com/bazelbuild/bazelisk/releases/download/v1.20.0/" +) +BazeliskPackage = collections.namedtuple("BazeliskPackage", ["file", "sha256"]) +BAZELISK_PACKAGES = { + ("Linux", "x86_64"): BazeliskPackage( + file="bazelisk-linux-amd64", + sha256=( + "d9af1fa808c0529753c3befda75123236a711d971d3485a390507122148773a3" + ), + ), + ("Linux", "aarch64"): BazeliskPackage( + file="bazelisk-linux-arm64", + sha256=( + "467ec3821aca5e278c8570b7c25e0dfc1a061d2873be89e4a266aaf488148426" + ), + ), + ("Darwin", "x86_64"): BazeliskPackage( + file="bazelisk-darwin", + sha256=( + "9a4b169038a63ebf60a9b4f367b449ab9b484c4ec7d1ef9f6b7a4196dfd50f33" + ), + ), + ("Darwin", "arm64"): BazeliskPackage( + file="bazelisk-darwin-arm64", + sha256=( + "29753341c0ddc35931fb240e247fbba0b83ef81bccc2433dd075363ec02a67a6" + ), + ), + ("Windows", "AMD64"): BazeliskPackage( + file="bazelisk-windows-amd64.exe", + sha256=( + "4175ce7ef4b552fb17e93ce49a245679dc26a35cf2fbc7c3146daca6ffc7a81e" + ), + ), +} + +def guess_clang_paths(clang_path_flag): + """ + Yields a sequence of guesses about Clang path. Some of sequence elements + can be None. The resulting iterator is lazy and potentially has a side + effects. + """ + + yield clang_path_flag + yield shutil.which("clang") + +def get_clang_path(clang_path_flag): + for clang_path in guess_clang_paths(clang_path_flag): + if clang_path: + absolute_clang_path = os.path.realpath(clang_path) + logger.info("Found path to Clang: %s.", absolute_clang_path) + return absolute_clang_path + logger.warning("Could not find path to Clang. Continuing without Clang.") + +def get_jax_supported_bazel_version(filename: str = ".bazelversion"): + """Reads the contents of .bazelversion into a string. + + Args: + filename: The path to ".bazelversion". + + Returns: + The Bazel version as a string, or None if the file doesn't exist. + """ + try: + with open(filename, 'r') as file: + content = file.read() + return content.strip() + except FileNotFoundError: + print(f"Error: File '{filename}' not found.") + return None + +def get_bazel_path(bazel_path_flag): + for bazel_path in guess_bazel_paths(bazel_path_flag): + if bazel_path and verify_bazel_version(bazel_path): + logger.info("Found a compatible Bazel installation.") + return bazel_path + logger.info("Unable not find a compatible Bazel installation.") + return download_and_verify_bazelisk() + +def verify_bazel_version(bazel_path): + """ Verifies if the version of Bazel is compatible with JAX's required + Bazel version. + """ + system_bazel_version = subprocess.check_output([bazel_path, "--version"]).strip().decode('UTF-8') + # `bazel --version` returns the version as "bazel a.b.c" so we split the + # result to get only the version numbers. + system_bazel_version = system_bazel_version.split(" ")[1] + expected_bazel_version = get_jax_supported_bazel_version() + if expected_bazel_version != system_bazel_version: + logger.info("Bazel version mismatch. JAX requires %s but got %s when `%s --version` was run", expected_bazel_version, system_bazel_version, bazel_path) + return False + return True + +def guess_bazel_paths(bazel_path_flag): + """Yields a sequence of guesses about bazel path. Some of sequence elements + can be None. The resulting iterator is lazy and potentially has a side + effects. + """ + + yield bazel_path_flag + # For when Bazelisk was downloaded and is present on the root JAX directory + yield shutil.which("./bazel") + yield shutil.which("bazel") + +def download_and_verify_bazelisk(): + """Downloads and verifies Bazelisk.""" + system = platform.system() + machine = platform.machine() + downloaded_filename = "bazel" + expected_sha256 = BAZELISK_PACKAGES[system, machine].sha256 + + # Download Bazelisk and store it as "bazel". + logger.info("Downloading Bazelisk...") + _, _ = urllib.request.urlretrieve(BAZELISK_BASE_URI + BAZELISK_PACKAGES[system, machine].file, downloaded_filename) + + with open(downloaded_filename, "rb") as downloaded_file: + contents = downloaded_file.read() + + calculated_sha256 = hashlib.sha256(contents).hexdigest() + + # Verify checksum + logger.info("Verifying the checksum...") + if calculated_sha256 != expected_sha256: + raise ValueError("SHA256 checksum mismatch. Download may be corrupted.") + logger.info("Checksum verified!") + + logger.info("Setting the Bazelisk binary to executable mode...") + subprocess.run(["chmod", "+x", downloaded_filename], check=True) + + return os.path.realpath(downloaded_filename) + diff --git a/ci/envs/build_artifacts/jax b/ci/envs/build_artifacts/jax new file mode 100644 index 000000000000..41b4be2c4555 --- /dev/null +++ b/ci/envs/build_artifacts/jax @@ -0,0 +1,3 @@ +# Inherit default environment variables. +source ci/envs/default +export JAXCI_BUILD_JAX_ENABLE="1" \ No newline at end of file diff --git a/ci/envs/build_artifacts/jax-cuda-pjrt b/ci/envs/build_artifacts/jax-cuda-pjrt new file mode 100644 index 000000000000..b5b3dc0eb05f --- /dev/null +++ b/ci/envs/build_artifacts/jax-cuda-pjrt @@ -0,0 +1,23 @@ +# Inherit default environment variables. +source ci/envs/default + +# Enable jax-cuda-pjrt build. +export JAXCI_BUILD_PJRT_ENABLE="1" +# Enable wheel audit to check for manylinux compliance. +export JAXCI_WHEEL_AUDIT_ENABLE="1" + +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# Linux x86 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then + # Note Python version of the container does not matter. JAX supports hermetic + # Python and thus the actual Python version of the artifact is controlled by + # the value set in `HERMETIC_PYTHON_VERSION` + export JAXCI_DOCKER_IMAGE="tensorflow/build:2.18-python3.10" +fi + +# Linux Aarch64 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:tf-2-18-multi-python" +fi diff --git a/ci/envs/build_artifacts/jax-cuda-plugin b/ci/envs/build_artifacts/jax-cuda-plugin new file mode 100644 index 000000000000..1d6681b358ce --- /dev/null +++ b/ci/envs/build_artifacts/jax-cuda-plugin @@ -0,0 +1,23 @@ +# Inherit default environment variables. +source ci/envs/default + +# Enable jax-cuda-plugin build +export JAXCI_BUILD_PLUGIN_ENABLE="1" +# Enable wheel audit to check for manylinux compliance. +export JAXCI_WHEEL_AUDIT_ENABLE="1" + +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# Linux x86 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then + # Note Python version of the container does not matter. JAX supports hermetic + # Python and thus the actual Python version of the artifact is controlled by + # the value set in `HERMETIC_PYTHON_VERSION` + export JAXCI_DOCKER_IMAGE="tensorflow/build:2.18-python3.10" +fi + +# Linux Aarch64 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:tf-2-18-multi-python" +fi diff --git a/ci/envs/build_artifacts/jaxlib b/ci/envs/build_artifacts/jaxlib new file mode 100644 index 000000000000..875e42ed435c --- /dev/null +++ b/ci/envs/build_artifacts/jaxlib @@ -0,0 +1,47 @@ +# Inherit default environment variables. +source ci/envs/default + +# Enable jaxlib build. +export JAXCI_BUILD_JAXLIB_ENABLE="1" + +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# Linux x86 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then + # Enable wheel audit to check for manylinux compliance. + export JAXCI_WHEEL_AUDIT_ENABLE=1 + + # Note Python version of the container does not matter. JAX supports hermetic + # Python and thus the actual Python version of the artifact is controlled by + # the value set in `HERMETIC_PYTHON_VERSION` + export JAXCI_DOCKER_IMAGE="tensorflow/build:2.18-python3.10" +fi + +# Linux Aarch64 specifc settings +if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then + # Enable wheel audit to check for manylinux compliance. + export JAXCI_WHEEL_AUDIT_ENABLE=1 + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:tf-2-18-multi-python" +fi + +# Windows specific settings +if [[ $os =~ "msys_nt" ]]; then + export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc" + + # While we run these scripts in Cygwin on Windows, + # the path needs to be in Windows format. + export JAXCI_GIT_DIR=$(cygpath -w "/tmpfs/src/github/jax") + export JAXCI_OUTPUT_DIR=$(cygpath -w /c/jax/dist) + export JAXCI_CONTAINER_WORK_DIR=$(cygpath -w /c/jax/) + + if [[ -n $JAXCI_XLA_GIT_DIR ]]; then + export JAXCI_XLA_GIT_DIR=$(cygpath -w "$JAXCI_XLA_GIT_DIR") + fi +fi + +# Mac specific settings +if [[ $os == "macos" ]]; then + # Mac builds do not run in Docker. + export JAXCI_USE_DOCKER=0 +fi \ No newline at end of file diff --git a/ci/envs/default b/ci/envs/default new file mode 100644 index 000000000000..39d781be335d --- /dev/null +++ b/ci/envs/default @@ -0,0 +1,58 @@ +# This file contains all the default values for the environment variables +# used in the JAX CI scripts. +# +# The default values are set here. Other build specifc envs such as those +# in the "build_artifacts" and "run_tests" directory source this file and +# override the default values depening on the build type. + +# The build CLI can be run in either "ci" or "local" mode. This is used to +# determine which .bazelrc configs to pass to the CLI. If the variable is not +#set, we will default to CI mode. +export JAXCI_CLI_BUILD_MODE=${JAXCI_CLI_BUILD_MODE:-ci} + +# Environment variables that control which artifact to build. Used by +# `build_artifacts.sh` +export JAXCI_BUILD_JAX_ENABLE="" +export JAXCI_BUILD_JAXLIB_ENABLE="" +export JAXCI_BUILD_PLUGIN_ENABLE="" +export JAXCI_BUILD_PJRT_ENABLE="" +export JAXCI_WHEEL_AUDIT_ENABLE="" + +# Docker specifc environment variables. Used by `setup_docker.sh` +export JAXCI_USE_DOCKER=${JAXCI_USE_DOCKER:-1} +export JAXCI_DOCKER_IMAGE="" +export JAXCI_CONTAINER_WORK_DIR="/jax" +export JAXCI_DOCKER_ARGS="" + +# Controls the version of Hermetic Python to use. Use 3.12 as default if not +# set. +export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-3.12} + +# Controls the location where the artifacts are stored on the Docker container. +export JAXCI_OUTPUT_DIR="$JAXCI_CONTAINER_WORK_DIR/dist" + +# Release tag to use for the build. +export JAXCI_RELEASE_TAG="" + +# This is expected to be the root of the JAX git repository. +export JAXCI_GIT_DIR=$(pwd) + +# Test specific environment variables below. Used by `run_bazel_test.sh` and +# `run_pytest.sh` +export JAXCI_RUN_TESTS=${JAXCI_RUN_TESTS:-0} + +# When running tests, we disable x64 mode. +if [[ $JAXCI_RUN_TESTS == 1 ]]; then + export JAX_ENABLE_X64=0 +fi + +export JAXCI_RUN_BAZEL_GPU_TEST_LOCAL=${JAXCI_RUN_BAZEL_GPU_TEST_LOCAL:-0} +export JAXCI_RUN_BAZEL_GPU_TEST_RBE=${JAXCI_RUN_BAZEL_GPU_TEST_RBE:-0} + +# Allow overriding the XLA git repository path and commit. +export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} +export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} + +# JAXCI_PYTHON is used to install the wheels locally. It needs to match the +# version of the hermetic Python used to Bazel. +export JAXCI_PYTHON=python${JAXCI_HERMETIC_PYTHON_VERSION} \ No newline at end of file diff --git a/ci/envs/run_tests/bazel_local_gpu b/ci/envs/run_tests/bazel_local_gpu new file mode 100644 index 000000000000..c4e54459bb36 --- /dev/null +++ b/ci/envs/run_tests/bazel_local_gpu @@ -0,0 +1,18 @@ +# Inherit default environment variables. +source ci/envs/default + +export JAXCI_RUN_TESTS=1 + +# Enable local GPU tests +export JAXCI_RUN_BAZEL_GPU_TEST_LOCAL=1 + +# Tests targets are built with CUDA 12.3 and then tested with CUDA 12.3 and +# CUDA 12.1. By default, we set the CUDA version of the docker image to 12.3. +export JAXCI_DOCKER_CUDA_VERSION=${JAX_DOCKER_CUDA_VERSION:-12.3} + +# Only Linux x86 runs local GPU tests at the moment. +export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda${JAXCI_DOCKER_CUDA_VERSION}-cudnn9.1-ubuntu20.04-manylinux2014-multipython" +export JAXCI_DOCKER_ARGS="--gpus all" + +export TF_CPP_MIN_LOG_LEVEL=0 +export NCCL_DEBUG=WARN \ No newline at end of file diff --git a/ci/envs/run_tests/bazel_rbe_gpu b/ci/envs/run_tests/bazel_rbe_gpu new file mode 100644 index 000000000000..99ee2b59746a --- /dev/null +++ b/ci/envs/run_tests/bazel_rbe_gpu @@ -0,0 +1,9 @@ +# Inherit default environment variables. +source ci/envs/default + +export JAXCI_RUN_TESTS=1 +export JAXCI_RUN_BAZEL_GPU_TEST_RBE=1 + +# Only Linux x86 runs local GPU tests at the moment. +export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython" +export JAXCI_DOCKER_ARGS="--gpus all" \ No newline at end of file diff --git a/ci/run_bazel_test_cpu.sh b/ci/run_bazel_test_cpu.sh new file mode 100644 index 000000000000..5126cd76b5a2 --- /dev/null +++ b/ci/run_bazel_test_cpu.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright 2024 JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +source "ci/utilities/setup.sh" + +jaxrun "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +if [[ $JAXCI_BUILD_BAZEL_TEST_ENABLE == 1 ]]; then + # Bazel build on RBE CPU. Used when RBE is not available for the platform. E.g + # Linux Aarch64 + jaxrun bazel --bazelrc=ci/.bazelrc build --config=rbe_cross_compile_${os}_${arch} \ + --override_repository=xla="${KOKORO_ARTIFACTS_DIR}"/xla \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + //tests:cpu_tests //tests:backend_independent_tests +else + # Bazel test on RBE CPU. Only Linux x86_64 can run tests on RBE at the moment. + jaxrun bazel --bazelrc=ci/.bazelrc test --config=rbe_${os}_${arch} \ + --override_repository=xla="${KOKORO_ARTIFACTS_DIR}"/xla \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + //tests:cpu_tests //tests:backend_independent_tests +fi \ No newline at end of file diff --git a/ci/run_bazel_test_gpu.sh b/ci/run_bazel_test_gpu.sh new file mode 100644 index 000000000000..e94d02599fe1 --- /dev/null +++ b/ci/run_bazel_test_gpu.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2024 JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +source "ci/utilities/setup.sh" + +jaxrun nvidia-smi + +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +if [[ $JAXCI_RUN_BAZEL_GPU_TEST_LOCAL == 1 ]]; then + echo "Running local GPU tests..." + + jaxrun "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + + # Only Linux x86 builds run these for now. + # Runs non-multiaccelerator tests with one GPU apiece. + # It appears --run_under needs an absolute path. + jaxrun bazel --bazelrc=ci/.bazelrc test --config=ci_${os}_${arch}_cuda \ + --config=non_multiaccelerator_local \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --run_under "$JAXCI_CONTAINER_WORK_DIR/build/parallel_accelerator_execute.sh" \ + //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests + + # Runs multiaccelerator tests with all GPUs. + jaxrun bazel --bazelrc=ci/.bazelrc test --config=ci_${os}_${arch}_cuda \ + --config=multiaccelerator_local \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + //tests:gpu_tests //tests/pallas:gpu_tests +fi + +if [[ $JAXCI_RUN_BAZEL_GPU_TEST_RBE == 1 ]]; then + echo "Running RBE GPU tests..." + # RBE GPU tests. Only Linux x86 builds run these for now. + # Runs non-multiaccelerator tests with one GPU apiece. + jaxrun bazel --bazelrc=ci/.bazelrc test --config=rbe_${os}_${arch}_cuda \ + --config=non_multiaccelerator \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests //docs/... +fi \ No newline at end of file diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh new file mode 100755 index 000000000000..4506f81a058f --- /dev/null +++ b/ci/utilities/run_auditwheel.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# Copyright 2024 JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Runs auditwheel to ensure manylinux compatibility. + +# Get a list of all the wheels in the output directory. Only look for wheels +# that need to be verified for manylinux compliance. +WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) + +for wheel in $WHEELS; do + printf "\nRunning auditwheel on the following wheel:" + ls $wheel + OUTPUT_FULL=$(python3 -m auditwheel show $wheel) + # Remove the wheel name from the output to avoid false positives. + wheel_name=$(basename $wheel) + OUTPUT=${OUTPUT_FULL//${wheel_name}/} + + # If a wheel is manylinux2014 compliant, `auditwheel show` will return the + # platform tag as manylinux_2_17. manylinux2014 is an alias for + # manylinux_2_17. + if echo "$OUTPUT" | grep -q "manylinux_2_17"; then + printf "\nThe wheel is manylinux2014 compliant.\n" + else + echo "$OUTPUT_FULL" + printf "\nThe wheel is NOT manylinux2014 compliant.\n" + exit 1 + fi +done \ No newline at end of file diff --git a/ci/utilities/setup.sh b/ci/utilities/setup.sh new file mode 100644 index 000000000000..19c83d0a5d7b --- /dev/null +++ b/ci/utilities/setup.sh @@ -0,0 +1,94 @@ +#!/bin/bash +# Copyright 2024 JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Common setup for all JAX builds. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -euo pipefail -o history -o allexport + +if [[ -z "${ENV_FILE+dummy}" ]]; then + echo "Setup script requires an ENV_FILE to be set." + echo "If you are looking to build JAX artifacts, please set ENV_FILE to an" + echo "env file in the ci/envs/build_artifacts directory." + echo "If you are looking to run JAX tests, please set ENV_FILE to an" + echo "env file in the ci/envs/run_tests directory." + exit 1 +fi +set -x +source "$ENV_FILE" + +# Pre-emptively mark the git directory as safe. This is necessary for JAX CI +# jobs running on GitHub Actions. Without this, git complains that the directory +# has dubious ownership and refuses to run any commands. +git config --global --add safe.directory $JAXCI_GIT_DIR + +# Decide whether to use the release tag. JAX CI jobs build from the main +# branch by default. +if [[ -n "$JAXCI_RELEASE_TAG" ]]; then + git checkout tags/"$JAXCI_RELEASE_TAG" +fi + +# Setup jaxrun, a helper function for executing steps that can either be run +# locally or run under Docker. setup_docker.sh, below, redefines it as "docker +# exec". +# Important: "jaxrun foo | bar" is "( jaxrun foo ) | bar", not "jaxrun (foo | bar)". +# Therefore, "jaxrun" commands cannot include pipes -- which is +# probably for the better. If a pipe is necessary for something, it is probably +# complex. Write a well-documented script under utilities/ to encapsulate the +# functionality instead. +jaxrun() { "$@"; } + +# When running tests, we need to check out XLA at HEAD. +if [[ -n ${JAXCI_XLA_GIT_DIR} ]] && [[ "$JAXCI_RUN_TESTS" == 1 ]]; then + if [[ ! -d $(pwd)/xla ]]; then + rm -rf $(pwd)/xla + echo "Checking out XLA..." + jaxrun git clone --depth=1 https://github.com/openxla/xla.git $(pwd)/xla + echo "Using XLA from $(pwd)/xla" + fi +fi + +if [[ -n ${JAXCI_XLA_GIT_DIR} ]]; then + echo "Using XLA from $JAXCI_XLA_GIT_DIR" +fi + +if [[ -n "$JAXCI_XLA_COMMIT" ]]; then + jaxrun pushd "$JAXCI_XLA_GIT_DIR" + + jaxrun git fetch --depth=1 origin "$JAXCI_XLA_COMMIT" + jaxrun git checkout "$JAXCI_XLA_COMMIT" + jaxrun echo "XLA git hash: $(git rev-parse HEAD)" + + jaxrun popd +fi + +# All builds except for Mac run under Docker. +# GitHub actions do not need to invoke this script. It always runs in a Docker +# container. The image and the runner type are set in the workflow file. +if [[ "$JAXCI_USE_DOCKER" == 1 ]]; then + source ./ci/utilities/setup_docker.sh +fi + +if [[ "$JAXCI_RUN_TESTS" == 1 ]]; then + source ./ci/utilities/setup_test_environment.sh +fi + +# TODO: cleanup steps \ No newline at end of file diff --git a/ci/utilities/setup_docker.sh b/ci/utilities/setup_docker.sh new file mode 100644 index 000000000000..4e94121446b1 --- /dev/null +++ b/ci/utilities/setup_docker.sh @@ -0,0 +1,70 @@ +#!/bin/bash +# Copyright 2024 JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Set up Docker for JAX CI jobs. + +# Keep the existing "jax" container if it's already present. +if ! docker container inspect jax >/dev/null 2>&1 ; then + # Simple retry logic for docker-pull errors. Sleeps if a pull fails. + # Pulling an already-pulled container image will finish instantly, so + # repeating the command costs nothing. + docker pull "$JAXCI_DOCKER_IMAGE" || sleep 15 + docker pull "$JAXCI_DOCKER_IMAGE" + + if [[ "$(uname -s)" =~ "MSYS_NT" ]]; then + # Docker on Windows doesn't support the `host` networking mode, and so + # port-forwarding is required for the container to detect it's running on GCE. + export IP_ADDR=$(powershell -command "(Get-NetIPAddress -AddressFamily IPv4 -InterfaceAlias 'vEthernet (nat)').IPAddress") + netsh interface portproxy add v4tov4 listenaddress=$IP_ADDR listenport=80 connectaddress=169.254.169.254 connectport=80 + JAXCI_DOCKER_ARGS="$JAXCI_DOCKER_ARGS -e GCE_METADATA_HOST=$IP_ADDR" + else + # The volume mapping flag below shares the user's gcloud credentials, if any, + # with the container, in case the user has credentials stored there. + # This would allow Bazel to authenticate for RBE. + # Note: JAX's CI does not have any credentials stored there. + JAXCI_DOCKER_ARGS="$JAXCI_DOCKER_ARGS -v $HOME/.config/gcloud:/root/.config/gcloud" + fi + + # If XLA repository on the local system is to be used, map it to the container + # and set the JAXCI_XLA_GIT_DIR environment variable to the container path. + if [[ -n $JAXCI_XLA_GIT_DIR ]]; then + JAXCI_DOCKER_ARGS="$JAXCI_DOCKER_ARGS -v $JAXCI_XLA_GIT_DIR:$JAXCI_CONTAINER_WORK_DIR/xla -e JAXCI_XLA_GIT_DIR=$JAXCI_CONTAINER_WORK_DIR/xla" + # Update `JAXCI_XLA_GIT_DIR` with the new path on the host shell + # environment as when running commands with `docker exec`, the command is + # run in the host shell environment. See `run_bazel_test_gpu.sh` for where + # this is needed. + export JAXCI_XLA_GIT_DIR=$JAXCI_CONTAINER_WORK_DIR/xla + fi + + # When running `bazel test` and specifying dependencies on local wheels, + # Bazel will look for them in the ../dist directory by default. This can be + # overridden by the setting `local_wheel_dist_folder`. + docker run --env-file <(env | grep ^JAXCI_ ) $JAXCI_DOCKER_ARGS --name jax \ + -w $JAXCI_CONTAINER_WORK_DIR -itd --rm \ + -v "$JAXCI_GIT_DIR:$JAXCI_CONTAINER_WORK_DIR" \ + -e local_wheel_dist_folder=$JAXCI_OUTPUT_DIR \ + "$JAXCI_DOCKER_IMAGE" \ + bash + + if [[ "$(uname -s)" =~ "MSYS_NT" ]]; then + # Allow requests from the container. + CONTAINER_IP_ADDR=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' jax) + netsh advfirewall firewall add rule name="Allow Metadata Proxy" dir=in action=allow protocol=TCP localport=80 remoteip="$CONTAINER_IP_ADDR" + fi +fi +jaxrun() { docker exec jax "$@"; } + +jaxrun git config --global --add safe.directory $JAXCI_CONTAINER_WORK_DIR \ No newline at end of file diff --git a/ci/utilities/setup_test_environment.sh b/ci/utilities/setup_test_environment.sh new file mode 100644 index 000000000000..1987c7fa94f7 --- /dev/null +++ b/ci/utilities/setup_test_environment.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Copyright 2024 JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Set up the environment for JAX tests + +if [[ $JAXCI_RUN_BAZEL_GPU_TEST_LOCAL == 1 ]]; then + # Install the `jaxlib`, `jax-cuda-plugin` and `jax-pjrt` wheels. + jaxrun bash -c "$JAXCI_PYTHON -m pip install $JAXCI_OUTPUT_DIR/*.whl" + + # Install JAX package at the current commit. + # TODO(srnitin): Check if this is needed when running Bazel tests. + jaxrun "$JAXCI_PYTHON" -m pip install -U -e . +fi