From 9426c46e011bc385177226bb4b53f84d9517d2f7 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 2 Sep 2025 17:29:03 +0200 Subject: [PATCH 01/23] Add gradient checks: first attempt. --- src/complex_tensor/ops/_common.py | 32 ++++++++++++- src/complex_tensor/ops/aten.py | 70 ++++++++++++++++++++++++++++- src/complex_tensor/test/test_ops.py | 27 ++++++++++- src/complex_tensor/test/utils.py | 14 +----- 4 files changed, 126 insertions(+), 17 deletions(-) diff --git a/src/complex_tensor/ops/_common.py b/src/complex_tensor/ops/_common.py index 01d2b43..481504e 100644 --- a/src/complex_tensor/ops/_common.py +++ b/src/complex_tensor/ops/_common.py @@ -4,7 +4,8 @@ import torch from torch._ops import OpOverloadPacket from torch._refs import is_complex -from torch.utils._pytree import tree_flatten, tree_unflatten +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from ..complex_tensor import ComplexTensor @@ -131,7 +132,7 @@ def ordered_impl(*args, **kwargs): def register_binary_nonlinear(op: OpType) -> Callable: def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: - a_r, a_i = split_complex_tensor(lhs) + a_r, a_i = split_complex_arg(lhs) b_r, b_i = split_complex_arg(rhs) out_dt, (a_r, a_i, b_r, b_i) = promote_real_cpu_tensors(a_r, a_i, b_r, b_i) real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs) @@ -161,3 +162,30 @@ def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: impl.__qualname__ = func_name return register_complex(op, impl) + + +def _as_complex_tensor(arg: torch.Tensor | Any) -> torch.Tensor | ComplexTensor | Any: + if ( + not isinstance(arg, ComplexTensor) + and isinstance(arg, torch.Tensor) + and arg.dtype in COMPLEX_TO_REAL + ): + return ComplexTensor.from_interleaved(arg) + return arg + + +def _as_interleaved(arg: ComplexTensor | Any) -> torch.Tensor | Any: + if isinstance(arg, ComplexTensor): + return arg.as_interleaved() + return arg + + +class ComplexDispatchMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + args = tree_map(_as_complex_tensor, args) + kwargs = tree_map(_as_complex_tensor, kwargs) + + return tree_map(_as_interleaved, func(*args, **kwargs)) diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index 268db1c..c7fa550 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -33,7 +33,7 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens alpha = kwargs.pop("alpha", None) if alpha is not None: return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs) - a_r, a_i = split_complex_tensor(lhs) + a_r, a_i = split_complex_arg(lhs) b_r, b_i = split_complex_arg(rhs) out_dt, (a_r, a_i, b_r, b_i) = promote_real_cpu_tensors(a_r, a_i, b_r, b_i) u = op(a_r, b_r, *args, **kwargs) @@ -78,6 +78,9 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b index_select_impl = register_simple(aten.index_select) split_with_sizes_impl = register_simple(aten.split_with_sizes) cumsum_impl = register_simple(aten.cumsum) +detach_impl = register_simple(aten.detach) +select_impl = register_simple(aten.select) +squeeze_impl = register_simple(aten.squeeze) # TODO (hameerabbasi): Not being tested copy_impl = register_force_test(aten.copy, register_simple(aten.copy)) @@ -502,3 +505,68 @@ def nonzero_impl(self: ComplexTensor, other: ComplexTensor, *args, **kwargs) -> @register_complex(aten.logical_not) def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: return torch.logical_not(elemwise_nonzero(self), *args, **kwargs) + + +@register_complex(aten.view_as_real) +def view_as_real_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.stack([re, im], dim=-1) + + +@register_complex(aten.linalg_vector_norm) +def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs) + + +@register_force_test(aten.copy_) +def copy__impl(self: ComplexTensor, src, *args, **kwargs): + self_re, self_im = split_complex_tensor(self) + src_re, src_im = split_complex_arg(src) + + ret_re = self_re.copy_(src_re, *args, **kwargs) + ret_im = self_im.copy_(src_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.new_zeros) +def new_zeros_impl( + self: ComplexTensor, size, *, dtype=None, **kwargs +) -> ComplexTensor | torch.Tensor: + if dtype is not None and torch.dtype(dtype) not in COMPLEX_TO_REAL: + return self.re.new_zeros(self, size, dtype=dtype, **kwargs) + + if dtype is not None: + dtype = COMPLEX_TO_REAL[torch.dtype(dtype)] + + re = self.re.new_zeros(size, dtype=dtype, **kwargs) + im = self.im.new_zeros(size, dtype=dtype, **kwargs) + + return ComplexTensor(re, im) + + +@register_complex(aten._local_scalar_dense) +def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex: + x, y = split_complex_tensor(self) + u = aten._local_scalar_dense(x, *args, **kwargs) + v = aten._local_scalar_dense(y, *args, **kwargs) + return complex(u, v) + + +@register_complex(aten.allclose) +def allclose_impl( + input: torch.Tensor, + other: torch.Tensor, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> torch.Tensor: + return torch.all(torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)) + + +@register_complex(aten.stack) +def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor: + re_im_tuples = [split_complex_arg(self_i) for self_i in self] + u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs) + v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs) + return ComplexTensor(u, v) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index c160ee0..01b2482 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -4,11 +4,22 @@ from torch._ops import OpOverload from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_utils import parametrize, run_tests +from torch.testing._internal.common_utils import ( + TestGradients, + parametrize, + run_tests, + unMarkDynamoStrictTest, +) from torch.testing._internal.opinfo.core import OpInfo from complex_tensor.ops import COMPLEX_OPS_TABLE, FORCE_TEST_LIST -from complex_tensor.test.utils import COMPLEX_DTYPES, TestCase, TestDescriptor, _as_complex_tensor +from complex_tensor.test.utils import ( + COMPLEX_DTYPES, + ComplexDispatchMode, + TestCase, + TestDescriptor, + _as_complex_tensor, +) torch._dynamo.config.recompile_limit = float("inf") torch._dynamo.config.accumulated_recompile_limit = float("inf") @@ -99,7 +110,19 @@ def actual(subclass_sample=subclass_sample): self.assertSameResult(expected, actual, ignore_exc_types=compile) +@unMarkDynamoStrictTest +class TestComplexBwdGradients(TestGradients): + @ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES)) + def test_fn_grad(self, device, dtype, op: OpInfo) -> None: + if dtype not in op.supported_backward_dtypes(torch.device(device).type): + self.skipTest("Skipped! Dtype is not in supported backward dtypes!") + else: + with ComplexDispatchMode(): + self._grad_test_helper(device, dtype, op, op.get_op()) + + instantiate_device_type_tests(TestComplexTensor, globals()) +instantiate_device_type_tests(TestComplexBwdGradients, globals()) if __name__ == "__main__": run_tests() diff --git a/src/complex_tensor/test/utils.py b/src/complex_tensor/test/utils.py index d06205e..8c1f550 100644 --- a/src/complex_tensor/test/utils.py +++ b/src/complex_tensor/test/utils.py @@ -8,9 +8,9 @@ from torch.testing._internal.common_utils import TestCase as PytorchTestCase from torch.utils._pytree import tree_flatten -from complex_tensor.complex_tensor import ComplexTensor +from complex_tensor.ops._common import COMPLEX_TO_REAL, _as_complex_tensor -COMPLEX_DTYPES = {torch.complex128, torch.complex64, torch.complex32} +COMPLEX_DTYPES = set(COMPLEX_TO_REAL) @dataclass(frozen=True) @@ -35,16 +35,6 @@ def matches(self, other: TestDescriptor) -> bool: return True -def _as_complex_tensor(arg): - if ( - not isinstance(arg, ComplexTensor) - and isinstance(arg, torch.Tensor) - and arg.dtype in COMPLEX_DTYPES - ): - return ComplexTensor.from_interleaved(arg) - return arg - - class TestCase(PytorchTestCase): def assertSameResult( self, From 4a0b409e28a5b864127fe01238c4f55a45dc9069 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 3 Sep 2025 11:57:14 +0200 Subject: [PATCH 02/23] Add a bunch of ops trying to get gradients to work. --- src/complex_tensor/ops/aten.py | 28 ++++++++++++++++++++++------ src/complex_tensor/ops/prims.py | 8 ++++++++ src/complex_tensor/test/test_ops.py | 3 ++- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index c7fa550..efb4cb8 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -81,6 +81,8 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b detach_impl = register_simple(aten.detach) select_impl = register_simple(aten.select) squeeze_impl = register_simple(aten.squeeze) +zero__impl = register_simple(aten.zero_) +transpose_impl = register_simple(aten.transpose) # TODO (hameerabbasi): Not being tested copy_impl = register_force_test(aten.copy, register_simple(aten.copy)) @@ -533,14 +535,14 @@ def copy__impl(self: ComplexTensor, src, *args, **kwargs): def new_zeros_impl( self: ComplexTensor, size, *, dtype=None, **kwargs ) -> ComplexTensor | torch.Tensor: - if dtype is not None and torch.dtype(dtype) not in COMPLEX_TO_REAL: - return self.re.new_zeros(self, size, dtype=dtype, **kwargs) + self_re, self_im = split_complex_tensor(self) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return self_re.new_zeros(size, dtype=dtype, **kwargs) if dtype is not None: - dtype = COMPLEX_TO_REAL[torch.dtype(dtype)] - - re = self.re.new_zeros(size, dtype=dtype, **kwargs) - im = self.im.new_zeros(size, dtype=dtype, **kwargs) + dtype = COMPLEX_TO_REAL[dtype] + re = self_re.new_zeros(size, dtype=dtype, **kwargs) + im = self_im.new_zeros(size, dtype=dtype, **kwargs) return ComplexTensor(re, im) @@ -570,3 +572,17 @@ def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor: u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs) v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs) return ComplexTensor(u, v) + + +@register_complex(aten.randn_like) +def randn_like_impl(self: ComplexTensor, *, dtype=None, **kwargs) -> ComplexTensor | torch.Tensor: + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return torch.randn_like(self.re, dtype=dtype, **kwargs) + + if dtype is not None: + dtype = COMPLEX_TO_REAL[dtype] + + self_re, self_im = split_complex_tensor(self) + ret_re = torch.randn_like(self_re, dtype=dtype, **kwargs) / 2 + ret_im = torch.randn_like(self_im, dtype=dtype, **kwargs) / 2 + return ComplexTensor(ret_re, ret_im) diff --git a/src/complex_tensor/ops/prims.py b/src/complex_tensor/ops/prims.py index bea46c2..555d012 100644 --- a/src/complex_tensor/ops/prims.py +++ b/src/complex_tensor/ops/prims.py @@ -3,6 +3,7 @@ from ..complex_tensor import ComplexTensor from ._common import ( complex_to_real_dtype, + register_complex, register_force_test, split_complex_tensor, ) @@ -18,3 +19,10 @@ def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTe v_out = prims.convert_element_type(v, dtype) return ComplexTensor(u_out, v_out) + + +@register_complex(prims.conj_physical) +@register_complex(prims.conj) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, -im) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index 01b2482..a39b5f9 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -13,9 +13,9 @@ from torch.testing._internal.opinfo.core import OpInfo from complex_tensor.ops import COMPLEX_OPS_TABLE, FORCE_TEST_LIST +from complex_tensor.ops._common import ComplexDispatchMode from complex_tensor.test.utils import ( COMPLEX_DTYPES, - ComplexDispatchMode, TestCase, TestDescriptor, _as_complex_tensor, @@ -118,6 +118,7 @@ def test_fn_grad(self, device, dtype, op: OpInfo) -> None: self.skipTest("Skipped! Dtype is not in supported backward dtypes!") else: with ComplexDispatchMode(): + op.gradcheck_fast_mode = False self._grad_test_helper(device, dtype, op, op.get_op()) From dbf695ba9897c624c685974ab650d75533e37f7f Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 3 Sep 2025 13:02:02 +0200 Subject: [PATCH 03/23] Clarify warning message slightly. --- src/complex_tensor/test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index a39b5f9..a111db0 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -58,7 +58,7 @@ def _get_opname_from_aten_op(aten_op): sorted([op._qualified_op_name.replace("::", ".") for op in non_tested_ops]) ) warnings.warn( - "Not all ops are tested. List of missing ops:" + "Not all implemented ops are tested. List of ops missing tests:" f"\n{textwrap.indent(list_missing_ops, ' ')}", UserWarning, stacklevel=2, From 9c3c81cabb013b79fcd2167d4da786c7e20c2947 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 3 Sep 2025 15:12:22 +0200 Subject: [PATCH 04/23] Skip `allclose` for `compile=True` due to data-dependency. --- src/complex_tensor/test/test_ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index a111db0..8868197 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -70,6 +70,9 @@ def _get_opname_from_aten_op(aten_op): TestDescriptor(op_name="imag"): "`aten.imag` does not hit `__torch_dispatch__`", TestDescriptor(op_name="repeat", dtype=torch.complex64, compile=True): "Heisenbug", TestDescriptor(op_name="repeat", dtype=torch.complex128, compile=True): "Heisenbug", + TestDescriptor( + op_name="allclose", compile=True + ): "`aten.allclose` requires data-dependent control-flow", } From c309f2fd5f62e3dee03f25306851813314ff61b5 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 3 Sep 2025 17:09:03 +0200 Subject: [PATCH 05/23] Small fix for `allclose`. --- src/complex_tensor/ops/aten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index efb4cb8..a939425 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -562,8 +562,8 @@ def allclose_impl( rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, -) -> torch.Tensor: - return torch.all(torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)) +) -> complex: + return torch.all(torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)).item() @register_complex(aten.stack) From 5b4d80c245f4fab808fe6e5b5b4ffd404bbdd228 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 4 Sep 2025 08:36:37 +0200 Subject: [PATCH 06/23] Add a number of ops for backward computations. --- src/complex_tensor/ops/aten.py | 99 +++++++++++++++++++++++++++++ src/complex_tensor/test/test_ops.py | 3 + 2 files changed, 102 insertions(+) diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index a939425..5cff7f1 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -83,6 +83,8 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b squeeze_impl = register_simple(aten.squeeze) zero__impl = register_simple(aten.zero_) transpose_impl = register_simple(aten.transpose) +t_impl = register_simple(aten.t) +zeros_like_impl = register_simple(aten.zeros_like) # TODO (hameerabbasi): Not being tested copy_impl = register_force_test(aten.copy, register_simple(aten.copy)) @@ -92,6 +94,14 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b col2im_impl = register_force_test(aten.col2im, register_simple(aten.col2im)) # TODO (hameerabbasi): Not being tested alias_impl = register_force_test(aten.alias, register_simple(aten.alias)) +# TODO (hameerabbasi): Not being tested +lift_fresh_impl = register_force_test(aten.lift_fresh, register_simple(aten.lift_fresh)) +# TODO (hameerabbasi): Not being tested +_unsafe_view_impl = register_force_test(aten._unsafe_view, register_simple(aten._unsafe_view)) +# TODO (hameerabbasi): Not being tested +index_put__impl = register_force_test(aten.index_put_, register_simple(aten.index_put_)) +# TODO (hameerabbasi): Not being tested +index_impl = register_force_test(aten.index, register_simple(aten.index)) # some binary ops which we can stamp out mul_impl = register_binary_nonlinear(aten.mul) @@ -146,6 +156,23 @@ def prod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: return ComplexTensor(u, v) +@register_complex(aten.pow) +def pow_impl(self: ComplexTensor, exponent: ComplexTensor) -> ComplexTensor: + return torch.exp(exponent * torch.log(self)) + + +@register_complex(aten.cumprod) +def cumprod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + dtype = kwargs.pop("dtype", self.dtype) + kwargs["dtype"] = complex_to_real_dtype(dtype) + + prod_r = torch.cumprod(torch.abs(self), *args, **kwargs) + sum_phi = torch.cumsum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v) + + # unary funcs, # most of these are simple or require some kind of identity @register_complex(aten.abs) @@ -237,6 +264,13 @@ def expm1_impl(self: ComplexTensor) -> ComplexTensor: return ComplexTensor(u.to(out_dt), v.to(out_dt)) +@register_complex(aten.log) +def log_impl(self: ComplexTensor) -> ComplexTensor: + re = torch.log(torch.abs(self)) + im = torch.angle(self) + return ComplexTensor(re, im) + + @register_complex(aten.any) def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: x, y = split_complex_tensor(self) @@ -586,3 +620,68 @@ def randn_like_impl(self: ComplexTensor, *, dtype=None, **kwargs) -> ComplexTens ret_re = torch.randn_like(self_re, dtype=dtype, **kwargs) / 2 ret_im = torch.randn_like(self_im, dtype=dtype, **kwargs) / 2 return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj_physical) +@register_complex(aten._conj) +def _conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, -im) + + +@register_complex(aten.index_add) +def index_add_impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add(dim, index, source_re) + ret_im = self_im.index_add(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.index_add_) +def index_add__impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add_(dim, index, source_re) + ret_im = self_im.index_add_(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.masked_fill) +def masked_fill_impl(self: ComplexTensor, mask: torch.Tensor, value: complex) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill(mask, value_re) + ret_im = self_im.masked_fill(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.masked_fill_) +def masked_fill__impl(self: ComplexTensor, mask: torch.Tensor, value: complex) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill_(mask, value_re) + ret_im = self_im.masked_fill_(mask, value_im) + + return ComplexTensor(ret_re, ret_im) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index 8868197..138817e 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -73,6 +73,9 @@ def _get_opname_from_aten_op(aten_op): TestDescriptor( op_name="allclose", compile=True ): "`aten.allclose` requires data-dependent control-flow", + TestDescriptor( + op_name="randn_like", compile=True + ): "`aten.randn_like` doesn't support `torch.compile`", } From 5b8545faf25962a3e5fa90ecbae4955567dbf825 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 4 Sep 2025 09:13:44 +0200 Subject: [PATCH 07/23] Adjust gradient tests. --- src/complex_tensor/test/test_ops.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index 138817e..cc7c9ba 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -122,10 +122,11 @@ class TestComplexBwdGradients(TestGradients): def test_fn_grad(self, device, dtype, op: OpInfo) -> None: if dtype not in op.supported_backward_dtypes(torch.device(device).type): self.skipTest("Skipped! Dtype is not in supported backward dtypes!") - else: - with ComplexDispatchMode(): - op.gradcheck_fast_mode = False - self._grad_test_helper(device, dtype, op, op.get_op()) + + with ComplexDispatchMode(): + op.gradcheck_fast_mode = False + op.check_batched_grad = False + self._grad_test_helper(device, dtype, op, op.get_op()) instantiate_device_type_tests(TestComplexTensor, globals()) From 5bf3c83c39cf4f862ce2556942adb618c97ddbab Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 4 Sep 2025 09:18:01 +0200 Subject: [PATCH 08/23] Update lockfile. --- uv.lock | 99 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 50 insertions(+), 49 deletions(-) diff --git a/uv.lock b/uv.lock index e536c7a..5c4caf0 100644 --- a/uv.lock +++ b/uv.lock @@ -34,8 +34,8 @@ dependencies = [ { name = "expecttest" }, { name = "numpy" }, { name = "pytest" }, - { name = "torch", version = "2.9.0.dev20250828", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform == 'darwin'" }, - { name = "torch", version = "2.9.0.dev20250828+cpu", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform != 'darwin'" }, + { name = "torch", version = "2.9.0.dev20250903", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.9.0.dev20250903+cpu", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform != 'darwin'" }, ] [package.optional-dependencies] @@ -99,11 +99,11 @@ wheels = [ [[package]] name = "executing" -version = "2.2.0" +version = "2.2.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693, upload-time = "2025-01-22T15:41:29.403Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/28/c14e053b6762b1044f34a13aab6859bbf40456d37d23aa286ac24cfd9a5d/executing-2.2.1.tar.gz", hash = "sha256:3632cc370565f6648cc328b32435bd120a1e4ebb20c77e3fdde9a13cd1e533c4", size = 1129488, upload-time = "2025-09-01T09:48:10.866Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702, upload-time = "2025-01-22T15:41:25.929Z" }, + { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, ] [[package]] @@ -125,10 +125,11 @@ wheels = [ [[package]] name = "fsspec" -version = "2025.7.0" -source = { registry = "https://download.pytorch.org/whl/nightly/cpu" } +version = "2025.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/de/e0/bab50af11c2d75c9c4a2a26a5254573c0bd97cea152254401510950486fa/fsspec-2025.9.0.tar.gz", hash = "sha256:19fd429483d25d28b65ec68f9f4adc16c17ea2c7c7bf54ec61360d478fb19c19", size = 304847, upload-time = "2025-09-02T19:10:49.215Z" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/fsspec-2025.7.0-py3-none-any.whl" }, + { url = "https://files.pythonhosted.org/packages/47/71/70db47e4f6ce3e5c37a607355f80da8860a33226be640226ac52cb05ef2e/fsspec-2025.9.0-py3-none-any.whl", hash = "sha256:530dc2a2af60a414a832059574df4a6e10cce927f6f4a78209390fe38955cfb7", size = 199289, upload-time = "2025-09-02T19:10:47.708Z" }, ] [[package]] @@ -147,7 +148,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "decorator" }, { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "ipython", version = "9.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "tomli", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3d/1b/7e07e7b752017f7693a0f4d41c13e5ca29ce8cbcfdcc1fd6c4ad8c0a27a0/ipdb-0.13.13.tar.gz", hash = "sha256:e3ac6018ef05126d442af680aad863006ec19d02290561ac88b8b1c0b0cfc726", size = 17042, upload-time = "2023-03-09T15:40:57.487Z" } @@ -183,7 +184,7 @@ wheels = [ [[package]] name = "ipython" -version = "9.4.0" +version = "9.5.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'darwin'", @@ -204,9 +205,9 @@ dependencies = [ { name = "traitlets", marker = "python_full_version >= '3.11'" }, { name = "typing-extensions", marker = "python_full_version == '3.11.*'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/54/80/406f9e3bde1c1fd9bf5a0be9d090f8ae623e401b7670d8f6fdf2ab679891/ipython-9.4.0.tar.gz", hash = "sha256:c033c6d4e7914c3d9768aabe76bbe87ba1dc66a92a05db6bfa1125d81f2ee270", size = 4385338, upload-time = "2025-07-01T11:11:30.606Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6e/71/a86262bf5a68bf211bcc71fe302af7e05f18a2852fdc610a854d20d085e6/ipython-9.5.0.tar.gz", hash = "sha256:129c44b941fe6d9b82d36fc7a7c18127ddb1d6f02f78f867f402e2e3adde3113", size = 4389137, upload-time = "2025-08-29T12:15:21.519Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/63/f8/0031ee2b906a15a33d6bfc12dd09c3dfa966b3cb5b284ecfb7549e6ac3c4/ipython-9.4.0-py3-none-any.whl", hash = "sha256:25850f025a446d9b359e8d296ba175a36aedd32e83ca9b5060430fe16801f066", size = 611021, upload-time = "2025-07-01T11:11:27.85Z" }, + { url = "https://files.pythonhosted.org/packages/08/2a/5628a99d04acb2d2f2e749cdf4ea571d2575e898df0528a090948018b726/ipython-9.5.0-py3-none-any.whl", hash = "sha256:88369ffa1d5817d609120daa523a6da06d02518e582347c29f8451732a9c5e72", size = 612426, upload-time = "2025-08-29T12:15:18.866Z" }, ] [[package]] @@ -573,7 +574,7 @@ wheels = [ [[package]] name = "torch" -version = "2.9.0.dev20250828" +version = "2.9.0.dev20250903" source = { registry = "https://download.pytorch.org/whl/nightly/cpu" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform == 'darwin'", @@ -591,18 +592,18 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform == 'darwin'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp310-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp311-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp312-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp313-cp313t-macosx_14_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp313-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp314-cp314-macosx_14_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828-cp314-cp314t-macosx_14_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp310-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp311-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp312-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp313-cp313t-macosx_14_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp313-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp314-cp314-macosx_14_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp314-cp314t-macosx_14_0_arm64.whl" }, ] [[package]] name = "torch" -version = "2.9.0.dev20250828+cpu" +version = "2.9.0.dev20250903+cpu" source = { registry = "https://download.pytorch.org/whl/nightly/cpu" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'darwin'", @@ -620,34 +621,34 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform != 'darwin'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp310-cp310-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp310-cp310-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp311-cp311-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp311-cp311-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp312-cp312-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp312-cp312-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp313-cp313t-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250828%2Bcpu-cp314-cp314t-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp310-cp310-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp310-cp310-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp311-cp311-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp311-cp311-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp311-cp311-win_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp312-cp312-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp312-cp312-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp312-cp312-win_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313-win_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313t-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314t-win_amd64.whl" }, ] [[package]] From bbbee1892742d019e5b9bcf81b59d43bc324c104 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 4 Sep 2025 10:05:33 +0200 Subject: [PATCH 09/23] Add a few ops. --- src/complex_tensor/ops/__init__.py | 4 +++- src/complex_tensor/ops/_c10d_functional.py | 18 ++++++++++++++++++ src/complex_tensor/ops/prims.py | 1 + src/complex_tensor/test/test_ops.py | 1 - 4 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 src/complex_tensor/ops/_c10d_functional.py diff --git a/src/complex_tensor/ops/__init__.py b/src/complex_tensor/ops/__init__.py index 7bb1163..3fea279 100644 --- a/src/complex_tensor/ops/__init__.py +++ b/src/complex_tensor/ops/__init__.py @@ -1,9 +1,11 @@ __all__ = [ "aten", + "prims", + "_c10d_functional", "COMPLEX_OPS_TABLE", "FORCE_TEST_LIST", "lookup_complex", ] -from . import aten +from . import _c10d_functional, aten, prims from ._common import COMPLEX_OPS_TABLE, FORCE_TEST_LIST, lookup_complex diff --git a/src/complex_tensor/ops/_c10d_functional.py b/src/complex_tensor/ops/_c10d_functional.py new file mode 100644 index 0000000..12ca2cf --- /dev/null +++ b/src/complex_tensor/ops/_c10d_functional.py @@ -0,0 +1,18 @@ +import torch + +from ._common import ( + register_force_test, + register_simple, +) + +_c10d_functional = torch.ops._c10d_functional + +# TODO (hameerabbasi): Not being tested +broadcast_impl = register_force_test( + _c10d_functional.broadcast, register_simple(_c10d_functional.broadcast) +) + +# TODO (hameerabbasi): Not being tested +broadcast__impl = register_force_test( + _c10d_functional.broadcast_, register_simple(_c10d_functional.broadcast_) +) diff --git a/src/complex_tensor/ops/prims.py b/src/complex_tensor/ops/prims.py index 555d012..928149e 100644 --- a/src/complex_tensor/ops/prims.py +++ b/src/complex_tensor/ops/prims.py @@ -11,6 +11,7 @@ prims = torch.ops.prims +# TODO (hameerabbasi): Not being tested @register_force_test(prims.convert_element_type) def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor: dtype = complex_to_real_dtype(dtype) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index cc7c9ba..a5a3ed1 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -125,7 +125,6 @@ def test_fn_grad(self, device, dtype, op: OpInfo) -> None: with ComplexDispatchMode(): op.gradcheck_fast_mode = False - op.check_batched_grad = False self._grad_test_helper(device, dtype, op, op.get_op()) From d59e11ac673710417bc74479dc79ea2a7e6efe3a Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 4 Sep 2025 10:12:21 +0200 Subject: [PATCH 10/23] Add a guard to avoid registering multiple impls. --- src/complex_tensor/ops/_common.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/complex_tensor/ops/_common.py b/src/complex_tensor/ops/_common.py index 481504e..8450064 100644 --- a/src/complex_tensor/ops/_common.py +++ b/src/complex_tensor/ops/_common.py @@ -64,6 +64,11 @@ def register_complex( """Decorator to register an implementation for some ops in some dispatch tables""" def inner(func): + if COMPLEX_OPS_TABLE.get(op, func) is not func: + raise RuntimeError( + "Attempted to register multiple functions for " + f"{op._qualified_op_name.replace('::', '.')}" + ) COMPLEX_OPS_TABLE[op] = func return func From c09eb171e86c1d7ab3d419fc632d8b664c1b3326 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 9 Sep 2025 09:20:28 +0200 Subject: [PATCH 11/23] Add a lot more trigonometric functions. --- src/complex_tensor/ops/_common.py | 7 ++ src/complex_tensor/ops/aten.py | 123 ++++++++++++++++++++++++---- src/complex_tensor/test/test_ops.py | 32 ++++++-- src/complex_tensor/test/utils.py | 1 + uv.lock | 87 ++++++++++---------- 5 files changed, 185 insertions(+), 65 deletions(-) diff --git a/src/complex_tensor/ops/_common.py b/src/complex_tensor/ops/_common.py index 8450064..959cfb0 100644 --- a/src/complex_tensor/ops/_common.py +++ b/src/complex_tensor/ops/_common.py @@ -186,10 +186,17 @@ def _as_interleaved(arg: ComplexTensor | Any) -> torch.Tensor | Any: class ComplexDispatchMode(TorchDispatchMode): + def __init__(self, _dispatch_key=None, *, _compile=False): + super().__init__(_dispatch_key) + self._compile = _compile + def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} + if compile: + func = torch.compile(func) + args = tree_map(_as_complex_tensor, args) kwargs = tree_map(_as_complex_tensor, kwargs) diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index 5cff7f1..0e6169f 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -85,6 +85,9 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b transpose_impl = register_simple(aten.transpose) t_impl = register_simple(aten.t) zeros_like_impl = register_simple(aten.zeros_like) +masked_scatter_backward_impl = register_simple(aten.masked_scatter_backward) +select_backward_impl = register_simple(aten.select_backward) +slice_backward_impl = register_simple(aten.slice_backward) # TODO (hameerabbasi): Not being tested copy_impl = register_force_test(aten.copy, register_simple(aten.copy)) @@ -202,7 +205,7 @@ def acos_impl(self: ComplexTensor) -> ComplexTensor: y2 = y**2 a = (x**2) + y2 b = torch.sqrt((a - 1) ** 2 + 4 * y2) - t = (a - 1 + b) / 2 + t = 0.5 * (a - 1 + b) u = torch.acos(x / torch.sqrt(1 + t)) v = torch.asinh(-torch.sign(y) * torch.sqrt(t)) @@ -212,19 +215,52 @@ def acos_impl(self: ComplexTensor) -> ComplexTensor: @register_complex(aten.asin) def asin_impl(self: ComplexTensor) -> ComplexTensor: x, y = split_complex_tensor(self) - out_dt, (x, y) = promote_real_cpu_tensors(x, y) + y2 = y**2 a = (x**2) + y2 b = torch.sqrt((a - 1) ** 2 + 4 * y2) - t = (a - 1 + b) / 2 + t = 0.5 * (a - 1 + b) - u = torch.arcsin(x / torch.sqrt(1 + t)) - v = torch.arcsinh(torch.sign(y) * torch.sqrt(t)) + u = torch.asin(x / torch.sqrt(1 + t)) + v = torch.asinh(torch.sign(y) * torch.sqrt(t)) return ComplexTensor(u.to(out_dt), v.to(out_dt)) +@register_complex(aten.atan) +def atan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_real_cpu_tensors(x, y) + + int1 = 0.5 * (torch.log(ComplexTensor(1 - y, x)) - torch.log(ComplexTensor(1 + y, -x))) + + int1_re, int1_im = split_complex_tensor(int1) + + return ComplexTensor(-int1_im.to(out_dt), int1_re.to(out_dt)) + + +@register_complex(aten.asinh) +def asinh_impl(self: ComplexTensor) -> ComplexTensor: + return torch.log(self + torch.sqrt(self * self + 1)) + + +@register_complex(aten.acosh) +def acosh_impl(self: ComplexTensor) -> ComplexTensor: + return torch.log(self + torch.sqrt(self * self - 1)) + + +@register_complex(aten.atanh) +def atanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_real_cpu_tensors(x, y) + + ret = 0.5 * (torch.log(ComplexTensor(1 + x, y)) - torch.log(ComplexTensor(1 - x, -y))) + ret_re, ret_im = split_complex_tensor(ret) + + return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + + @register_complex(aten.cos) def cos_impl(self: ComplexTensor) -> ComplexTensor: x, y = split_complex_tensor(self) @@ -236,11 +272,68 @@ def cos_impl(self: ComplexTensor) -> ComplexTensor: @register_complex(aten.cosh) def cosh_impl(self: ComplexTensor) -> ComplexTensor: - x, y = split_complex_tensor(self) - out_dt, (x, y) = promote_real_cpu_tensors(x, y) - u = torch.cosh(x) * torch.cos(y) - v = torch.sinh(x) * torch.sin(y) - return ComplexTensor(u.to(out_dt), v.to(out_dt)) + self_re, self_im = split_complex_tensor(self) + out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im) + ret_re = torch.cosh(self_re) * torch.cos(self_im) + ret_im = torch.sinh(self_re) * torch.sin(self_im) + return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + + +@register_complex(aten.sin) +def sin_impl(self: ComplexTensor) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im) + + ret_re = torch.sin(self_re) * torch.cosh(self_im) + ret_im = torch.cos(self_re) * torch.sinh(self_im) + + return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + + +@register_complex(aten.sinh) +def sinh_impl(self: ComplexTensor) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im) + + ret_re = torch.sinh(self_re) * torch.cos(self_im) + ret_im = torch.cosh(self_re) * torch.sin(self_im) + + return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + + +@register_complex(aten.tan) +def tan_impl(self: ComplexTensor) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im) + + cos_x = torch.cos(self_re) + sinh_y = torch.sinh(self_im) + + num_re = torch.sin(self_re) * cos_x + num_im = sinh_y * torch.cosh(self_im) + + den = cos_x * cos_x + sinh_y * sinh_y + + return ComplexTensor((num_re / den).to(out_dt), (num_im / den).to(out_dt)) + + +@register_complex(aten.tanh) +def tanh_impl(self: ComplexTensor) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im) + + tanh_x = torch.tanh(self_re) + tan_y = torch.tan(self_im) + + tanh2_x = tanh_x * tanh_x + tan2_y = tan_y * tan_y + + num_re = tanh_x * (1 + tan2_y) + num_im = -tan_y * (1 + tanh2_x) + + den = 1 + tanh2_x * tan2_y + + return ComplexTensor((num_re / den).to(out_dt), (num_im / den).to(out_dt)) @register_complex(aten.exp) @@ -464,8 +557,8 @@ def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor: def sgn_impl(self: ComplexTensor) -> ComplexTensor: self_r, self_i = split_complex_tensor(self) out_dt, (self_r, self_i) = promote_real_cpu_tensors(self_r, self_i) - mask = self != 0 abs_self = torch.abs(ComplexTensor(self_r, self_i)) + mask = abs_self != 0 masked_sgn = ComplexTensor( torch.div(self_r, abs_self).to(out_dt), torch.div(self_i, abs_self).to(out_dt) ) @@ -478,7 +571,7 @@ def sqrt_impl(self: ComplexTensor) -> ComplexTensor: out_dt, (self_r, self_i) = promote_real_cpu_tensors(self_r, self_i) self = ComplexTensor(self_r, self_i) self_abs_sqrt = torch.sqrt(torch.abs(self)) - self_half_angle = torch.angle(self) / 2 + self_half_angle = 0.5 * torch.angle(self) ret_r = self_abs_sqrt * torch.cos(self_half_angle) ret_i = self_abs_sqrt * torch.sin(self_half_angle) @@ -492,7 +585,7 @@ def rsqrt_impl(self: ComplexTensor) -> ComplexTensor: out_dt, (self_r, self_i) = promote_real_cpu_tensors(self_r, self_i) self = ComplexTensor(self_r, self_i) self_abs_rsqrt = torch.rsqrt(torch.abs(self)) - self_neg_half_angle = -torch.angle(self) / 2 + self_neg_half_angle = -0.5 * torch.angle(self) ret_r = self_abs_rsqrt * torch.cos(self_neg_half_angle) ret_i = self_abs_rsqrt * torch.sin(self_neg_half_angle) @@ -617,8 +710,8 @@ def randn_like_impl(self: ComplexTensor, *, dtype=None, **kwargs) -> ComplexTens dtype = COMPLEX_TO_REAL[dtype] self_re, self_im = split_complex_tensor(self) - ret_re = torch.randn_like(self_re, dtype=dtype, **kwargs) / 2 - ret_im = torch.randn_like(self_im, dtype=dtype, **kwargs) / 2 + ret_re = 0.5 * torch.randn_like(self_re, dtype=dtype, **kwargs) + ret_im = 0.5 * torch.randn_like(self_im, dtype=dtype, **kwargs) return ComplexTensor(ret_re, ret_im) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index a5a3ed1..1af7468 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -8,7 +8,6 @@ TestGradients, parametrize, run_tests, - unMarkDynamoStrictTest, ) from torch.testing._internal.opinfo.core import OpInfo @@ -34,8 +33,7 @@ def _get_opname_from_aten_op(aten_op): if isinstance(aten_op, OpOverload): aten_op = aten_op.overloadpacket - _, name = str(aten_op).split(".", 1) - return name + return aten_op._qualified_op_name.split("::")[-1] force_test_names = set(map(_get_opname_from_aten_op, FORCE_TEST_LIST)) @@ -93,7 +91,9 @@ def test_maybe_error(self, device, dtype, op: OpInfo, compile: bool): self.check_consistency(device, dtype, op, compile) def check_consistency(self, device, dtype, op: OpInfo, compile: bool) -> None: - test_info = TestDescriptor(op_name=op.name, device=device, dtype=dtype, compile=compile) + test_info = TestDescriptor( + op_name=op.name, device=device, dtype=dtype, compile=compile, gradcheck=False + ) for xfail_info, reason in SKIPS.items(): if xfail_info.matches(test_info): self.skipTest(reason) @@ -116,15 +116,31 @@ def actual(subclass_sample=subclass_sample): self.assertSameResult(expected, actual, ignore_exc_types=compile) -@unMarkDynamoStrictTest class TestComplexBwdGradients(TestGradients): - @ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES)) - def test_fn_grad(self, device, dtype, op: OpInfo) -> None: + @parametrize("compile", [False]) + @ops(implemented_op_db, allowed_dtypes=[torch.complex128]) + def test_fn_grad(self, device, dtype, op: OpInfo, compile: bool) -> None: if dtype not in op.supported_backward_dtypes(torch.device(device).type): self.skipTest("Skipped! Dtype is not in supported backward dtypes!") - with ComplexDispatchMode(): + self.check_consistency(device, dtype, op, compile) + + def check_consistency(self, device, dtype, op: OpInfo, compile: bool) -> None: + test_info = TestDescriptor( + op_name=op.name, device=device, dtype=dtype, compile=compile, gradcheck=True + ) + for xfail_info, reason in SKIPS.items(): + if xfail_info.matches(test_info): + self.skipTest(reason) + + try: + self._grad_test_helper(device, dtype, op, op.get_op()) + except Exception: # noqa: BLE001 + self.skipTest("Fails even without `ComplexDispatchMode` mode on.") + + with ComplexDispatchMode(_compile=compile): op.gradcheck_fast_mode = False + op.skip_correctness_check_compile_vs_eager = True self._grad_test_helper(device, dtype, op, op.get_op()) diff --git a/src/complex_tensor/test/utils.py b/src/complex_tensor/test/utils.py index 8c1f550..5602628 100644 --- a/src/complex_tensor/test/utils.py +++ b/src/complex_tensor/test/utils.py @@ -19,6 +19,7 @@ class TestDescriptor: device: str | None = field(default=None) dtype: torch.dtype | None = field(default=None) compile: bool | None = field(default=None) + gradcheck: bool | None = field(default=None) def matches(self, other: TestDescriptor) -> bool: fields1 = fields(self) diff --git a/uv.lock b/uv.lock index 5c4caf0..e9ff991 100644 --- a/uv.lock +++ b/uv.lock @@ -34,8 +34,8 @@ dependencies = [ { name = "expecttest" }, { name = "numpy" }, { name = "pytest" }, - { name = "torch", version = "2.9.0.dev20250903", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform == 'darwin'" }, - { name = "torch", version = "2.9.0.dev20250903+cpu", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform != 'darwin'" }, + { name = "torch", version = "2.9.0.dev20250907", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.9.0.dev20250907+cpu", source = { registry = "https://download.pytorch.org/whl/nightly/cpu" }, marker = "sys_platform != 'darwin'" }, ] [package.optional-dependencies] @@ -470,7 +470,7 @@ wheels = [ [[package]] name = "pytest" -version = "8.4.1" +version = "8.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, @@ -481,9 +481,9 @@ dependencies = [ { name = "pygments" }, { name = "tomli", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, + { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, ] [[package]] @@ -574,7 +574,7 @@ wheels = [ [[package]] name = "torch" -version = "2.9.0.dev20250903" +version = "2.9.0.dev20250907" source = { registry = "https://download.pytorch.org/whl/nightly/cpu" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform == 'darwin'", @@ -592,18 +592,18 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform == 'darwin'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp310-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp311-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp312-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp313-cp313t-macosx_14_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp313-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp314-cp314-macosx_14_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903-cp314-cp314t-macosx_14_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp310-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp311-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp312-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp313-cp313t-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp313-none-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp314-cp314-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907-cp314-cp314t-macosx_11_0_arm64.whl" }, ] [[package]] name = "torch" -version = "2.9.0.dev20250903+cpu" +version = "2.9.0.dev20250907+cpu" source = { registry = "https://download.pytorch.org/whl/nightly/cpu" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'darwin'", @@ -621,34 +621,37 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform != 'darwin'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp310-cp310-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp310-cp310-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp311-cp311-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp311-cp311-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp312-cp312-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp312-cp312-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313-linux_s390x.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp313-cp313t-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250903%2Bcpu-cp314-cp314t-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp310-cp310-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp310-cp310-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp311-cp311-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp311-cp311-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp311-cp311-win_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp312-cp312-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp312-cp312-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp312-cp312-win_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313-win_arm64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313t-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp313-cp313t-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314t-linux_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cpu/torch-2.9.0.dev20250907%2Bcpu-cp314-cp314t-win_amd64.whl" }, ] [[package]] From 086ec02c494977179c5f26451c45330360a37aa8 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 9 Sep 2025 09:21:15 +0200 Subject: [PATCH 12/23] Remove gradient checking temporarily. --- src/complex_tensor/test/test_ops.py | 31 ----------------------------- 1 file changed, 31 deletions(-) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index 1af7468..52b0906 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -5,14 +5,12 @@ from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( - TestGradients, parametrize, run_tests, ) from torch.testing._internal.opinfo.core import OpInfo from complex_tensor.ops import COMPLEX_OPS_TABLE, FORCE_TEST_LIST -from complex_tensor.ops._common import ComplexDispatchMode from complex_tensor.test.utils import ( COMPLEX_DTYPES, TestCase, @@ -116,36 +114,7 @@ def actual(subclass_sample=subclass_sample): self.assertSameResult(expected, actual, ignore_exc_types=compile) -class TestComplexBwdGradients(TestGradients): - @parametrize("compile", [False]) - @ops(implemented_op_db, allowed_dtypes=[torch.complex128]) - def test_fn_grad(self, device, dtype, op: OpInfo, compile: bool) -> None: - if dtype not in op.supported_backward_dtypes(torch.device(device).type): - self.skipTest("Skipped! Dtype is not in supported backward dtypes!") - - self.check_consistency(device, dtype, op, compile) - - def check_consistency(self, device, dtype, op: OpInfo, compile: bool) -> None: - test_info = TestDescriptor( - op_name=op.name, device=device, dtype=dtype, compile=compile, gradcheck=True - ) - for xfail_info, reason in SKIPS.items(): - if xfail_info.matches(test_info): - self.skipTest(reason) - - try: - self._grad_test_helper(device, dtype, op, op.get_op()) - except Exception: # noqa: BLE001 - self.skipTest("Fails even without `ComplexDispatchMode` mode on.") - - with ComplexDispatchMode(_compile=compile): - op.gradcheck_fast_mode = False - op.skip_correctness_check_compile_vs_eager = True - self._grad_test_helper(device, dtype, op, op.get_op()) - - instantiate_device_type_tests(TestComplexTensor, globals()) -instantiate_device_type_tests(TestComplexBwdGradients, globals()) if __name__ == "__main__": run_tests() From 8dd8ef2b5065e965a0db1513852b0f071a5995be Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 9 Sep 2025 09:50:18 +0200 Subject: [PATCH 13/23] Remove repetition when registering ops. --- src/complex_tensor/ops/aten.py | 89 ++++++++++++++++++---------------- 1 file changed, 48 insertions(+), 41 deletions(-) diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index 0e6169f..b5647df 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -62,49 +62,56 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b return self.is_pinned(device) -slice_impl = register_simple(aten.slice) -flatten_impl = register_simple(aten.flatten) -view_impl = register_simple(aten.view) -diagonal_impl = register_simple(aten.diagonal) -expand_impl = register_simple(aten.expand) -unsqueeze_impl = register_simple(aten.unsqueeze) -mean_impl = register_simple(aten.mean) -sum_impl = register_simple(aten.sum) -clone_impl = register_simple(aten.clone) -neg_impl = register_simple(aten.neg) -flip_impl = register_simple(aten.flip) -permute_impl = register_simple(aten.permute) -repeat_impl = register_simple(aten.repeat) -index_select_impl = register_simple(aten.index_select) -split_with_sizes_impl = register_simple(aten.split_with_sizes) -cumsum_impl = register_simple(aten.cumsum) -detach_impl = register_simple(aten.detach) -select_impl = register_simple(aten.select) -squeeze_impl = register_simple(aten.squeeze) -zero__impl = register_simple(aten.zero_) -transpose_impl = register_simple(aten.transpose) -t_impl = register_simple(aten.t) -zeros_like_impl = register_simple(aten.zeros_like) -masked_scatter_backward_impl = register_simple(aten.masked_scatter_backward) -select_backward_impl = register_simple(aten.select_backward) -slice_backward_impl = register_simple(aten.slice_backward) +SIMPLE_OPS_LIST = [ + aten.slice, + aten.flatten, + aten.view, + aten.diagonal, + aten.expand, + aten.unsqueeze, + aten.mean, + aten.sum, + aten.clone, + aten.neg, + aten.flip, + aten.permute, + aten.repeat, + aten.index_select, + aten.split_with_sizes, + aten.cumsum, + aten.detach, + aten.select, + aten.squeeze, + aten.zero_, + aten.transpose, + aten.t, + aten.zeros_like, + aten.masked_scatter_backward, + aten.select_backward, + aten.slice_backward, +] + +for simple_op in SIMPLE_OPS_LIST: + globals()[f"{str(simple_op).split('.', 1)}_impl"] = register_simple(simple_op) # TODO (hameerabbasi): Not being tested -copy_impl = register_force_test(aten.copy, register_simple(aten.copy)) -# TODO (hameerabbasi): Not being tested -_to_copy_impl = register_force_test(aten._to_copy, register_simple(aten._to_copy)) -# TODO (hameerabbasi): Not being tested -col2im_impl = register_force_test(aten.col2im, register_simple(aten.col2im)) -# TODO (hameerabbasi): Not being tested -alias_impl = register_force_test(aten.alias, register_simple(aten.alias)) -# TODO (hameerabbasi): Not being tested -lift_fresh_impl = register_force_test(aten.lift_fresh, register_simple(aten.lift_fresh)) -# TODO (hameerabbasi): Not being tested -_unsafe_view_impl = register_force_test(aten._unsafe_view, register_simple(aten._unsafe_view)) -# TODO (hameerabbasi): Not being tested -index_put__impl = register_force_test(aten.index_put_, register_simple(aten.index_put_)) -# TODO (hameerabbasi): Not being tested -index_impl = register_force_test(aten.index, register_simple(aten.index)) +SIMPLE_FORCE_TESTED_OPS = [ + aten.copy, + aten._to_copy, + aten.col2im, + aten.alias, + aten.lift_fresh, + aten._unsafe_view, + aten.index_put_, + aten.index, +] + +for simple_op in SIMPLE_FORCE_TESTED_OPS: + globals()[f"{str(simple_op).split('.', 1)}_impl"] = register_force_test( + simple_op, register_simple(simple_op) + ) + +del simple_op # some binary ops which we can stamp out mul_impl = register_binary_nonlinear(aten.mul) From 6227c830acd73eea46555fa3b21b401401f8a858 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 10 Sep 2025 08:56:38 +0200 Subject: [PATCH 14/23] More ops. --- src/complex_tensor/ops/aten.py | 68 +++++++++++++++++++++++---------- src/complex_tensor/ops/prims.py | 10 +++-- 2 files changed, 54 insertions(+), 24 deletions(-) diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index b5647df..15c1bb6 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -77,6 +77,7 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b aten.permute, aten.repeat, aten.index_select, + aten.split, aten.split_with_sizes, aten.cumsum, aten.detach, @@ -86,9 +87,7 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b aten.transpose, aten.t, aten.zeros_like, - aten.masked_scatter_backward, - aten.select_backward, - aten.slice_backward, + aten.empty_like, ] for simple_op in SIMPLE_OPS_LIST: @@ -104,6 +103,14 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b aten._unsafe_view, aten.index_put_, aten.index, + aten._neg_view, + aten.avg_pool2d, + aten.avg_pool3d, + aten.avg_pool2d_backward, + aten.avg_pool3d_backward, + aten.masked_scatter_backward, + aten.select_backward, + aten.slice_backward, ] for simple_op in SIMPLE_FORCE_TESTED_OPS: @@ -124,11 +131,20 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b aten.convolution, register_binary_nonlinear(aten.convolution) ) +slice_scatter_impl = register_force_test( + aten.slice_scatter, register_binary_linear(aten.slice_scatter) +) +select_scatter_impl = register_force_test( + aten.select_scatter, register_binary_linear(aten.select_scatter) +) + add_impl = register_binary_linear(aten.add) sub_impl = register_binary_linear(aten.sub) +diagonal_scatter_impl = register_binary_linear(aten.diagonal_scatter) @register_complex(aten.div) +@register_complex(aten.true_divide) def div_impl(lhs: ComplexTensor, rhs: ComplexTensor, *, rounding_mode=None): a_r, a_i = split_complex_tensor(lhs) b_r, b_i = split_complex_arg(rhs) @@ -371,6 +387,13 @@ def log_impl(self: ComplexTensor) -> ComplexTensor: return ComplexTensor(re, im) +@register_complex(aten.log1p) +def log1p_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + # TODO (hameerabbasi): The line below may have numerical issues + return torch.log(ComplexTensor(x + 1, y)) + + @register_complex(aten.any) def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: x, y = split_complex_tensor(self) @@ -457,6 +480,7 @@ def isclose_impl( aten.sort, aten.topk, aten.round, + aten.fmod, ] @@ -490,23 +514,6 @@ def masked_scatter_impl( return ComplexTensor(ret_r, ret_i) -@register_force_test(aten.slice_scatter) -def slice_scatter_impl( - self: ComplexTensor, - source: ComplexTensor, - dim: int = 0, - start: int | None = None, - end: int | None = None, - step: int = 1, -) -> ComplexTensor: - self_r, self_i = split_complex_tensor(self) - source_r, source_i = split_complex_arg(source) - ret_r = torch.slice_scatter(self_r, source_r, dim=dim, start=start, end=end, step=step) - ret_i = torch.slice_scatter(self_i, source_i, dim=dim, start=start, end=end, step=step) - - return ComplexTensor(ret_r, ret_i) - - @register_complex(aten.index_put) def index_put_impl( self: ComplexTensor, @@ -724,12 +731,17 @@ def randn_like_impl(self: ComplexTensor, *, dtype=None, **kwargs) -> ComplexTens # TODO (hameerabbasi): Not being tested @register_complex(aten._conj_physical) -@register_complex(aten._conj) def _conj_physical_impl(self: ComplexTensor) -> ComplexTensor: re, im = split_complex_tensor(self) return ComplexTensor(re, -im) +@register_complex(aten._conj) +def _conj_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, aten._neg_view(im)) + + @register_complex(aten.index_add) def index_add_impl( self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs @@ -785,3 +797,17 @@ def masked_fill__impl(self: ComplexTensor, mask: torch.Tensor, value: complex) - ret_im = self_im.masked_fill_(mask, value_im) return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.constant_pad_nd) +def constant_pad_nd_impl(self: ComplexTensor, pad, value: complex | None = None): + self_re, self_im = split_complex_tensor(self) + if value is None: + ret_re = aten.constant_pad_nd(self_re, pad) + ret_im = aten.constant_pad_nd(self_im, pad) + else: + value_re, value_im = split_complex_arg(value) + ret_re = aten.constant_pad_nd(self_re, pad, value_re) + ret_im = aten.constant_pad_nd(self_im, pad, value_im) + + return ComplexTensor(ret_re, ret_im) diff --git a/src/complex_tensor/ops/prims.py b/src/complex_tensor/ops/prims.py index 928149e..95ba84f 100644 --- a/src/complex_tensor/ops/prims.py +++ b/src/complex_tensor/ops/prims.py @@ -9,6 +9,7 @@ ) prims = torch.ops.prims +aten = torch.ops.aten # TODO (hameerabbasi): Not being tested @@ -23,7 +24,10 @@ def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTe @register_complex(prims.conj_physical) -@register_complex(prims.conj) def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: - re, im = split_complex_tensor(self) - return ComplexTensor(re, -im) + return aten._conj_physical(self) + + +@register_complex(prims.conj) +def conj_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj(self) From e25c57c34d05e90e46ec828909712fbe7a5eb27c Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 10 Sep 2025 09:06:29 +0200 Subject: [PATCH 15/23] Remove `real` skip due to upstream fix. --- src/complex_tensor/test/test_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index 52b0906..1ebd9df 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -62,7 +62,6 @@ def _get_opname_from_aten_op(aten_op): SKIPS = { - TestDescriptor(op_name="real"): "`aten.real` does not hit `__torch_dispatch__`", TestDescriptor(op_name="imag"): "`aten.imag` does not hit `__torch_dispatch__`", TestDescriptor(op_name="repeat", dtype=torch.complex64, compile=True): "Heisenbug", TestDescriptor(op_name="repeat", dtype=torch.complex128, compile=True): "Heisenbug", From 8d89a2e58ae743af8077e49d765cb113b8f950a1 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 10 Sep 2025 09:25:43 +0200 Subject: [PATCH 16/23] Clearer error message in testing. --- src/complex_tensor/test/utils.py | 42 +++++++++++++++++--------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/src/complex_tensor/test/utils.py b/src/complex_tensor/test/utils.py index 5602628..2734b40 100644 --- a/src/complex_tensor/test/utils.py +++ b/src/complex_tensor/test/utils.py @@ -39,38 +39,40 @@ def matches(self, other: TestDescriptor) -> bool: class TestCase(PytorchTestCase): def assertSameResult( self, - f1: Callable[[], Any], - f2: Callable[[], Any], + expected: Callable[[], Any], + actual: Callable[[], Any], ignore_exc_types: bool = False, *args, **kwargs, ) -> None: try: - result_1 = f1() - exception_1 = None + result_e = expected() + exception_e = None except Exception as e: # noqa: BLE001 - result_1 = None - exception_1 = e + result_e = None + exception_e = e try: - result_2 = f2() - exception_2 = None + result_a = actual() + exception_a = None except Exception as e: # noqa: BLE001 - result_2 = None - exception_2 = e + result_a = None + exception_a = e # Special case: compiled versions don't match the error type exactly. - if ((exception_1 is None) != (exception_2 is None)) or not ignore_exc_types: - self.assertIs(type(exception_1), type(exception_2), f"\n{exception_1=}\n{exception_2=}") + if ((exception_e is None) != (exception_a is None)) or not ignore_exc_types: + if exception_a is not None and exception_e is None: + raise exception_a + self.assertIs(type(exception_e), type(exception_a), f"\n{exception_e=}\n{exception_a=}") - if exception_1 is None: - flattened_1, spec_1 = tree_flatten(result_1) - flattened_2, spec_2 = tree_flatten(result_2) + if exception_e is None: + flattened_e, spec_e = tree_flatten(result_e) + flattened_a, spec_a = tree_flatten(result_a) self.assertEqual( - spec_1, spec_2, "Both functions must return a result with the same tree structure." + spec_e, spec_a, "Both functions must return a result with the same tree structure." ) - for f1, f2 in zip(flattened_1, flattened_2, strict=False): - f1 = _as_complex_tensor(f1) - f2 = _as_complex_tensor(f1) + for value_e, value_a in zip(flattened_e, flattened_a, strict=False): + value_e = _as_complex_tensor(value_e) + value_a = _as_complex_tensor(value_a) - self.assertEqual(f1, f2, *args, **kwargs) + self.assertEqual(value_e, value_a, *args, **kwargs) From ff129e7cf7f577730f2d8a131ecf83924b742e33 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 10 Sep 2025 09:35:44 +0200 Subject: [PATCH 17/23] Merge two repeat skips into one. --- src/complex_tensor/test/test_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index 1ebd9df..4ebea10 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -63,8 +63,7 @@ def _get_opname_from_aten_op(aten_op): SKIPS = { TestDescriptor(op_name="imag"): "`aten.imag` does not hit `__torch_dispatch__`", - TestDescriptor(op_name="repeat", dtype=torch.complex64, compile=True): "Heisenbug", - TestDescriptor(op_name="repeat", dtype=torch.complex128, compile=True): "Heisenbug", + TestDescriptor(op_name="repeat", compile=True): "Heisenbug", TestDescriptor( op_name="allclose", compile=True ): "`aten.allclose` requires data-dependent control-flow", From f139450492613a884e64d949489433b79870093f Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 10 Sep 2025 12:18:39 +0200 Subject: [PATCH 18/23] Implement `aten.var`. --- src/complex_tensor/ops/aten.py | 8 +++++++- src/complex_tensor/test/test_ops.py | 3 +++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index 15c1bb6..42fdf5f 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -800,7 +800,7 @@ def masked_fill__impl(self: ComplexTensor, mask: torch.Tensor, value: complex) - @register_complex(aten.constant_pad_nd) -def constant_pad_nd_impl(self: ComplexTensor, pad, value: complex | None = None): +def constant_pad_nd_impl(self: ComplexTensor, pad, value: complex | None = None) -> ComplexTensor: self_re, self_im = split_complex_tensor(self) if value is None: ret_re = aten.constant_pad_nd(self_re, pad) @@ -811,3 +811,9 @@ def constant_pad_nd_impl(self: ComplexTensor, pad, value: complex | None = None) ret_im = aten.constant_pad_nd(self_im, pad, value_im) return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.var) +def var_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + self_re, self_im = split_complex_tensor(self) + return torch.var(self_re, *args, **kwargs) + torch.var(self_im, *args, **kwargs) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index 4ebea10..0357b32 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -70,6 +70,9 @@ def _get_opname_from_aten_op(aten_op): TestDescriptor( op_name="randn_like", compile=True ): "`aten.randn_like` doesn't support `torch.compile`", + TestDescriptor( + op_name="var", compile=True + ): "`aten.var` doesn't return valid results with `torch.compile`", } From a84d0a2fe618b01f1088767d1d94f1078efd2ec2 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 10 Sep 2025 14:48:34 +0200 Subject: [PATCH 19/23] Few more ops. --- src/complex_tensor/ops/aten.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index 42fdf5f..f6fa5fa 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -69,6 +69,7 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b aten.diagonal, aten.expand, aten.unsqueeze, + aten.unsqueeze_, aten.mean, aten.sum, aten.clone, @@ -88,6 +89,7 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b aten.t, aten.zeros_like, aten.empty_like, + aten.gather, ] for simple_op in SIMPLE_OPS_LIST: @@ -111,6 +113,7 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b aten.masked_scatter_backward, aten.select_backward, aten.slice_backward, + aten.embedding, ] for simple_op in SIMPLE_FORCE_TESTED_OPS: From 38e2d105713201149ec100dc7d9cd62033fc43a7 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 11 Sep 2025 08:40:16 +0200 Subject: [PATCH 20/23] Fixes for tests + introduce custom tolerances. --- src/complex_tensor/ops/_common.py | 13 +++- src/complex_tensor/ops/aten.py | 106 ++++++++++++++++++---------- src/complex_tensor/test/test_ops.py | 26 +++++-- src/complex_tensor/test/utils.py | 8 +-- 4 files changed, 104 insertions(+), 49 deletions(-) diff --git a/src/complex_tensor/ops/_common.py b/src/complex_tensor/ops/_common.py index 959cfb0..54f61ac 100644 --- a/src/complex_tensor/ops/_common.py +++ b/src/complex_tensor/ops/_common.py @@ -152,10 +152,17 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens def register_simple(op: OpType): - def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + def impl(self: ComplexTensor, *args, dtype=None, **kwargs) -> ComplexTensor: x, y = split_complex_tensor(self) - u = op(x, *args, **kwargs) - v = op(y, *args, **kwargs) + if dtype is None: + u = op(x, *args, **kwargs) + v = op(y, *args, **kwargs) + elif dtype in COMPLEX_TO_REAL: + dtype = COMPLEX_TO_REAL[dtype] + u = op(x, *args, dtype=dtype, **kwargs) + v = op(y, *args, dtype=dtype, **kwargs) + else: + raise RuntimeError("Non-complex `dtype` specified, please write custom impl.") u_flat, u_spec = tree_flatten(u) v_flat, v_spec = tree_flatten(v) assert u_spec == v_spec diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index f6fa5fa..b68b62e 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -87,8 +87,6 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b aten.zero_, aten.transpose, aten.t, - aten.zeros_like, - aten.empty_like, aten.gather, ] @@ -259,11 +257,13 @@ def atan_impl(self: ComplexTensor) -> ComplexTensor: x, y = split_complex_tensor(self) out_dt, (x, y) = promote_real_cpu_tensors(x, y) - int1 = 0.5 * (torch.log(ComplexTensor(1 - y, x)) - torch.log(ComplexTensor(1 + y, -x))) - + int1 = torch.log(ComplexTensor(-x, 1 - y) / ComplexTensor(x, 1 + y)) int1_re, int1_im = split_complex_tensor(int1) - return ComplexTensor(-int1_im.to(out_dt), int1_re.to(out_dt)) + out_re = 0.5 * int1_im + out_im = -0.5 * int1_re + + return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt)) @register_complex(aten.asinh) @@ -298,11 +298,9 @@ def cos_impl(self: ComplexTensor) -> ComplexTensor: @register_complex(aten.cosh) def cosh_impl(self: ComplexTensor) -> ComplexTensor: - self_re, self_im = split_complex_tensor(self) - out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im) - ret_re = torch.cosh(self_re) * torch.cos(self_im) - ret_im = torch.sinh(self_re) * torch.sin(self_im) - return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + exp_x = torch.exp(self) + exp_nx = torch.reciprocal(exp_x) + return 0.5 * (exp_x + exp_nx) @register_complex(aten.sin) @@ -318,13 +316,9 @@ def sin_impl(self: ComplexTensor) -> ComplexTensor: @register_complex(aten.sinh) def sinh_impl(self: ComplexTensor) -> ComplexTensor: - self_re, self_im = split_complex_tensor(self) - out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im) - - ret_re = torch.sinh(self_re) * torch.cos(self_im) - ret_im = torch.cosh(self_re) * torch.sin(self_im) - - return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + exp_x = torch.exp(self) + exp_nx = torch.reciprocal(exp_x) + return 0.5 * (exp_x - exp_nx) @register_complex(aten.tan) @@ -345,21 +339,7 @@ def tan_impl(self: ComplexTensor) -> ComplexTensor: @register_complex(aten.tanh) def tanh_impl(self: ComplexTensor) -> ComplexTensor: - self_re, self_im = split_complex_tensor(self) - out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im) - - tanh_x = torch.tanh(self_re) - tan_y = torch.tan(self_im) - - tanh2_x = tanh_x * tanh_x - tan2_y = tan_y * tan_y - - num_re = tanh_x * (1 + tan2_y) - num_im = -tan_y * (1 + tanh2_x) - - den = 1 + tanh2_x * tan2_y - - return ComplexTensor((num_re / den).to(out_dt), (num_im / den).to(out_dt)) + return torch.sinh(self) / torch.cosh(self) @register_complex(aten.exp) @@ -411,7 +391,7 @@ def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: @register_complex(aten.eq) def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: - a_r, a_i = split_complex_tensor(self) + a_r, a_i = split_complex_arg(self) b_r, b_i = split_complex_arg(rhs) return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs) @@ -544,12 +524,63 @@ def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> Comple @register_complex(aten.full_like) -def full_like_impl(input: ComplexTensor, fill_value: complex, *args, **kwargs) -> ComplexTensor: +def full_like_impl( + input: ComplexTensor, fill_value: complex, *args, **kwargs +) -> torch.Tensor | ComplexTensor: + dtype = kwargs.pop("dtype", None) input_r, input_i = split_complex_tensor(input) fv_r, fv_i = split_complex_arg(fill_value) + if dtype is None: + ret_r = torch.full_like(input_r, fv_r, *args, **kwargs) + ret_i = torch.full_like(input_i, fv_i, *args, **kwargs) + return ComplexTensor(ret_r, ret_i) + + if dtype not in COMPLEX_TO_REAL: + return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs) + + dtype = COMPLEX_TO_REAL[dtype] + ret_r = torch.full_like(input_r, fv_r, *args, dtype=dtype, **kwargs) + ret_i = torch.full_like(input_i, fv_i, *args, dtype=dtype, **kwargs) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.empty_like) +def empty_like_impl( + input: ComplexTensor, fill_value: complex, *args, **kwargs +) -> torch.Tensor | ComplexTensor: + dtype = kwargs.pop("dtype", None) + input_r, input_i = split_complex_tensor(input) + if dtype is None: + ret_r = torch.empty_like(input_r, *args, **kwargs) + ret_i = torch.empty_like(input_i, *args, **kwargs) + return ComplexTensor(ret_r, ret_i) + + if dtype not in COMPLEX_TO_REAL: + return torch.empty_like(input_r, *args, dtype=dtype, **kwargs) - ret_r = torch.full_like(input_r, fv_r, *args, **kwargs) - ret_i = torch.full_like(input_i, fv_i, *args, **kwargs) + dtype = COMPLEX_TO_REAL[dtype] + ret_r = torch.empty_like(input_r, *args, dtype=dtype, **kwargs) + ret_i = torch.empty_like(input_i, *args, dtype=dtype, **kwargs) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.zeros_like) +def zeros_like_impl(input: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor: + dtype = kwargs.pop("dtype", None) + input_r, input_i = split_complex_tensor(input) + if dtype is None: + ret_r = torch.zeros_like(input_r, *args, **kwargs) + ret_i = torch.zeros_like(input_i, *args, **kwargs) + return ComplexTensor(ret_r, ret_i) + + if dtype not in COMPLEX_TO_REAL: + return torch.zeros_like(input_r, *args, dtype=dtype, **kwargs) + + dtype = COMPLEX_TO_REAL[dtype] + ret_r = torch.zeros_like(input_r, *args, dtype=dtype, **kwargs) + ret_i = torch.zeros_like(input_i, *args, dtype=dtype, **kwargs) return ComplexTensor(ret_r, ret_i) @@ -619,7 +650,7 @@ def addmm_impl( beta: complex = 1, alpha: complex = 1, ) -> ComplexTensor: - ret = alpha * input + beta * torch.mm(mat1, mat2) + ret = beta * input + alpha * torch.mm(mat1, mat2) ret_r, ret_i = split_complex_tensor(ret) if out_dtype is not None: out_dtype = COMPLEX_TO_REAL[out_dtype] @@ -739,6 +770,7 @@ def _conj_physical_impl(self: ComplexTensor) -> ComplexTensor: return ComplexTensor(re, -im) +# TODO (hameerabbasi): Not being tested @register_complex(aten._conj) def _conj_impl(self: ComplexTensor) -> ComplexTensor: re, im = split_complex_tensor(self) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index 0357b32..74696d4 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -11,11 +11,11 @@ from torch.testing._internal.opinfo.core import OpInfo from complex_tensor.ops import COMPLEX_OPS_TABLE, FORCE_TEST_LIST +from complex_tensor.ops._common import _as_complex_tensor from complex_tensor.test.utils import ( COMPLEX_DTYPES, TestCase, TestDescriptor, - _as_complex_tensor, ) torch._dynamo.config.recompile_limit = float("inf") @@ -62,19 +62,29 @@ def _get_opname_from_aten_op(aten_op): SKIPS = { + TestDescriptor(op_name="real"): "`aten.real` does not hit `__torch_dispatch__`", TestDescriptor(op_name="imag"): "`aten.imag` does not hit `__torch_dispatch__`", + TestDescriptor(op_name="conj"): "`prims.conj` does not hit `__torch_dispatch__`", + TestDescriptor( + op_name="conj_physical" + ): "`prims.conj_physical` does not hit `__torch_dispatch__`", + TestDescriptor(op_name="empty_like"): "Inconsistent output", TestDescriptor(op_name="repeat", compile=True): "Heisenbug", TestDescriptor( op_name="allclose", compile=True ): "`aten.allclose` requires data-dependent control-flow", - TestDescriptor( - op_name="randn_like", compile=True - ): "`aten.randn_like` doesn't support `torch.compile`", + TestDescriptor(op_name="randn_like"): "Inconsistent output", TestDescriptor( op_name="var", compile=True ): "`aten.var` doesn't return valid results with `torch.compile`", } +EXTRA_KWARGS = { + TestDescriptor(op_name="asinh", dtype=torch.complex64): {"rtol": 2e-5, "atol": 5e-5}, + TestDescriptor(op_name="tanh", dtype=torch.complex64): {"rtol": 1e-4, "atol": 1e-5}, + TestDescriptor(op_name="pow", dtype=torch.complex64): {"rtol": 2e-2, "atol": 2e-6}, +} + class TestComplexTensor(TestCase): _default_dtype_check_enabled = True @@ -97,6 +107,12 @@ def check_consistency(self, device, dtype, op: OpInfo, compile: bool) -> None: if xfail_info.matches(test_info): self.skipTest(reason) + kwargs = {} + for extra_info, extra_kw in EXTRA_KWARGS.items(): + if extra_info.matches(test_info): + kwargs = extra_kw + break + sample_inputs = op.sample_inputs(device, dtype) op_eager = op if compile: @@ -112,7 +128,7 @@ def expected(sample_input=sample_input): def actual(subclass_sample=subclass_sample): return op(subclass_sample.input, *subclass_sample.args, **subclass_sample.kwargs) - self.assertSameResult(expected, actual, ignore_exc_types=compile) + self.assertSameResult(expected, actual, ignore_exc_types=compile, **kwargs) instantiate_device_type_tests(TestComplexTensor, globals()) diff --git a/src/complex_tensor/test/utils.py b/src/complex_tensor/test/utils.py index 2734b40..f716401 100644 --- a/src/complex_tensor/test/utils.py +++ b/src/complex_tensor/test/utils.py @@ -8,7 +8,7 @@ from torch.testing._internal.common_utils import TestCase as PytorchTestCase from torch.utils._pytree import tree_flatten -from complex_tensor.ops._common import COMPLEX_TO_REAL, _as_complex_tensor +from complex_tensor.ops._common import COMPLEX_TO_REAL, _as_interleaved COMPLEX_DTYPES = set(COMPLEX_TO_REAL) @@ -71,8 +71,8 @@ def assertSameResult( self.assertEqual( spec_e, spec_a, "Both functions must return a result with the same tree structure." ) - for value_e, value_a in zip(flattened_e, flattened_a, strict=False): - value_e = _as_complex_tensor(value_e) - value_a = _as_complex_tensor(value_a) + for value_e, value_a in zip(flattened_e, flattened_a, strict=True): + value_e = _as_interleaved(value_e) + value_a = _as_interleaved(value_a) self.assertEqual(value_e, value_a, *args, **kwargs) From 16c0a774b453ffb8f6f41b08ff3742c515edf451 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 11 Sep 2025 08:43:21 +0200 Subject: [PATCH 21/23] Revert "Remove gradient checking temporarily." This reverts commit 086ec02c494977179c5f26451c45330360a37aa8. --- src/complex_tensor/test/test_ops.py | 53 +++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index 74696d4..4be220a 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -80,9 +80,18 @@ def _get_opname_from_aten_op(aten_op): } EXTRA_KWARGS = { - TestDescriptor(op_name="asinh", dtype=torch.complex64): {"rtol": 2e-5, "atol": 5e-5}, - TestDescriptor(op_name="tanh", dtype=torch.complex64): {"rtol": 1e-4, "atol": 1e-5}, - TestDescriptor(op_name="pow", dtype=torch.complex64): {"rtol": 2e-2, "atol": 2e-6}, + TestDescriptor(op_name="asinh", dtype=torch.complex64, gradcheck=False): { + "rtol": 2e-5, + "atol": 5e-5, + }, + TestDescriptor(op_name="tanh", dtype=torch.complex64, gradcheck=False): { + "rtol": 1e-4, + "atol": 1e-5, + }, + TestDescriptor(op_name="pow", dtype=torch.complex64, gradcheck=False): { + "rtol": 2e-2, + "atol": 2e-6, + }, } @@ -131,7 +140,45 @@ def actual(subclass_sample=subclass_sample): self.assertSameResult(expected, actual, ignore_exc_types=compile, **kwargs) +class TestComplexBwdGradients(TestCase): + @parametrize("compile", [False]) + @ops(implemented_op_db, allowed_dtypes=[torch.complex128]) + def test_fn_grad(self, device, dtype, op: OpInfo, compile: bool) -> None: + if dtype not in op.supported_backward_dtypes(torch.device(device).type): + self.skipTest("Skipped! Dtype is not in supported backward dtypes!") + + self.check_consistency(device, dtype, op, compile) + + def check_consistency(self, device, dtype, op: OpInfo, compile: bool) -> None: + test_info = TestDescriptor( + op_name=op.name, device=device, dtype=dtype, compile=compile, gradcheck=True + ) + for xfail_info, reason in SKIPS.items(): + if xfail_info.matches(test_info): + self.skipTest(reason) + + kwargs = {} + for extra_info, extra_kw in EXTRA_KWARGS.items(): + if extra_info.matches(test_info): + kwargs = extra_kw + break + + sample_inputs = op.sample_inputs(device, dtype) + if compile: + op = torch.compile(op, fullgraph=True) + + for sample_input in sample_inputs: + subclass_sample = sample_input.transform(_as_complex_tensor) + + def grad_fn(input, subclass_sample=subclass_sample): + return op(input, *subclass_sample.args, **subclass_sample.kwargs) + + subclass_sample.input.requires_grad_() + torch.autograd.gradcheck(grad_fn, subclass_sample.input, **kwargs) + + instantiate_device_type_tests(TestComplexTensor, globals()) +instantiate_device_type_tests(TestComplexBwdGradients, globals()) if __name__ == "__main__": run_tests() From aeac72e3f4fa1e1551e20f2eb2234d2fc510ac91 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 15 Sep 2025 09:46:44 +0200 Subject: [PATCH 22/23] Annotations for `dtype`, reduce duplication in `_like` functions. --- src/complex_tensor/ops/_common.py | 20 +++--- src/complex_tensor/ops/aten.py | 110 ++++++++++-------------------- 2 files changed, 48 insertions(+), 82 deletions(-) diff --git a/src/complex_tensor/ops/_common.py b/src/complex_tensor/ops/_common.py index 54f61ac..5d2196f 100644 --- a/src/complex_tensor/ops/_common.py +++ b/src/complex_tensor/ops/_common.py @@ -152,17 +152,19 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens def register_simple(op: OpType): - def impl(self: ComplexTensor, *args, dtype=None, **kwargs) -> ComplexTensor: + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> ComplexTensor: x, y = split_complex_tensor(self) - if dtype is None: - u = op(x, *args, **kwargs) - v = op(y, *args, **kwargs) - elif dtype in COMPLEX_TO_REAL: - dtype = COMPLEX_TO_REAL[dtype] - u = op(x, *args, dtype=dtype, **kwargs) - v = op(y, *args, dtype=dtype, **kwargs) - else: + if dtype is not None and dtype not in COMPLEX_TO_REAL: raise RuntimeError("Non-complex `dtype` specified, please write custom impl.") + + if dtype in COMPLEX_TO_REAL: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + u = op(x, *args, **kwargs) + v = op(y, *args, **kwargs) + u_flat, u_spec = tree_flatten(u) v_flat, v_spec = tree_flatten(v) assert u_spec == v_spec diff --git a/src/complex_tensor/ops/aten.py b/src/complex_tensor/ops/aten.py index b68b62e..d1fa23e 100644 --- a/src/complex_tensor/ops/aten.py +++ b/src/complex_tensor/ops/aten.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import torch @@ -525,64 +525,58 @@ def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> Comple @register_complex(aten.full_like) def full_like_impl( - input: ComplexTensor, fill_value: complex, *args, **kwargs + input: ComplexTensor, fill_value: complex, *args, dtype: torch.dtype | None = None, **kwargs ) -> torch.Tensor | ComplexTensor: - dtype = kwargs.pop("dtype", None) + # Note: Cannot be merged with the cases below due to the `fill_value` argument input_r, input_i = split_complex_tensor(input) - fv_r, fv_i = split_complex_arg(fill_value) - if dtype is None: - ret_r = torch.full_like(input_r, fv_r, *args, **kwargs) - ret_i = torch.full_like(input_i, fv_i, *args, **kwargs) - return ComplexTensor(ret_r, ret_i) - - if dtype not in COMPLEX_TO_REAL: + if dtype is not None and dtype not in COMPLEX_TO_REAL: return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs) - dtype = COMPLEX_TO_REAL[dtype] - ret_r = torch.full_like(input_r, fv_r, *args, dtype=dtype, **kwargs) - ret_i = torch.full_like(input_i, fv_i, *args, dtype=dtype, **kwargs) + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + fv_r, fv_i = split_complex_arg(fill_value) + ret_r = torch.full_like(input_r, fv_r, *args, **kwargs) + ret_i = torch.full_like(input_i, fv_i, *args, **kwargs) return ComplexTensor(ret_r, ret_i) -@register_complex(aten.empty_like) -def empty_like_impl( - input: ComplexTensor, fill_value: complex, *args, **kwargs -) -> torch.Tensor | ComplexTensor: - dtype = kwargs.pop("dtype", None) - input_r, input_i = split_complex_tensor(input) - if dtype is None: - ret_r = torch.empty_like(input_r, *args, **kwargs) - ret_i = torch.empty_like(input_i, *args, **kwargs) - return ComplexTensor(ret_r, ret_i) +def register_like(op: OpType) -> Callable[..., torch.Tensor | ComplexTensor]: + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> torch.Tensor | ComplexTensor: + self_re, self_im = split_complex_tensor(self) - if dtype not in COMPLEX_TO_REAL: - return torch.empty_like(input_r, *args, dtype=dtype, **kwargs) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return op(self_re, *args, dtype=dtype, **kwargs) - dtype = COMPLEX_TO_REAL[dtype] - ret_r = torch.empty_like(input_r, *args, dtype=dtype, **kwargs) - ret_i = torch.empty_like(input_i, *args, dtype=dtype, **kwargs) + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] - return ComplexTensor(ret_r, ret_i) + ret_re = op(self_re, *args, **kwargs) + ret_im = op(self_im, *args, **kwargs) + return ComplexTensor(ret_re, ret_im) + + func_name = f"{str(op).split('.', 1)}_impl" + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) -@register_complex(aten.zeros_like) -def zeros_like_impl(input: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor: - dtype = kwargs.pop("dtype", None) - input_r, input_i = split_complex_tensor(input) - if dtype is None: - ret_r = torch.zeros_like(input_r, *args, **kwargs) - ret_i = torch.zeros_like(input_i, *args, **kwargs) - return ComplexTensor(ret_r, ret_i) - if dtype not in COMPLEX_TO_REAL: - return torch.zeros_like(input_r, *args, dtype=dtype, **kwargs) +LIKE_OPS_LIST = [ + aten.empty_like, + aten.zeros_like, + aten.randn_like, + aten.new_zeros, +] - dtype = COMPLEX_TO_REAL[dtype] - ret_r = torch.zeros_like(input_r, *args, dtype=dtype, **kwargs) - ret_i = torch.zeros_like(input_i, *args, dtype=dtype, **kwargs) +for like_op in LIKE_OPS_LIST: + globals()[f"{str(like_op).split('.', 1)}_impl"] = register_like(like_op) - return ComplexTensor(ret_r, ret_i) +del like_op @register_complex(aten.cat) @@ -706,22 +700,6 @@ def copy__impl(self: ComplexTensor, src, *args, **kwargs): return ComplexTensor(ret_re, ret_im) -@register_complex(aten.new_zeros) -def new_zeros_impl( - self: ComplexTensor, size, *, dtype=None, **kwargs -) -> ComplexTensor | torch.Tensor: - self_re, self_im = split_complex_tensor(self) - if dtype is not None and dtype not in COMPLEX_TO_REAL: - return self_re.new_zeros(size, dtype=dtype, **kwargs) - - if dtype is not None: - dtype = COMPLEX_TO_REAL[dtype] - re = self_re.new_zeros(size, dtype=dtype, **kwargs) - im = self_im.new_zeros(size, dtype=dtype, **kwargs) - - return ComplexTensor(re, im) - - @register_complex(aten._local_scalar_dense) def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex: x, y = split_complex_tensor(self) @@ -749,20 +727,6 @@ def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor: return ComplexTensor(u, v) -@register_complex(aten.randn_like) -def randn_like_impl(self: ComplexTensor, *, dtype=None, **kwargs) -> ComplexTensor | torch.Tensor: - if dtype is not None and dtype not in COMPLEX_TO_REAL: - return torch.randn_like(self.re, dtype=dtype, **kwargs) - - if dtype is not None: - dtype = COMPLEX_TO_REAL[dtype] - - self_re, self_im = split_complex_tensor(self) - ret_re = 0.5 * torch.randn_like(self_re, dtype=dtype, **kwargs) - ret_im = 0.5 * torch.randn_like(self_im, dtype=dtype, **kwargs) - return ComplexTensor(ret_re, ret_im) - - # TODO (hameerabbasi): Not being tested @register_complex(aten._conj_physical) def _conj_physical_impl(self: ComplexTensor) -> ComplexTensor: From def06938f52b147b2d99d0bd7109ceda31b25761 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 15 Sep 2025 09:56:33 +0200 Subject: [PATCH 23/23] Remove gradient checking temporarily. --- src/complex_tensor/test/test_ops.py | 38 ----------------------------- 1 file changed, 38 deletions(-) diff --git a/src/complex_tensor/test/test_ops.py b/src/complex_tensor/test/test_ops.py index 4be220a..914b8b4 100644 --- a/src/complex_tensor/test/test_ops.py +++ b/src/complex_tensor/test/test_ops.py @@ -140,45 +140,7 @@ def actual(subclass_sample=subclass_sample): self.assertSameResult(expected, actual, ignore_exc_types=compile, **kwargs) -class TestComplexBwdGradients(TestCase): - @parametrize("compile", [False]) - @ops(implemented_op_db, allowed_dtypes=[torch.complex128]) - def test_fn_grad(self, device, dtype, op: OpInfo, compile: bool) -> None: - if dtype not in op.supported_backward_dtypes(torch.device(device).type): - self.skipTest("Skipped! Dtype is not in supported backward dtypes!") - - self.check_consistency(device, dtype, op, compile) - - def check_consistency(self, device, dtype, op: OpInfo, compile: bool) -> None: - test_info = TestDescriptor( - op_name=op.name, device=device, dtype=dtype, compile=compile, gradcheck=True - ) - for xfail_info, reason in SKIPS.items(): - if xfail_info.matches(test_info): - self.skipTest(reason) - - kwargs = {} - for extra_info, extra_kw in EXTRA_KWARGS.items(): - if extra_info.matches(test_info): - kwargs = extra_kw - break - - sample_inputs = op.sample_inputs(device, dtype) - if compile: - op = torch.compile(op, fullgraph=True) - - for sample_input in sample_inputs: - subclass_sample = sample_input.transform(_as_complex_tensor) - - def grad_fn(input, subclass_sample=subclass_sample): - return op(input, *subclass_sample.args, **subclass_sample.kwargs) - - subclass_sample.input.requires_grad_() - torch.autograd.gradcheck(grad_fn, subclass_sample.input, **kwargs) - - instantiate_device_type_tests(TestComplexTensor, globals()) -instantiate_device_type_tests(TestComplexBwdGradients, globals()) if __name__ == "__main__": run_tests()