@@ -81,6 +81,8 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b
81
81
detach_impl = register_simple (aten .detach )
82
82
select_impl = register_simple (aten .select )
83
83
squeeze_impl = register_simple (aten .squeeze )
84
+ zero__impl = register_simple (aten .zero_ )
85
+ transpose_impl = register_simple (aten .transpose )
84
86
85
87
# TODO (hameerabbasi): Not being tested
86
88
copy_impl = register_force_test (aten .copy , register_simple (aten .copy ))
@@ -533,14 +535,14 @@ def copy__impl(self: ComplexTensor, src, *args, **kwargs):
533
535
def new_zeros_impl (
534
536
self : ComplexTensor , size , * , dtype = None , ** kwargs
535
537
) -> ComplexTensor | torch .Tensor :
536
- if dtype is not None and torch .dtype (dtype ) not in COMPLEX_TO_REAL :
537
- return self .re .new_zeros (self , size , dtype = dtype , ** kwargs )
538
+ self_re , self_im = split_complex_tensor (self )
539
+ if dtype is not None and dtype not in COMPLEX_TO_REAL :
540
+ return self_re .new_zeros (size , dtype = dtype , ** kwargs )
538
541
539
542
if dtype is not None :
540
- dtype = COMPLEX_TO_REAL [torch .dtype (dtype )]
541
-
542
- re = self .re .new_zeros (size , dtype = dtype , ** kwargs )
543
- im = self .im .new_zeros (size , dtype = dtype , ** kwargs )
543
+ dtype = COMPLEX_TO_REAL [dtype ]
544
+ re = self_re .new_zeros (size , dtype = dtype , ** kwargs )
545
+ im = self_im .new_zeros (size , dtype = dtype , ** kwargs )
544
546
545
547
return ComplexTensor (re , im )
546
548
@@ -561,7 +563,7 @@ def allclose_impl(
561
563
atol : float = 1e-08 ,
562
564
equal_nan : bool = False ,
563
565
) -> torch .Tensor :
564
- return torch .all (torch .isclose (input , other , rtol = rtol , atol = atol , equal_nan = equal_nan ))
566
+ return torch .all (torch .isclose (input , other , rtol = rtol , atol = atol , equal_nan = equal_nan )). item ()
565
567
566
568
567
569
@register_complex (aten .stack )
@@ -570,3 +572,17 @@ def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor:
570
572
u = torch .stack ([c [0 ] for c in re_im_tuples ], * args , ** kwargs )
571
573
v = torch .stack ([c [1 ] for c in re_im_tuples ], * args , ** kwargs )
572
574
return ComplexTensor (u , v )
575
+
576
+
577
+ @register_complex (aten .randn_like )
578
+ def randn_like_impl (self : ComplexTensor , * , dtype = None , ** kwargs ) -> ComplexTensor | torch .Tensor :
579
+ if dtype is not None and dtype not in COMPLEX_TO_REAL :
580
+ return torch .randn_like (self .re , dtype = dtype , ** kwargs )
581
+
582
+ if dtype is not None :
583
+ dtype = COMPLEX_TO_REAL [dtype ]
584
+
585
+ self_re , self_im = split_complex_tensor (self )
586
+ ret_re = torch .randn_like (self_re , dtype = dtype , ** kwargs ) / 2
587
+ ret_im = torch .randn_like (self_im , dtype = dtype , ** kwargs ) / 2
588
+ return ComplexTensor (ret_re , ret_im )
0 commit comments