Skip to content

Commit 9426c46

Browse files
committed
Add gradient checks: first attempt.
1 parent 45dffd6 commit 9426c46

File tree

4 files changed

+126
-17
lines changed

4 files changed

+126
-17
lines changed

src/complex_tensor/ops/_common.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import torch
55
from torch._ops import OpOverloadPacket
66
from torch._refs import is_complex
7-
from torch.utils._pytree import tree_flatten, tree_unflatten
7+
from torch.utils._python_dispatch import TorchDispatchMode
8+
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
89

910
from ..complex_tensor import ComplexTensor
1011

@@ -131,7 +132,7 @@ def ordered_impl(*args, **kwargs):
131132

132133
def register_binary_nonlinear(op: OpType) -> Callable:
133134
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
134-
a_r, a_i = split_complex_tensor(lhs)
135+
a_r, a_i = split_complex_arg(lhs)
135136
b_r, b_i = split_complex_arg(rhs)
136137
out_dt, (a_r, a_i, b_r, b_i) = promote_real_cpu_tensors(a_r, a_i, b_r, b_i)
137138
real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs)
@@ -161,3 +162,30 @@ def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
161162
impl.__qualname__ = func_name
162163

163164
return register_complex(op, impl)
165+
166+
167+
def _as_complex_tensor(arg: torch.Tensor | Any) -> torch.Tensor | ComplexTensor | Any:
168+
if (
169+
not isinstance(arg, ComplexTensor)
170+
and isinstance(arg, torch.Tensor)
171+
and arg.dtype in COMPLEX_TO_REAL
172+
):
173+
return ComplexTensor.from_interleaved(arg)
174+
return arg
175+
176+
177+
def _as_interleaved(arg: ComplexTensor | Any) -> torch.Tensor | Any:
178+
if isinstance(arg, ComplexTensor):
179+
return arg.as_interleaved()
180+
return arg
181+
182+
183+
class ComplexDispatchMode(TorchDispatchMode):
184+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
185+
if kwargs is None:
186+
kwargs = {}
187+
188+
args = tree_map(_as_complex_tensor, args)
189+
kwargs = tree_map(_as_complex_tensor, kwargs)
190+
191+
return tree_map(_as_interleaved, func(*args, **kwargs))

src/complex_tensor/ops/aten.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens
3333
alpha = kwargs.pop("alpha", None)
3434
if alpha is not None:
3535
return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs)
36-
a_r, a_i = split_complex_tensor(lhs)
36+
a_r, a_i = split_complex_arg(lhs)
3737
b_r, b_i = split_complex_arg(rhs)
3838
out_dt, (a_r, a_i, b_r, b_i) = promote_real_cpu_tensors(a_r, a_i, b_r, b_i)
3939
u = op(a_r, b_r, *args, **kwargs)
@@ -78,6 +78,9 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b
7878
index_select_impl = register_simple(aten.index_select)
7979
split_with_sizes_impl = register_simple(aten.split_with_sizes)
8080
cumsum_impl = register_simple(aten.cumsum)
81+
detach_impl = register_simple(aten.detach)
82+
select_impl = register_simple(aten.select)
83+
squeeze_impl = register_simple(aten.squeeze)
8184

8285
# TODO (hameerabbasi): Not being tested
8386
copy_impl = register_force_test(aten.copy, register_simple(aten.copy))
@@ -502,3 +505,68 @@ def nonzero_impl(self: ComplexTensor, other: ComplexTensor, *args, **kwargs) ->
502505
@register_complex(aten.logical_not)
503506
def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
504507
return torch.logical_not(elemwise_nonzero(self), *args, **kwargs)
508+
509+
510+
@register_complex(aten.view_as_real)
511+
def view_as_real_impl(self: ComplexTensor) -> torch.Tensor:
512+
re, im = split_complex_tensor(self)
513+
return torch.stack([re, im], dim=-1)
514+
515+
516+
@register_complex(aten.linalg_vector_norm)
517+
def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
518+
return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs)
519+
520+
521+
@register_force_test(aten.copy_)
522+
def copy__impl(self: ComplexTensor, src, *args, **kwargs):
523+
self_re, self_im = split_complex_tensor(self)
524+
src_re, src_im = split_complex_arg(src)
525+
526+
ret_re = self_re.copy_(src_re, *args, **kwargs)
527+
ret_im = self_im.copy_(src_im, *args, **kwargs)
528+
529+
return ComplexTensor(ret_re, ret_im)
530+
531+
532+
@register_complex(aten.new_zeros)
533+
def new_zeros_impl(
534+
self: ComplexTensor, size, *, dtype=None, **kwargs
535+
) -> ComplexTensor | torch.Tensor:
536+
if dtype is not None and torch.dtype(dtype) not in COMPLEX_TO_REAL:
537+
return self.re.new_zeros(self, size, dtype=dtype, **kwargs)
538+
539+
if dtype is not None:
540+
dtype = COMPLEX_TO_REAL[torch.dtype(dtype)]
541+
542+
re = self.re.new_zeros(size, dtype=dtype, **kwargs)
543+
im = self.im.new_zeros(size, dtype=dtype, **kwargs)
544+
545+
return ComplexTensor(re, im)
546+
547+
548+
@register_complex(aten._local_scalar_dense)
549+
def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex:
550+
x, y = split_complex_tensor(self)
551+
u = aten._local_scalar_dense(x, *args, **kwargs)
552+
v = aten._local_scalar_dense(y, *args, **kwargs)
553+
return complex(u, v)
554+
555+
556+
@register_complex(aten.allclose)
557+
def allclose_impl(
558+
input: torch.Tensor,
559+
other: torch.Tensor,
560+
rtol: float = 1e-05,
561+
atol: float = 1e-08,
562+
equal_nan: bool = False,
563+
) -> torch.Tensor:
564+
return torch.all(torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan))
565+
566+
567+
@register_complex(aten.stack)
568+
def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor:
569+
re_im_tuples = [split_complex_arg(self_i) for self_i in self]
570+
u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs)
571+
v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs)
572+
return ComplexTensor(u, v)

src/complex_tensor/test/test_ops.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,22 @@
44
from torch._ops import OpOverload
55
from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops
66
from torch.testing._internal.common_methods_invocations import op_db
7-
from torch.testing._internal.common_utils import parametrize, run_tests
7+
from torch.testing._internal.common_utils import (
8+
TestGradients,
9+
parametrize,
10+
run_tests,
11+
unMarkDynamoStrictTest,
12+
)
813
from torch.testing._internal.opinfo.core import OpInfo
914

1015
from complex_tensor.ops import COMPLEX_OPS_TABLE, FORCE_TEST_LIST
11-
from complex_tensor.test.utils import COMPLEX_DTYPES, TestCase, TestDescriptor, _as_complex_tensor
16+
from complex_tensor.test.utils import (
17+
COMPLEX_DTYPES,
18+
ComplexDispatchMode,
19+
TestCase,
20+
TestDescriptor,
21+
_as_complex_tensor,
22+
)
1223

1324
torch._dynamo.config.recompile_limit = float("inf")
1425
torch._dynamo.config.accumulated_recompile_limit = float("inf")
@@ -99,7 +110,19 @@ def actual(subclass_sample=subclass_sample):
99110
self.assertSameResult(expected, actual, ignore_exc_types=compile)
100111

101112

113+
@unMarkDynamoStrictTest
114+
class TestComplexBwdGradients(TestGradients):
115+
@ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
116+
def test_fn_grad(self, device, dtype, op: OpInfo) -> None:
117+
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
118+
self.skipTest("Skipped! Dtype is not in supported backward dtypes!")
119+
else:
120+
with ComplexDispatchMode():
121+
self._grad_test_helper(device, dtype, op, op.get_op())
122+
123+
102124
instantiate_device_type_tests(TestComplexTensor, globals())
125+
instantiate_device_type_tests(TestComplexBwdGradients, globals())
103126

104127
if __name__ == "__main__":
105128
run_tests()

src/complex_tensor/test/utils.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
99
from torch.utils._pytree import tree_flatten
1010

11-
from complex_tensor.complex_tensor import ComplexTensor
11+
from complex_tensor.ops._common import COMPLEX_TO_REAL, _as_complex_tensor
1212

13-
COMPLEX_DTYPES = {torch.complex128, torch.complex64, torch.complex32}
13+
COMPLEX_DTYPES = set(COMPLEX_TO_REAL)
1414

1515

1616
@dataclass(frozen=True)
@@ -35,16 +35,6 @@ def matches(self, other: TestDescriptor) -> bool:
3535
return True
3636

3737

38-
def _as_complex_tensor(arg):
39-
if (
40-
not isinstance(arg, ComplexTensor)
41-
and isinstance(arg, torch.Tensor)
42-
and arg.dtype in COMPLEX_DTYPES
43-
):
44-
return ComplexTensor.from_interleaved(arg)
45-
return arg
46-
47-
4838
class TestCase(PytorchTestCase):
4939
def assertSameResult(
5040
self,

0 commit comments

Comments
 (0)