diff --git a/src/complex_tensor/ops/__init__.py b/src/complex_tensor/ops/__init__.py index 7bb1163..3fea279 100644 --- a/src/complex_tensor/ops/__init__.py +++ b/src/complex_tensor/ops/__init__.py @@ -1,9 +1,11 @@ __all__ = [ "aten", + "prims", + "_c10d_functional", "COMPLEX_OPS_TABLE", "FORCE_TEST_LIST", "lookup_complex", ] -from . import aten +from . import _c10d_functional, aten, prims from ._common import COMPLEX_OPS_TABLE, FORCE_TEST_LIST, lookup_complex diff --git a/src/complex_tensor/ops/_c10d_functional.py b/src/complex_tensor/ops/_c10d_functional.py new file mode 100644 index 0000000..12ca2cf --- /dev/null +++ b/src/complex_tensor/ops/_c10d_functional.py @@ -0,0 +1,18 @@ +import torch + +from ._common import ( + register_force_test, + register_simple, +) + +_c10d_functional = torch.ops._c10d_functional + +# TODO (hameerabbasi): Not being tested +broadcast_impl = register_force_test( + _c10d_functional.broadcast, register_simple(_c10d_functional.broadcast) +) + +# TODO (hameerabbasi): Not being tested +broadcast__impl = register_force_test( + _c10d_functional.broadcast_, register_simple(_c10d_functional.broadcast_) +) diff --git a/src/complex_tensor/ops/_common.py b/src/complex_tensor/ops/_common.py index 01d2b43..5d2196f 100644 --- a/src/complex_tensor/ops/_common.py +++ b/src/complex_tensor/ops/_common.py @@ -4,7 +4,8 @@ import torch from torch._ops import OpOverloadPacket from torch._refs import is_complex -from torch.utils._pytree import tree_flatten, tree_unflatten +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from ..complex_tensor import ComplexTensor @@ -63,6 +64,11 @@ def register_complex( """Decorator to register an implementation for some ops in some dispatch tables""" def inner(func): + if COMPLEX_OPS_TABLE.get(op, func) is not func: + raise RuntimeError( + "Attempted to register multiple functions for " + f"{op._qualified_op_name.replace('::', '.')}" + ) COMPLEX_OPS_TABLE[op] = func return func @@ -131,7 +137,7 @@ def ordered_impl(*args, **kwargs): def register_binary_nonlinear(op: OpType) -> Callable: def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: - a_r, a_i = split_complex_tensor(lhs) + a_r, a_i = split_complex_arg(lhs) b_r, b_i = split_complex_arg(rhs) out_dt, (a_r, a_i, b_r, b_i) = promote_real_cpu_tensors(a_r, a_i, b_r, b_i) real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs) @@ -146,10 +152,19 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens def register_simple(op: OpType): - def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> ComplexTensor: x, y = split_complex_tensor(self) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + raise RuntimeError("Non-complex `dtype` specified, please write custom impl.") + + if dtype in COMPLEX_TO_REAL: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + u = op(x, *args, **kwargs) v = op(y, *args, **kwargs) + u_flat, u_spec = tree_flatten(u) v_flat, v_spec = tree_flatten(v) assert u_spec == v_spec @@ -161,3 +176,37 @@ def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: impl.__qualname__ = func_name return register_complex(op, impl) + + +def _as_complex_tensor(arg: torch.Tensor | Any) -> torch.Tensor | ComplexTensor | Any: + if ( + not isinstance(arg, ComplexTensor) + and isinstance(arg, torch.Tensor) + and arg.dtype in COMPLEX_TO_REAL + ): + return ComplexTensor.from_interleaved(arg) + return arg + + +def _as_interleaved(arg: ComplexTensor | Any) -> torch.Tensor | Any: + if isinstance(arg, ComplexTensor): + return arg.as_interleaved() + return arg + + +class ComplexDispatchMode(TorchDispatchMode): + def __init__(self, _dispatch_key=None, *, _compile=False): + super().__init__(_dispatch_key) + self._compile = _compile + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if compile: + func = torch.compile(func) + + args = tree_map(_as_complex_tensor, args) + kwargs = tree_map(_as_complex_tensor, kwargs) + + return tree_map(_as_interleaved, func(*args, **kwargs)) diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index 268db1c..d1fa23e 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import torch @@ -33,7 +33,7 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens alpha = kwargs.pop("alpha", None) if alpha is not None: return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs) - a_r, a_i = split_complex_tensor(lhs) + a_r, a_i = split_complex_arg(lhs) b_r, b_i = split_complex_arg(rhs) out_dt, (a_r, a_i, b_r, b_i) = promote_real_cpu_tensors(a_r, a_i, b_r, b_i) u = op(a_r, b_r, *args, **kwargs) @@ -62,31 +62,64 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b return self.is_pinned(device) -slice_impl = register_simple(aten.slice) -flatten_impl = register_simple(aten.flatten) -view_impl = register_simple(aten.view) -diagonal_impl = register_simple(aten.diagonal) -expand_impl = register_simple(aten.expand) -unsqueeze_impl = register_simple(aten.unsqueeze) -mean_impl = register_simple(aten.mean) -sum_impl = register_simple(aten.sum) -clone_impl = register_simple(aten.clone) -neg_impl = register_simple(aten.neg) -flip_impl = register_simple(aten.flip) -permute_impl = register_simple(aten.permute) -repeat_impl = register_simple(aten.repeat) -index_select_impl = register_simple(aten.index_select) -split_with_sizes_impl = register_simple(aten.split_with_sizes) -cumsum_impl = register_simple(aten.cumsum) +SIMPLE_OPS_LIST = [ + aten.slice, + aten.flatten, + aten.view, + aten.diagonal, + aten.expand, + aten.unsqueeze, + aten.unsqueeze_, + aten.mean, + aten.sum, + aten.clone, + aten.neg, + aten.flip, + aten.permute, + aten.repeat, + aten.index_select, + aten.split, + aten.split_with_sizes, + aten.cumsum, + aten.detach, + aten.select, + aten.squeeze, + aten.zero_, + aten.transpose, + aten.t, + aten.gather, +] + +for simple_op in SIMPLE_OPS_LIST: + globals()[f"{str(simple_op).split('.', 1)}_impl"] = register_simple(simple_op) # TODO (hameerabbasi): Not being tested -copy_impl = register_force_test(aten.copy, register_simple(aten.copy)) -# TODO (hameerabbasi): Not being tested -_to_copy_impl = register_force_test(aten._to_copy, register_simple(aten._to_copy)) -# TODO (hameerabbasi): Not being tested -col2im_impl = register_force_test(aten.col2im, register_simple(aten.col2im)) -# TODO (hameerabbasi): Not being tested -alias_impl = register_force_test(aten.alias, register_simple(aten.alias)) +SIMPLE_FORCE_TESTED_OPS = [ + aten.copy, + aten._to_copy, + aten.col2im, + aten.alias, + aten.lift_fresh, + aten._unsafe_view, + aten.index_put_, + aten.index, + aten._neg_view, + aten.avg_pool2d, + aten.avg_pool3d, + aten.avg_pool2d_backward, + aten.avg_pool3d_backward, + aten.masked_scatter_backward, + aten.select_backward, + aten.slice_backward, + aten.embedding, +] + +for simple_op in SIMPLE_FORCE_TESTED_OPS: + globals()[f"{str(simple_op).split('.', 1)}_impl"] = register_force_test( + simple_op, register_simple(simple_op) + ) + +del simple_op # some binary ops which we can stamp out mul_impl = register_binary_nonlinear(aten.mul) @@ -99,11 +132,20 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b aten.convolution, register_binary_nonlinear(aten.convolution) ) +slice_scatter_impl = register_force_test( + aten.slice_scatter, register_binary_linear(aten.slice_scatter) +) +select_scatter_impl = register_force_test( + aten.select_scatter, register_binary_linear(aten.select_scatter) +) + add_impl = register_binary_linear(aten.add) sub_impl = register_binary_linear(aten.sub) +diagonal_scatter_impl = register_binary_linear(aten.diagonal_scatter) @register_complex(aten.div) +@register_complex(aten.true_divide) def div_impl(lhs: ComplexTensor, rhs: ComplexTensor, *, rounding_mode=None): a_r, a_i = split_complex_tensor(lhs) b_r, b_i = split_complex_arg(rhs) @@ -141,6 +183,23 @@ def prod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: return ComplexTensor(u, v) +@register_complex(aten.pow) +def pow_impl(self: ComplexTensor, exponent: ComplexTensor) -> ComplexTensor: + return torch.exp(exponent * torch.log(self)) + + +@register_complex(aten.cumprod) +def cumprod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + dtype = kwargs.pop("dtype", self.dtype) + kwargs["dtype"] = complex_to_real_dtype(dtype) + + prod_r = torch.cumprod(torch.abs(self), *args, **kwargs) + sum_phi = torch.cumsum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v) + + # unary funcs, # most of these are simple or require some kind of identity @register_complex(aten.abs) @@ -170,7 +229,7 @@ def acos_impl(self: ComplexTensor) -> ComplexTensor: y2 = y**2 a = (x**2) + y2 b = torch.sqrt((a - 1) ** 2 + 4 * y2) - t = (a - 1 + b) / 2 + t = 0.5 * (a - 1 + b) u = torch.acos(x / torch.sqrt(1 + t)) v = torch.asinh(-torch.sign(y) * torch.sqrt(t)) @@ -180,19 +239,54 @@ def acos_impl(self: ComplexTensor) -> ComplexTensor: @register_complex(aten.asin) def asin_impl(self: ComplexTensor) -> ComplexTensor: x, y = split_complex_tensor(self) - out_dt, (x, y) = promote_real_cpu_tensors(x, y) + y2 = y**2 a = (x**2) + y2 b = torch.sqrt((a - 1) ** 2 + 4 * y2) - t = (a - 1 + b) / 2 + t = 0.5 * (a - 1 + b) - u = torch.arcsin(x / torch.sqrt(1 + t)) - v = torch.arcsinh(torch.sign(y) * torch.sqrt(t)) + u = torch.asin(x / torch.sqrt(1 + t)) + v = torch.asinh(torch.sign(y) * torch.sqrt(t)) return ComplexTensor(u.to(out_dt), v.to(out_dt)) +@register_complex(aten.atan) +def atan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_real_cpu_tensors(x, y) + + int1 = torch.log(ComplexTensor(-x, 1 - y) / ComplexTensor(x, 1 + y)) + int1_re, int1_im = split_complex_tensor(int1) + + out_re = 0.5 * int1_im + out_im = -0.5 * int1_re + + return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt)) + + +@register_complex(aten.asinh) +def asinh_impl(self: ComplexTensor) -> ComplexTensor: + return torch.log(self + torch.sqrt(self * self + 1)) + + +@register_complex(aten.acosh) +def acosh_impl(self: ComplexTensor) -> ComplexTensor: + return torch.log(self + torch.sqrt(self * self - 1)) + + +@register_complex(aten.atanh) +def atanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_real_cpu_tensors(x, y) + + ret = 0.5 * (torch.log(ComplexTensor(1 + x, y)) - torch.log(ComplexTensor(1 - x, -y))) + ret_re, ret_im = split_complex_tensor(ret) + + return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + + @register_complex(aten.cos) def cos_impl(self: ComplexTensor) -> ComplexTensor: x, y = split_complex_tensor(self) @@ -204,11 +298,48 @@ def cos_impl(self: ComplexTensor) -> ComplexTensor: @register_complex(aten.cosh) def cosh_impl(self: ComplexTensor) -> ComplexTensor: - x, y = split_complex_tensor(self) - out_dt, (x, y) = promote_real_cpu_tensors(x, y) - u = torch.cosh(x) * torch.cos(y) - v = torch.sinh(x) * torch.sin(y) - return ComplexTensor(u.to(out_dt), v.to(out_dt)) + exp_x = torch.exp(self) + exp_nx = torch.reciprocal(exp_x) + return 0.5 * (exp_x + exp_nx) + + +@register_complex(aten.sin) +def sin_impl(self: ComplexTensor) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im) + + ret_re = torch.sin(self_re) * torch.cosh(self_im) + ret_im = torch.cos(self_re) * torch.sinh(self_im) + + return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + + +@register_complex(aten.sinh) +def sinh_impl(self: ComplexTensor) -> ComplexTensor: + exp_x = torch.exp(self) + exp_nx = torch.reciprocal(exp_x) + return 0.5 * (exp_x - exp_nx) + + +@register_complex(aten.tan) +def tan_impl(self: ComplexTensor) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im) + + cos_x = torch.cos(self_re) + sinh_y = torch.sinh(self_im) + + num_re = torch.sin(self_re) * cos_x + num_im = sinh_y * torch.cosh(self_im) + + den = cos_x * cos_x + sinh_y * sinh_y + + return ComplexTensor((num_re / den).to(out_dt), (num_im / den).to(out_dt)) + + +@register_complex(aten.tanh) +def tanh_impl(self: ComplexTensor) -> ComplexTensor: + return torch.sinh(self) / torch.cosh(self) @register_complex(aten.exp) @@ -232,6 +363,20 @@ def expm1_impl(self: ComplexTensor) -> ComplexTensor: return ComplexTensor(u.to(out_dt), v.to(out_dt)) +@register_complex(aten.log) +def log_impl(self: ComplexTensor) -> ComplexTensor: + re = torch.log(torch.abs(self)) + im = torch.angle(self) + return ComplexTensor(re, im) + + +@register_complex(aten.log1p) +def log1p_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + # TODO (hameerabbasi): The line below may have numerical issues + return torch.log(ComplexTensor(x + 1, y)) + + @register_complex(aten.any) def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: x, y = split_complex_tensor(self) @@ -246,7 +391,7 @@ def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: @register_complex(aten.eq) def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: - a_r, a_i = split_complex_tensor(self) + a_r, a_i = split_complex_arg(self) b_r, b_i = split_complex_arg(rhs) return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs) @@ -318,6 +463,7 @@ def isclose_impl( aten.sort, aten.topk, aten.round, + aten.fmod, ] @@ -351,23 +497,6 @@ def masked_scatter_impl( return ComplexTensor(ret_r, ret_i) -@register_force_test(aten.slice_scatter) -def slice_scatter_impl( - self: ComplexTensor, - source: ComplexTensor, - dim: int = 0, - start: int | None = None, - end: int | None = None, - step: int = 1, -) -> ComplexTensor: - self_r, self_i = split_complex_tensor(self) - source_r, source_i = split_complex_arg(source) - ret_r = torch.slice_scatter(self_r, source_r, dim=dim, start=start, end=end, step=step) - ret_i = torch.slice_scatter(self_i, source_i, dim=dim, start=start, end=end, step=step) - - return ComplexTensor(ret_r, ret_i) - - @register_complex(aten.index_put) def index_put_impl( self: ComplexTensor, @@ -395,16 +524,61 @@ def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> Comple @register_complex(aten.full_like) -def full_like_impl(input: ComplexTensor, fill_value: complex, *args, **kwargs) -> ComplexTensor: +def full_like_impl( + input: ComplexTensor, fill_value: complex, *args, dtype: torch.dtype | None = None, **kwargs +) -> torch.Tensor | ComplexTensor: + # Note: Cannot be merged with the cases below due to the `fill_value` argument input_r, input_i = split_complex_tensor(input) - fv_r, fv_i = split_complex_arg(fill_value) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + fv_r, fv_i = split_complex_arg(fill_value) ret_r = torch.full_like(input_r, fv_r, *args, **kwargs) ret_i = torch.full_like(input_i, fv_i, *args, **kwargs) return ComplexTensor(ret_r, ret_i) +def register_like(op: OpType) -> Callable[..., torch.Tensor | ComplexTensor]: + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> torch.Tensor | ComplexTensor: + self_re, self_im = split_complex_tensor(self) + + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return op(self_re, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + ret_re = op(self_re, *args, **kwargs) + ret_im = op(self_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + func_name = f"{str(op).split('.', 1)}_impl" + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +LIKE_OPS_LIST = [ + aten.empty_like, + aten.zeros_like, + aten.randn_like, + aten.new_zeros, +] + +for like_op in LIKE_OPS_LIST: + globals()[f"{str(like_op).split('.', 1)}_impl"] = register_like(like_op) + +del like_op + + @register_complex(aten.cat) def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor: tensors_r = [] @@ -425,8 +599,8 @@ def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor: def sgn_impl(self: ComplexTensor) -> ComplexTensor: self_r, self_i = split_complex_tensor(self) out_dt, (self_r, self_i) = promote_real_cpu_tensors(self_r, self_i) - mask = self != 0 abs_self = torch.abs(ComplexTensor(self_r, self_i)) + mask = abs_self != 0 masked_sgn = ComplexTensor( torch.div(self_r, abs_self).to(out_dt), torch.div(self_i, abs_self).to(out_dt) ) @@ -439,7 +613,7 @@ def sqrt_impl(self: ComplexTensor) -> ComplexTensor: out_dt, (self_r, self_i) = promote_real_cpu_tensors(self_r, self_i) self = ComplexTensor(self_r, self_i) self_abs_sqrt = torch.sqrt(torch.abs(self)) - self_half_angle = torch.angle(self) / 2 + self_half_angle = 0.5 * torch.angle(self) ret_r = self_abs_sqrt * torch.cos(self_half_angle) ret_i = self_abs_sqrt * torch.sin(self_half_angle) @@ -453,7 +627,7 @@ def rsqrt_impl(self: ComplexTensor) -> ComplexTensor: out_dt, (self_r, self_i) = promote_real_cpu_tensors(self_r, self_i) self = ComplexTensor(self_r, self_i) self_abs_rsqrt = torch.rsqrt(torch.abs(self)) - self_neg_half_angle = -torch.angle(self) / 2 + self_neg_half_angle = -0.5 * torch.angle(self) ret_r = self_abs_rsqrt * torch.cos(self_neg_half_angle) ret_i = self_abs_rsqrt * torch.sin(self_neg_half_angle) @@ -470,7 +644,7 @@ def addmm_impl( beta: complex = 1, alpha: complex = 1, ) -> ComplexTensor: - ret = alpha * input + beta * torch.mm(mat1, mat2) + ret = beta * input + alpha * torch.mm(mat1, mat2) ret_r, ret_i = split_complex_tensor(ret) if out_dtype is not None: out_dtype = COMPLEX_TO_REAL[out_dtype] @@ -502,3 +676,143 @@ def nonzero_impl(self: ComplexTensor, other: ComplexTensor, *args, **kwargs) -> @register_complex(aten.logical_not) def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: return torch.logical_not(elemwise_nonzero(self), *args, **kwargs) + + +@register_complex(aten.view_as_real) +def view_as_real_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.stack([re, im], dim=-1) + + +@register_complex(aten.linalg_vector_norm) +def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs) + + +@register_force_test(aten.copy_) +def copy__impl(self: ComplexTensor, src, *args, **kwargs): + self_re, self_im = split_complex_tensor(self) + src_re, src_im = split_complex_arg(src) + + ret_re = self_re.copy_(src_re, *args, **kwargs) + ret_im = self_im.copy_(src_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten._local_scalar_dense) +def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex: + x, y = split_complex_tensor(self) + u = aten._local_scalar_dense(x, *args, **kwargs) + v = aten._local_scalar_dense(y, *args, **kwargs) + return complex(u, v) + + +@register_complex(aten.allclose) +def allclose_impl( + input: torch.Tensor, + other: torch.Tensor, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> complex: + return torch.all(torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)).item() + + +@register_complex(aten.stack) +def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor: + re_im_tuples = [split_complex_arg(self_i) for self_i in self] + u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs) + v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs) + return ComplexTensor(u, v) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj_physical) +def _conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, -im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj) +def _conj_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, aten._neg_view(im)) + + +@register_complex(aten.index_add) +def index_add_impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add(dim, index, source_re) + ret_im = self_im.index_add(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.index_add_) +def index_add__impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add_(dim, index, source_re) + ret_im = self_im.index_add_(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.masked_fill) +def masked_fill_impl(self: ComplexTensor, mask: torch.Tensor, value: complex) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill(mask, value_re) + ret_im = self_im.masked_fill(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.masked_fill_) +def masked_fill__impl(self: ComplexTensor, mask: torch.Tensor, value: complex) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill_(mask, value_re) + ret_im = self_im.masked_fill_(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.constant_pad_nd) +def constant_pad_nd_impl(self: ComplexTensor, pad, value: complex | None = None) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + if value is None: + ret_re = aten.constant_pad_nd(self_re, pad) + ret_im = aten.constant_pad_nd(self_im, pad) + else: + value_re, value_im = split_complex_arg(value) + ret_re = aten.constant_pad_nd(self_re, pad, value_re) + ret_im = aten.constant_pad_nd(self_im, pad, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.var) +def var_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + self_re, self_im = split_complex_tensor(self) + return torch.var(self_re, *args, **kwargs) + torch.var(self_im, *args, **kwargs) diff --git a/src/complex_tensor/ops/prims.py b/src/complex_tensor/ops/prims.py index bea46c2..95ba84f 100644 --- a/src/complex_tensor/ops/prims.py +++ b/src/complex_tensor/ops/prims.py @@ -3,13 +3,16 @@ from ..complex_tensor import ComplexTensor from ._common import ( complex_to_real_dtype, + register_complex, register_force_test, split_complex_tensor, ) prims = torch.ops.prims +aten = torch.ops.aten +# TODO (hameerabbasi): Not being tested @register_force_test(prims.convert_element_type) def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor: dtype = complex_to_real_dtype(dtype) @@ -18,3 +21,13 @@ def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTe v_out = prims.convert_element_type(v, dtype) return ComplexTensor(u_out, v_out) + + +@register_complex(prims.conj_physical) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj_physical(self) + + +@register_complex(prims.conj) +def conj_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj(self) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index c160ee0..914b8b4 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -4,11 +4,19 @@ from torch._ops import OpOverload from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_utils import parametrize, run_tests +from torch.testing._internal.common_utils import ( + parametrize, + run_tests, +) from torch.testing._internal.opinfo.core import OpInfo from complex_tensor.ops import COMPLEX_OPS_TABLE, FORCE_TEST_LIST -from complex_tensor.test.utils import COMPLEX_DTYPES, TestCase, TestDescriptor, _as_complex_tensor +from complex_tensor.ops._common import _as_complex_tensor +from complex_tensor.test.utils import ( + COMPLEX_DTYPES, + TestCase, + TestDescriptor, +) torch._dynamo.config.recompile_limit = float("inf") torch._dynamo.config.accumulated_recompile_limit = float("inf") @@ -23,8 +31,7 @@ def _get_opname_from_aten_op(aten_op): if isinstance(aten_op, OpOverload): aten_op = aten_op.overloadpacket - _, name = str(aten_op).split(".", 1) - return name + return aten_op._qualified_op_name.split("::")[-1] force_test_names = set(map(_get_opname_from_aten_op, FORCE_TEST_LIST)) @@ -47,7 +54,7 @@ def _get_opname_from_aten_op(aten_op): sorted([op._qualified_op_name.replace("::", ".") for op in non_tested_ops]) ) warnings.warn( - "Not all ops are tested. List of missing ops:" + "Not all implemented ops are tested. List of ops missing tests:" f"\n{textwrap.indent(list_missing_ops, ' ')}", UserWarning, stacklevel=2, @@ -57,8 +64,34 @@ def _get_opname_from_aten_op(aten_op): SKIPS = { TestDescriptor(op_name="real"): "`aten.real` does not hit `__torch_dispatch__`", TestDescriptor(op_name="imag"): "`aten.imag` does not hit `__torch_dispatch__`", - TestDescriptor(op_name="repeat", dtype=torch.complex64, compile=True): "Heisenbug", - TestDescriptor(op_name="repeat", dtype=torch.complex128, compile=True): "Heisenbug", + TestDescriptor(op_name="conj"): "`prims.conj` does not hit `__torch_dispatch__`", + TestDescriptor( + op_name="conj_physical" + ): "`prims.conj_physical` does not hit `__torch_dispatch__`", + TestDescriptor(op_name="empty_like"): "Inconsistent output", + TestDescriptor(op_name="repeat", compile=True): "Heisenbug", + TestDescriptor( + op_name="allclose", compile=True + ): "`aten.allclose` requires data-dependent control-flow", + TestDescriptor(op_name="randn_like"): "Inconsistent output", + TestDescriptor( + op_name="var", compile=True + ): "`aten.var` doesn't return valid results with `torch.compile`", +} + +EXTRA_KWARGS = { + TestDescriptor(op_name="asinh", dtype=torch.complex64, gradcheck=False): { + "rtol": 2e-5, + "atol": 5e-5, + }, + TestDescriptor(op_name="tanh", dtype=torch.complex64, gradcheck=False): { + "rtol": 1e-4, + "atol": 1e-5, + }, + TestDescriptor(op_name="pow", dtype=torch.complex64, gradcheck=False): { + "rtol": 2e-2, + "atol": 2e-6, + }, } @@ -76,11 +109,19 @@ def test_maybe_error(self, device, dtype, op: OpInfo, compile: bool): self.check_consistency(device, dtype, op, compile) def check_consistency(self, device, dtype, op: OpInfo, compile: bool) -> None: - test_info = TestDescriptor(op_name=op.name, device=device, dtype=dtype, compile=compile) + test_info = TestDescriptor( + op_name=op.name, device=device, dtype=dtype, compile=compile, gradcheck=False + ) for xfail_info, reason in SKIPS.items(): if xfail_info.matches(test_info): self.skipTest(reason) + kwargs = {} + for extra_info, extra_kw in EXTRA_KWARGS.items(): + if extra_info.matches(test_info): + kwargs = extra_kw + break + sample_inputs = op.sample_inputs(device, dtype) op_eager = op if compile: @@ -96,7 +137,7 @@ def expected(sample_input=sample_input): def actual(subclass_sample=subclass_sample): return op(subclass_sample.input, *subclass_sample.args, **subclass_sample.kwargs) - self.assertSameResult(expected, actual, ignore_exc_types=compile) + self.assertSameResult(expected, actual, ignore_exc_types=compile, **kwargs) instantiate_device_type_tests(TestComplexTensor, globals()) diff --git a/src/complex_tensor/test/utils.py b/src/complex_tensor/test/utils.py index d06205e..f716401 100644 --- a/src/complex_tensor/test/utils.py +++ b/src/complex_tensor/test/utils.py @@ -8,9 +8,9 @@ from torch.testing._internal.common_utils import TestCase as PytorchTestCase from torch.utils._pytree import tree_flatten -from complex_tensor.complex_tensor import ComplexTensor +from complex_tensor.ops._common import COMPLEX_TO_REAL, _as_interleaved -COMPLEX_DTYPES = {torch.complex128, torch.complex64, torch.complex32} +COMPLEX_DTYPES = set(COMPLEX_TO_REAL) @dataclass(frozen=True) @@ -19,6 +19,7 @@ class TestDescriptor: device: str | None = field(default=None) dtype: torch.dtype | None = field(default=None) compile: bool | None = field(default=None) + gradcheck: bool | None = field(default=None) def matches(self, other: TestDescriptor) -> bool: fields1 = fields(self) @@ -35,51 +36,43 @@ def matches(self, other: TestDescriptor) -> bool: return True -def _as_complex_tensor(arg): - if ( - not isinstance(arg, ComplexTensor) - and isinstance(arg, torch.Tensor) - and arg.dtype in COMPLEX_DTYPES - ): - return ComplexTensor.from_interleaved(arg) - return arg - - class TestCase(PytorchTestCase): def assertSameResult( self, - f1: Callable[[], Any], - f2: Callable[[], Any], + expected: Callable[[], Any], + actual: Callable[[], Any], ignore_exc_types: bool = False, *args, **kwargs, ) -> None: try: - result_1 = f1() - exception_1 = None + result_e = expected() + exception_e = None except Exception as e: # noqa: BLE001 - result_1 = None - exception_1 = e + result_e = None + exception_e = e try: - result_2 = f2() - exception_2 = None + result_a = actual() + exception_a = None except Exception as e: # noqa: BLE001 - result_2 = None - exception_2 = e + result_a = None + exception_a = e # Special case: compiled versions don't match the error type exactly. - if ((exception_1 is None) != (exception_2 is None)) or not ignore_exc_types: - self.assertIs(type(exception_1), type(exception_2), f"\n{exception_1=}\n{exception_2=}") + if ((exception_e is None) != (exception_a is None)) or not ignore_exc_types: + if exception_a is not None and exception_e is None: + raise exception_a + self.assertIs(type(exception_e), type(exception_a), f"\n{exception_e=}\n{exception_a=}") - if exception_1 is None: - flattened_1, spec_1 = tree_flatten(result_1) - flattened_2, spec_2 = tree_flatten(result_2) + if exception_e is None: + flattened_e, spec_e = tree_flatten(result_e) + flattened_a, spec_a = tree_flatten(result_a) self.assertEqual( - spec_1, spec_2, "Both functions must return a result with the same tree structure." + spec_e, spec_a, "Both functions must return a result with the same tree structure." ) - for f1, f2 in zip(flattened_1, flattened_2, strict=False): - f1 = _as_complex_tensor(f1) - f2 = _as_complex_tensor(f1) + for value_e, value_a in zip(flattened_e, flattened_a, strict=True): + value_e = _as_interleaved(value_e) + value_a = _as_interleaved(value_a) - self.assertEqual(f1, f2, *args, **kwargs) + self.assertEqual(value_e, value_a, *args, **kwargs) diff --git a/uv.lock b/uv.lock index e536c7a..e9ff991 100644 --- a/uv.lock +++ b/uv.lock @@ -34,8 +34,8 @@ dependencies = [ { name = "expecttest" }, { name = "numpy" }, { name = "pytest" }, - { name = "torch", version = "2.9.0.dev20250828", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform == 'darwin'" }, - { name = "torch", version = "2.9.0.dev20250828+cpu", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform != 'darwin'" }, + { name = "torch", version = "2.9.0.dev20250907", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.9.0.dev20250907+cpu", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform != 'darwin'" }, ] [package.optional-dependencies] @@ -99,11 +99,11 @@ wheels = [ [[package]] name = "executing" -version = "2.2.0" +version = "2.2.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693, upload-time = "2025-01-22T15:41:29.403Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/28/c14e053b6762b1044f34a13aab6859bbf40456d37d23aa286ac24cfd9a5d/executing-2.2.1.tar.gz", hash = "sha256:3632cc370565f6648cc328b32435bd120a1e4ebb20c77e3fdde9a13cd1e533c4", size = 1129488, upload-time = "2025-09-01T09:48:10.866Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702, upload-time = "2025-01-22T15:41:25.929Z" }, + { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, ] [[package]] @@ -125,10 +125,11 @@ wheels = [ [[package]] name = "fsspec" -version = "2025.7.0" -source = { registry = "https://download.pytorch.org/whl/nightly/cpu" } +version = "2025.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/de/e0/bab50af11c2d75c9c4a2a26a5254573c0bd97cea152254401510950486fa/fsspec-2025.9.0.tar.gz", hash = "sha256:19fd429483d25d28b65ec68f9f4adc16c17ea2c7c7bf54ec61360d478fb19c19", size = 304847, upload-time = "2025-09-02T19:10:49.215Z" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/fsspec-2025.7.0-py3-none-any.whl" }, + { url = "https://files.pythonhosted.org/packages/47/71/70db47e4f6ce3e5c37a607355f80da8860a33226be640226ac52cb05ef2e/fsspec-2025.9.0-py3-none-any.whl", hash = "sha256:530dc2a2af60a414a832059574df4a6e10cce927f6f4a78209390fe38955cfb7", size = 199289, upload-time = "2025-09-02T19:10:47.708Z" }, ] [[package]] @@ -147,7 +148,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "decorator" }, { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "ipython", version = "9.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "tomli", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3d/1b/7e07e7b752017f7693a0f4d41c13e5ca29ce8cbcfdcc1fd6c4ad8c0a27a0/ipdb-0.13.13.tar.gz", hash = "sha256:e3ac6018ef05126d442af680aad863006ec19d02290561ac88b8b1c0b0cfc726", size = 17042, upload-time = "2023-03-09T15:40:57.487Z" } @@ -183,7 +184,7 @@ wheels = [ [[package]] name = "ipython" -version = "9.4.0" +version = "9.5.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'darwin'", @@ -204,9 +205,9 @@ dependencies = [ { name = "traitlets", marker = "python_full_version >= '3.11'" }, { name = "typing-extensions", marker = "python_full_version == '3.11.*'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/54/80/406f9e3bde1c1fd9bf5a0be9d090f8ae623e401b7670d8f6fdf2ab679891/ipython-9.4.0.tar.gz", hash = "sha256:c033c6d4e7914c3d9768aabe76bbe87ba1dc66a92a05db6bfa1125d81f2ee270", size = 4385338, upload-time = "2025-07-01T11:11:30.606Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6e/71/a86262bf5a68bf211bcc71fe302af7e05f18a2852fdc610a854d20d085e6/ipython-9.5.0.tar.gz", hash = "sha256:129c44b941fe6d9b82d36fc7a7c18127ddb1d6f02f78f867f402e2e3adde3113", size = 4389137, upload-time = "2025-08-29T12:15:21.519Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/63/f8/0031ee2b906a15a33d6bfc12dd09c3dfa966b3cb5b284ecfb7549e6ac3c4/ipython-9.4.0-py3-none-any.whl", hash = "sha256:25850f025a446d9b359e8d296ba175a36aedd32e83ca9b5060430fe16801f066", size = 611021, upload-time = "2025-07-01T11:11:27.85Z" }, + { url = "https://files.pythonhosted.org/packages/08/2a/5628a99d04acb2d2f2e749cdf4ea571d2575e898df0528a090948018b726/ipython-9.5.0-py3-none-any.whl", hash = "sha256:88369ffa1d5817d609120daa523a6da06d02518e582347c29f8451732a9c5e72", size = 612426, upload-time = "2025-08-29T12:15:18.866Z" }, ] [[package]] @@ -469,7 +470,7 @@ wheels = [ [[package]] name = "pytest" -version = "8.4.1" +version = "8.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, @@ -480,9 +481,9 @@ dependencies = [ { name = "pygments" }, { name = "tomli", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, + { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, ] [[package]] @@ -573,7 +574,7 @@ wheels = [ [[package]] name = "torch" -version = "2.9.0.dev20250828" +version = "2.9.0.dev20250907" source = { registry = "https://download.pytorch.org/whl/nightly/cpu" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform == 'darwin'", @@ -591,18 +592,18 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform == 'darwin'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp310-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp311-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp312-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp313-cp313t-macosx_14_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp313-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp314-cp314-macosx_14_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp314-cp314t-macosx_14_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp310-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp311-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp312-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp313-cp313t-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp313-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp314-cp314-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp314-cp314t-macosx_11_0_arm64.whl" }, ] [[package]] name = "torch" -version = "2.9.0.dev20250828+cpu" +version = "2.9.0.dev20250907+cpu" source = { registry = "https://download.pytorch.org/whl/nightly/cpu" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'darwin'", @@ -620,34 +621,37 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform != 'darwin'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp310-cp310-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp310-cp310-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp311-cp311-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp311-cp311-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp312-cp312-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp312-cp312-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313t-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314t-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp310-cp310-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp310-cp310-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp311-cp311-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp311-cp311-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp311-cp311-win_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp312-cp312-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp312-cp312-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp312-cp312-win_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313-win_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313t-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313t-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314t-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314t-win_amd64.whl" }, ] [[package]]