diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml new file mode 100644 index 000000000000..25e8918d4a70 --- /dev/null +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -0,0 +1,70 @@ +name: Cloud-TPU-CI-Nightly + +on: + schedule: + - cron: "0 12 * * *" # Daily at 12:00 UTC + workflow_dispatch: # allows triggering the workflow run manually + repository_dispatch: # allows triggering the workflow via HTTP + pull_request: + +jobs: + cloud-tpu-test: + runs-on: tpu + defaults: + run: + shell: bash -l {0} + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] # TODO(jakevdp): update to 3.11 when available. + outputs: + artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }} + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install JAX test requirements + run: | + pip install -r build/test-requirements.txt + pip install pytest-reportlog + - name: Install JAX + run: | + pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + - name: Run tests + if: success() + id: status + env: + JAX_PLATFORMS: tpu,cpu + run: | + pytest --tb=short \ + --report-log output-${{ matrix.python-version }}-log.jsonl \ + tests/compilation_cache_test.py \ + || ( + echo '::set-output name=ARTIFACTS_AVAILABLE::true' && false + ) + # run: | + # pytest --tb=short \ + # --deselect tests/callback_test.py \ + # --deselect tests/checkify_test.py \ + # --deselect tests/debugger_test.py \ + # --deselect tests/debugging_primitives_test.py \ + # --deselect tests/jaxpr_effects_test.py-rf \ + # --report-log output-${{ matrix.python-version }}-log.jsonl \ + # tests \ + # || ( + # echo '::set-output name=ARTIFACTS_AVAILABLE::true' && false + # ) + - name: Upload artifacts + # if: | + # failure() + # && steps.status.outcome == 'failure' + # && github.event_name == 'schedule' + # && github.repository == 'google/jax' + if: failure() + uses: actions/upload-artifact@v3 + with: + name: output-${{ matrix.python-version }}-log.jsonl + path: output-${{ matrix.python-version }}-log.jsonl + retention-days: 5 diff --git a/.github/workflows/self_hosted_runner_utils/runner.env b/.github/workflows/self_hosted_runner_utils/runner.env new file mode 100644 index 000000000000..741fd558c578 --- /dev/null +++ b/.github/workflows/self_hosted_runner_utils/runner.env @@ -0,0 +1 @@ +ACTIONS_RUNNER_HOOK_JOB_STARTED=~/jax/.github/workflows/self_hosted_runner_utils/validate_job.sh diff --git a/.github/workflows/self_hosted_runner_utils/start_github_runner.sh b/.github/workflows/self_hosted_runner_utils/start_github_runner.sh new file mode 100755 index 000000000000..7f2d00109a1d --- /dev/null +++ b/.github/workflows/self_hosted_runner_utils/start_github_runner.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# More or less copied from +# https://github.com/iree-org/iree/tree/main/build_tools/github_actions/runner/config + +~/actions-runner/run.sh > /tmp/actions-runner.`date +"%Y%m%d-%H%M"`.log diff --git a/.github/workflows/self_hosted_runner_utils/validate_job.sh b/.github/workflows/self_hosted_runner_utils/validate_job.sh new file mode 100644 index 000000000000..7f47f4b638cb --- /dev/null +++ b/.github/workflows/self_hosted_runner_utils/validate_job.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# More or less copied from +# https://github.com/iree-org/iree/tree/main/build_tools/github_actions/runner/config + +set -euo pipefail + +ALLOWED_EVENTS=( + "schedule" + "workflow_dispatch" +) + +# Tests if the first argument is contained in the array in the second argument. +# Usage `is_contained "element" "${array[@]}"` +is_contained() { + local e; + local match="$1" + shift + for e in "$@"; do + if [[ "${e}" == "${match}" ]]; then + return 0 + fi + done + return 1 +} + +if ! is_contained "${GITHUB_EVENT_NAME}" "${ALLOWED_EVENTS[@]}"; then + echo "Event type '${GITHUB_EVENT_NAME}' is not allowed on this runner. Aborting workflow." + # clean up any nefarious stuff we may have fetched in job setup. + cd ~/actions-runner/_work + rm -rfv _actions/ _temp/ + exit 1 +fi diff --git a/build/build_wheel.py b/build/build_wheel.py index f10efab99997..b6a290455492 100644 --- a/build/build_wheel.py +++ b/build/build_wheel.py @@ -169,9 +169,7 @@ def prepare_wheel(sources_path): copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py") copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}") copy_to_jaxlib("__main__/jaxlib/lapack.py") - copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}") copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py") - copy_to_jaxlib(f"__main__/jaxlib/_ducc_fft.{pyext}") copy_to_jaxlib("__main__/jaxlib/ducc_fft.py") copy_to_jaxlib("__main__/jaxlib/gpu_prng.py") copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py") @@ -180,6 +178,10 @@ def prepare_wheel(sources_path): copy_to_jaxlib("__main__/jaxlib/version.py") copy_to_jaxlib("__main__/jaxlib/xla_client.py") copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}") + cpu_dir = os.path.join(jaxlib_dir, "cpu") + os.makedirs(cpu_dir) + copy_file(f"__main__/jaxlib/cpu/_lapack.{pyext}", dst_dir=cpu_dir) + copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir) cuda_dir = os.path.join(jaxlib_dir, "cuda") if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"): diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index ac92d8e82fab..cce695524061 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -318,6 +318,19 @@ jax.scipy.stats.t logpdf pdf +jax.scipy.stats.truncnorm +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: jax.scipy.stats.truncnorm +.. autosummary:: + :toctree: _autosummary + + cdf + logcdf + logpdf + logsf + pdf + sf + jax.scipy.stats.uniform ~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: jax.scipy.stats.uniform diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index 9e363b86a3e5..598c39f5155d 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -149,7 +149,7 @@ level. [as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE). To update the version of XLA used during the build, one must update the pinned version in the Bazel `WORKSPACE`. This is done manually on an -as-needed basis, but can be overriden on a build-by-build basis. +as-needed basis, but can be overridden on a build-by-build basis. ## How do we make changes across the `jax` and `jaxlib` boundary between releases? diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a51602ac9c20..cb3120082496 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -960,13 +960,13 @@ def transpose(operand: ArrayLike, permutation: Sequence[int]) -> Array: return transpose_p.bind(operand, permutation=permutation) def argmin(operand: ArrayLike, axis: int, - index_dtype: DTypeLike) -> Tuple[Array, Array]: + index_dtype: DTypeLike) -> Array: """Computes the index of the minimum element along ``axis``.""" return argmin_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) def argmax(operand: ArrayLike, axis: int, - index_dtype: DTypeLike) -> Tuple[Array, Array]: + index_dtype: DTypeLike) -> Array: """Computes the index of the maximum element along ``axis``.""" return argmax_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index d70c80be9a85..a0c77bec2957 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -41,6 +41,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import svd as lax_svd from jax._src.lib import lapack +from jax._src.lib import mlir_api_version from jax._src.lib import gpu_linalg from jax._src.lib import gpu_solver @@ -873,10 +874,16 @@ def _triangular_solve_lowering( transpose = "NO_TRANSPOSE" else: transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" - return mhlo.TriangularSolveOp( - mlir.aval_to_ir_type(out_aval), a, b, ir.BoolAttr.get(left_side), - ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), - mhlo.TransposeAttr.get(transpose)).results + if mlir_api_version < 36: + return mhlo.TriangularSolveOp( + mlir.aval_to_ir_type(out_aval), a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), + mhlo.TransposeAttr.get(transpose)).results + else: + return mhlo.TriangularSolveOp( + a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), + mhlo.TransposeAttr.get(transpose)).results mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering) @@ -900,10 +907,16 @@ def _triangular_solve_cpu_lower( transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" else: transpose = "NO_TRANSPOSE" - return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side), - ir.BoolAttr.get(lower), - ir.BoolAttr.get(unit_diagonal), - mhlo.TransposeAttr.get(transpose)).results + if mlir_api_version < 36: + return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), + ir.BoolAttr.get(unit_diagonal), + mhlo.TransposeAttr.get(transpose)).results + else: + return mhlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), + ir.BoolAttr.get(unit_diagonal), + mhlo.TransposeAttr.get(transpose)).results mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower, platform='cpu') diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3a673c549d60..1586f562eb9f 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -26,10 +26,11 @@ import builtins import collections -from functools import partial, wraps as functools_wraps +from functools import partial import operator import types -from typing import overload, Any, Callable, Sequence, FrozenSet, Optional, Tuple, Union +from typing import ( + overload, Any, Callable, Dict, Sequence, FrozenSet, List, Optional, Tuple, Union) from textwrap import dedent as _dedent import warnings @@ -50,7 +51,7 @@ from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, - _sort_le_comparator) + _sort_le_comparator, PrecisionLike) from jax._src.lax import lax as lax_internal from jax._src.numpy.ndarray import ndarray from jax._src.numpy.reductions import ( # noqa: F401 @@ -77,7 +78,7 @@ _register_stackable, _stackable, _where, _wraps) from jax._src.numpy.vectorize import vectorize from jax._src.ops import scatter -from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape +from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio, partition_list, canonicalize_axis as _canonicalize_axis) @@ -135,7 +136,7 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: set_printoptions = np.set_printoptions @_wraps(np.iscomplexobj) -def iscomplexobj(x): +def iscomplexobj(x: Any) -> bool: try: typ = x.dtype.type except AttributeError: @@ -145,7 +146,9 @@ def iscomplexobj(x): shape = _shape = np.shape ndim = _ndim = np.ndim size = np.size -_dtype = partial(dtypes.dtype, canonicalize=True) + +def _dtype(x: Any) -> DType: + return dtypes.dtype(x, canonicalize=True) # At present JAX doesn't have a reason to distinguish between scalars and arrays # in its object system. Further, we want JAX scalars to have the same type @@ -154,22 +157,22 @@ def iscomplexobj(x): # types return JAX arrays when instantiated. class _ScalarMeta(type): - def __hash__(self): + def __hash__(self) -> int: return hash(self.dtype.type) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return id(self) == id(other) or self.dtype.type == other - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not (self == other) - def __call__(self, x): + def __call__(self, x: Any) -> Array: return asarray(x, dtype=self.dtype) - def __instancecheck__(self, instance): + def __instancecheck__(self, instance: Any) -> bool: return isinstance(instance, self.dtype.type) -def _make_scalar_type(np_scalar_type): +def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: meta = _ScalarMeta(np_scalar_type.__name__, (object,), {"dtype": np.dtype(np_scalar_type)}) meta.__module__ = _PUBLIC_MODULE_NAME @@ -226,7 +229,8 @@ def _make_scalar_type(np_scalar_type): savez = np.savez @_wraps(np.dtype) -def _jnp_dtype(obj, align=False, copy=False): +def _jnp_dtype(obj: Optional[DTypeLike], *, align: bool = False, + copy: bool = False) -> DType: """Similar to np.dtype, but respects JAX dtype defaults.""" if obj is None: obj = dtypes.float_ @@ -236,7 +240,7 @@ def _jnp_dtype(obj, align=False, copy=False): ### utility functions -_DEFAULT_TYPEMAP = { +_DEFAULT_TYPEMAP: Dict[type, _ScalarMeta] = { np.bool_: bool_, np.int_: int_, np.float_: float_, @@ -246,13 +250,13 @@ def _jnp_dtype(obj, align=False, copy=False): _lax_const = lax_internal._const -def _result_dtype(op, *args): +def _result_dtype(op: Callable[..., ArrayLike], *args: Any) -> DType: """Compute result dtype of applying op to arguments with given dtypes.""" - args = [np.ones((0,) * ndim(arg), _dtype(arg)) for arg in args] - return _dtype(op(*args)) + np_args = [np.ones((0,) * ndim(arg), _dtype(arg)) for arg in args] + return _dtype(op(*np_args)) -def _convert_and_clip_integer(val, dtype): +def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: """ Convert integer-typed val to specified integer dtype, clipping to dtype range rather than wrapping. @@ -294,7 +298,7 @@ def _convert_and_clip_integer(val, dtype): @_wraps(np.load, update_doc=False) -def load(*args, **kwargs): +def load(*args: Any, **kwargs: Any) -> Array: # The main purpose of this wrapper is to recover bfloat16 data types. # Note: this will only work for files created via np.save(), not np.savez(). out = np.load(*args, **kwargs) @@ -312,20 +316,20 @@ def load(*args, **kwargs): @_wraps(np.fmin, module='numpy') @jit -def fmin(x1, x2): - return where((x1 < x2) | isnan(x2), x1, x2) +def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: + return where(less(x1, x2) | isnan(x2), x1, x2) @_wraps(np.fmax, module='numpy') @jit -def fmax(x1, x2): - return where((x1 > x2) | isnan(x2), x1, x2) +def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: + return where(greater(x1, x2) | isnan(x2), x1, x2) @_wraps(np.issubdtype) -def issubdtype(arg1, arg2): +def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: return dtypes.issubdtype(arg1, arg2) @_wraps(np.isscalar) -def isscalar(element): +def isscalar(element: Any) -> bool: if hasattr(element, '__jax_array__'): element = element.__jax_array__() return dtypes.is_python_scalar(element) or np.isscalar(element) @@ -333,36 +337,36 @@ def isscalar(element): iterable = np.iterable @_wraps(np.result_type) -def result_type(*args): +def result_type(*args: ArrayLike) -> DType: return dtypes.result_type(*args) @_wraps(np.trapz) @partial(jit, static_argnames=('axis',)) -def trapz(y, x=None, dx=1.0, axis: int = -1): +def trapz(y: ArrayLike, x: Optional[ArrayLike] = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: if x is None: _check_arraylike('trapz', y) - y, = _promote_dtypes_inexact(y) + y_arr, = _promote_dtypes_inexact(y) else: _check_arraylike('trapz', y, x) - y, x = _promote_dtypes_inexact(y, x) - if ndim(x) == 1: - dx = diff(x) + y_arr, x_arr = _promote_dtypes_inexact(y, x) + if x_arr.ndim == 1: + dx = diff(x_arr) else: - dx = moveaxis(diff(x, axis=axis), axis, -1) - y = moveaxis(y, axis, -1) - return 0.5 * (dx * (y[..., 1:] + y[..., :-1])).sum(-1) + dx = moveaxis(diff(x_arr, axis=axis), axis, -1) + y_arr = moveaxis(y_arr, axis, -1) + return 0.5 * (dx * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) @_wraps(np.trunc, module='numpy') @jit -def trunc(x): +def trunc(x: ArrayLike) -> Array: _check_arraylike('trunc', x) return where(lax.lt(x, _lax_const(x, 0)), ceil(x), floor(x)) @partial(jit, static_argnums=(2, 3, 4)) -def _conv(x, y, mode, op, precision): +def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike) -> Array: if ndim(x) != 1 or ndim(y) != 1: raise ValueError(f"{op}() only support 1-dimensional inputs.") x, y = _promote_dtypes_inexact(x, y) @@ -396,49 +400,57 @@ def _conv(x, y, mode, op, precision): @_wraps(np.convolve, lax_description=_PRECISION_DOC) @partial(jit, static_argnames=('mode', 'precision')) -def convolve(a, v, mode='full', *, precision=None): +def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, + precision: PrecisionLike = None) -> Array: _check_arraylike("convolve", a, v) - return _conv(a, v, mode, 'convolve', precision) + return _conv(asarray(a), asarray(v), mode, 'convolve', precision) @_wraps(np.correlate, lax_description=_PRECISION_DOC) @partial(jit, static_argnames=('mode', 'precision')) -def correlate(a, v, mode='valid', *, precision=None): +def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, + precision: PrecisionLike = None) -> Array: _check_arraylike("correlate", a, v) - return _conv(a, v, mode, 'correlate', precision) + return _conv(asarray(a), asarray(v), mode, 'correlate', precision) @_wraps(np.histogram_bin_edges) -def histogram_bin_edges(a, bins=10, range=None, weights=None): +def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, + range: Union[None, Array, Sequence[ArrayLike]] = None, + weights: Optional[ArrayLike] = None) -> Array: del weights # unused, because string bins is not supported. if isinstance(bins, str): raise NotImplementedError("string values for `bins` not implemented.") _check_arraylike("histogram_bin_edges", a, bins) - a = ravel(a) - dtype = dtypes.to_inexact_dtype(_dtype(a)) + arr = ravel(a) + dtype = dtypes.to_inexact_dtype(arr.dtype) if _ndim(bins) == 1: return asarray(bins, dtype=dtype) bins = core.concrete_or_error(operator.index, bins, "bins argument of histogram_bin_edges") if range is None: - range = [a.min(), a.max()] + range = [arr.min(), arr.max()] range = asarray(range, dtype=dtype) - if range.shape != (2,): - raise ValueError("`range` must be either None or a sequence of scalars.") + if shape(range) != (2,): + raise ValueError(f"`range` must be either None or a sequence of scalars, got {range}") range = (where(ptp(range) == 0, range[0] - 0.5, range[0]), where(ptp(range) == 0, range[1] + 0.5, range[1])) + assert range is not None return linspace(range[0], range[1], bins + 1, dtype=dtype) @_wraps(np.histogram) -def histogram(a, bins=10, range=None, weights=None, density=None): +def histogram(a: ArrayLike, bins: ArrayLike = 10, + range: Optional[Sequence[ArrayLike]] = None, + weights: Optional[ArrayLike] = None, + density: Optional[bool] = None) -> Tuple[Array, Array]: if weights is None: _check_arraylike("histogram", a, bins) a = ravel(*_promote_dtypes_inexact(a)) weights = ones_like(a) else: _check_arraylike("histogram", a, bins, weights) - if a.shape != weights.shape: + if shape(a) != shape(weights): raise ValueError("weights should have the same shape as a.") a, weights = map(ravel, _promote_dtypes_inexact(a, weights)) @@ -452,10 +464,13 @@ def histogram(a, bins=10, range=None, weights=None, density=None): return counts, bin_edges @_wraps(np.histogram2d) -def histogram2d(x, y, bins=10, range=None, weights=None, density=None): +def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10, + range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]]=None, + weights: Optional[ArrayLike] = None, + density: Optional[bool] = None) -> Tuple[Array, Array, Array]: _check_arraylike("histogram2d", x, y) try: - N = len(bins) + N = len(bins) # type: ignore[arg-type] except TypeError: N = 1 @@ -468,47 +483,51 @@ def histogram2d(x, y, bins=10, range=None, weights=None, density=None): return hist, edges[0], edges[1] @_wraps(np.histogramdd) -def histogramdd(sample, bins=10, range=None, weights=None, density=None): +def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10, + range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = None, + weights: Optional[ArrayLike] = None, + density: Optional[bool] = None) -> Tuple[Array, List[Array]]: if weights is None: _check_arraylike("histogramdd", sample) sample, = _promote_dtypes_inexact(sample) else: _check_arraylike("histogramdd", sample, weights) - if weights.shape != sample.shape[:1]: + if shape(weights) != shape(sample)[:1]: raise ValueError("should have one weight for each sample.") sample, weights = _promote_dtypes_inexact(sample, weights) N, D = shape(sample) if range is not None and ( - len(range) != D or _any(r is not None and len(r) != 2 for r in range)): + len(range) != D or _any(r is not None and shape(r)[0] != 2 for r in range)): # type: ignore[arg-type] raise ValueError(f"For sample.shape={(N, D)}, range must be a sequence " f"of {D} pairs or Nones; got range={range}") try: - num_bins = len(bins) - if num_bins != D: - raise ValueError("should be a bin for each dimension.") + num_bins = len(bins) # type: ignore[arg-type] except TypeError: # when bin_size is integer, the same bin is used for each dimension - bins = D * [bins] + bins_per_dimension: List[ArrayLike] = D * [bins] # type: ignore[assignment] + else: + if num_bins != D: + raise ValueError("should be a bin for each dimension.") + bins_per_dimension = list(bins) # type: ignore[arg-type] - bin_idx_by_dim = D * [None] - nbins = np.empty(D, int) - bin_edges_by_dim = D * [None] - dedges = D * [None] + bin_idx_by_dim: List[Array] = [] + bin_edges_by_dim: List[Array] = [] for i in builtins.range(D): range_i = None if range is None else range[i] - bin_edges = histogram_bin_edges(sample[:, i], bins[i], range_i, weights) + bin_edges = histogram_bin_edges(sample[:, i], bins_per_dimension[i], range_i, weights) bin_idx = searchsorted(bin_edges, sample[:, i], side='right') bin_idx = where(sample[:, i] == bin_edges[-1], bin_idx - 1, bin_idx) - bin_idx_by_dim[i] = bin_idx - nbins[i] = len(bin_edges) + 1 - bin_edges_by_dim[i] = bin_edges - dedges[i] = diff(bin_edges_by_dim[i]) + bin_idx_by_dim.append(bin_idx) + bin_edges_by_dim.append(bin_edges) + + nbins = tuple(len(bin_edges) + 1 for bin_edges in bin_edges_by_dim) + dedges = [diff(bin_edges) for bin_edges in bin_edges_by_dim] - xy = ravel_multi_index(bin_idx_by_dim, nbins, mode='clip') - hist = bincount(xy, weights, length=nbins.prod()) + xy = ravel_multi_index(tuple(bin_idx_by_dim), nbins, mode='clip') + hist = bincount(xy, weights, length=_prod(nbins)) hist = reshape(hist, nbins) core = D*(slice(1, -1),) hist = hist[core] @@ -528,16 +547,16 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=None): """ @_wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC) -def transpose(a, axes=None): +def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = None) -> Array: _stackable(a) or _check_arraylike("transpose", a) - axes = np.arange(ndim(a))[::-1] if axes is None else axes - axes = tuple(_canonicalize_axis(i, ndim(a)) for i in axes) - return lax.transpose(a, axes) + axes_ = list(range(ndim(a))[::-1]) if axes is None else axes + axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_] + return lax.transpose(a, axes_) @_wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('k', 'axes')) -def rot90(m, k=1, axes=(0, 1)): +def rot90(m: ArrayLike, k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Array: _check_arraylike("rot90", m) ax1, ax2 = axes ax1 = _canonicalize_axis(ax1, ndim(m)) @@ -546,11 +565,11 @@ def rot90(m, k=1, axes=(0, 1)): raise ValueError("Axes must be different") # same as numpy error k = k % 4 if k == 0: - return m + return asarray(m) elif k == 2: return flip(flip(m, ax1), ax2) else: - perm = list(range(m.ndim)) + perm = list(range(ndim(m))) perm[ax1], perm[ax2] = perm[ax2], perm[ax1] if k == 1: return transpose(flip(m, ax2), perm) @@ -559,12 +578,12 @@ def rot90(m, k=1, axes=(0, 1)): @_wraps(np.flip, lax_description=_ARRAY_VIEW_DOC) -def flip(m, axis: Optional[Union[int, Tuple[int, ...]]] = None): - return _flip(m, _ensure_optional_axes(axis)) +def flip(m: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + _check_arraylike("flip", m) + return _flip(asarray(m), _ensure_optional_axes(axis)) @partial(jit, static_argnames=('axis',)) -def _flip(m, axis: Optional[Union[int, Tuple[int, ...]]] = None): - _check_arraylike("flip", m) +def _flip(m: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: if axis is None: return lax.rev(m, list(range(len(shape(m))))) axis = _ensure_index_tuple(axis) @@ -572,29 +591,31 @@ def _flip(m, axis: Optional[Union[int, Tuple[int, ...]]] = None): @_wraps(np.fliplr, lax_description=_ARRAY_VIEW_DOC) -def fliplr(m): - return _flip(m, 1) +def fliplr(m: ArrayLike) -> Array: + _check_arraylike("fliplr", m) + return _flip(asarray(m), 1) @_wraps(np.flipud, lax_description=_ARRAY_VIEW_DOC) -def flipud(m): - return _flip(m, 0) +def flipud(m: ArrayLike) -> Array: + _check_arraylike("flipud", m) + return _flip(asarray(m), 0) @_wraps(np.iscomplex) @jit -def iscomplex(x): +def iscomplex(x: ArrayLike) -> Array: i = imag(x) return lax.ne(i, _lax_const(i, 0)) @_wraps(np.isreal) @jit -def isreal(x): +def isreal(x: ArrayLike) -> Array: i = imag(x) return lax.eq(i, _lax_const(i, 0)) @_wraps(np.angle) @partial(jit, static_argnames=['deg']) -def angle(z, deg=False): +def angle(z: ArrayLike, deg: bool = False) -> Array: re = real(z) im = imag(z) dtype = _dtype(re) @@ -609,41 +630,44 @@ def angle(z, deg=False): @_wraps(np.diff) @partial(jit, static_argnames=('n', 'axis')) -def diff(a, n=1, axis: int = -1, prepend=None, append=None): +def diff(a: ArrayLike, n: int = 1, axis: int = -1, + prepend: Optional[ArrayLike] = None, + append: Optional[ArrayLike] = None) -> Array: _check_arraylike("diff", a) + arr = asarray(a) n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.diff") if n == 0: - return a + return arr if n < 0: raise ValueError(f"order must be non-negative but got {n}") - if ndim(a) == 0: + if arr.ndim == 0: raise ValueError(f"diff requires input that is at least one dimensional; got {a}") - nd = a.ndim + nd = arr.ndim axis = _canonicalize_axis(axis, nd) - combined = [] + combined: List[Array] = [] if prepend is not None: _check_arraylike("diff", prepend) if isscalar(prepend): - shape = list(a.shape) + shape = list(arr.shape) shape[axis] = 1 prepend = broadcast_to(prepend, tuple(shape)) - combined.append(prepend) + combined.append(asarray(prepend)) - combined.append(a) + combined.append(arr) if append is not None: _check_arraylike("diff", append) if isscalar(append): - shape = list(a.shape) + shape = list(arr.shape) shape[axis] = 1 append = broadcast_to(append, tuple(shape)) - combined.append(append) + combined.append(asarray(append)) if len(combined) > 1: - a = concatenate(combined, axis) + arr = concatenate(combined, axis) slice1 = [slice(None)] * nd slice2 = [slice(None)] * nd @@ -652,11 +676,11 @@ def diff(a, n=1, axis: int = -1, prepend=None, append=None): slice1_tuple = tuple(slice1) slice2_tuple = tuple(slice2) - op = not_equal if a.dtype == np.bool_ else subtract + op = not_equal if arr.dtype == np.bool_ else subtract for _ in range(n): - a = op(a[slice1_tuple], a[slice2_tuple]) + arr = op(arr[slice1_tuple], arr[slice2_tuple]) - return a + return arr _EDIFF1D_DOC = """\ Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not @@ -666,23 +690,25 @@ def diff(a, n=1, axis: int = -1, prepend=None, append=None): @_wraps(np.ediff1d, lax_description=_EDIFF1D_DOC) @jit -def ediff1d(ary, to_end=None, to_begin=None): +def ediff1d(ary: ArrayLike, to_end: Optional[ArrayLike] = None, + to_begin: Optional[ArrayLike] = None) -> Array: _check_arraylike("ediff1d", ary) - ary = ravel(ary) - result = lax.sub(ary[1:], ary[:-1]) + arr = ravel(ary) + result = lax.sub(arr[1:], arr[:-1]) if to_begin is not None: _check_arraylike("ediff1d", to_begin) - result = concatenate((ravel(asarray(to_begin, dtype=ary.dtype)), result)) + result = concatenate((ravel(asarray(to_begin, dtype=arr.dtype)), result)) if to_end is not None: _check_arraylike("ediff1d", to_end) - result = concatenate((result, ravel(asarray(to_end, dtype=ary.dtype)))) + result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype)))) return result @_wraps(np.gradient, skip_params=['edge_order']) @partial(jit, static_argnames=('axis', 'edge_order')) -def gradient(f, *varargs, axis: Optional[Union[int, Tuple[int, ...]]] = None, - edge_order=None): +def gradient(f: ArrayLike, *varargs, axis: Optional[Union[int, Tuple[int, ...]]] = None, + edge_order: Optional[int] = None) -> Union[Array, List[Array]]: + # TODO(jakevdp): call check_arraylike on f if edge_order is not None: raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") @@ -695,7 +721,7 @@ def gradient_along_axis(a, h, axis): ), axis) return a_grad / h - a = f + a = asarray(f) axis_tuple: Tuple[int, ...] if axis is None: axis_tuple = tuple(range(a.ndim)) @@ -730,28 +756,24 @@ def gradient_along_axis(a, h, axis): # TODO: use jax.lax loop tools if possible a_grad = [gradient_along_axis(a, h, ax) for ax, h in zip(axis_tuple, dx)] - - if len(axis_tuple) == 1: - a_grad = a_grad[0] - - return a_grad + return a_grad[0] if len(axis_tuple) == 1 else a_grad @_wraps(np.isrealobj) -def isrealobj(x): +def isrealobj(x: Any) -> bool: return not iscomplexobj(x) - @_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC) -def reshape(a, newshape, order="C"): +def reshape(a: ArrayLike, newshape: Shape, order: str = "C") -> Array: _stackable(a) or _check_arraylike("reshape", a) try: - return a.reshape(newshape, order=order) # forward to method for ndarrays + # forward to method for ndarrays + return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr] except AttributeError: - return _reshape(a, newshape, order=order) + return _reshape(asarray(a), newshape, order=order) -def _compute_newshape(a, newshape): +def _compute_newshape(a: ArrayLike, newshape: Shape) -> Shape: """Fixes a -1 value in newshape, if present.""" # other errors, like having more than one -1, are caught downstream, in # reshape_shape_rule. @@ -764,19 +786,19 @@ def _compute_newshape(a, newshape): for d in newshape) -def _reshape(a, *args, order="C"): +def _reshape(a: Array, *args: Any, order: str = "C") -> Array: newshape = _compute_newshape(a, args[0] if len(args) == 1 else args) if order == "C": return lax.reshape(a, newshape, None) elif order == "F": - dims = np.arange(ndim(a))[::-1] + dims = list(range(ndim(a))[::-1]) return lax.reshape(a, newshape[::-1], dims).T elif order == "A": raise NotImplementedError("np.reshape order=A is not implemented.") else: raise ValueError(f"Unexpected value for 'order' argument: {order}.") -def _transpose(a, *args): +def _transpose(a: Array, *args: Any) -> Array: if not args: axis = None elif len(args) == 1: @@ -787,7 +809,7 @@ def _transpose(a, *args): @_wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('order',), inline=True) -def ravel(a, order="C"): +def ravel(a: ArrayLike, order: str = "C") -> Array: _stackable(a) or _check_arraylike("ravel", a) if order == "K": raise NotImplementedError("Ravel not implemented for order='K'.") @@ -795,12 +817,13 @@ def ravel(a, order="C"): @_wraps(np.ravel_multi_index) -def ravel_multi_index(multi_index, dims, mode='raise', order='C'): +def ravel_multi_index(multi_index: Tuple[ArrayLike, ...], dims: Tuple[int, ...], + mode: str = 'raise', order: str = 'C') -> Array: assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims) _check_arraylike("ravel_multi_index", *multi_index) - multi_index = [asarray(i) for i in multi_index] - for index in multi_index: + multi_index_arr = [asarray(i) for i in multi_index] + for index in multi_index_arr: if mode == 'raise': core.concrete_or_error(array, index, "The error occurred because ravel_multi_index was jit-compiled" @@ -808,12 +831,12 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'): if not issubdtype(_dtype(index), integer): raise TypeError("only int indices permitted") if mode == "raise": - if _any(any((i < 0) | (i >= d)) for i, d in zip(multi_index, dims)): + if _any(any((i < 0) | (i >= d)) for i, d in zip(multi_index_arr, dims)): raise ValueError("invalid entry in coordinates array") elif mode == "clip": - multi_index = [clip(i, 0, d - 1) for i, d in zip(multi_index, dims)] + multi_index_arr = [clip(i, 0, d - 1) for i, d in zip(multi_index_arr, dims)] elif mode == "wrap": - multi_index = [i % d for i, d in zip(multi_index, dims)] + multi_index_arr = [i % d for i, d in zip(multi_index_arr, dims)] else: raise ValueError(f"invalid mode={mode!r}. Expected 'raise', 'wrap', or 'clip'") @@ -824,9 +847,9 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'): else: raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'") - result = array(0, dtype=(multi_index[0].dtype if multi_index + result = array(0, dtype=(multi_index_arr[0].dtype if multi_index_arr else dtypes.canonicalize_dtype(int_))) - for i, s in zip(multi_index, strides): + for i, s in zip(multi_index_arr, strides): result = result + i * int(s) return result @@ -837,8 +860,9 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'): """ @_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) -def unravel_index(indices, shape): +def unravel_index(indices: ArrayLike, shape: Shape) -> Tuple[Array, ...]: _check_arraylike("unravel_index", indices) + indices_arr = asarray(indices) # Note: we do not convert shape to an array, because it may be passed as a # tuple of weakly-typed values, and asarray() would strip these weak types. try: @@ -849,39 +873,39 @@ def unravel_index(indices, shape): raise ValueError("unravel_index: shape should be a scalar or 1D sequence.") out_indices = [None] * len(shape) for i, s in reversed(list(enumerate(shape))): - indices, out_indices[i] = divmod(indices, s) - oob_pos = indices > 0 - oob_neg = indices < -1 + indices_arr, out_indices[i] = divmod(indices_arr, s) + oob_pos = indices_arr > 0 + oob_neg = indices_arr < -1 return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i)) for s, i in zip(shape, out_indices)) @_wraps(np.resize) @partial(jit, static_argnames=('new_shape',)) -def resize(a, new_shape): +def resize(a: ArrayLike, new_shape: Shape) -> Array: _check_arraylike("resize", a) new_shape = _ensure_index_tuple(new_shape) if _any(dim_length < 0 for dim_length in new_shape): raise ValueError("all elements of `new_shape` must be non-negative") - a = ravel(a) + arr = ravel(a) new_size = _prod(new_shape) - if a.size == 0 or new_size == 0: - return zeros_like(a, shape=new_shape) + if arr.size == 0 or new_size == 0: + return zeros_like(arr, shape=new_shape) - repeats = ceil_of_ratio(new_size, a.size) - a = tile(a, repeats)[:new_size] + repeats = ceil_of_ratio(new_size, arr.size) + arr = tile(arr, repeats)[:new_size] - return reshape(a, new_shape) + return reshape(arr, new_shape) @_wraps(np.squeeze, lax_description=_ARRAY_VIEW_DOC) -def squeeze(a, axis: Optional[Union[int, Tuple[int, ...]]] = None): - return _squeeze(a, _ensure_index_tuple(axis) if axis is not None else None) +def squeeze(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + _check_arraylike("squeeze", a) + return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None) @partial(jit, static_argnames=('axis',), inline=True) -def _squeeze(a, axis): - _check_arraylike("squeeze", a) +def _squeeze(a: Array, axis: Tuple[int]) -> Array: if axis is None: a_shape = shape(a) axis = tuple(i for i, d in enumerate(a_shape) if d == 1) @@ -889,17 +913,17 @@ def _squeeze(a, axis): @_wraps(np.expand_dims) -def expand_dims(a, axis: Union[int, Sequence[int]]): +def expand_dims(a: ArrayLike, axis: Union[int, Sequence[int]]) -> Array: _stackable(a) or _check_arraylike("expand_dims", a) axis = _ensure_index_tuple(axis) if hasattr(a, "expand_dims"): - return a.expand_dims(axis) + return a.expand_dims(axis) # type: ignore return lax.expand_dims(a, axis) @_wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('axis1', 'axis2'), inline=True) -def swapaxes(a, axis1: int, axis2: int): +def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: _check_arraylike("swapaxes", a) perm = np.arange(ndim(a)) perm[axis1], perm[axis2] = perm[axis2], perm[axis1] @@ -907,14 +931,14 @@ def swapaxes(a, axis1: int, axis2: int): @_wraps(np.moveaxis, lax_description=_ARRAY_VIEW_DOC) -def moveaxis(a, source: Union[int, Sequence[int]], - destination: Union[int, Sequence[int]]): - return _moveaxis(a, _ensure_index_tuple(source), +def moveaxis(a: ArrayLike, source: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]]) -> Array: + _check_arraylike("moveaxis", a) + return _moveaxis(asarray(a), _ensure_index_tuple(source), _ensure_index_tuple(destination)) @partial(jit, static_argnames=('source', 'destination'), inline=True) -def _moveaxis(a, source: Tuple[int, ...], destination: Tuple[int, ...]): - _check_arraylike("moveaxis", a) +def _moveaxis(a: Array, source: Tuple[int, ...], destination: Tuple[int, ...]) -> Array: source = tuple(_canonicalize_axis(i, ndim(a)) for i in source) destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination) if len(source) != len(destination): @@ -928,7 +952,8 @@ def _moveaxis(a, source: Tuple[int, ...], destination: Tuple[int, ...]): @_wraps(np.isclose) @partial(jit, static_argnames=('equal_nan',)) -def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): +def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, + equal_nan: bool = False) -> Array: a, b = _promote_args("isclose", a, b) dtype = _dtype(a) if issubdtype(dtype, inexact): diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 49cbbeb51f1f..0c5f07df4bd2 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -601,17 +601,22 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *, f"{b.ndim}-dimensional array given. Array must be one or two-dimensional") m, n = a.shape dtype = a.dtype - if rcond is None: - rcond = jnp.finfo(dtype).eps * max(n, m) + if a.size == 0: + s = jnp.empty(0, dtype=a.dtype) + rank = jnp.array(0, dtype=int) + x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype) else: - rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) - u, s, vt = svd(a, full_matrices=False) - mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0] - rank = mask.sum() - safe_s = jnp.where(mask, s, 1).astype(a.dtype) - s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] - uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) - x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) + if rcond is None: + rcond = jnp.finfo(dtype).eps * max(n, m) + else: + rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) + u, s, vt = svd(a, full_matrices=False) + mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0] + rank = mask.sum() + safe_s = jnp.where(mask, s, 1).astype(a.dtype) + s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] + uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) + x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) # Numpy returns empty residuals in some cases. To allow compilation, we # default to returning full residuals in all cases. if numpy_resid and (rank < n or m <= n): diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 62d488311951..5ed44611b08c 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -154,6 +154,10 @@ def _wraps( be determined from the wrapped function itself. """ def wrap(op): + op.__np_wrapped__ = fun + # Allows this pattern: @wraps(getattr(np, 'new_function', None)) + if fun is None: + return op docstr = getattr(fun, "__doc__", None) name = getattr(fun, "__name__", getattr(op, "__name__", str(op))) try: @@ -203,7 +207,6 @@ def wrap(op): docstr = fun.__doc__ op.__doc__ = docstr - op.__np_wrapped__ = fun for attr in ['__name__', '__qualname__']: try: value = getattr(fun, attr) diff --git a/jax/_src/scipy/stats/truncnorm.py b/jax/_src/scipy/stats/truncnorm.py new file mode 100644 index 000000000000..9838c9016eb2 --- /dev/null +++ b/jax/_src/scipy/stats/truncnorm.py @@ -0,0 +1,130 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import scipy.stats as osp_stats + +from jax import lax +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact +from jax._src.scipy.stats import norm +from jax._src.scipy.special import logsumexp, log_ndtr, ndtr + + +def _log_diff(x, y): + return logsumexp( + jnp.array([x, y]), + b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]), + axis=0 + ) + + +def _log_gauss_mass(a, b): + """Log of Gaussian probability mass within an interval""" + a, b = jnp.array(a), jnp.array(b) + a, b = jnp.broadcast_arrays(a, b) + + # Note: Docstring carried over from scipy + # Calculations in right tail are inaccurate, so we'll exploit the + # symmetry and work only in the left tail + case_left = b <= 0 + case_right = a > 0 + case_central = ~(case_left | case_right) + + def mass_case_left(a, b): + return _log_diff(log_ndtr(b), log_ndtr(a)) + + def mass_case_right(a, b): + return mass_case_left(-b, -a) + + def mass_case_central(a, b): + # Note: Docstring carried over from scipy + # Previously, this was implemented as: + # left_mass = mass_case_left(a, 0) + # right_mass = mass_case_right(0, b) + # return _log_sum(left_mass, right_mass) + # Catastrophic cancellation occurs as np.exp(log_mass) approaches 1. + # Correct for this with an alternative formulation. + # We're not concerned with underflow here: if only one term + # underflows, it was insignificant; if both terms underflow, + # the result can't accurately be represented in logspace anyway + # because sc.log1p(x) ~ x for small x. + return jnp.log1p(-ndtr(a) - ndtr(-b)) + + out = jnp.select( + [case_left, case_right, case_central], + [mass_case_left(a, b), mass_case_right(a, b), mass_case_central(a, b)] + ) + return out + + +@_wraps(osp_stats.truncnorm.logpdf, update_doc=False) +def logpdf(x, a, b, loc=0, scale=1): + x, a, b, loc, scale = _promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale) + val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b)) + + x_scaled = lax.div(lax.sub(x, loc), scale) + val = jnp.where((x_scaled < a) | (x_scaled > b), -jnp.inf, val) + val = jnp.where(a >= b, jnp.nan, val) + return val + + +@_wraps(osp_stats.truncnorm.pdf, update_doc=False) +def pdf(x, a, b, loc=0, scale=1): + return lax.exp(logpdf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.truncnorm.logsf, update_doc=False) +def logsf(x, a, b, loc=0, scale=1): + x, a, b, loc, scale = _promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale) + x, a, b = jnp.broadcast_arrays(x, a, b) + x = lax.div(lax.sub(x, loc), scale) + logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b) + logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b) + + logsf = jnp.select( + # third condition: avoid catastrophic cancellation (from scipy) + [x >= b, x <= a, logsf > -0.1, x > a], + [-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf] + ) + logsf = jnp.where(a >= b, jnp.nan, logsf) + return logsf + + +@_wraps(osp_stats.truncnorm.sf, update_doc=False) +def sf(x, a, b, loc=0, scale=1): + return lax.exp(logsf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.truncnorm.logcdf, update_doc=False) +def logcdf(x, a, b, loc=0, scale=1): + x, a, b, loc, scale = _promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale) + x, a, b = jnp.broadcast_arrays(x, a, b) + x = lax.div(lax.sub(x, loc), scale) + logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b) + logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b) + + logcdf = jnp.select( + # third condition: avoid catastrophic cancellation (from scipy) + [x >= b, x <= a, logcdf > -0.1, x > a], + [0, -jnp.inf, jnp.log1p(-jnp.exp(logsf)), logcdf] + ) + logcdf = jnp.where(a >= b, jnp.nan, logcdf) + return logcdf + + +@_wraps(osp_stats.truncnorm.cdf, update_doc=False) +def cdf(x, a, b, loc=0, scale=1): + return lax.exp(logcdf(x, a, b, loc, scale)) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index ad64f629d3f0..30880714e7b6 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -68,7 +68,7 @@ def _get_impl(ref: Ref, *idx: int, **_): Indexer = Tuple[Union[int, slice, jnp.ndarray], ...] def _unpack_idx(idx: Indexer, ndim: int - ) -> Tuple[Tuple[int, ...], Tuple[bool, ...]]: + ) -> Tuple[Tuple[Array, ...], Tuple[bool, ...]]: indexed_dims_ = [type(i) != slice for i in idx] _, non_slice_idx = partition_list(indexed_dims_, idx) indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_)) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 595fc734b7c2..a1c72872438d 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -797,6 +797,8 @@ def setUp(self): def tearDown(self): for key, value in self._original_config.items(): config.update(key, value) + # TODO(parkers): Remove this when a real fix for most_recent_entry lands. + dispatch.xla_callable.most_recent_entry() super().tearDown() def rng(self): diff --git a/jax/_src/third_party/scipy/signal_helper.py b/jax/_src/third_party/scipy/signal_helper.py index 5f072bc17b40..1f9d0995d7d5 100644 --- a/jax/_src/third_party/scipy/signal_helper.py +++ b/jax/_src/third_party/scipy/signal_helper.py @@ -39,7 +39,7 @@ def _triage_segments(window: Union[ArrayLike, str, Tuple[Any, ...]], nperseg: Op 256. If window is array_like, nperseg is set to the length of the window. """ if isinstance(window, (str, tuple)): - nperseg_int = 256 if nperseg is None else int(nperseg) + nperseg_int = input_length if nperseg is None else int(nperseg) if nperseg_int > input_length: warnings.warn(f'nperseg = {nperseg_int} is greater than input length ' f' = {input_length}, using nperseg = {input_length}') @@ -47,16 +47,13 @@ def _triage_segments(window: Union[ArrayLike, str, Tuple[Any, ...]], nperseg: Op win = jnp.array(osp_signal.get_window(window, nperseg_int), dtype=dtype) else: win = jnp.asarray(window) + nperseg_int = win.size if nperseg is None else int(nperseg) if win.ndim != 1: raise ValueError('window must be 1-D') if input_length < win.size: raise ValueError('window is longer than input signal') - if nperseg is None: - nperseg_int = win.size - elif nperseg != win.size: + if nperseg_int != win.size: raise ValueError("value specified for nperseg is different from length of window") - else: - nperseg_int = int(nperseg) return win, nperseg_int diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py index b5621220cd73..e3b000cd8288 100644 --- a/jax/experimental/compilation_cache/compilation_cache.py +++ b/jax/experimental/compilation_cache/compilation_cache.py @@ -20,6 +20,7 @@ from typing import List, Optional from jax.experimental.compilation_cache.gfile_cache import GFileCache +from jax._src import path as pathlib from jax._src.lib import version_str as jaxlib_version_str from jax._src.lib import xla_client from jax.interpreters import xla @@ -31,8 +32,18 @@ def initialize_cache(path): """Creates a global cache object. Should only be called once per process. + + Will throw an assertion error if called a second time with a different path. + + Args: + path: path for the cache directory. + """ global _cache + if _cache is not None and _cache._path == pathlib.Path(path): + logger.warning("Cache already previously initialized at %s", _cache._path) + return + assert _cache == None, f"The cache path has already been initialized to {_cache._path}" _cache = GFileCache(path) logger.warning("Initialized persistent compilation cache at %s", path) diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index 78bd520e0891..73bf6cd94e14 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -27,6 +27,7 @@ from jax.interpreters import pxla, xla, mlir from jax._src.util import prod, safe_zip from jax._src.api import device_put +from jax._src.lib import xla_extension_version from jax.interpreters.pxla import PartitionSpec Shape = Tuple[int, ...] @@ -331,6 +332,7 @@ def _init_buffers(self, device_buffers): @property def _device_buffers(self): + self._check_if_deleted() if self._maybe_device_buffers is None: self._maybe_device_buffers = self._sharded_buffer.get_device_buffers() # type: ignore return self._maybe_device_buffers @@ -393,14 +395,17 @@ def _create_local_shards(self) -> Sequence[Shard]: @pxla.maybe_cached_property def local_shards(self) -> Sequence[Shard]: + self._check_if_deleted() return self._create_local_shards() @pxla.maybe_cached_property def addressable_shards(self) -> Sequence[Shard]: + self._check_if_deleted() return self.local_shards @property def global_shards(self) -> Sequence[Shard]: + self._check_if_deleted() if self.mesh.size == len(self._local_devices): return self.addressable_shards @@ -424,6 +429,7 @@ def global_shards(self) -> Sequence[Shard]: @property def _value(self): + self._check_if_deleted() if self.is_fully_replicated: return np.asarray(self._device_buffers[0]) @@ -440,15 +446,19 @@ def _value(self): return npy_value def __array__(self, dtype=None, context=None): + self._check_if_deleted() return self._value if dtype is None else self._value.astype(dtype) def local_data(self, index) -> DeviceArray: + self._check_if_deleted() return pxla._set_aval(self._device_buffers[index]) def addressable_data(self, index) -> DeviceArray: + self._check_if_deleted() return self.local_data(index) def block_until_ready(self): + self._check_if_deleted() # self._sharded_buffer can be None if xla_extension_version < 90 or # _DeviceArray is used. if self._sharded_buffer is None: @@ -458,6 +468,26 @@ def block_until_ready(self): self._sharded_buffer.block_until_ready() # type: ignore return self + def _check_if_deleted(self): + if self.is_deleted(): + raise RuntimeError("GlobalDeviceArray has been deleted.") + + def is_deleted(self): + return self._sharded_buffer is None and self._maybe_device_buffers is None + + def delete(self): + if self._sharded_buffer: + if xla_extension_version >= 101: + self._sharded_buffer.delete() + else: + for b in self._sharded_buffer.get_device_buffers(): + b.delete() + self._sharded_buffer = None + if self._maybe_device_buffers: + for b in self._maybe_device_buffers: + b.delete() + self._maybe_device_buffers = None + @property def sharding(self): return jax.sharding.MeshPspecSharding(self._global_mesh, self.mesh_axes) @@ -635,6 +665,7 @@ def _gda_mlir_constant_handler(val, canonicalize_types=True): def _gda_shard_arg(x, devices, indices, mode): + x._check_if_deleted() if mode == pxla.InputsHandlerMode.pmap: raise RuntimeError('GDA is not supported with pmap.') # self._sharded_buffer can be None if xla_extension_version < 90 or diff --git a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py index 022f221cd29a..4b3900ff2c23 100644 --- a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py +++ b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py @@ -124,7 +124,7 @@ class Embedder(nn.Module): word_dropout_rate: float = 0. unk_idx: Optional[int] = None deterministic: Optional[bool] = None - dtype: jnp.dtype = jnp.float32 + dtype: jnp.dtype = jnp.dtype('float32') def setup(self): self.embedding = self.param( diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 1f71491e9c7e..3d08c712402a 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -32,5 +32,6 @@ from jax.scipy.stats import chi2 as chi2 from jax.scipy.stats import betabinom as betabinom from jax.scipy.stats import gennorm as gennorm +from jax.scipy.stats import truncnorm as truncnorm from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde from jax._src.scipy.stats._core import mode as mode diff --git a/jax/scipy/stats/truncnorm.py b/jax/scipy/stats/truncnorm.py new file mode 100644 index 000000000000..3d85d47067e2 --- /dev/null +++ b/jax/scipy/stats/truncnorm.py @@ -0,0 +1,22 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.scipy.stats.truncnorm import ( + cdf as cdf, + logcdf as logcdf, + logpdf as logpdf, + pdf as pdf, + logsf as logsf, + sf as sf +) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 5a6a579e51cc..f5a2cc82932a 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -17,7 +17,6 @@ load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", - "flatbuffer_cc_library", "if_windows", "pybind_extension", ) @@ -42,9 +41,9 @@ py_library( ], data = [":xla_extension"], deps = [ - ":_ducc_fft", - ":_lapack", ":cpu_feature_guard", + "//jaxlib/cpu:_ducc_fft", + "//jaxlib/cpu:_lapack", "//jaxlib/mlir", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:chlo_dialect", @@ -132,8 +131,8 @@ cc_library( ], ) -# CPU kernels - +# This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong +# target architecture. pybind_extension( name = "cpu_feature_guard", srcs = ["cpu_feature_guard.c"], @@ -143,89 +142,14 @@ pybind_extension( ], ) -# LAPACK - -cc_library( - name = "lapack_kernels", - srcs = ["lapack_kernels.cc"], - hdrs = ["lapack_kernels.h"], - deps = [ - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", - "@com_google_absl//absl/base:dynamic_annotations", - ], -) - -cc_library( - name = "lapack_kernels_using_lapack", - srcs = ["lapack_kernels_using_lapack.cc"], - deps = [":lapack_kernels"], - alwayslink = 1, -) - -pybind_extension( - name = "_lapack", - srcs = ["lapack.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_lapack", - deps = [ - ":kernel_pybind11_helpers", - ":lapack_kernels", - "@pybind11", - ], -) - -# DUCC (CPU FFTs) - -flatbuffer_cc_library( - name = "ducc_fft_flatbuffers_cc", - srcs = ["ducc_fft.fbs"], -) - -cc_library( - name = "ducc_fft_kernels", - srcs = ["ducc_fft_kernels.cc"], - hdrs = ["ducc_fft_kernels.h"], - copts = ["-fexceptions"], # DUCC may throw. - features = ["-use_header_modules"], - deps = [ - ":ducc_fft_flatbuffers_cc", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", - "@ducc", - "@flatbuffers//:runtime_cc", - ], -) - -pybind_extension( - name = "_ducc_fft", - srcs = ["ducc_fft.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_ducc_fft", - deps = [ - ":ducc_fft_flatbuffers_cc", - ":ducc_fft_kernels", - ":kernel_pybind11_helpers", - "@flatbuffers//:runtime_cc", - "@pybind11", - ], -) +# CPU kernels +# TODO(phawkins): Remove this forwarding target. cc_library( name = "cpu_kernels", - srcs = ["cpu_kernels.cc"], visibility = ["//visibility:public"], deps = [ - ":ducc_fft_kernels", - ":lapack_kernels", - ":lapack_kernels_using_lapack", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry", + "//jaxlib/cpu:cpu_kernels", ], alwayslink = 1, ) diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD new file mode 100644 index 000000000000..0ba15d760e85 --- /dev/null +++ b/jaxlib/cpu/BUILD @@ -0,0 +1,112 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# JAX is Autograd and XLA + +load( + "//jaxlib:jax.bzl", + "flatbuffer_cc_library", + "pybind_extension", +) + +licenses(["notice"]) + +package(default_visibility = ["//:__subpackages__"]) + +# LAPACK + +cc_library( + name = "lapack_kernels", + srcs = ["lapack_kernels.cc"], + hdrs = ["lapack_kernels.h"], + deps = [ + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@com_google_absl//absl/base:dynamic_annotations", + ], +) + +cc_library( + name = "lapack_kernels_using_lapack", + srcs = ["lapack_kernels_using_lapack.cc"], + deps = [":lapack_kernels"], + alwayslink = 1, +) + +pybind_extension( + name = "_lapack", + srcs = ["lapack.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_lapack", + deps = [ + ":lapack_kernels", + "//jaxlib:kernel_pybind11_helpers", + "@pybind11", + ], +) + +# DUCC (CPU FFTs) + +flatbuffer_cc_library( + name = "ducc_fft_flatbuffers_cc", + srcs = ["ducc_fft.fbs"], +) + +cc_library( + name = "ducc_fft_kernels", + srcs = ["ducc_fft_kernels.cc"], + hdrs = ["ducc_fft_kernels.h"], + copts = ["-fexceptions"], # DUCC may throw. + features = ["-use_header_modules"], + deps = [ + ":ducc_fft_flatbuffers_cc", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@ducc", + "@flatbuffers//:runtime_cc", + ], +) + +pybind_extension( + name = "_ducc_fft", + srcs = ["ducc_fft.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_ducc_fft", + deps = [ + ":ducc_fft_flatbuffers_cc", + ":ducc_fft_kernels", + "//jaxlib:kernel_pybind11_helpers", + "@flatbuffers//:runtime_cc", + "@pybind11", + ], +) + +cc_library( + name = "cpu_kernels", + srcs = ["cpu_kernels.cc"], + visibility = ["//visibility:public"], + deps = [ + ":ducc_fft_kernels", + ":lapack_kernels", + ":lapack_kernels_using_lapack", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry", + ], + alwayslink = 1, +) diff --git a/jaxlib/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc similarity index 98% rename from jaxlib/cpu_kernels.cc rename to jaxlib/cpu/cpu_kernels.cc index 7e8f0fc56ffb..d7ead0a81afc 100644 --- a/jaxlib/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -16,8 +16,8 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. -#include "jaxlib/lapack_kernels.h" -#include "jaxlib/ducc_fft_kernels.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/cpu/ducc_fft_kernels.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" namespace jax { diff --git a/jaxlib/ducc_fft.cc b/jaxlib/cpu/ducc_fft.cc similarity index 96% rename from jaxlib/ducc_fft.cc rename to jaxlib/cpu/ducc_fft.cc index b5211667d776..d3bba133cdaf 100644 --- a/jaxlib/ducc_fft.cc +++ b/jaxlib/cpu/ducc_fft.cc @@ -18,8 +18,8 @@ limitations under the License. #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" -#include "jaxlib/ducc_fft_generated.h" -#include "jaxlib/ducc_fft_kernels.h" +#include "jaxlib/cpu/ducc_fft_generated.h" +#include "jaxlib/cpu/ducc_fft_kernels.h" #include "jaxlib/kernel_pybind11_helpers.h" namespace py = pybind11; diff --git a/jaxlib/ducc_fft.fbs b/jaxlib/cpu/ducc_fft.fbs similarity index 100% rename from jaxlib/ducc_fft.fbs rename to jaxlib/cpu/ducc_fft.fbs diff --git a/jaxlib/ducc_fft_kernels.cc b/jaxlib/cpu/ducc_fft_kernels.cc similarity index 99% rename from jaxlib/ducc_fft_kernels.cc rename to jaxlib/cpu/ducc_fft_kernels.cc index 6bed7ffbfebf..789de83eb82c 100644 --- a/jaxlib/ducc_fft_kernels.cc +++ b/jaxlib/cpu/ducc_fft_kernels.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" -#include "jaxlib/ducc_fft_generated.h" +#include "jaxlib/cpu/ducc_fft_generated.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" #include "ducc/src/ducc0/fft/fft.h" diff --git a/jaxlib/ducc_fft_kernels.h b/jaxlib/cpu/ducc_fft_kernels.h similarity index 86% rename from jaxlib/ducc_fft_kernels.h rename to jaxlib/cpu/ducc_fft_kernels.h index 2491e7e7a886..b0ababd7659b 100644 --- a/jaxlib/ducc_fft_kernels.h +++ b/jaxlib/cpu/ducc_fft_kernels.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef JAXLIB_CPU_DUCC_FFT_KERNELS_H_ +#define JAXLIB_CPU_DUCC_FFT_KERNELS_H_ + #include "tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { @@ -20,3 +23,5 @@ namespace jax { void DuccFft(void* out, void** in, XlaCustomCallStatus*); } // namespace jax + +#endif // JAXLIB_CPU_DUCC_FFT_KERNELS_H_ diff --git a/jaxlib/lapack.cc b/jaxlib/cpu/lapack.cc similarity index 99% rename from jaxlib/lapack.cc rename to jaxlib/cpu/lapack.cc index 36cd22c7dd0c..6d52d8d1f475 100644 --- a/jaxlib/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "jaxlib/kernel_pybind11_helpers.h" -#include "jaxlib/lapack_kernels.h" +#include "jaxlib/cpu/lapack_kernels.h" #include "include/pybind11/pybind11.h" namespace jax { diff --git a/jaxlib/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc similarity index 99% rename from jaxlib/lapack_kernels.cc rename to jaxlib/cpu/lapack_kernels.cc index 6d29613fa246..afe1466832fa 100644 --- a/jaxlib/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/lapack_kernels.h" +#include "jaxlib/cpu/lapack_kernels.h" #include #include diff --git a/jaxlib/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h similarity index 98% rename from jaxlib/lapack_kernels.h rename to jaxlib/cpu/lapack_kernels.h index 17ac6e160afa..8df85741f096 100644 --- a/jaxlib/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_LAPACK_KERNELS_H_ -#define JAXLIB_LAPACK_KERNELS_H_ +#ifndef JAXLIB_CPU_LAPACK_KERNELS_H_ +#define JAXLIB_CPU_LAPACK_KERNELS_H_ #include #include @@ -178,4 +178,4 @@ struct ComplexGees { } // namespace jax -#endif // JAXLIB_LAPACK_KERNELS_H_ +#endif // JAXLIB_CPU_LAPACK_KERNELS_H_ diff --git a/jaxlib/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc similarity index 99% rename from jaxlib/lapack_kernels_using_lapack.cc rename to jaxlib/cpu/lapack_kernels_using_lapack.cc index 908e8cdfeb63..9c380828eeb0 100644 --- a/jaxlib/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/lapack_kernels.h" +#include "jaxlib/cpu/lapack_kernels.h" // From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but // a C++ user should link against LAPACK directly. This is needed when using diff --git a/jaxlib/cuda/cusparse.cc b/jaxlib/cuda/cusparse.cc index 3a5c331992c4..0deb0b8b06b0 100644 --- a/jaxlib/cuda/cusparse.cc +++ b/jaxlib/cuda/cusparse.cc @@ -295,7 +295,7 @@ std::pair BuildCsrMatvecDescriptor( CudaConst beta = CudaZero(y.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, - CUSPARSE_SPMV_CSR_ALG2, &buffer_size))); + CUSPARSE_MV_ALG_DEFAULT, &buffer_size))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); @@ -468,7 +468,7 @@ std::pair BuildCooMatvecDescriptor( CudaConst beta = CudaZero(y.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, - CUSPARSE_SPMV_COO_ALG2, &buffer_size))); + CUSPARSE_MV_ALG_DEFAULT, &buffer_size))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); diff --git a/jaxlib/cuda/cusparse_kernels.cc b/jaxlib/cuda/cusparse_kernels.cc index ffc23c55e4f9..44e2877d374b 100644 --- a/jaxlib/cuda/cusparse_kernels.cc +++ b/jaxlib/cuda/cusparse_kernels.cc @@ -277,7 +277,7 @@ static absl::Status CsrMatvec_(cudaStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(JAX_AS_STATUS( cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, - d.y.type, CUSPARSE_SPMV_CSR_ALG2, buf))); + d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); @@ -474,7 +474,7 @@ static absl::Status CooMatvec_(cudaStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(JAX_AS_STATUS( cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, - d.y.type, CUSPARSE_SPMV_COO_ALG2, buf))); + d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); diff --git a/jaxlib/ducc_fft.py b/jaxlib/ducc_fft.py index f805234fee5f..c4e5477bf545 100644 --- a/jaxlib/ducc_fft.py +++ b/jaxlib/ducc_fft.py @@ -19,7 +19,7 @@ from .mhlo_helpers import custom_call -from . import _ducc_fft +from .cpu import _ducc_fft import numpy as np from jaxlib import xla_client diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index b842140eab68..8a248f12838c 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -22,7 +22,7 @@ from jaxlib import xla_client from .mhlo_helpers import custom_call -from . import _lapack +from .cpu import _lapack for _name, _value in _lapack.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="cpu") diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 6a4cfa953733..6cebc0deb3ac 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -61,6 +61,7 @@ def has_ext_modules(self): '*.so', '*.pyd*', 'py.typed', + 'cpu/*', 'cuda/*', 'cuda/nvvm/libdevice/libdevice*', 'mlir/*.py', diff --git a/tests/global_device_array_test.py b/tests/global_device_array_test.py index 3070f7ca6d02..20623cc130ea 100644 --- a/tests/global_device_array_test.py +++ b/tests/global_device_array_test.py @@ -379,6 +379,16 @@ def test_gda_value(self, mesh_axes): gda, global_data = create_gda(input_shape, global_mesh, mesh_axes) self.assertArraysEqual(gda._value, global_data) + def test_gda_delete(self): + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + input_shape = (8, 2) + gda, _ = create_gda(input_shape, global_mesh, P("x", "y")) + gda._check_if_deleted() + gda.delete() + with self.assertRaisesRegex(RuntimeError, + "GlobalDeviceArray has been deleted."): + gda._check_if_deleted() + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 0169380c0148..7767f4619b26 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -890,6 +890,9 @@ def testMultiDot(self, shapes, dtype): ((4, 6), (4,)), ((6, 6), (6, 1)), ((8, 6), (8, 4)), + ((0, 3), (0,)), + ((3, 0), (3,)), + ((3, 1), (3, 0)), ] ], rcond=[-1, None, 0.5], diff --git a/tests/pjit_test.py b/tests/pjit_test.py index c28ac67ec3a7..4a8576947a7d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1818,6 +1818,10 @@ def test_numpy_array_input(self): @jax_array(True) def test_unspecified_out_axis_resources(self): + # TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend. + if (xla_bridge.get_backend().runtime_type == 'stream_executor' and + jtu.device_under_test() == 'tpu'): + self.skipTest('Does not work with the cloud TPU SE runtime.') def _checks(out, input_data): self.assertIsInstance(out, array.ArrayImpl) @@ -1835,26 +1839,27 @@ def _checks(out, input_data): input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes) - f = pjit(lambda x: x) + f = pjit(lambda x: x * 2) out = f(input_array) - _checks(out, input_data) + _checks(out, input_data * 2) out2 = f(out) - _checks(out2, input_data) + _checks(out2, input_data * 4) @parameterized.named_parameters( - ('mesh1', (4, 2), (2, 1), (2, 2), (1, 2), (8, 2)), - ('mesh2', (2, 2), (4, 1), (4, 2), (2, 2), (8, 2)), - ('mesh3', (2, 1), (4, 2), (4, 2), (4, 2), (8, 2)), + ('mesh1', (4, 2), (2, 8), (2, 2), (1, 2), (8, 2)), + ('mesh2', (2, 2), (4, 8), (4, 2), (2, 2), (8, 2)), + ('mesh3', (2, 1), (4, 8), (4, 2), (4, 2), (8, 2)), ) @jax_array(True) def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, s2_shape, s3_shape, s4_shape): - # Disable on SE runtime type because XLA sharding propagation is not - # supported. - if xla_bridge.get_backend().runtime_type == 'se': - raise unittest.SkipTest('Needs TFRT runtime.') + # TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend. + if (xla_bridge.get_backend().runtime_type == 'stream_executor' and + jtu.device_under_test() == 'tpu'): + self.skipTest('Does not work with the cloud TPU SE runtime.') + global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) global_input_shape = (8, 2) @@ -1870,14 +1875,15 @@ def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, @pjit def f(tree): return tree - out_tree = f((a1, (a2, (a3, a4)))) + out_tree = f((a1 @ a1.T, (a2, (a3 * 2, a4)))) (out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree) self.assertIsInstance(out1, array.ArrayImpl) - self.assertEqual(out1.shape, (8, 2)) + self.assertEqual(out1.shape, (8, 8)) self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape) for s in out1.addressable_shards: - self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) + self.assertArraysEqual( + s.data._arrays[0], (input_data @ input_data.T)[s.index]) self.assertIsInstance(out2, array.ArrayImpl) self.assertEqual(out2.shape, (8, 2)) @@ -1889,7 +1895,7 @@ def f(tree): self.assertEqual(out3.shape, (8, 2)) self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape) for s in out3.addressable_shards: - self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) + self.assertArraysEqual(s.data._arrays[0], (input_data * 2)[s.index]) self.assertIsInstance(out4, array.ArrayImpl) self.assertEqual(out4.shape, (8, 2)) @@ -2176,7 +2182,7 @@ def test_pjit_different_device_recompilation(self): @jax_array(True) def test_grad_of_pjit_single_device_sharding(self): a = jnp.array(16, dtype=jnp.float32) - f = lambda x: x + f = lambda x: x * 3 out = jax.grad(pjit(f))(a) self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, jax.grad(f)(a)) @@ -2488,6 +2494,31 @@ def _invoke_with_mesh_twice(arg_tuple): for i, x, y in zip(range(n), xs, ys): self.assertAllClose(x + i, y) + @jax_array(True) + def test_trivial_computation(self): + shape = (8, 2) + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = MeshPspecSharding(mesh, P('x', 'y')) + inp_data = np.arange(prod(shape)).reshape(shape) + arr = jax.device_put(inp_data, s) + out = pjit(lambda x: x)(arr) + self.assertArraysEqual(out, inp_data) + + @jax_array(True) + def test_multi_device_pjit_mul(self): + shape = (8, 2) + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + inp_data = np.arange(prod(shape)).reshape(shape) + arr1 = jax.device_put(inp_data, MeshPspecSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(inp_data, MeshPspecSharding(mesh, P(None, 'y'))) + + out1, out2 = pjit(lambda x, y: (x @ x.T, y * 2))(arr1, arr2) + + self.assertArraysEqual(out1, inp_data @ inp_data.T) + self.assertEqual(out1.shape, (8, 8)) + self.assertArraysEqual(out2, inp_data * 2) + self.assertEqual(out2.shape, (8, 2)) + class TempSharding(Sharding): diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 75d22bd0fb19..03dd28eda26c 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -313,22 +313,24 @@ def osp_fun(x): for shape, nperseg, noverlap, timeaxis in welch_test_shapes ], use_nperseg=[False, True], + use_window=[False, True], use_noverlap=[False, True], dtype=jtu.dtypes.floating + jtu.dtypes.integer, ) def testWelchWithDefaultStepArgsAgainstNumpy( self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap, - timeaxis): + use_window, timeaxis): kwargs = {'axis': timeaxis} if use_nperseg: kwargs['nperseg'] = nperseg - else: + if use_window: kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg), dtype=dtypes.to_complex_dtype(dtype)) if use_noverlap: kwargs['noverlap'] = noverlap + @jtu.ignore_warning(message="nperseg = 256 is greater than") def osp_fun(x): freqs, Pxx = osp_signal.welch(x, **kwargs) return freqs.astype(_real_dtype(dtype)), Pxx.astype(_real_dtype(dtype)) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index cf8715f04e9e..0007a099e99f 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -474,6 +474,113 @@ def args_maker(): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + @genNamedParametersNArgs(5) + def testTruncnormLogPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.logpdf + lax_fun = lsp_stats.truncnorm.logpdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.pdf + lax_fun = lsp_stats.truncnorm.pdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormLogCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.logcdf + lax_fun = lsp_stats.truncnorm.logcdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.cdf + lax_fun = lsp_stats.truncnorm.cdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormLogSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.logsf + lax_fun = lsp_stats.truncnorm.logsf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.sf + lax_fun = lsp_stats.truncnorm.sf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testParetoLogPdf(self, shapes, dtypes):