Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9426c46
Add gradient checks: first attempt.
hameerabbasi Sep 2, 2025
4a0b409
Add a bunch of ops trying to get gradients to work.
hameerabbasi Sep 3, 2025
dbf695b
Clarify warning message slightly.
hameerabbasi Sep 3, 2025
9c3c81c
Skip `allclose` for `compile=True` due to data-dependency.
hameerabbasi Sep 3, 2025
c309f2f
Small fix for `allclose`.
hameerabbasi Sep 3, 2025
5b4d80c
Add a number of ops for backward computations.
hameerabbasi Sep 4, 2025
5b8545f
Adjust gradient tests.
hameerabbasi Sep 4, 2025
5bf3c83
Update lockfile.
hameerabbasi Sep 4, 2025
bbbee18
Add a few ops.
hameerabbasi Sep 4, 2025
d59e11a
Add a guard to avoid registering multiple impls.
hameerabbasi Sep 4, 2025
c09eb17
Add a lot more trigonometric functions.
hameerabbasi Sep 9, 2025
086ec02
Remove gradient checking temporarily.
hameerabbasi Sep 9, 2025
8dd8ef2
Remove repetition when registering ops.
hameerabbasi Sep 9, 2025
6227c83
More ops.
hameerabbasi Sep 10, 2025
e25c57c
Remove `real` skip due to upstream fix.
hameerabbasi Sep 10, 2025
8d89a2e
Clearer error message in testing.
hameerabbasi Sep 10, 2025
ff129e7
Merge two repeat skips into one.
hameerabbasi Sep 10, 2025
f139450
Implement `aten.var`.
hameerabbasi Sep 10, 2025
a84d0a2
Few more ops.
hameerabbasi Sep 10, 2025
38e2d10
Fixes for tests + introduce custom tolerances.
hameerabbasi Sep 11, 2025
16c0a77
Revert "Remove gradient checking temporarily."
hameerabbasi Sep 11, 2025
aeac72e
Annotations for `dtype`, reduce duplication in `_like` functions.
hameerabbasi Sep 15, 2025
def0693
Remove gradient checking temporarily.
hameerabbasi Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/complex_tensor/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,17 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens


def register_simple(op: OpType):
def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
def impl(self: ComplexTensor, *args, dtype=None, **kwargs) -> ComplexTensor:
x, y = split_complex_tensor(self)
u = op(x, *args, **kwargs)
v = op(y, *args, **kwargs)
if dtype is None:
u = op(x, *args, **kwargs)
v = op(y, *args, **kwargs)
elif dtype in COMPLEX_TO_REAL:
dtype = COMPLEX_TO_REAL[dtype]
u = op(x, *args, dtype=dtype, **kwargs)
v = op(y, *args, dtype=dtype, **kwargs)
else:
raise RuntimeError("Non-complex `dtype` specified, please write custom impl.")
u_flat, u_spec = tree_flatten(u)
v_flat, v_spec = tree_flatten(v)
assert u_spec == v_spec
Expand Down
106 changes: 69 additions & 37 deletions src/complex_tensor/ops/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b
aten.zero_,
aten.transpose,
aten.t,
aten.zeros_like,
aten.empty_like,
aten.gather,
]

Expand Down Expand Up @@ -259,11 +257,13 @@ def atan_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_real_cpu_tensors(x, y)

int1 = 0.5 * (torch.log(ComplexTensor(1 - y, x)) - torch.log(ComplexTensor(1 + y, -x)))

int1 = torch.log(ComplexTensor(-x, 1 - y) / ComplexTensor(x, 1 + y))
int1_re, int1_im = split_complex_tensor(int1)

return ComplexTensor(-int1_im.to(out_dt), int1_re.to(out_dt))
out_re = 0.5 * int1_im
out_im = -0.5 * int1_re

return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt))


@register_complex(aten.asinh)
Expand Down Expand Up @@ -298,11 +298,9 @@ def cos_impl(self: ComplexTensor) -> ComplexTensor:

@register_complex(aten.cosh)
def cosh_impl(self: ComplexTensor) -> ComplexTensor:
self_re, self_im = split_complex_tensor(self)
out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im)
ret_re = torch.cosh(self_re) * torch.cos(self_im)
ret_im = torch.sinh(self_re) * torch.sin(self_im)
return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt))
exp_x = torch.exp(self)
exp_nx = torch.reciprocal(exp_x)
return 0.5 * (exp_x + exp_nx)


@register_complex(aten.sin)
Expand All @@ -318,13 +316,9 @@ def sin_impl(self: ComplexTensor) -> ComplexTensor:

@register_complex(aten.sinh)
def sinh_impl(self: ComplexTensor) -> ComplexTensor:
self_re, self_im = split_complex_tensor(self)
out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im)

ret_re = torch.sinh(self_re) * torch.cos(self_im)
ret_im = torch.cosh(self_re) * torch.sin(self_im)

return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt))
exp_x = torch.exp(self)
exp_nx = torch.reciprocal(exp_x)
return 0.5 * (exp_x - exp_nx)


@register_complex(aten.tan)
Expand All @@ -345,21 +339,7 @@ def tan_impl(self: ComplexTensor) -> ComplexTensor:

@register_complex(aten.tanh)
def tanh_impl(self: ComplexTensor) -> ComplexTensor:
self_re, self_im = split_complex_tensor(self)
out_dt, (self_re, self_im) = promote_real_cpu_tensors(self_re, self_im)

tanh_x = torch.tanh(self_re)
tan_y = torch.tan(self_im)

tanh2_x = tanh_x * tanh_x
tan2_y = tan_y * tan_y

num_re = tanh_x * (1 + tan2_y)
num_im = -tan_y * (1 + tanh2_x)

den = 1 + tanh2_x * tan2_y

return ComplexTensor((num_re / den).to(out_dt), (num_im / den).to(out_dt))
return torch.sinh(self) / torch.cosh(self)


@register_complex(aten.exp)
Expand Down Expand Up @@ -411,7 +391,7 @@ def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:

@register_complex(aten.eq)
def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor:
a_r, a_i = split_complex_tensor(self)
a_r, a_i = split_complex_arg(self)
b_r, b_i = split_complex_arg(rhs)
return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs)

Expand Down Expand Up @@ -544,12 +524,63 @@ def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> Comple


@register_complex(aten.full_like)
def full_like_impl(input: ComplexTensor, fill_value: complex, *args, **kwargs) -> ComplexTensor:
def full_like_impl(
input: ComplexTensor, fill_value: complex, *args, **kwargs
) -> torch.Tensor | ComplexTensor:
dtype = kwargs.pop("dtype", None)
input_r, input_i = split_complex_tensor(input)
fv_r, fv_i = split_complex_arg(fill_value)
if dtype is None:
ret_r = torch.full_like(input_r, fv_r, *args, **kwargs)
ret_i = torch.full_like(input_i, fv_i, *args, **kwargs)
return ComplexTensor(ret_r, ret_i)

if dtype not in COMPLEX_TO_REAL:
return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs)

dtype = COMPLEX_TO_REAL[dtype]
ret_r = torch.full_like(input_r, fv_r, *args, dtype=dtype, **kwargs)
ret_i = torch.full_like(input_i, fv_i, *args, dtype=dtype, **kwargs)

return ComplexTensor(ret_r, ret_i)


@register_complex(aten.empty_like)
def empty_like_impl(
input: ComplexTensor, fill_value: complex, *args, **kwargs
) -> torch.Tensor | ComplexTensor:
dtype = kwargs.pop("dtype", None)
input_r, input_i = split_complex_tensor(input)
if dtype is None:
ret_r = torch.empty_like(input_r, *args, **kwargs)
ret_i = torch.empty_like(input_i, *args, **kwargs)
return ComplexTensor(ret_r, ret_i)

if dtype not in COMPLEX_TO_REAL:
return torch.empty_like(input_r, *args, dtype=dtype, **kwargs)

ret_r = torch.full_like(input_r, fv_r, *args, **kwargs)
ret_i = torch.full_like(input_i, fv_i, *args, **kwargs)
dtype = COMPLEX_TO_REAL[dtype]
ret_r = torch.empty_like(input_r, *args, dtype=dtype, **kwargs)
ret_i = torch.empty_like(input_i, *args, dtype=dtype, **kwargs)

return ComplexTensor(ret_r, ret_i)


@register_complex(aten.zeros_like)
def zeros_like_impl(input: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor:
dtype = kwargs.pop("dtype", None)
input_r, input_i = split_complex_tensor(input)
if dtype is None:
ret_r = torch.zeros_like(input_r, *args, **kwargs)
ret_i = torch.zeros_like(input_i, *args, **kwargs)
return ComplexTensor(ret_r, ret_i)

if dtype not in COMPLEX_TO_REAL:
return torch.zeros_like(input_r, *args, dtype=dtype, **kwargs)

dtype = COMPLEX_TO_REAL[dtype]
ret_r = torch.zeros_like(input_r, *args, dtype=dtype, **kwargs)
ret_i = torch.zeros_like(input_i, *args, dtype=dtype, **kwargs)

return ComplexTensor(ret_r, ret_i)

Expand Down Expand Up @@ -619,7 +650,7 @@ def addmm_impl(
beta: complex = 1,
alpha: complex = 1,
) -> ComplexTensor:
ret = alpha * input + beta * torch.mm(mat1, mat2)
ret = beta * input + alpha * torch.mm(mat1, mat2)
ret_r, ret_i = split_complex_tensor(ret)
if out_dtype is not None:
out_dtype = COMPLEX_TO_REAL[out_dtype]
Expand Down Expand Up @@ -739,6 +770,7 @@ def _conj_physical_impl(self: ComplexTensor) -> ComplexTensor:
return ComplexTensor(re, -im)


# TODO (hameerabbasi): Not being tested
@register_complex(aten._conj)
def _conj_impl(self: ComplexTensor) -> ComplexTensor:
re, im = split_complex_tensor(self)
Expand Down
26 changes: 21 additions & 5 deletions src/complex_tensor/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from torch.testing._internal.opinfo.core import OpInfo

from complex_tensor.ops import COMPLEX_OPS_TABLE, FORCE_TEST_LIST
from complex_tensor.ops._common import _as_complex_tensor
from complex_tensor.test.utils import (
COMPLEX_DTYPES,
TestCase,
TestDescriptor,
_as_complex_tensor,
)

torch._dynamo.config.recompile_limit = float("inf")
Expand Down Expand Up @@ -62,19 +62,29 @@ def _get_opname_from_aten_op(aten_op):


SKIPS = {
TestDescriptor(op_name="real"): "`aten.real` does not hit `__torch_dispatch__`",
TestDescriptor(op_name="imag"): "`aten.imag` does not hit `__torch_dispatch__`",
TestDescriptor(op_name="conj"): "`prims.conj` does not hit `__torch_dispatch__`",
TestDescriptor(
op_name="conj_physical"
): "`prims.conj_physical` does not hit `__torch_dispatch__`",
TestDescriptor(op_name="empty_like"): "Inconsistent output",
TestDescriptor(op_name="repeat", compile=True): "Heisenbug",
TestDescriptor(
op_name="allclose", compile=True
): "`aten.allclose` requires data-dependent control-flow",
TestDescriptor(
op_name="randn_like", compile=True
): "`aten.randn_like` doesn't support `torch.compile`",
TestDescriptor(op_name="randn_like"): "Inconsistent output",
TestDescriptor(
op_name="var", compile=True
): "`aten.var` doesn't return valid results with `torch.compile`",
}

EXTRA_KWARGS = {
TestDescriptor(op_name="asinh", dtype=torch.complex64): {"rtol": 2e-5, "atol": 5e-5},
TestDescriptor(op_name="tanh", dtype=torch.complex64): {"rtol": 1e-4, "atol": 1e-5},
TestDescriptor(op_name="pow", dtype=torch.complex64): {"rtol": 2e-2, "atol": 2e-6},
}


class TestComplexTensor(TestCase):
_default_dtype_check_enabled = True
Expand All @@ -97,6 +107,12 @@ def check_consistency(self, device, dtype, op: OpInfo, compile: bool) -> None:
if xfail_info.matches(test_info):
self.skipTest(reason)

kwargs = {}
for extra_info, extra_kw in EXTRA_KWARGS.items():
if extra_info.matches(test_info):
kwargs = extra_kw
break

sample_inputs = op.sample_inputs(device, dtype)
op_eager = op
if compile:
Expand All @@ -112,7 +128,7 @@ def expected(sample_input=sample_input):
def actual(subclass_sample=subclass_sample):
return op(subclass_sample.input, *subclass_sample.args, **subclass_sample.kwargs)

self.assertSameResult(expected, actual, ignore_exc_types=compile)
self.assertSameResult(expected, actual, ignore_exc_types=compile, **kwargs)


instantiate_device_type_tests(TestComplexTensor, globals())
Expand Down
8 changes: 4 additions & 4 deletions src/complex_tensor/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
from torch.utils._pytree import tree_flatten

from complex_tensor.ops._common import COMPLEX_TO_REAL, _as_complex_tensor
from complex_tensor.ops._common import COMPLEX_TO_REAL, _as_interleaved

COMPLEX_DTYPES = set(COMPLEX_TO_REAL)

Expand Down Expand Up @@ -71,8 +71,8 @@ def assertSameResult(
self.assertEqual(
spec_e, spec_a, "Both functions must return a result with the same tree structure."
)
for value_e, value_a in zip(flattened_e, flattened_a, strict=False):
value_e = _as_complex_tensor(value_e)
value_a = _as_complex_tensor(value_a)
for value_e, value_a in zip(flattened_e, flattened_a, strict=True):
value_e = _as_interleaved(value_e)
value_a = _as_interleaved(value_a)

self.assertEqual(value_e, value_a, *args, **kwargs)