@@ -33,7 +33,7 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens
33
33
alpha = kwargs .pop ("alpha" , None )
34
34
if alpha is not None :
35
35
return impl_with_alpha (lhs , rhs , * args , alpha = alpha , ** kwargs )
36
- a_r , a_i = split_complex_tensor (lhs )
36
+ a_r , a_i = split_complex_arg (lhs )
37
37
b_r , b_i = split_complex_arg (rhs )
38
38
out_dt , (a_r , a_i , b_r , b_i ) = promote_real_cpu_tensors (a_r , a_i , b_r , b_i )
39
39
u = op (a_r , b_r , * args , ** kwargs )
@@ -78,6 +78,9 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b
78
78
index_select_impl = register_simple (aten .index_select )
79
79
split_with_sizes_impl = register_simple (aten .split_with_sizes )
80
80
cumsum_impl = register_simple (aten .cumsum )
81
+ detach_impl = register_simple (aten .detach )
82
+ select_impl = register_simple (aten .select )
83
+ squeeze_impl = register_simple (aten .squeeze )
81
84
82
85
# TODO (hameerabbasi): Not being tested
83
86
copy_impl = register_force_test (aten .copy , register_simple (aten .copy ))
@@ -502,3 +505,68 @@ def nonzero_impl(self: ComplexTensor, other: ComplexTensor, *args, **kwargs) ->
502
505
@register_complex (aten .logical_not )
503
506
def logical_not_impl (self : ComplexTensor , * args , ** kwargs ) -> torch .Tensor :
504
507
return torch .logical_not (elemwise_nonzero (self ), * args , ** kwargs )
508
+
509
+
510
+ @register_complex (aten .view_as_real )
511
+ def view_as_real_impl (self : ComplexTensor ) -> torch .Tensor :
512
+ re , im = split_complex_tensor (self )
513
+ return torch .stack ([re , im ], dim = - 1 )
514
+
515
+
516
+ @register_complex (aten .linalg_vector_norm )
517
+ def linalg_vector_norm_impl (self : ComplexTensor , * args , ** kwargs ) -> torch .Tensor :
518
+ return torch .linalg .vector_norm (torch .abs (self ), * args , ** kwargs )
519
+
520
+
521
+ @register_force_test (aten .copy_ )
522
+ def copy__impl (self : ComplexTensor , src , * args , ** kwargs ):
523
+ self_re , self_im = split_complex_tensor (self )
524
+ src_re , src_im = split_complex_arg (src )
525
+
526
+ ret_re = self_re .copy_ (src_re , * args , ** kwargs )
527
+ ret_im = self_im .copy_ (src_im , * args , ** kwargs )
528
+
529
+ return ComplexTensor (ret_re , ret_im )
530
+
531
+
532
+ @register_complex (aten .new_zeros )
533
+ def new_zeros_impl (
534
+ self : ComplexTensor , size , * , dtype = None , ** kwargs
535
+ ) -> ComplexTensor | torch .Tensor :
536
+ if dtype is not None and torch .dtype (dtype ) not in COMPLEX_TO_REAL :
537
+ return self .re .new_zeros (self , size , dtype = dtype , ** kwargs )
538
+
539
+ if dtype is not None :
540
+ dtype = COMPLEX_TO_REAL [torch .dtype (dtype )]
541
+
542
+ re = self .re .new_zeros (size , dtype = dtype , ** kwargs )
543
+ im = self .im .new_zeros (size , dtype = dtype , ** kwargs )
544
+
545
+ return ComplexTensor (re , im )
546
+
547
+
548
+ @register_complex (aten ._local_scalar_dense )
549
+ def _local_scalar_dense_impl (self : ComplexTensor , * args , ** kwargs ) -> complex :
550
+ x , y = split_complex_tensor (self )
551
+ u = aten ._local_scalar_dense (x , * args , ** kwargs )
552
+ v = aten ._local_scalar_dense (y , * args , ** kwargs )
553
+ return complex (u , v )
554
+
555
+
556
+ @register_complex (aten .allclose )
557
+ def allclose_impl (
558
+ input : torch .Tensor ,
559
+ other : torch .Tensor ,
560
+ rtol : float = 1e-05 ,
561
+ atol : float = 1e-08 ,
562
+ equal_nan : bool = False ,
563
+ ) -> torch .Tensor :
564
+ return torch .all (torch .isclose (input , other , rtol = rtol , atol = atol , equal_nan = equal_nan ))
565
+
566
+
567
+ @register_complex (aten .stack )
568
+ def stack_impl (self : list [ComplexTensor ], * args , ** kwargs ) -> ComplexTensor :
569
+ re_im_tuples = [split_complex_arg (self_i ) for self_i in self ]
570
+ u = torch .stack ([c [0 ] for c in re_im_tuples ], * args , ** kwargs )
571
+ v = torch .stack ([c [1 ] for c in re_im_tuples ], * args , ** kwargs )
572
+ return ComplexTensor (u , v )
0 commit comments