@@ -81,6 +81,8 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b
8181detach_impl = register_simple (aten .detach )
8282select_impl = register_simple (aten .select )
8383squeeze_impl = register_simple (aten .squeeze )
84+ zero__impl = register_simple (aten .zero_ )
85+ transpose_impl = register_simple (aten .transpose )
8486
8587# TODO (hameerabbasi): Not being tested
8688copy_impl = register_force_test (aten .copy , register_simple (aten .copy ))
@@ -533,14 +535,14 @@ def copy__impl(self: ComplexTensor, src, *args, **kwargs):
533535def new_zeros_impl (
534536 self : ComplexTensor , size , * , dtype = None , ** kwargs
535537) -> ComplexTensor | torch .Tensor :
536- if dtype is not None and torch .dtype (dtype ) not in COMPLEX_TO_REAL :
537- return self .re .new_zeros (self , size , dtype = dtype , ** kwargs )
538+ self_re , self_im = split_complex_tensor (self )
539+ if dtype is not None and dtype not in COMPLEX_TO_REAL :
540+ return self_re .new_zeros (size , dtype = dtype , ** kwargs )
538541
539542 if dtype is not None :
540- dtype = COMPLEX_TO_REAL [torch .dtype (dtype )]
541-
542- re = self .re .new_zeros (size , dtype = dtype , ** kwargs )
543- im = self .im .new_zeros (size , dtype = dtype , ** kwargs )
543+ dtype = COMPLEX_TO_REAL [dtype ]
544+ re = self_re .new_zeros (size , dtype = dtype , ** kwargs )
545+ im = self_im .new_zeros (size , dtype = dtype , ** kwargs )
544546
545547 return ComplexTensor (re , im )
546548
@@ -561,7 +563,7 @@ def allclose_impl(
561563 atol : float = 1e-08 ,
562564 equal_nan : bool = False ,
563565) -> torch .Tensor :
564- return torch .all (torch .isclose (input , other , rtol = rtol , atol = atol , equal_nan = equal_nan ))
566+ return torch .all (torch .isclose (input , other , rtol = rtol , atol = atol , equal_nan = equal_nan )). item ()
565567
566568
567569@register_complex (aten .stack )
@@ -570,3 +572,17 @@ def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor:
570572 u = torch .stack ([c [0 ] for c in re_im_tuples ], * args , ** kwargs )
571573 v = torch .stack ([c [1 ] for c in re_im_tuples ], * args , ** kwargs )
572574 return ComplexTensor (u , v )
575+
576+
577+ @register_complex (aten .randn_like )
578+ def randn_like_impl (self : ComplexTensor , * , dtype = None , ** kwargs ) -> ComplexTensor | torch .Tensor :
579+ if dtype is not None and dtype not in COMPLEX_TO_REAL :
580+ return torch .randn_like (self .re , dtype = dtype , ** kwargs )
581+
582+ if dtype is not None :
583+ dtype = COMPLEX_TO_REAL [dtype ]
584+
585+ self_re , self_im = split_complex_tensor (self )
586+ ret_re = torch .randn_like (self_re , dtype = dtype , ** kwargs ) / 2
587+ ret_im = torch .randn_like (self_im , dtype = dtype , ** kwargs ) / 2
588+ return ComplexTensor (ret_re , ret_im )
0 commit comments