From 97b17af5bef31864bd437b84a3ff473d955937a9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 21 Oct 2022 11:59:53 -0700 Subject: [PATCH 01/18] [typing] add type annotations to the first several lax_numpy functions --- jax/_src/numpy/lax_numpy.py | 126 ++++++++++-------- jax/_src/state/primitives.py | 2 +- .../tests/flax_models/bilstm_classifier.py | 2 +- 3 files changed, 71 insertions(+), 59 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3a673c549d60..5ef7f089ebe8 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -26,10 +26,10 @@ import builtins import collections -from functools import partial, wraps as functools_wraps +from functools import partial import operator import types -from typing import overload, Any, Callable, Sequence, FrozenSet, Optional, Tuple, Union +from typing import overload, Any, Callable, Dict, Sequence, FrozenSet, Optional, Tuple, Union from textwrap import dedent as _dedent import warnings @@ -50,7 +50,7 @@ from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, - _sort_le_comparator) + _sort_le_comparator, PrecisionLike) from jax._src.lax import lax as lax_internal from jax._src.numpy.ndarray import ndarray from jax._src.numpy.reductions import ( # noqa: F401 @@ -77,7 +77,7 @@ _register_stackable, _stackable, _where, _wraps) from jax._src.numpy.vectorize import vectorize from jax._src.ops import scatter -from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape +from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio, partition_list, canonicalize_axis as _canonicalize_axis) @@ -145,7 +145,9 @@ def iscomplexobj(x): shape = _shape = np.shape ndim = _ndim = np.ndim size = np.size -_dtype = partial(dtypes.dtype, canonicalize=True) + +def _dtype(x: Any) -> DType: + return dtypes.dtype(x, canonicalize=True) # At present JAX doesn't have a reason to distinguish between scalars and arrays # in its object system. Further, we want JAX scalars to have the same type @@ -154,22 +156,22 @@ def iscomplexobj(x): # types return JAX arrays when instantiated. class _ScalarMeta(type): - def __hash__(self): + def __hash__(self) -> int: return hash(self.dtype.type) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return id(self) == id(other) or self.dtype.type == other - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not (self == other) - def __call__(self, x): + def __call__(self, x: Any) -> Array: return asarray(x, dtype=self.dtype) - def __instancecheck__(self, instance): + def __instancecheck__(self, instance: Any) -> bool: return isinstance(instance, self.dtype.type) -def _make_scalar_type(np_scalar_type): +def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: meta = _ScalarMeta(np_scalar_type.__name__, (object,), {"dtype": np.dtype(np_scalar_type)}) meta.__module__ = _PUBLIC_MODULE_NAME @@ -226,7 +228,8 @@ def _make_scalar_type(np_scalar_type): savez = np.savez @_wraps(np.dtype) -def _jnp_dtype(obj, align=False, copy=False): +def _jnp_dtype(obj: Optional[DTypeLike], *, align: bool = False, + copy: bool = False) -> DType: """Similar to np.dtype, but respects JAX dtype defaults.""" if obj is None: obj = dtypes.float_ @@ -236,7 +239,7 @@ def _jnp_dtype(obj, align=False, copy=False): ### utility functions -_DEFAULT_TYPEMAP = { +_DEFAULT_TYPEMAP: Dict[type, _ScalarMeta] = { np.bool_: bool_, np.int_: int_, np.float_: float_, @@ -246,13 +249,13 @@ def _jnp_dtype(obj, align=False, copy=False): _lax_const = lax_internal._const -def _result_dtype(op, *args): +def _result_dtype(op: Callable[..., ArrayLike], *args: Any) -> DType: """Compute result dtype of applying op to arguments with given dtypes.""" - args = [np.ones((0,) * ndim(arg), _dtype(arg)) for arg in args] - return _dtype(op(*args)) + np_args = [np.ones((0,) * ndim(arg), _dtype(arg)) for arg in args] + return _dtype(op(*np_args)) -def _convert_and_clip_integer(val, dtype): +def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: """ Convert integer-typed val to specified integer dtype, clipping to dtype range rather than wrapping. @@ -294,7 +297,7 @@ def _convert_and_clip_integer(val, dtype): @_wraps(np.load, update_doc=False) -def load(*args, **kwargs): +def load(*args: Any, **kwargs: Any) -> Array: # The main purpose of this wrapper is to recover bfloat16 data types. # Note: this will only work for files created via np.save(), not np.savez(). out = np.load(*args, **kwargs) @@ -312,20 +315,20 @@ def load(*args, **kwargs): @_wraps(np.fmin, module='numpy') @jit -def fmin(x1, x2): - return where((x1 < x2) | isnan(x2), x1, x2) +def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: + return where(less(x1, x2) | isnan(x2), x1, x2) @_wraps(np.fmax, module='numpy') @jit -def fmax(x1, x2): - return where((x1 > x2) | isnan(x2), x1, x2) +def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: + return where(greater(x1, x2) | isnan(x2), x1, x2) @_wraps(np.issubdtype) -def issubdtype(arg1, arg2): +def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: return dtypes.issubdtype(arg1, arg2) @_wraps(np.isscalar) -def isscalar(element): +def isscalar(element: Any) -> bool: if hasattr(element, '__jax_array__'): element = element.__jax_array__() return dtypes.is_python_scalar(element) or np.isscalar(element) @@ -333,36 +336,36 @@ def isscalar(element): iterable = np.iterable @_wraps(np.result_type) -def result_type(*args): +def result_type(*args: ArrayLike) -> DType: return dtypes.result_type(*args) @_wraps(np.trapz) @partial(jit, static_argnames=('axis',)) -def trapz(y, x=None, dx=1.0, axis: int = -1): +def trapz(y: ArrayLike, x: Optional[ArrayLike] = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: if x is None: _check_arraylike('trapz', y) - y, = _promote_dtypes_inexact(y) + y_arr, = _promote_dtypes_inexact(y) else: _check_arraylike('trapz', y, x) - y, x = _promote_dtypes_inexact(y, x) - if ndim(x) == 1: - dx = diff(x) + y_arr, x_arr = _promote_dtypes_inexact(y, x) + if x_arr.ndim == 1: + dx = diff(x_arr) else: - dx = moveaxis(diff(x, axis=axis), axis, -1) - y = moveaxis(y, axis, -1) - return 0.5 * (dx * (y[..., 1:] + y[..., :-1])).sum(-1) + dx = moveaxis(diff(x_arr, axis=axis), axis, -1) + y_arr = moveaxis(y_arr, axis, -1) + return 0.5 * (dx * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) @_wraps(np.trunc, module='numpy') @jit -def trunc(x): +def trunc(x: ArrayLike) -> Array: _check_arraylike('trunc', x) return where(lax.lt(x, _lax_const(x, 0)), ceil(x), floor(x)) @partial(jit, static_argnums=(2, 3, 4)) -def _conv(x, y, mode, op, precision): +def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike) -> Array: if ndim(x) != 1 or ndim(y) != 1: raise ValueError(f"{op}() only support 1-dimensional inputs.") x, y = _promote_dtypes_inexact(x, y) @@ -396,49 +399,57 @@ def _conv(x, y, mode, op, precision): @_wraps(np.convolve, lax_description=_PRECISION_DOC) @partial(jit, static_argnames=('mode', 'precision')) -def convolve(a, v, mode='full', *, precision=None): +def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, + precision: PrecisionLike = None) -> Array: _check_arraylike("convolve", a, v) - return _conv(a, v, mode, 'convolve', precision) + return _conv(asarray(a), asarray(v), mode, 'convolve', precision) @_wraps(np.correlate, lax_description=_PRECISION_DOC) @partial(jit, static_argnames=('mode', 'precision')) -def correlate(a, v, mode='valid', *, precision=None): +def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, + precision: PrecisionLike = None) -> Array: _check_arraylike("correlate", a, v) - return _conv(a, v, mode, 'correlate', precision) + return _conv(asarray(a), asarray(v), mode, 'correlate', precision) @_wraps(np.histogram_bin_edges) -def histogram_bin_edges(a, bins=10, range=None, weights=None): +def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, + range: Union[None, Array, Sequence[ArrayLike]] = None, + weights: Optional[ArrayLike] = None) -> Array: del weights # unused, because string bins is not supported. if isinstance(bins, str): raise NotImplementedError("string values for `bins` not implemented.") _check_arraylike("histogram_bin_edges", a, bins) - a = ravel(a) - dtype = dtypes.to_inexact_dtype(_dtype(a)) + arr = ravel(a) + dtype = dtypes.to_inexact_dtype(arr.dtype) if _ndim(bins) == 1: return asarray(bins, dtype=dtype) bins = core.concrete_or_error(operator.index, bins, "bins argument of histogram_bin_edges") if range is None: - range = [a.min(), a.max()] + range = [arr.min(), arr.max()] range = asarray(range, dtype=dtype) - if range.shape != (2,): + if shape(range) != (2,): raise ValueError("`range` must be either None or a sequence of scalars.") range = (where(ptp(range) == 0, range[0] - 0.5, range[0]), where(ptp(range) == 0, range[1] + 0.5, range[1])) + assert range is not None return linspace(range[0], range[1], bins + 1, dtype=dtype) @_wraps(np.histogram) -def histogram(a, bins=10, range=None, weights=None, density=None): +def histogram(a: ArrayLike, bins: ArrayLike = 10, + range: Optional[Sequence[ArrayLike]] = None, + weights: Optional[ArrayLike] = None, + density: Optional[bool] = None) -> Tuple[Array, Array]: if weights is None: _check_arraylike("histogram", a, bins) a = ravel(*_promote_dtypes_inexact(a)) weights = ones_like(a) else: _check_arraylike("histogram", a, bins, weights) - if a.shape != weights.shape: + if shape(a) != shape(weights): raise ValueError("weights should have the same shape as a.") a, weights = map(ravel, _promote_dtypes_inexact(a, weights)) @@ -528,11 +539,11 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=None): """ @_wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC) -def transpose(a, axes=None): +def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = None) -> Array: _stackable(a) or _check_arraylike("transpose", a) - axes = np.arange(ndim(a))[::-1] if axes is None else axes - axes = tuple(_canonicalize_axis(i, ndim(a)) for i in axes) - return lax.transpose(a, axes) + axes_ = list(range(ndim(a))[::-1]) if axes is None else axes + axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_] + return lax.transpose(a, axes_) @_wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC) @@ -744,12 +755,13 @@ def isrealobj(x): @_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC) -def reshape(a, newshape, order="C"): +def reshape(a: ArrayLike, newshape: Shape, order: str = "C") -> Array: _stackable(a) or _check_arraylike("reshape", a) try: - return a.reshape(newshape, order=order) # forward to method for ndarrays + # forward to method for ndarrays + return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr] except AttributeError: - return _reshape(a, newshape, order=order) + return _reshape(asarray(a), newshape, order=order) def _compute_newshape(a, newshape): """Fixes a -1 value in newshape, if present.""" @@ -764,19 +776,19 @@ def _compute_newshape(a, newshape): for d in newshape) -def _reshape(a, *args, order="C"): +def _reshape(a: Array, *args: Any, order: str = "C") -> Array: newshape = _compute_newshape(a, args[0] if len(args) == 1 else args) if order == "C": return lax.reshape(a, newshape, None) elif order == "F": - dims = np.arange(ndim(a))[::-1] + dims = list(range(ndim(a))[::-1]) return lax.reshape(a, newshape[::-1], dims).T elif order == "A": raise NotImplementedError("np.reshape order=A is not implemented.") else: raise ValueError(f"Unexpected value for 'order' argument: {order}.") -def _transpose(a, *args): +def _transpose(a: Array, *args: Any) -> Array: if not args: axis = None elif len(args) == 1: @@ -787,7 +799,7 @@ def _transpose(a, *args): @_wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('order',), inline=True) -def ravel(a, order="C"): +def ravel(a: ArrayLike, order: str = "C") -> Array: _stackable(a) or _check_arraylike("ravel", a) if order == "K": raise NotImplementedError("Ravel not implemented for order='K'.") diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index ad64f629d3f0..30880714e7b6 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -68,7 +68,7 @@ def _get_impl(ref: Ref, *idx: int, **_): Indexer = Tuple[Union[int, slice, jnp.ndarray], ...] def _unpack_idx(idx: Indexer, ndim: int - ) -> Tuple[Tuple[int, ...], Tuple[bool, ...]]: + ) -> Tuple[Tuple[Array, ...], Tuple[bool, ...]]: indexed_dims_ = [type(i) != slice for i in idx] _, non_slice_idx = partition_list(indexed_dims_, idx) indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_)) diff --git a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py index 022f221cd29a..4b3900ff2c23 100644 --- a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py +++ b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py @@ -124,7 +124,7 @@ class Embedder(nn.Module): word_dropout_rate: float = 0. unk_idx: Optional[int] = None deterministic: Optional[bool] = None - dtype: jnp.dtype = jnp.float32 + dtype: jnp.dtype = jnp.dtype('float32') def setup(self): self.embedding = self.param( From 4714a5cc8f668095b5ee47f09390480997c5401e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 21 Oct 2022 12:52:32 -0700 Subject: [PATCH 02/18] Add regression test for #12920 --- jax/_src/third_party/scipy/signal_helper.py | 9 +++------ tests/scipy_signal_test.py | 6 ++++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/jax/_src/third_party/scipy/signal_helper.py b/jax/_src/third_party/scipy/signal_helper.py index 5f072bc17b40..1f9d0995d7d5 100644 --- a/jax/_src/third_party/scipy/signal_helper.py +++ b/jax/_src/third_party/scipy/signal_helper.py @@ -39,7 +39,7 @@ def _triage_segments(window: Union[ArrayLike, str, Tuple[Any, ...]], nperseg: Op 256. If window is array_like, nperseg is set to the length of the window. """ if isinstance(window, (str, tuple)): - nperseg_int = 256 if nperseg is None else int(nperseg) + nperseg_int = input_length if nperseg is None else int(nperseg) if nperseg_int > input_length: warnings.warn(f'nperseg = {nperseg_int} is greater than input length ' f' = {input_length}, using nperseg = {input_length}') @@ -47,16 +47,13 @@ def _triage_segments(window: Union[ArrayLike, str, Tuple[Any, ...]], nperseg: Op win = jnp.array(osp_signal.get_window(window, nperseg_int), dtype=dtype) else: win = jnp.asarray(window) + nperseg_int = win.size if nperseg is None else int(nperseg) if win.ndim != 1: raise ValueError('window must be 1-D') if input_length < win.size: raise ValueError('window is longer than input signal') - if nperseg is None: - nperseg_int = win.size - elif nperseg != win.size: + if nperseg_int != win.size: raise ValueError("value specified for nperseg is different from length of window") - else: - nperseg_int = int(nperseg) return win, nperseg_int diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 75d22bd0fb19..03dd28eda26c 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -313,22 +313,24 @@ def osp_fun(x): for shape, nperseg, noverlap, timeaxis in welch_test_shapes ], use_nperseg=[False, True], + use_window=[False, True], use_noverlap=[False, True], dtype=jtu.dtypes.floating + jtu.dtypes.integer, ) def testWelchWithDefaultStepArgsAgainstNumpy( self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap, - timeaxis): + use_window, timeaxis): kwargs = {'axis': timeaxis} if use_nperseg: kwargs['nperseg'] = nperseg - else: + if use_window: kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg), dtype=dtypes.to_complex_dtype(dtype)) if use_noverlap: kwargs['noverlap'] = noverlap + @jtu.ignore_warning(message="nperseg = 256 is greater than") def osp_fun(x): freqs, Pxx = osp_signal.welch(x, **kwargs) return freqs.astype(_real_dtype(dtype)), Pxx.astype(_real_dtype(dtype)) From ca7d05f4f1f57e2968f95c27c13cf00297f0e09b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 21 Oct 2022 14:37:59 -0700 Subject: [PATCH 03/18] [typing] fix incorrect type annotation on lax.argmax/argmin --- jax/_src/lax/lax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a51602ac9c20..cb3120082496 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -960,13 +960,13 @@ def transpose(operand: ArrayLike, permutation: Sequence[int]) -> Array: return transpose_p.bind(operand, permutation=permutation) def argmin(operand: ArrayLike, axis: int, - index_dtype: DTypeLike) -> Tuple[Array, Array]: + index_dtype: DTypeLike) -> Array: """Computes the index of the minimum element along ``axis``.""" return argmin_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) def argmax(operand: ArrayLike, axis: int, - index_dtype: DTypeLike) -> Tuple[Array, Array]: + index_dtype: DTypeLike) -> Array: """Computes the index of the maximum element along ``axis``.""" return argmax_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) From 4e8fbd0239fed976cea98fa62dd88257f976abbd Mon Sep 17 00:00:00 2001 From: Qiao Zhang Date: Fri, 21 Oct 2022 14:41:56 -0700 Subject: [PATCH 04/18] Add delete method to GlobalDeviceArray and ShardedBuffer. This ensures all existing JAX buffer types have a `delete` method that can be used to free device buffer allocation eagerly. User code sometimes have lingering python refs due to cyclic deps and other reasons, yet users may know for sure that certain arrays will no longer be used after a certain point. Calling `foo_array.delete()` for DeviceArray/ShardedDeviceArray/GlobalDeviceArray/Array allows users to force free the device side allocation to minimize device memory usage. PiperOrigin-RevId: 482892157 --- jax/experimental/global_device_array.py | 31 +++++++++++++++++++++++++ tests/global_device_array_test.py | 10 ++++++++ 2 files changed, 41 insertions(+) diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index 78bd520e0891..73bf6cd94e14 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -27,6 +27,7 @@ from jax.interpreters import pxla, xla, mlir from jax._src.util import prod, safe_zip from jax._src.api import device_put +from jax._src.lib import xla_extension_version from jax.interpreters.pxla import PartitionSpec Shape = Tuple[int, ...] @@ -331,6 +332,7 @@ def _init_buffers(self, device_buffers): @property def _device_buffers(self): + self._check_if_deleted() if self._maybe_device_buffers is None: self._maybe_device_buffers = self._sharded_buffer.get_device_buffers() # type: ignore return self._maybe_device_buffers @@ -393,14 +395,17 @@ def _create_local_shards(self) -> Sequence[Shard]: @pxla.maybe_cached_property def local_shards(self) -> Sequence[Shard]: + self._check_if_deleted() return self._create_local_shards() @pxla.maybe_cached_property def addressable_shards(self) -> Sequence[Shard]: + self._check_if_deleted() return self.local_shards @property def global_shards(self) -> Sequence[Shard]: + self._check_if_deleted() if self.mesh.size == len(self._local_devices): return self.addressable_shards @@ -424,6 +429,7 @@ def global_shards(self) -> Sequence[Shard]: @property def _value(self): + self._check_if_deleted() if self.is_fully_replicated: return np.asarray(self._device_buffers[0]) @@ -440,15 +446,19 @@ def _value(self): return npy_value def __array__(self, dtype=None, context=None): + self._check_if_deleted() return self._value if dtype is None else self._value.astype(dtype) def local_data(self, index) -> DeviceArray: + self._check_if_deleted() return pxla._set_aval(self._device_buffers[index]) def addressable_data(self, index) -> DeviceArray: + self._check_if_deleted() return self.local_data(index) def block_until_ready(self): + self._check_if_deleted() # self._sharded_buffer can be None if xla_extension_version < 90 or # _DeviceArray is used. if self._sharded_buffer is None: @@ -458,6 +468,26 @@ def block_until_ready(self): self._sharded_buffer.block_until_ready() # type: ignore return self + def _check_if_deleted(self): + if self.is_deleted(): + raise RuntimeError("GlobalDeviceArray has been deleted.") + + def is_deleted(self): + return self._sharded_buffer is None and self._maybe_device_buffers is None + + def delete(self): + if self._sharded_buffer: + if xla_extension_version >= 101: + self._sharded_buffer.delete() + else: + for b in self._sharded_buffer.get_device_buffers(): + b.delete() + self._sharded_buffer = None + if self._maybe_device_buffers: + for b in self._maybe_device_buffers: + b.delete() + self._maybe_device_buffers = None + @property def sharding(self): return jax.sharding.MeshPspecSharding(self._global_mesh, self.mesh_axes) @@ -635,6 +665,7 @@ def _gda_mlir_constant_handler(val, canonicalize_types=True): def _gda_shard_arg(x, devices, indices, mode): + x._check_if_deleted() if mode == pxla.InputsHandlerMode.pmap: raise RuntimeError('GDA is not supported with pmap.') # self._sharded_buffer can be None if xla_extension_version < 90 or diff --git a/tests/global_device_array_test.py b/tests/global_device_array_test.py index 3070f7ca6d02..20623cc130ea 100644 --- a/tests/global_device_array_test.py +++ b/tests/global_device_array_test.py @@ -379,6 +379,16 @@ def test_gda_value(self, mesh_axes): gda, global_data = create_gda(input_shape, global_mesh, mesh_axes) self.assertArraysEqual(gda._value, global_data) + def test_gda_delete(self): + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + input_shape = (8, 2) + gda, _ = create_gda(input_shape, global_mesh, P("x", "y")) + gda._check_if_deleted() + gda.delete() + with self.assertRaisesRegex(RuntimeError, + "GlobalDeviceArray has been deleted."): + gda._check_if_deleted() + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From e219d55c366d510316642d46eaf00cf288159f35 Mon Sep 17 00:00:00 2001 From: Tianjian Lu Date: Fri, 21 Oct 2022 15:05:42 -0700 Subject: [PATCH 05/18] Roll-back #12892 because CUSPARSE_SPMV_COO_ALG2 is not available in CUDA 11.1 PiperOrigin-RevId: 482897448 --- jaxlib/cuda/cusparse.cc | 4 ++-- jaxlib/cuda/cusparse_kernels.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jaxlib/cuda/cusparse.cc b/jaxlib/cuda/cusparse.cc index 3a5c331992c4..0deb0b8b06b0 100644 --- a/jaxlib/cuda/cusparse.cc +++ b/jaxlib/cuda/cusparse.cc @@ -295,7 +295,7 @@ std::pair BuildCsrMatvecDescriptor( CudaConst beta = CudaZero(y.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, - CUSPARSE_SPMV_CSR_ALG2, &buffer_size))); + CUSPARSE_MV_ALG_DEFAULT, &buffer_size))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); @@ -468,7 +468,7 @@ std::pair BuildCooMatvecDescriptor( CudaConst beta = CudaZero(y.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, - CUSPARSE_SPMV_COO_ALG2, &buffer_size))); + CUSPARSE_MV_ALG_DEFAULT, &buffer_size))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); diff --git a/jaxlib/cuda/cusparse_kernels.cc b/jaxlib/cuda/cusparse_kernels.cc index ffc23c55e4f9..44e2877d374b 100644 --- a/jaxlib/cuda/cusparse_kernels.cc +++ b/jaxlib/cuda/cusparse_kernels.cc @@ -277,7 +277,7 @@ static absl::Status CsrMatvec_(cudaStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(JAX_AS_STATUS( cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, - d.y.type, CUSPARSE_SPMV_CSR_ALG2, buf))); + d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); @@ -474,7 +474,7 @@ static absl::Status CooMatvec_(cudaStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(JAX_AS_STATUS( cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, - d.y.type, CUSPARSE_SPMV_COO_ALG2, buf))); + d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); From 9956ad2f893a7230587788e237e8d5214435e418 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 21 Oct 2022 16:53:14 -0700 Subject: [PATCH 06/18] Add more pjit tests and make some tests go via actual computations rather than trivial computation. PiperOrigin-RevId: 482919649 --- jax/_src/test_util.py | 2 ++ tests/pjit_test.py | 61 ++++++++++++++++++++++++++++++++----------- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 595fc734b7c2..a1c72872438d 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -797,6 +797,8 @@ def setUp(self): def tearDown(self): for key, value in self._original_config.items(): config.update(key, value) + # TODO(parkers): Remove this when a real fix for most_recent_entry lands. + dispatch.xla_callable.most_recent_entry() super().tearDown() def rng(self): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index c28ac67ec3a7..4a8576947a7d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1818,6 +1818,10 @@ def test_numpy_array_input(self): @jax_array(True) def test_unspecified_out_axis_resources(self): + # TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend. + if (xla_bridge.get_backend().runtime_type == 'stream_executor' and + jtu.device_under_test() == 'tpu'): + self.skipTest('Does not work with the cloud TPU SE runtime.') def _checks(out, input_data): self.assertIsInstance(out, array.ArrayImpl) @@ -1835,26 +1839,27 @@ def _checks(out, input_data): input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes) - f = pjit(lambda x: x) + f = pjit(lambda x: x * 2) out = f(input_array) - _checks(out, input_data) + _checks(out, input_data * 2) out2 = f(out) - _checks(out2, input_data) + _checks(out2, input_data * 4) @parameterized.named_parameters( - ('mesh1', (4, 2), (2, 1), (2, 2), (1, 2), (8, 2)), - ('mesh2', (2, 2), (4, 1), (4, 2), (2, 2), (8, 2)), - ('mesh3', (2, 1), (4, 2), (4, 2), (4, 2), (8, 2)), + ('mesh1', (4, 2), (2, 8), (2, 2), (1, 2), (8, 2)), + ('mesh2', (2, 2), (4, 8), (4, 2), (2, 2), (8, 2)), + ('mesh3', (2, 1), (4, 8), (4, 2), (4, 2), (8, 2)), ) @jax_array(True) def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, s2_shape, s3_shape, s4_shape): - # Disable on SE runtime type because XLA sharding propagation is not - # supported. - if xla_bridge.get_backend().runtime_type == 'se': - raise unittest.SkipTest('Needs TFRT runtime.') + # TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend. + if (xla_bridge.get_backend().runtime_type == 'stream_executor' and + jtu.device_under_test() == 'tpu'): + self.skipTest('Does not work with the cloud TPU SE runtime.') + global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) global_input_shape = (8, 2) @@ -1870,14 +1875,15 @@ def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, @pjit def f(tree): return tree - out_tree = f((a1, (a2, (a3, a4)))) + out_tree = f((a1 @ a1.T, (a2, (a3 * 2, a4)))) (out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree) self.assertIsInstance(out1, array.ArrayImpl) - self.assertEqual(out1.shape, (8, 2)) + self.assertEqual(out1.shape, (8, 8)) self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape) for s in out1.addressable_shards: - self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) + self.assertArraysEqual( + s.data._arrays[0], (input_data @ input_data.T)[s.index]) self.assertIsInstance(out2, array.ArrayImpl) self.assertEqual(out2.shape, (8, 2)) @@ -1889,7 +1895,7 @@ def f(tree): self.assertEqual(out3.shape, (8, 2)) self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape) for s in out3.addressable_shards: - self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) + self.assertArraysEqual(s.data._arrays[0], (input_data * 2)[s.index]) self.assertIsInstance(out4, array.ArrayImpl) self.assertEqual(out4.shape, (8, 2)) @@ -2176,7 +2182,7 @@ def test_pjit_different_device_recompilation(self): @jax_array(True) def test_grad_of_pjit_single_device_sharding(self): a = jnp.array(16, dtype=jnp.float32) - f = lambda x: x + f = lambda x: x * 3 out = jax.grad(pjit(f))(a) self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, jax.grad(f)(a)) @@ -2488,6 +2494,31 @@ def _invoke_with_mesh_twice(arg_tuple): for i, x, y in zip(range(n), xs, ys): self.assertAllClose(x + i, y) + @jax_array(True) + def test_trivial_computation(self): + shape = (8, 2) + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = MeshPspecSharding(mesh, P('x', 'y')) + inp_data = np.arange(prod(shape)).reshape(shape) + arr = jax.device_put(inp_data, s) + out = pjit(lambda x: x)(arr) + self.assertArraysEqual(out, inp_data) + + @jax_array(True) + def test_multi_device_pjit_mul(self): + shape = (8, 2) + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + inp_data = np.arange(prod(shape)).reshape(shape) + arr1 = jax.device_put(inp_data, MeshPspecSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(inp_data, MeshPspecSharding(mesh, P(None, 'y'))) + + out1, out2 = pjit(lambda x, y: (x @ x.T, y * 2))(arr1, arr2) + + self.assertArraysEqual(out1, inp_data @ inp_data.T) + self.assertEqual(out1.shape, (8, 8)) + self.assertArraysEqual(out2, inp_data * 2) + self.assertEqual(out2.shape, (8, 2)) + class TempSharding(Sharding): From 3be5ab218a9f3b875ae5911b2f30fb4798955cac Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 21 Oct 2022 19:53:28 -0700 Subject: [PATCH 07/18] Allow calling `initialize_cache` a second time if the path is the same. PiperOrigin-RevId: 482945880 --- .../compilation_cache/compilation_cache.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py index b5621220cd73..64e6a88e0ad9 100644 --- a/jax/experimental/compilation_cache/compilation_cache.py +++ b/jax/experimental/compilation_cache/compilation_cache.py @@ -20,6 +20,7 @@ from typing import List, Optional from jax.experimental.compilation_cache.gfile_cache import GFileCache +from jax._src import path as pathlib from jax._src.lib import version_str as jaxlib_version_str from jax._src.lib import xla_client from jax.interpreters import xla @@ -31,8 +32,18 @@ def initialize_cache(path): """Creates a global cache object. Should only be called once per process. + + Will throw an assertion error if called a second time with a different path. + + Args: + path: path for the cache directory. + """ global _cache + if _cache is not None and _cache._path == pathlib.Path(path): + logger.warning("Cache already previoulsy initialized at %s", _cache._path) + return + assert _cache == None, f"The cache path has already been initialized to {_cache._path}" _cache = GFileCache(path) logger.warning("Initialized persistent compilation cache at %s", path) From b07c586565726e32a6d1613fd75d1394a8569695 Mon Sep 17 00:00:00 2001 From: Xin Zhou Date: Fri, 21 Oct 2022 20:33:33 -0700 Subject: [PATCH 08/18] [mhlo] Use 11 out of 12 new shared type inferences from StableHLO. The shape function of DotGeneralOp can't be integrated into MHLO yet: the shape function only predicts return shape but not able to predict element type. However, the current python binding infra will generate the constructor __init__() without the `return` as the first arg, which assumes the shape function can provide a fully inferred type (including an accurate element type). This leads to "inferred type does not match actual result type" errors in JAX. This needs a future solution. This CL is the corresponding change with https://github.com/openxla/stablehlo/pull/269 Related Python __init__() interface changes (used by JAX): batch_norm_grad: not used by JAX batch_norm_inference: not used by JAX batch_norm_training: not used by JAX case: no change* dot_general: open new b/253644255 to track the issue if: no change* map: no change* reduce: no change* reduce_window: no change* sort: no change* triangular_solve: updated in `linalg.py` while: no change* no change*: the signature of __init()__ for the op is not changed because of existence of regions https://github.com/llvm/llvm-project/blob/main/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp#L577 PiperOrigin-RevId: 482951512 --- jax/_src/lax/linalg.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index d70c80be9a85..a0c77bec2957 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -41,6 +41,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import svd as lax_svd from jax._src.lib import lapack +from jax._src.lib import mlir_api_version from jax._src.lib import gpu_linalg from jax._src.lib import gpu_solver @@ -873,10 +874,16 @@ def _triangular_solve_lowering( transpose = "NO_TRANSPOSE" else: transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" - return mhlo.TriangularSolveOp( - mlir.aval_to_ir_type(out_aval), a, b, ir.BoolAttr.get(left_side), - ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), - mhlo.TransposeAttr.get(transpose)).results + if mlir_api_version < 36: + return mhlo.TriangularSolveOp( + mlir.aval_to_ir_type(out_aval), a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), + mhlo.TransposeAttr.get(transpose)).results + else: + return mhlo.TriangularSolveOp( + a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), + mhlo.TransposeAttr.get(transpose)).results mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering) @@ -900,10 +907,16 @@ def _triangular_solve_cpu_lower( transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" else: transpose = "NO_TRANSPOSE" - return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side), - ir.BoolAttr.get(lower), - ir.BoolAttr.get(unit_diagonal), - mhlo.TransposeAttr.get(transpose)).results + if mlir_api_version < 36: + return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), + ir.BoolAttr.get(unit_diagonal), + mhlo.TransposeAttr.get(transpose)).results + else: + return mhlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), + ir.BoolAttr.get(unit_diagonal), + mhlo.TransposeAttr.get(transpose)).results mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower, platform='cpu') From 5784d61048facfa9dac1f1d309bde2d60a32810c Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan Date: Mon, 3 Oct 2022 17:46:28 -0400 Subject: [PATCH 09/18] implement truncnorm in jax.scipy.stats fix some shape and type issues import into namespace imports into non-_src library working logpdf test cleanup working tests for cdf and sf after fixing select relax need for x to be in (a, b) ensure behavior with invalid input matches scipy remove enforcing valid parameters in tests added truncnorm to docs whoops alphabetical fix linter error fix circular import issue --- docs/jax.scipy.rst | 13 +++ jax/_src/scipy/stats/truncnorm.py | 130 ++++++++++++++++++++++++++++++ jax/scipy/stats/__init__.py | 1 + jax/scipy/stats/truncnorm.py | 22 +++++ tests/scipy_stats_test.py | 107 ++++++++++++++++++++++++ 5 files changed, 273 insertions(+) create mode 100644 jax/_src/scipy/stats/truncnorm.py create mode 100644 jax/scipy/stats/truncnorm.py diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 0bf581c6b3ff..408622bbce57 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -311,6 +311,19 @@ jax.scipy.stats.t logpdf pdf +jax.scipy.stats.truncnorm +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: jax.scipy.stats.truncnorm +.. autosummary:: + :toctree: _autosummary + + cdf + logcdf + logpdf + logsf + pdf + sf + jax.scipy.stats.uniform ~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: jax.scipy.stats.uniform diff --git a/jax/_src/scipy/stats/truncnorm.py b/jax/_src/scipy/stats/truncnorm.py new file mode 100644 index 000000000000..9838c9016eb2 --- /dev/null +++ b/jax/_src/scipy/stats/truncnorm.py @@ -0,0 +1,130 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import scipy.stats as osp_stats + +from jax import lax +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact +from jax._src.scipy.stats import norm +from jax._src.scipy.special import logsumexp, log_ndtr, ndtr + + +def _log_diff(x, y): + return logsumexp( + jnp.array([x, y]), + b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]), + axis=0 + ) + + +def _log_gauss_mass(a, b): + """Log of Gaussian probability mass within an interval""" + a, b = jnp.array(a), jnp.array(b) + a, b = jnp.broadcast_arrays(a, b) + + # Note: Docstring carried over from scipy + # Calculations in right tail are inaccurate, so we'll exploit the + # symmetry and work only in the left tail + case_left = b <= 0 + case_right = a > 0 + case_central = ~(case_left | case_right) + + def mass_case_left(a, b): + return _log_diff(log_ndtr(b), log_ndtr(a)) + + def mass_case_right(a, b): + return mass_case_left(-b, -a) + + def mass_case_central(a, b): + # Note: Docstring carried over from scipy + # Previously, this was implemented as: + # left_mass = mass_case_left(a, 0) + # right_mass = mass_case_right(0, b) + # return _log_sum(left_mass, right_mass) + # Catastrophic cancellation occurs as np.exp(log_mass) approaches 1. + # Correct for this with an alternative formulation. + # We're not concerned with underflow here: if only one term + # underflows, it was insignificant; if both terms underflow, + # the result can't accurately be represented in logspace anyway + # because sc.log1p(x) ~ x for small x. + return jnp.log1p(-ndtr(a) - ndtr(-b)) + + out = jnp.select( + [case_left, case_right, case_central], + [mass_case_left(a, b), mass_case_right(a, b), mass_case_central(a, b)] + ) + return out + + +@_wraps(osp_stats.truncnorm.logpdf, update_doc=False) +def logpdf(x, a, b, loc=0, scale=1): + x, a, b, loc, scale = _promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale) + val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b)) + + x_scaled = lax.div(lax.sub(x, loc), scale) + val = jnp.where((x_scaled < a) | (x_scaled > b), -jnp.inf, val) + val = jnp.where(a >= b, jnp.nan, val) + return val + + +@_wraps(osp_stats.truncnorm.pdf, update_doc=False) +def pdf(x, a, b, loc=0, scale=1): + return lax.exp(logpdf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.truncnorm.logsf, update_doc=False) +def logsf(x, a, b, loc=0, scale=1): + x, a, b, loc, scale = _promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale) + x, a, b = jnp.broadcast_arrays(x, a, b) + x = lax.div(lax.sub(x, loc), scale) + logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b) + logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b) + + logsf = jnp.select( + # third condition: avoid catastrophic cancellation (from scipy) + [x >= b, x <= a, logsf > -0.1, x > a], + [-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf] + ) + logsf = jnp.where(a >= b, jnp.nan, logsf) + return logsf + + +@_wraps(osp_stats.truncnorm.sf, update_doc=False) +def sf(x, a, b, loc=0, scale=1): + return lax.exp(logsf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.truncnorm.logcdf, update_doc=False) +def logcdf(x, a, b, loc=0, scale=1): + x, a, b, loc, scale = _promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale) + x, a, b = jnp.broadcast_arrays(x, a, b) + x = lax.div(lax.sub(x, loc), scale) + logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b) + logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b) + + logcdf = jnp.select( + # third condition: avoid catastrophic cancellation (from scipy) + [x >= b, x <= a, logcdf > -0.1, x > a], + [0, -jnp.inf, jnp.log1p(-jnp.exp(logsf)), logcdf] + ) + logcdf = jnp.where(a >= b, jnp.nan, logcdf) + return logcdf + + +@_wraps(osp_stats.truncnorm.cdf, update_doc=False) +def cdf(x, a, b, loc=0, scale=1): + return lax.exp(logcdf(x, a, b, loc, scale)) diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 1f71491e9c7e..3d08c712402a 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -32,5 +32,6 @@ from jax.scipy.stats import chi2 as chi2 from jax.scipy.stats import betabinom as betabinom from jax.scipy.stats import gennorm as gennorm +from jax.scipy.stats import truncnorm as truncnorm from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde from jax._src.scipy.stats._core import mode as mode diff --git a/jax/scipy/stats/truncnorm.py b/jax/scipy/stats/truncnorm.py new file mode 100644 index 000000000000..3d85d47067e2 --- /dev/null +++ b/jax/scipy/stats/truncnorm.py @@ -0,0 +1,22 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.scipy.stats.truncnorm import ( + cdf as cdf, + logcdf as logcdf, + logpdf as logpdf, + pdf as pdf, + logsf as logsf, + sf as sf +) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index cf8715f04e9e..0007a099e99f 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -474,6 +474,113 @@ def args_maker(): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + @genNamedParametersNArgs(5) + def testTruncnormLogPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.logpdf + lax_fun = lsp_stats.truncnorm.logpdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.pdf + lax_fun = lsp_stats.truncnorm.pdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormLogCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.logcdf + lax_fun = lsp_stats.truncnorm.logcdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.cdf + lax_fun = lsp_stats.truncnorm.cdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormLogSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.logsf + lax_fun = lsp_stats.truncnorm.logsf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.sf + lax_fun = lsp_stats.truncnorm.sf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testParetoLogPdf(self, shapes, dtypes): From 67fa7c27d56579db44c3a65eb794e46717d615f2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Oct 2022 07:53:22 -0700 Subject: [PATCH 10/18] Typo fix. PiperOrigin-RevId: 483380789 --- jax/experimental/compilation_cache/compilation_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py index 64e6a88e0ad9..e3b000cd8288 100644 --- a/jax/experimental/compilation_cache/compilation_cache.py +++ b/jax/experimental/compilation_cache/compilation_cache.py @@ -41,7 +41,7 @@ def initialize_cache(path): """ global _cache if _cache is not None and _cache._path == pathlib.Path(path): - logger.warning("Cache already previoulsy initialized at %s", _cache._path) + logger.warning("Cache already previously initialized at %s", _cache._path) return assert _cache == None, f"The cache path has already been initialized to {_cache._path}" From 48e680c839397bd162feaacd1e127e4e17119469 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 24 Oct 2022 08:57:53 -0700 Subject: [PATCH 11/18] CI: avoid raising error when wrapped function is None --- jax/_src/numpy/util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 62d488311951..5ed44611b08c 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -154,6 +154,10 @@ def _wraps( be determined from the wrapped function itself. """ def wrap(op): + op.__np_wrapped__ = fun + # Allows this pattern: @wraps(getattr(np, 'new_function', None)) + if fun is None: + return op docstr = getattr(fun, "__doc__", None) name = getattr(fun, "__name__", getattr(op, "__name__", str(op))) try: @@ -203,7 +207,6 @@ def wrap(op): docstr = fun.__doc__ op.__doc__ = docstr - op.__np_wrapped__ = fun for attr in ['__name__', '__qualname__']: try: value = getattr(fun, attr) From 894093c0fb7e80838131149304bc45a4ae7a6573 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 24 Oct 2022 10:02:12 -0700 Subject: [PATCH 12/18] Move jaxlib cpu kernels under jaxlib/cpu/. No functional changes intended. PiperOrigin-RevId: 483413031 --- build/build_wheel.py | 6 +- jaxlib/BUILD | 90 ++------------ jaxlib/cpu/BUILD | 112 ++++++++++++++++++ jaxlib/{ => cpu}/cpu_kernels.cc | 4 +- jaxlib/{ => cpu}/ducc_fft.cc | 4 +- jaxlib/{ => cpu}/ducc_fft.fbs | 0 jaxlib/{ => cpu}/ducc_fft_kernels.cc | 2 +- jaxlib/{ => cpu}/ducc_fft_kernels.h | 5 + jaxlib/{ => cpu}/lapack.cc | 2 +- jaxlib/{ => cpu}/lapack_kernels.cc | 2 +- jaxlib/{ => cpu}/lapack_kernels.h | 6 +- .../{ => cpu}/lapack_kernels_using_lapack.cc | 2 +- jaxlib/ducc_fft.py | 2 +- jaxlib/lapack.py | 2 +- jaxlib/setup.py | 1 + 15 files changed, 142 insertions(+), 98 deletions(-) create mode 100644 jaxlib/cpu/BUILD rename jaxlib/{ => cpu}/cpu_kernels.cc (98%) rename jaxlib/{ => cpu}/ducc_fft.cc (96%) rename jaxlib/{ => cpu}/ducc_fft.fbs (100%) rename jaxlib/{ => cpu}/ducc_fft_kernels.cc (99%) rename jaxlib/{ => cpu}/ducc_fft_kernels.h (86%) rename jaxlib/{ => cpu}/lapack.cc (99%) rename jaxlib/{ => cpu}/lapack_kernels.cc (99%) rename jaxlib/{ => cpu}/lapack_kernels.h (98%) rename jaxlib/{ => cpu}/lapack_kernels_using_lapack.cc (99%) diff --git a/build/build_wheel.py b/build/build_wheel.py index f10efab99997..b6a290455492 100644 --- a/build/build_wheel.py +++ b/build/build_wheel.py @@ -169,9 +169,7 @@ def prepare_wheel(sources_path): copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py") copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}") copy_to_jaxlib("__main__/jaxlib/lapack.py") - copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}") copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py") - copy_to_jaxlib(f"__main__/jaxlib/_ducc_fft.{pyext}") copy_to_jaxlib("__main__/jaxlib/ducc_fft.py") copy_to_jaxlib("__main__/jaxlib/gpu_prng.py") copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py") @@ -180,6 +178,10 @@ def prepare_wheel(sources_path): copy_to_jaxlib("__main__/jaxlib/version.py") copy_to_jaxlib("__main__/jaxlib/xla_client.py") copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}") + cpu_dir = os.path.join(jaxlib_dir, "cpu") + os.makedirs(cpu_dir) + copy_file(f"__main__/jaxlib/cpu/_lapack.{pyext}", dst_dir=cpu_dir) + copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir) cuda_dir = os.path.join(jaxlib_dir, "cuda") if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"): diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 5a6a579e51cc..f5a2cc82932a 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -17,7 +17,6 @@ load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", - "flatbuffer_cc_library", "if_windows", "pybind_extension", ) @@ -42,9 +41,9 @@ py_library( ], data = [":xla_extension"], deps = [ - ":_ducc_fft", - ":_lapack", ":cpu_feature_guard", + "//jaxlib/cpu:_ducc_fft", + "//jaxlib/cpu:_lapack", "//jaxlib/mlir", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:chlo_dialect", @@ -132,8 +131,8 @@ cc_library( ], ) -# CPU kernels - +# This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong +# target architecture. pybind_extension( name = "cpu_feature_guard", srcs = ["cpu_feature_guard.c"], @@ -143,89 +142,14 @@ pybind_extension( ], ) -# LAPACK - -cc_library( - name = "lapack_kernels", - srcs = ["lapack_kernels.cc"], - hdrs = ["lapack_kernels.h"], - deps = [ - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", - "@com_google_absl//absl/base:dynamic_annotations", - ], -) - -cc_library( - name = "lapack_kernels_using_lapack", - srcs = ["lapack_kernels_using_lapack.cc"], - deps = [":lapack_kernels"], - alwayslink = 1, -) - -pybind_extension( - name = "_lapack", - srcs = ["lapack.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_lapack", - deps = [ - ":kernel_pybind11_helpers", - ":lapack_kernels", - "@pybind11", - ], -) - -# DUCC (CPU FFTs) - -flatbuffer_cc_library( - name = "ducc_fft_flatbuffers_cc", - srcs = ["ducc_fft.fbs"], -) - -cc_library( - name = "ducc_fft_kernels", - srcs = ["ducc_fft_kernels.cc"], - hdrs = ["ducc_fft_kernels.h"], - copts = ["-fexceptions"], # DUCC may throw. - features = ["-use_header_modules"], - deps = [ - ":ducc_fft_flatbuffers_cc", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", - "@ducc", - "@flatbuffers//:runtime_cc", - ], -) - -pybind_extension( - name = "_ducc_fft", - srcs = ["ducc_fft.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_ducc_fft", - deps = [ - ":ducc_fft_flatbuffers_cc", - ":ducc_fft_kernels", - ":kernel_pybind11_helpers", - "@flatbuffers//:runtime_cc", - "@pybind11", - ], -) +# CPU kernels +# TODO(phawkins): Remove this forwarding target. cc_library( name = "cpu_kernels", - srcs = ["cpu_kernels.cc"], visibility = ["//visibility:public"], deps = [ - ":ducc_fft_kernels", - ":lapack_kernels", - ":lapack_kernels_using_lapack", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry", + "//jaxlib/cpu:cpu_kernels", ], alwayslink = 1, ) diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD new file mode 100644 index 000000000000..0ba15d760e85 --- /dev/null +++ b/jaxlib/cpu/BUILD @@ -0,0 +1,112 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# JAX is Autograd and XLA + +load( + "//jaxlib:jax.bzl", + "flatbuffer_cc_library", + "pybind_extension", +) + +licenses(["notice"]) + +package(default_visibility = ["//:__subpackages__"]) + +# LAPACK + +cc_library( + name = "lapack_kernels", + srcs = ["lapack_kernels.cc"], + hdrs = ["lapack_kernels.h"], + deps = [ + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@com_google_absl//absl/base:dynamic_annotations", + ], +) + +cc_library( + name = "lapack_kernels_using_lapack", + srcs = ["lapack_kernels_using_lapack.cc"], + deps = [":lapack_kernels"], + alwayslink = 1, +) + +pybind_extension( + name = "_lapack", + srcs = ["lapack.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_lapack", + deps = [ + ":lapack_kernels", + "//jaxlib:kernel_pybind11_helpers", + "@pybind11", + ], +) + +# DUCC (CPU FFTs) + +flatbuffer_cc_library( + name = "ducc_fft_flatbuffers_cc", + srcs = ["ducc_fft.fbs"], +) + +cc_library( + name = "ducc_fft_kernels", + srcs = ["ducc_fft_kernels.cc"], + hdrs = ["ducc_fft_kernels.h"], + copts = ["-fexceptions"], # DUCC may throw. + features = ["-use_header_modules"], + deps = [ + ":ducc_fft_flatbuffers_cc", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@ducc", + "@flatbuffers//:runtime_cc", + ], +) + +pybind_extension( + name = "_ducc_fft", + srcs = ["ducc_fft.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_ducc_fft", + deps = [ + ":ducc_fft_flatbuffers_cc", + ":ducc_fft_kernels", + "//jaxlib:kernel_pybind11_helpers", + "@flatbuffers//:runtime_cc", + "@pybind11", + ], +) + +cc_library( + name = "cpu_kernels", + srcs = ["cpu_kernels.cc"], + visibility = ["//visibility:public"], + deps = [ + ":ducc_fft_kernels", + ":lapack_kernels", + ":lapack_kernels_using_lapack", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry", + ], + alwayslink = 1, +) diff --git a/jaxlib/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc similarity index 98% rename from jaxlib/cpu_kernels.cc rename to jaxlib/cpu/cpu_kernels.cc index 7e8f0fc56ffb..d7ead0a81afc 100644 --- a/jaxlib/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -16,8 +16,8 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. -#include "jaxlib/lapack_kernels.h" -#include "jaxlib/ducc_fft_kernels.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/cpu/ducc_fft_kernels.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" namespace jax { diff --git a/jaxlib/ducc_fft.cc b/jaxlib/cpu/ducc_fft.cc similarity index 96% rename from jaxlib/ducc_fft.cc rename to jaxlib/cpu/ducc_fft.cc index b5211667d776..d3bba133cdaf 100644 --- a/jaxlib/ducc_fft.cc +++ b/jaxlib/cpu/ducc_fft.cc @@ -18,8 +18,8 @@ limitations under the License. #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" -#include "jaxlib/ducc_fft_generated.h" -#include "jaxlib/ducc_fft_kernels.h" +#include "jaxlib/cpu/ducc_fft_generated.h" +#include "jaxlib/cpu/ducc_fft_kernels.h" #include "jaxlib/kernel_pybind11_helpers.h" namespace py = pybind11; diff --git a/jaxlib/ducc_fft.fbs b/jaxlib/cpu/ducc_fft.fbs similarity index 100% rename from jaxlib/ducc_fft.fbs rename to jaxlib/cpu/ducc_fft.fbs diff --git a/jaxlib/ducc_fft_kernels.cc b/jaxlib/cpu/ducc_fft_kernels.cc similarity index 99% rename from jaxlib/ducc_fft_kernels.cc rename to jaxlib/cpu/ducc_fft_kernels.cc index 6bed7ffbfebf..789de83eb82c 100644 --- a/jaxlib/ducc_fft_kernels.cc +++ b/jaxlib/cpu/ducc_fft_kernels.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" -#include "jaxlib/ducc_fft_generated.h" +#include "jaxlib/cpu/ducc_fft_generated.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" #include "ducc/src/ducc0/fft/fft.h" diff --git a/jaxlib/ducc_fft_kernels.h b/jaxlib/cpu/ducc_fft_kernels.h similarity index 86% rename from jaxlib/ducc_fft_kernels.h rename to jaxlib/cpu/ducc_fft_kernels.h index 2491e7e7a886..b0ababd7659b 100644 --- a/jaxlib/ducc_fft_kernels.h +++ b/jaxlib/cpu/ducc_fft_kernels.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef JAXLIB_CPU_DUCC_FFT_KERNELS_H_ +#define JAXLIB_CPU_DUCC_FFT_KERNELS_H_ + #include "tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { @@ -20,3 +23,5 @@ namespace jax { void DuccFft(void* out, void** in, XlaCustomCallStatus*); } // namespace jax + +#endif // JAXLIB_CPU_DUCC_FFT_KERNELS_H_ diff --git a/jaxlib/lapack.cc b/jaxlib/cpu/lapack.cc similarity index 99% rename from jaxlib/lapack.cc rename to jaxlib/cpu/lapack.cc index 36cd22c7dd0c..6d52d8d1f475 100644 --- a/jaxlib/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "jaxlib/kernel_pybind11_helpers.h" -#include "jaxlib/lapack_kernels.h" +#include "jaxlib/cpu/lapack_kernels.h" #include "include/pybind11/pybind11.h" namespace jax { diff --git a/jaxlib/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc similarity index 99% rename from jaxlib/lapack_kernels.cc rename to jaxlib/cpu/lapack_kernels.cc index 6d29613fa246..afe1466832fa 100644 --- a/jaxlib/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/lapack_kernels.h" +#include "jaxlib/cpu/lapack_kernels.h" #include #include diff --git a/jaxlib/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h similarity index 98% rename from jaxlib/lapack_kernels.h rename to jaxlib/cpu/lapack_kernels.h index 17ac6e160afa..8df85741f096 100644 --- a/jaxlib/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_LAPACK_KERNELS_H_ -#define JAXLIB_LAPACK_KERNELS_H_ +#ifndef JAXLIB_CPU_LAPACK_KERNELS_H_ +#define JAXLIB_CPU_LAPACK_KERNELS_H_ #include #include @@ -178,4 +178,4 @@ struct ComplexGees { } // namespace jax -#endif // JAXLIB_LAPACK_KERNELS_H_ +#endif // JAXLIB_CPU_LAPACK_KERNELS_H_ diff --git a/jaxlib/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc similarity index 99% rename from jaxlib/lapack_kernels_using_lapack.cc rename to jaxlib/cpu/lapack_kernels_using_lapack.cc index 908e8cdfeb63..9c380828eeb0 100644 --- a/jaxlib/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/lapack_kernels.h" +#include "jaxlib/cpu/lapack_kernels.h" // From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but // a C++ user should link against LAPACK directly. This is needed when using diff --git a/jaxlib/ducc_fft.py b/jaxlib/ducc_fft.py index f805234fee5f..c4e5477bf545 100644 --- a/jaxlib/ducc_fft.py +++ b/jaxlib/ducc_fft.py @@ -19,7 +19,7 @@ from .mhlo_helpers import custom_call -from . import _ducc_fft +from .cpu import _ducc_fft import numpy as np from jaxlib import xla_client diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index b842140eab68..8a248f12838c 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -22,7 +22,7 @@ from jaxlib import xla_client from .mhlo_helpers import custom_call -from . import _lapack +from .cpu import _lapack for _name, _value in _lapack.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="cpu") diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 6a4cfa953733..6cebc0deb3ac 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -61,6 +61,7 @@ def has_ext_modules(self): '*.so', '*.pyd*', 'py.typed', + 'cpu/*', 'cuda/*', 'cuda/nvvm/libdevice/libdevice*', 'mlir/*.py', From 28def736d1ed22d0638424f8e003ef0f4f3b5b15 Mon Sep 17 00:00:00 2001 From: Ikko Ashimine Date: Tue, 25 Oct 2022 03:26:48 +0900 Subject: [PATCH 13/18] Fix typo in 9419-jax-versioning.md overriden -> overridden --- docs/jep/9419-jax-versioning.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index 9e363b86a3e5..598c39f5155d 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -149,7 +149,7 @@ level. [as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE). To update the version of XLA used during the build, one must update the pinned version in the Bazel `WORKSPACE`. This is done manually on an -as-needed basis, but can be overriden on a build-by-build basis. +as-needed basis, but can be overridden on a build-by-build basis. ## How do we make changes across the `jax` and `jaxlib` boundary between releases? From 9ade89ea62e59e0d771903d65325d6fed60f6c4c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 24 Oct 2022 14:10:31 -0700 Subject: [PATCH 14/18] jnp.linalg.lstsq: handle zero-size inputs --- jax/_src/numpy/linalg.py | 25 +++++++++++++++---------- tests/linalg_test.py | 3 +++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 49cbbeb51f1f..0c5f07df4bd2 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -601,17 +601,22 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *, f"{b.ndim}-dimensional array given. Array must be one or two-dimensional") m, n = a.shape dtype = a.dtype - if rcond is None: - rcond = jnp.finfo(dtype).eps * max(n, m) + if a.size == 0: + s = jnp.empty(0, dtype=a.dtype) + rank = jnp.array(0, dtype=int) + x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype) else: - rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) - u, s, vt = svd(a, full_matrices=False) - mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0] - rank = mask.sum() - safe_s = jnp.where(mask, s, 1).astype(a.dtype) - s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] - uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) - x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) + if rcond is None: + rcond = jnp.finfo(dtype).eps * max(n, m) + else: + rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) + u, s, vt = svd(a, full_matrices=False) + mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0] + rank = mask.sum() + safe_s = jnp.where(mask, s, 1).astype(a.dtype) + s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] + uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) + x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) # Numpy returns empty residuals in some cases. To allow compilation, we # default to returning full residuals in all cases. if numpy_resid and (rank < n or m <= n): diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 0169380c0148..7767f4619b26 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -890,6 +890,9 @@ def testMultiDot(self, shapes, dtype): ((4, 6), (4,)), ((6, 6), (6, 1)), ((8, 6), (8, 4)), + ((0, 3), (0,)), + ((3, 0), (3,)), + ((3, 1), (3, 0)), ] ], rcond=[-1, None, 0.5], From 56d42c0edfccf091c7600f120d5e75a6d5158feb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 24 Oct 2022 14:21:33 -0700 Subject: [PATCH 15/18] [typing] annotate next batch of lax_numpy --- jax/_src/numpy/lax_numpy.py | 211 +++++++++++++++++++----------------- 1 file changed, 112 insertions(+), 99 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 5ef7f089ebe8..1586f562eb9f 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -29,7 +29,8 @@ from functools import partial import operator import types -from typing import overload, Any, Callable, Dict, Sequence, FrozenSet, Optional, Tuple, Union +from typing import ( + overload, Any, Callable, Dict, Sequence, FrozenSet, List, Optional, Tuple, Union) from textwrap import dedent as _dedent import warnings @@ -135,7 +136,7 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: set_printoptions = np.set_printoptions @_wraps(np.iscomplexobj) -def iscomplexobj(x): +def iscomplexobj(x: Any) -> bool: try: typ = x.dtype.type except AttributeError: @@ -431,7 +432,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, range = [arr.min(), arr.max()] range = asarray(range, dtype=dtype) if shape(range) != (2,): - raise ValueError("`range` must be either None or a sequence of scalars.") + raise ValueError(f"`range` must be either None or a sequence of scalars, got {range}") range = (where(ptp(range) == 0, range[0] - 0.5, range[0]), where(ptp(range) == 0, range[1] + 0.5, range[1])) assert range is not None @@ -463,10 +464,13 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, return counts, bin_edges @_wraps(np.histogram2d) -def histogram2d(x, y, bins=10, range=None, weights=None, density=None): +def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10, + range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]]=None, + weights: Optional[ArrayLike] = None, + density: Optional[bool] = None) -> Tuple[Array, Array, Array]: _check_arraylike("histogram2d", x, y) try: - N = len(bins) + N = len(bins) # type: ignore[arg-type] except TypeError: N = 1 @@ -479,47 +483,51 @@ def histogram2d(x, y, bins=10, range=None, weights=None, density=None): return hist, edges[0], edges[1] @_wraps(np.histogramdd) -def histogramdd(sample, bins=10, range=None, weights=None, density=None): +def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10, + range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = None, + weights: Optional[ArrayLike] = None, + density: Optional[bool] = None) -> Tuple[Array, List[Array]]: if weights is None: _check_arraylike("histogramdd", sample) sample, = _promote_dtypes_inexact(sample) else: _check_arraylike("histogramdd", sample, weights) - if weights.shape != sample.shape[:1]: + if shape(weights) != shape(sample)[:1]: raise ValueError("should have one weight for each sample.") sample, weights = _promote_dtypes_inexact(sample, weights) N, D = shape(sample) if range is not None and ( - len(range) != D or _any(r is not None and len(r) != 2 for r in range)): + len(range) != D or _any(r is not None and shape(r)[0] != 2 for r in range)): # type: ignore[arg-type] raise ValueError(f"For sample.shape={(N, D)}, range must be a sequence " f"of {D} pairs or Nones; got range={range}") try: - num_bins = len(bins) - if num_bins != D: - raise ValueError("should be a bin for each dimension.") + num_bins = len(bins) # type: ignore[arg-type] except TypeError: # when bin_size is integer, the same bin is used for each dimension - bins = D * [bins] + bins_per_dimension: List[ArrayLike] = D * [bins] # type: ignore[assignment] + else: + if num_bins != D: + raise ValueError("should be a bin for each dimension.") + bins_per_dimension = list(bins) # type: ignore[arg-type] - bin_idx_by_dim = D * [None] - nbins = np.empty(D, int) - bin_edges_by_dim = D * [None] - dedges = D * [None] + bin_idx_by_dim: List[Array] = [] + bin_edges_by_dim: List[Array] = [] for i in builtins.range(D): range_i = None if range is None else range[i] - bin_edges = histogram_bin_edges(sample[:, i], bins[i], range_i, weights) + bin_edges = histogram_bin_edges(sample[:, i], bins_per_dimension[i], range_i, weights) bin_idx = searchsorted(bin_edges, sample[:, i], side='right') bin_idx = where(sample[:, i] == bin_edges[-1], bin_idx - 1, bin_idx) - bin_idx_by_dim[i] = bin_idx - nbins[i] = len(bin_edges) + 1 - bin_edges_by_dim[i] = bin_edges - dedges[i] = diff(bin_edges_by_dim[i]) + bin_idx_by_dim.append(bin_idx) + bin_edges_by_dim.append(bin_edges) - xy = ravel_multi_index(bin_idx_by_dim, nbins, mode='clip') - hist = bincount(xy, weights, length=nbins.prod()) + nbins = tuple(len(bin_edges) + 1 for bin_edges in bin_edges_by_dim) + dedges = [diff(bin_edges) for bin_edges in bin_edges_by_dim] + + xy = ravel_multi_index(tuple(bin_idx_by_dim), nbins, mode='clip') + hist = bincount(xy, weights, length=_prod(nbins)) hist = reshape(hist, nbins) core = D*(slice(1, -1),) hist = hist[core] @@ -548,7 +556,7 @@ def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = None) -> Array: @_wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('k', 'axes')) -def rot90(m, k=1, axes=(0, 1)): +def rot90(m: ArrayLike, k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Array: _check_arraylike("rot90", m) ax1, ax2 = axes ax1 = _canonicalize_axis(ax1, ndim(m)) @@ -557,11 +565,11 @@ def rot90(m, k=1, axes=(0, 1)): raise ValueError("Axes must be different") # same as numpy error k = k % 4 if k == 0: - return m + return asarray(m) elif k == 2: return flip(flip(m, ax1), ax2) else: - perm = list(range(m.ndim)) + perm = list(range(ndim(m))) perm[ax1], perm[ax2] = perm[ax2], perm[ax1] if k == 1: return transpose(flip(m, ax2), perm) @@ -570,12 +578,12 @@ def rot90(m, k=1, axes=(0, 1)): @_wraps(np.flip, lax_description=_ARRAY_VIEW_DOC) -def flip(m, axis: Optional[Union[int, Tuple[int, ...]]] = None): - return _flip(m, _ensure_optional_axes(axis)) +def flip(m: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + _check_arraylike("flip", m) + return _flip(asarray(m), _ensure_optional_axes(axis)) @partial(jit, static_argnames=('axis',)) -def _flip(m, axis: Optional[Union[int, Tuple[int, ...]]] = None): - _check_arraylike("flip", m) +def _flip(m: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: if axis is None: return lax.rev(m, list(range(len(shape(m))))) axis = _ensure_index_tuple(axis) @@ -583,29 +591,31 @@ def _flip(m, axis: Optional[Union[int, Tuple[int, ...]]] = None): @_wraps(np.fliplr, lax_description=_ARRAY_VIEW_DOC) -def fliplr(m): - return _flip(m, 1) +def fliplr(m: ArrayLike) -> Array: + _check_arraylike("fliplr", m) + return _flip(asarray(m), 1) @_wraps(np.flipud, lax_description=_ARRAY_VIEW_DOC) -def flipud(m): - return _flip(m, 0) +def flipud(m: ArrayLike) -> Array: + _check_arraylike("flipud", m) + return _flip(asarray(m), 0) @_wraps(np.iscomplex) @jit -def iscomplex(x): +def iscomplex(x: ArrayLike) -> Array: i = imag(x) return lax.ne(i, _lax_const(i, 0)) @_wraps(np.isreal) @jit -def isreal(x): +def isreal(x: ArrayLike) -> Array: i = imag(x) return lax.eq(i, _lax_const(i, 0)) @_wraps(np.angle) @partial(jit, static_argnames=['deg']) -def angle(z, deg=False): +def angle(z: ArrayLike, deg: bool = False) -> Array: re = real(z) im = imag(z) dtype = _dtype(re) @@ -620,41 +630,44 @@ def angle(z, deg=False): @_wraps(np.diff) @partial(jit, static_argnames=('n', 'axis')) -def diff(a, n=1, axis: int = -1, prepend=None, append=None): +def diff(a: ArrayLike, n: int = 1, axis: int = -1, + prepend: Optional[ArrayLike] = None, + append: Optional[ArrayLike] = None) -> Array: _check_arraylike("diff", a) + arr = asarray(a) n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.diff") if n == 0: - return a + return arr if n < 0: raise ValueError(f"order must be non-negative but got {n}") - if ndim(a) == 0: + if arr.ndim == 0: raise ValueError(f"diff requires input that is at least one dimensional; got {a}") - nd = a.ndim + nd = arr.ndim axis = _canonicalize_axis(axis, nd) - combined = [] + combined: List[Array] = [] if prepend is not None: _check_arraylike("diff", prepend) if isscalar(prepend): - shape = list(a.shape) + shape = list(arr.shape) shape[axis] = 1 prepend = broadcast_to(prepend, tuple(shape)) - combined.append(prepend) + combined.append(asarray(prepend)) - combined.append(a) + combined.append(arr) if append is not None: _check_arraylike("diff", append) if isscalar(append): - shape = list(a.shape) + shape = list(arr.shape) shape[axis] = 1 append = broadcast_to(append, tuple(shape)) - combined.append(append) + combined.append(asarray(append)) if len(combined) > 1: - a = concatenate(combined, axis) + arr = concatenate(combined, axis) slice1 = [slice(None)] * nd slice2 = [slice(None)] * nd @@ -663,11 +676,11 @@ def diff(a, n=1, axis: int = -1, prepend=None, append=None): slice1_tuple = tuple(slice1) slice2_tuple = tuple(slice2) - op = not_equal if a.dtype == np.bool_ else subtract + op = not_equal if arr.dtype == np.bool_ else subtract for _ in range(n): - a = op(a[slice1_tuple], a[slice2_tuple]) + arr = op(arr[slice1_tuple], arr[slice2_tuple]) - return a + return arr _EDIFF1D_DOC = """\ Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not @@ -677,23 +690,25 @@ def diff(a, n=1, axis: int = -1, prepend=None, append=None): @_wraps(np.ediff1d, lax_description=_EDIFF1D_DOC) @jit -def ediff1d(ary, to_end=None, to_begin=None): +def ediff1d(ary: ArrayLike, to_end: Optional[ArrayLike] = None, + to_begin: Optional[ArrayLike] = None) -> Array: _check_arraylike("ediff1d", ary) - ary = ravel(ary) - result = lax.sub(ary[1:], ary[:-1]) + arr = ravel(ary) + result = lax.sub(arr[1:], arr[:-1]) if to_begin is not None: _check_arraylike("ediff1d", to_begin) - result = concatenate((ravel(asarray(to_begin, dtype=ary.dtype)), result)) + result = concatenate((ravel(asarray(to_begin, dtype=arr.dtype)), result)) if to_end is not None: _check_arraylike("ediff1d", to_end) - result = concatenate((result, ravel(asarray(to_end, dtype=ary.dtype)))) + result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype)))) return result @_wraps(np.gradient, skip_params=['edge_order']) @partial(jit, static_argnames=('axis', 'edge_order')) -def gradient(f, *varargs, axis: Optional[Union[int, Tuple[int, ...]]] = None, - edge_order=None): +def gradient(f: ArrayLike, *varargs, axis: Optional[Union[int, Tuple[int, ...]]] = None, + edge_order: Optional[int] = None) -> Union[Array, List[Array]]: + # TODO(jakevdp): call check_arraylike on f if edge_order is not None: raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") @@ -706,7 +721,7 @@ def gradient_along_axis(a, h, axis): ), axis) return a_grad / h - a = f + a = asarray(f) axis_tuple: Tuple[int, ...] if axis is None: axis_tuple = tuple(range(a.ndim)) @@ -741,19 +756,14 @@ def gradient_along_axis(a, h, axis): # TODO: use jax.lax loop tools if possible a_grad = [gradient_along_axis(a, h, ax) for ax, h in zip(axis_tuple, dx)] - - if len(axis_tuple) == 1: - a_grad = a_grad[0] - - return a_grad + return a_grad[0] if len(axis_tuple) == 1 else a_grad @_wraps(np.isrealobj) -def isrealobj(x): +def isrealobj(x: Any) -> bool: return not iscomplexobj(x) - @_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC) def reshape(a: ArrayLike, newshape: Shape, order: str = "C") -> Array: _stackable(a) or _check_arraylike("reshape", a) @@ -763,7 +773,7 @@ def reshape(a: ArrayLike, newshape: Shape, order: str = "C") -> Array: except AttributeError: return _reshape(asarray(a), newshape, order=order) -def _compute_newshape(a, newshape): +def _compute_newshape(a: ArrayLike, newshape: Shape) -> Shape: """Fixes a -1 value in newshape, if present.""" # other errors, like having more than one -1, are caught downstream, in # reshape_shape_rule. @@ -807,12 +817,13 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: @_wraps(np.ravel_multi_index) -def ravel_multi_index(multi_index, dims, mode='raise', order='C'): +def ravel_multi_index(multi_index: Tuple[ArrayLike, ...], dims: Tuple[int, ...], + mode: str = 'raise', order: str = 'C') -> Array: assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims) _check_arraylike("ravel_multi_index", *multi_index) - multi_index = [asarray(i) for i in multi_index] - for index in multi_index: + multi_index_arr = [asarray(i) for i in multi_index] + for index in multi_index_arr: if mode == 'raise': core.concrete_or_error(array, index, "The error occurred because ravel_multi_index was jit-compiled" @@ -820,12 +831,12 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'): if not issubdtype(_dtype(index), integer): raise TypeError("only int indices permitted") if mode == "raise": - if _any(any((i < 0) | (i >= d)) for i, d in zip(multi_index, dims)): + if _any(any((i < 0) | (i >= d)) for i, d in zip(multi_index_arr, dims)): raise ValueError("invalid entry in coordinates array") elif mode == "clip": - multi_index = [clip(i, 0, d - 1) for i, d in zip(multi_index, dims)] + multi_index_arr = [clip(i, 0, d - 1) for i, d in zip(multi_index_arr, dims)] elif mode == "wrap": - multi_index = [i % d for i, d in zip(multi_index, dims)] + multi_index_arr = [i % d for i, d in zip(multi_index_arr, dims)] else: raise ValueError(f"invalid mode={mode!r}. Expected 'raise', 'wrap', or 'clip'") @@ -836,9 +847,9 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'): else: raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'") - result = array(0, dtype=(multi_index[0].dtype if multi_index + result = array(0, dtype=(multi_index_arr[0].dtype if multi_index_arr else dtypes.canonicalize_dtype(int_))) - for i, s in zip(multi_index, strides): + for i, s in zip(multi_index_arr, strides): result = result + i * int(s) return result @@ -849,8 +860,9 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'): """ @_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) -def unravel_index(indices, shape): +def unravel_index(indices: ArrayLike, shape: Shape) -> Tuple[Array, ...]: _check_arraylike("unravel_index", indices) + indices_arr = asarray(indices) # Note: we do not convert shape to an array, because it may be passed as a # tuple of weakly-typed values, and asarray() would strip these weak types. try: @@ -861,39 +873,39 @@ def unravel_index(indices, shape): raise ValueError("unravel_index: shape should be a scalar or 1D sequence.") out_indices = [None] * len(shape) for i, s in reversed(list(enumerate(shape))): - indices, out_indices[i] = divmod(indices, s) - oob_pos = indices > 0 - oob_neg = indices < -1 + indices_arr, out_indices[i] = divmod(indices_arr, s) + oob_pos = indices_arr > 0 + oob_neg = indices_arr < -1 return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i)) for s, i in zip(shape, out_indices)) @_wraps(np.resize) @partial(jit, static_argnames=('new_shape',)) -def resize(a, new_shape): +def resize(a: ArrayLike, new_shape: Shape) -> Array: _check_arraylike("resize", a) new_shape = _ensure_index_tuple(new_shape) if _any(dim_length < 0 for dim_length in new_shape): raise ValueError("all elements of `new_shape` must be non-negative") - a = ravel(a) + arr = ravel(a) new_size = _prod(new_shape) - if a.size == 0 or new_size == 0: - return zeros_like(a, shape=new_shape) + if arr.size == 0 or new_size == 0: + return zeros_like(arr, shape=new_shape) - repeats = ceil_of_ratio(new_size, a.size) - a = tile(a, repeats)[:new_size] + repeats = ceil_of_ratio(new_size, arr.size) + arr = tile(arr, repeats)[:new_size] - return reshape(a, new_shape) + return reshape(arr, new_shape) @_wraps(np.squeeze, lax_description=_ARRAY_VIEW_DOC) -def squeeze(a, axis: Optional[Union[int, Tuple[int, ...]]] = None): - return _squeeze(a, _ensure_index_tuple(axis) if axis is not None else None) +def squeeze(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + _check_arraylike("squeeze", a) + return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None) @partial(jit, static_argnames=('axis',), inline=True) -def _squeeze(a, axis): - _check_arraylike("squeeze", a) +def _squeeze(a: Array, axis: Tuple[int]) -> Array: if axis is None: a_shape = shape(a) axis = tuple(i for i, d in enumerate(a_shape) if d == 1) @@ -901,17 +913,17 @@ def _squeeze(a, axis): @_wraps(np.expand_dims) -def expand_dims(a, axis: Union[int, Sequence[int]]): +def expand_dims(a: ArrayLike, axis: Union[int, Sequence[int]]) -> Array: _stackable(a) or _check_arraylike("expand_dims", a) axis = _ensure_index_tuple(axis) if hasattr(a, "expand_dims"): - return a.expand_dims(axis) + return a.expand_dims(axis) # type: ignore return lax.expand_dims(a, axis) @_wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('axis1', 'axis2'), inline=True) -def swapaxes(a, axis1: int, axis2: int): +def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: _check_arraylike("swapaxes", a) perm = np.arange(ndim(a)) perm[axis1], perm[axis2] = perm[axis2], perm[axis1] @@ -919,14 +931,14 @@ def swapaxes(a, axis1: int, axis2: int): @_wraps(np.moveaxis, lax_description=_ARRAY_VIEW_DOC) -def moveaxis(a, source: Union[int, Sequence[int]], - destination: Union[int, Sequence[int]]): - return _moveaxis(a, _ensure_index_tuple(source), +def moveaxis(a: ArrayLike, source: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]]) -> Array: + _check_arraylike("moveaxis", a) + return _moveaxis(asarray(a), _ensure_index_tuple(source), _ensure_index_tuple(destination)) @partial(jit, static_argnames=('source', 'destination'), inline=True) -def _moveaxis(a, source: Tuple[int, ...], destination: Tuple[int, ...]): - _check_arraylike("moveaxis", a) +def _moveaxis(a: Array, source: Tuple[int, ...], destination: Tuple[int, ...]) -> Array: source = tuple(_canonicalize_axis(i, ndim(a)) for i in source) destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination) if len(source) != len(destination): @@ -940,7 +952,8 @@ def _moveaxis(a, source: Tuple[int, ...], destination: Tuple[int, ...]): @_wraps(np.isclose) @partial(jit, static_argnames=('equal_nan',)) -def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): +def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, + equal_nan: bool = False) -> Array: a, b = _promote_args("isclose", a, b) dtype = _dtype(a) if issubdtype(dtype, inexact): From 21d02acbaac946c9090e75f713294091485137cc Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Fri, 21 Oct 2022 20:53:51 +0000 Subject: [PATCH 16/18] self_hosted_runner_utils --- .../self_hosted_runner_utils/runner.env | 1 + .../start_github_runner.sh | 20 ++++++++ .../self_hosted_runner_utils/validate_job.sh | 47 +++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 .github/workflows/self_hosted_runner_utils/runner.env create mode 100755 .github/workflows/self_hosted_runner_utils/start_github_runner.sh create mode 100644 .github/workflows/self_hosted_runner_utils/validate_job.sh diff --git a/.github/workflows/self_hosted_runner_utils/runner.env b/.github/workflows/self_hosted_runner_utils/runner.env new file mode 100644 index 000000000000..741fd558c578 --- /dev/null +++ b/.github/workflows/self_hosted_runner_utils/runner.env @@ -0,0 +1 @@ +ACTIONS_RUNNER_HOOK_JOB_STARTED=~/jax/.github/workflows/self_hosted_runner_utils/validate_job.sh diff --git a/.github/workflows/self_hosted_runner_utils/start_github_runner.sh b/.github/workflows/self_hosted_runner_utils/start_github_runner.sh new file mode 100755 index 000000000000..7f2d00109a1d --- /dev/null +++ b/.github/workflows/self_hosted_runner_utils/start_github_runner.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# More or less copied from +# https://github.com/iree-org/iree/tree/main/build_tools/github_actions/runner/config + +~/actions-runner/run.sh > /tmp/actions-runner.`date +"%Y%m%d-%H%M"`.log diff --git a/.github/workflows/self_hosted_runner_utils/validate_job.sh b/.github/workflows/self_hosted_runner_utils/validate_job.sh new file mode 100644 index 000000000000..7f47f4b638cb --- /dev/null +++ b/.github/workflows/self_hosted_runner_utils/validate_job.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# More or less copied from +# https://github.com/iree-org/iree/tree/main/build_tools/github_actions/runner/config + +set -euo pipefail + +ALLOWED_EVENTS=( + "schedule" + "workflow_dispatch" +) + +# Tests if the first argument is contained in the array in the second argument. +# Usage `is_contained "element" "${array[@]}"` +is_contained() { + local e; + local match="$1" + shift + for e in "$@"; do + if [[ "${e}" == "${match}" ]]; then + return 0 + fi + done + return 1 +} + +if ! is_contained "${GITHUB_EVENT_NAME}" "${ALLOWED_EVENTS[@]}"; then + echo "Event type '${GITHUB_EVENT_NAME}' is not allowed on this runner. Aborting workflow." + # clean up any nefarious stuff we may have fetched in job setup. + cd ~/actions-runner/_work + rm -rfv _actions/ _temp/ + exit 1 +fi From be12bfd071340bca12a642681d088502a5e442ae Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Fri, 21 Oct 2022 21:25:05 +0000 Subject: [PATCH 17/18] cloud-tpu-ci-nightly.yml --- .github/workflows/cloud-tpu-ci-nightly.yml | 69 ++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 .github/workflows/cloud-tpu-ci-nightly.yml diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml new file mode 100644 index 000000000000..a1138e2276dc --- /dev/null +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -0,0 +1,69 @@ +name: Cloud-TPU-CI-Nightly + +on: + schedule: + - cron: "0 12 * * *" # Daily at 12:00 UTC + workflow_dispatch: # allows triggering the workflow run manually + repository_dispatch: # allows triggering the workflow via HTTP + +jobs: + cloud-tpu-test: + runs-on: tpu + defaults: + run: + shell: bash -l {0} + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] # TODO(jakevdp): update to 3.11 when available. + outputs: + artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }} + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install JAX test requirements + run: | + pip install -r build/test-requirements.txt + pip install pytest-reportlog + - name: Install JAX + run: | + pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + - name: Run tests + if: success() + id: status + env: + JAX_PLATFORMS: tpu,cpu + run: | + pytest --tb=short \ + --report-log output-${{ matrix.python-version }}-log.jsonl \ + tests/compilation_cache_test.py \ + || ( + echo '::set-output name=ARTIFACTS_AVAILABLE::true' && false + ) + # run: | + # pytest --tb=short \ + # --deselect tests/callback_test.py \ + # --deselect tests/checkify_test.py \ + # --deselect tests/debugger_test.py \ + # --deselect tests/debugging_primitives_test.py \ + # --deselect tests/jaxpr_effects_test.py-rf \ + # --report-log output-${{ matrix.python-version }}-log.jsonl \ + # tests \ + # || ( + # echo '::set-output name=ARTIFACTS_AVAILABLE::true' && false + # ) + - name: Upload artifacts + # if: | + # failure() + # && steps.status.outcome == 'failure' + # && github.event_name == 'schedule' + # && github.repository == 'google/jax' + if: failure() + uses: actions/upload-artifact@v3 + with: + name: output-${{ matrix.python-version }}-log.jsonl + path: output-${{ matrix.python-version }}-log.jsonl + retention-days: 5 From cd02f622412dabb40f8a087e8a4a0f9f073b84c7 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 24 Oct 2022 23:32:09 +0000 Subject: [PATCH 18/18] add pull_request event --- .github/workflows/cloud-tpu-ci-nightly.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index a1138e2276dc..25e8918d4a70 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -5,6 +5,7 @@ on: - cron: "0 12 * * *" # Daily at 12:00 UTC workflow_dispatch: # allows triggering the workflow run manually repository_dispatch: # allows triggering the workflow via HTTP + pull_request: jobs: cloud-tpu-test: