Skip to content

Commit ce4729b

Browse files
committed
Add a bunch of ops trying to get gradients to work.
1 parent 9426c46 commit ce4729b

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

src/complex_tensor/ops/aten.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b
8181
detach_impl = register_simple(aten.detach)
8282
select_impl = register_simple(aten.select)
8383
squeeze_impl = register_simple(aten.squeeze)
84+
zero__impl = register_simple(aten.zero_)
85+
transpose_impl = register_simple(aten.transpose)
8486

8587
# TODO (hameerabbasi): Not being tested
8688
copy_impl = register_force_test(aten.copy, register_simple(aten.copy))
@@ -533,14 +535,14 @@ def copy__impl(self: ComplexTensor, src, *args, **kwargs):
533535
def new_zeros_impl(
534536
self: ComplexTensor, size, *, dtype=None, **kwargs
535537
) -> ComplexTensor | torch.Tensor:
536-
if dtype is not None and torch.dtype(dtype) not in COMPLEX_TO_REAL:
537-
return self.re.new_zeros(self, size, dtype=dtype, **kwargs)
538+
self_re, self_im = split_complex_tensor(self)
539+
if dtype is not None and dtype not in COMPLEX_TO_REAL:
540+
return self_re.new_zeros(size, dtype=dtype, **kwargs)
538541

539542
if dtype is not None:
540-
dtype = COMPLEX_TO_REAL[torch.dtype(dtype)]
541-
542-
re = self.re.new_zeros(size, dtype=dtype, **kwargs)
543-
im = self.im.new_zeros(size, dtype=dtype, **kwargs)
543+
dtype = COMPLEX_TO_REAL[dtype]
544+
re = self_re.new_zeros(size, dtype=dtype, **kwargs)
545+
im = self_im.new_zeros(size, dtype=dtype, **kwargs)
544546

545547
return ComplexTensor(re, im)
546548

@@ -561,7 +563,7 @@ def allclose_impl(
561563
atol: float = 1e-08,
562564
equal_nan: bool = False,
563565
) -> torch.Tensor:
564-
return torch.all(torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan))
566+
return torch.all(torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)).item()
565567

566568

567569
@register_complex(aten.stack)
@@ -570,3 +572,17 @@ def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor:
570572
u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs)
571573
v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs)
572574
return ComplexTensor(u, v)
575+
576+
577+
@register_complex(aten.randn_like)
578+
def randn_like_impl(self: ComplexTensor, *, dtype=None, **kwargs) -> ComplexTensor | torch.Tensor:
579+
if dtype is not None and dtype not in COMPLEX_TO_REAL:
580+
return torch.randn_like(self.re, dtype=dtype, **kwargs)
581+
582+
if dtype is not None:
583+
dtype = COMPLEX_TO_REAL[dtype]
584+
585+
self_re, self_im = split_complex_tensor(self)
586+
ret_re = torch.randn_like(self_re, dtype=dtype, **kwargs) / 2
587+
ret_im = torch.randn_like(self_im, dtype=dtype, **kwargs) / 2
588+
return ComplexTensor(ret_re, ret_im)

src/complex_tensor/ops/prims.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ..complex_tensor import ComplexTensor
44
from ._common import (
55
complex_to_real_dtype,
6+
register_complex,
67
register_force_test,
78
split_complex_tensor,
89
)
@@ -18,3 +19,10 @@ def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTe
1819
v_out = prims.convert_element_type(v, dtype)
1920

2021
return ComplexTensor(u_out, v_out)
22+
23+
24+
@register_complex(prims.conj_physical)
25+
@register_complex(prims.conj)
26+
def conj_physical_impl(self: ComplexTensor) -> ComplexTensor:
27+
re, im = split_complex_tensor(self)
28+
return ComplexTensor(re, -im)

src/complex_tensor/test/test_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def test_fn_grad(self, device, dtype, op: OpInfo) -> None:
118118
self.skipTest("Skipped! Dtype is not in supported backward dtypes!")
119119
else:
120120
with ComplexDispatchMode():
121+
op.gradcheck_fast_mode = False
121122
self._grad_test_helper(device, dtype, op, op.get_op())
122123

123124

0 commit comments

Comments
 (0)