Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
97b17af
[typing] add type annotations to the first several lax_numpy functions
jakevdp Oct 21, 2022
4714a5c
Add regression test for #12920
jakevdp Oct 21, 2022
ca7d05f
[typing] fix incorrect type annotation on lax.argmax/argmin
jakevdp Oct 21, 2022
4e8fbd0
Add delete method to GlobalDeviceArray and ShardedBuffer.
zhangqiaorjc Oct 21, 2022
e219d55
Roll-back #12892 because CUSPARSE_SPMV_COO_ALG2 is not available in C…
tlu7 Oct 21, 2022
4acc293
Merge pull request #12923 from jakevdp:nperseg-test
a-googler Oct 21, 2022
64e996e
Merge pull request #12925 from jakevdp:annotate-lax
a-googler Oct 21, 2022
a4e3663
Merge pull request #12921 from jakevdp:lax-numpy-dtypes
a-googler Oct 21, 2022
9956ad2
Add more pjit tests and make some tests go via actual computations ra…
yashk2810 Oct 21, 2022
3be5ab2
Allow calling `initialize_cache` a second time if the path is the same.
a-googler Oct 22, 2022
b07c586
[mhlo] Use 11 out of 12 new shared type inferences from StableHLO.
Oct 22, 2022
5784d61
implement truncnorm in jax.scipy.stats
adrn Oct 3, 2022
67fa7c2
Typo fix.
a-googler Oct 24, 2022
48e680c
CI: avoid raising error when wrapped function is None
jakevdp Oct 24, 2022
894093c
Move jaxlib cpu kernels under jaxlib/cpu/.
hawkinsp Oct 24, 2022
8f2f9f4
Merge pull request #12646 from adrn:truncnorm
a-googler Oct 24, 2022
b892108
Merge pull request #12950 from jakevdp:fix-ci-error
a-googler Oct 24, 2022
28def73
Fix typo in 9419-jax-versioning.md
eltociear Oct 24, 2022
964988c
Merge pull request #12953 from eltociear:patch-1
a-googler Oct 24, 2022
9ade89e
jnp.linalg.lstsq: handle zero-size inputs
jakevdp Oct 24, 2022
56d42c0
[typing] annotate next batch of lax_numpy
jakevdp Oct 24, 2022
15b415b
Merge pull request #12951 from jakevdp:annotate-lax-numpy
a-googler Oct 24, 2022
70f659a
Merge pull request #12957 from jakevdp:fix-lstsq
a-googler Oct 24, 2022
21d02ac
self_hosted_runner_utils
skye Oct 21, 2022
be12bfd
cloud-tpu-ci-nightly.yml
skye Oct 21, 2022
cd02f62
add pull_request event
skye Oct 24, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
name: Cloud-TPU-CI-Nightly

on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
workflow_dispatch: # allows triggering the workflow run manually
repository_dispatch: # allows triggering the workflow via HTTP
pull_request:

jobs:
cloud-tpu-test:
runs-on: tpu
defaults:
run:
shell: bash -l {0}
strategy:
fail-fast: false
matrix:
python-version: ["3.10"] # TODO(jakevdp): update to 3.11 when available.
outputs:
artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }}
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install JAX test requirements
run: |
pip install -r build/test-requirements.txt
pip install pytest-reportlog
- name: Install JAX
run: |
pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
- name: Run tests
if: success()
id: status
env:
JAX_PLATFORMS: tpu,cpu
run: |
pytest --tb=short \
--report-log output-${{ matrix.python-version }}-log.jsonl \
tests/compilation_cache_test.py \
|| (
echo '::set-output name=ARTIFACTS_AVAILABLE::true' && false
)
# run: |
# pytest --tb=short \
# --deselect tests/callback_test.py \
# --deselect tests/checkify_test.py \
# --deselect tests/debugger_test.py \
# --deselect tests/debugging_primitives_test.py \
# --deselect tests/jaxpr_effects_test.py-rf \
# --report-log output-${{ matrix.python-version }}-log.jsonl \
# tests \
# || (
# echo '::set-output name=ARTIFACTS_AVAILABLE::true' && false
# )
- name: Upload artifacts
# if: |
# failure()
# && steps.status.outcome == 'failure'
# && github.event_name == 'schedule'
# && github.repository == 'google/jax'
if: failure()
uses: actions/upload-artifact@v3
with:
name: output-${{ matrix.python-version }}-log.jsonl
path: output-${{ matrix.python-version }}-log.jsonl
retention-days: 5
1 change: 1 addition & 0 deletions .github/workflows/self_hosted_runner_utils/runner.env
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ACTIONS_RUNNER_HOOK_JOB_STARTED=~/jax/.github/workflows/self_hosted_runner_utils/validate_job.sh
20 changes: 20 additions & 0 deletions .github/workflows/self_hosted_runner_utils/start_github_runner.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# More or less copied from
# https://github.com/iree-org/iree/tree/main/build_tools/github_actions/runner/config

~/actions-runner/run.sh > /tmp/actions-runner.`date +"%Y%m%d-%H%M"`.log
47 changes: 47 additions & 0 deletions .github/workflows/self_hosted_runner_utils/validate_job.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# More or less copied from
# https://github.com/iree-org/iree/tree/main/build_tools/github_actions/runner/config

set -euo pipefail

ALLOWED_EVENTS=(
"schedule"
"workflow_dispatch"
)

# Tests if the first argument is contained in the array in the second argument.
# Usage `is_contained "element" "${array[@]}"`
is_contained() {
local e;
local match="$1"
shift
for e in "$@"; do
if [[ "${e}" == "${match}" ]]; then
return 0
fi
done
return 1
}

if ! is_contained "${GITHUB_EVENT_NAME}" "${ALLOWED_EVENTS[@]}"; then
echo "Event type '${GITHUB_EVENT_NAME}' is not allowed on this runner. Aborting workflow."
# clean up any nefarious stuff we may have fetched in job setup.
cd ~/actions-runner/_work
rm -rfv _actions/ _temp/
exit 1
fi
6 changes: 4 additions & 2 deletions build/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def prepare_wheel(sources_path):
copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py")
copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}")
copy_to_jaxlib("__main__/jaxlib/lapack.py")
copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}")
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
copy_to_jaxlib(f"__main__/jaxlib/_ducc_fft.{pyext}")
copy_to_jaxlib("__main__/jaxlib/ducc_fft.py")
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
Expand All @@ -180,6 +178,10 @@ def prepare_wheel(sources_path):
copy_to_jaxlib("__main__/jaxlib/version.py")
copy_to_jaxlib("__main__/jaxlib/xla_client.py")
copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}")
cpu_dir = os.path.join(jaxlib_dir, "cpu")
os.makedirs(cpu_dir)
copy_file(f"__main__/jaxlib/cpu/_lapack.{pyext}", dst_dir=cpu_dir)
copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir)

cuda_dir = os.path.join(jaxlib_dir, "cuda")
if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"):
Expand Down
13 changes: 13 additions & 0 deletions docs/jax.scipy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,19 @@ jax.scipy.stats.t
logpdf
pdf

jax.scipy.stats.truncnorm
~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.truncnorm
.. autosummary::
:toctree: _autosummary

cdf
logcdf
logpdf
logsf
pdf
sf

jax.scipy.stats.uniform
~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.uniform
Expand Down
2 changes: 1 addition & 1 deletion docs/jep/9419-jax-versioning.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ level.
[as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE).
To update the version of XLA used during the build, one must update the pinned
version in the Bazel `WORKSPACE`. This is done manually on an
as-needed basis, but can be overriden on a build-by-build basis.
as-needed basis, but can be overridden on a build-by-build basis.


## How do we make changes across the `jax` and `jaxlib` boundary between releases?
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,13 +960,13 @@ def transpose(operand: ArrayLike, permutation: Sequence[int]) -> Array:
return transpose_p.bind(operand, permutation=permutation)

def argmin(operand: ArrayLike, axis: int,
index_dtype: DTypeLike) -> Tuple[Array, Array]:
index_dtype: DTypeLike) -> Array:
"""Computes the index of the minimum element along ``axis``."""
return argmin_p.bind(operand, axes=(axis,),
index_dtype=dtypes.canonicalize_dtype(index_dtype))

def argmax(operand: ArrayLike, axis: int,
index_dtype: DTypeLike) -> Tuple[Array, Array]:
index_dtype: DTypeLike) -> Array:
"""Computes the index of the maximum element along ``axis``."""
return argmax_p.bind(operand, axes=(axis,),
index_dtype=dtypes.canonicalize_dtype(index_dtype))
Expand Down
29 changes: 21 additions & 8 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
from jax._src.lib import lapack
from jax._src.lib import mlir_api_version

from jax._src.lib import gpu_linalg
from jax._src.lib import gpu_solver
Expand Down Expand Up @@ -873,10 +874,16 @@ def _triangular_solve_lowering(
transpose = "NO_TRANSPOSE"
else:
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
return mhlo.TriangularSolveOp(
mlir.aval_to_ir_type(out_aval), a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
if mlir_api_version < 36:
return mhlo.TriangularSolveOp(
mlir.aval_to_ir_type(out_aval), a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
else:
return mhlo.TriangularSolveOp(
a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results

mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)

Expand All @@ -900,10 +907,16 @@ def _triangular_solve_cpu_lower(
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
else:
transpose = "NO_TRANSPOSE"
return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower),
ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
if mlir_api_version < 36:
return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower),
ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
else:
return mhlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower),
ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results

mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
platform='cpu')
Expand Down
Loading