44import torch
55from torch ._ops import OpOverloadPacket
66from torch ._refs import is_complex
7- from torch .utils ._pytree import tree_flatten , tree_unflatten
7+ from torch .utils ._python_dispatch import TorchDispatchMode
8+ from torch .utils ._pytree import tree_flatten , tree_map , tree_unflatten
89
910from ..complex_tensor import ComplexTensor
1011
@@ -63,6 +64,11 @@ def register_complex(
6364 """Decorator to register an implementation for some ops in some dispatch tables"""
6465
6566 def inner (func ):
67+ if COMPLEX_OPS_TABLE .get (op , func ) is not func :
68+ raise RuntimeError (
69+ "Attempted to register multiple functions for "
70+ f"{ op ._qualified_op_name .replace ('::' , '.' )} "
71+ )
6672 COMPLEX_OPS_TABLE [op ] = func
6773 return func
6874
@@ -131,7 +137,7 @@ def ordered_impl(*args, **kwargs):
131137
132138def register_binary_nonlinear (op : OpType ) -> Callable :
133139 def impl (lhs : ComplexTensor , rhs : ComplexTensor , * args , ** kwargs ) -> ComplexTensor :
134- a_r , a_i = split_complex_tensor (lhs )
140+ a_r , a_i = split_complex_arg (lhs )
135141 b_r , b_i = split_complex_arg (rhs )
136142 out_dt , (a_r , a_i , b_r , b_i ) = promote_real_cpu_tensors (a_r , a_i , b_r , b_i )
137143 real = op (a_r , b_r , * args , ** kwargs ) - op (a_i , b_i , * args , ** kwargs )
@@ -146,10 +152,19 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens
146152
147153
148154def register_simple (op : OpType ):
149- def impl (self : ComplexTensor , * args , ** kwargs ) -> ComplexTensor :
155+ def impl (
156+ self : ComplexTensor , * args , dtype : torch .dtype | None = None , ** kwargs
157+ ) -> ComplexTensor :
150158 x , y = split_complex_tensor (self )
159+ if dtype is not None and dtype not in COMPLEX_TO_REAL :
160+ raise RuntimeError ("Non-complex `dtype` specified, please write custom impl." )
161+
162+ if dtype in COMPLEX_TO_REAL :
163+ kwargs ["dtype" ] = COMPLEX_TO_REAL [dtype ]
164+
151165 u = op (x , * args , ** kwargs )
152166 v = op (y , * args , ** kwargs )
167+
153168 u_flat , u_spec = tree_flatten (u )
154169 v_flat , v_spec = tree_flatten (v )
155170 assert u_spec == v_spec
@@ -161,3 +176,37 @@ def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
161176 impl .__qualname__ = func_name
162177
163178 return register_complex (op , impl )
179+
180+
181+ def _as_complex_tensor (arg : torch .Tensor | Any ) -> torch .Tensor | ComplexTensor | Any :
182+ if (
183+ not isinstance (arg , ComplexTensor )
184+ and isinstance (arg , torch .Tensor )
185+ and arg .dtype in COMPLEX_TO_REAL
186+ ):
187+ return ComplexTensor .from_interleaved (arg )
188+ return arg
189+
190+
191+ def _as_interleaved (arg : ComplexTensor | Any ) -> torch .Tensor | Any :
192+ if isinstance (arg , ComplexTensor ):
193+ return arg .as_interleaved ()
194+ return arg
195+
196+
197+ class ComplexDispatchMode (TorchDispatchMode ):
198+ def __init__ (self , _dispatch_key = None , * , _compile = False ):
199+ super ().__init__ (_dispatch_key )
200+ self ._compile = _compile
201+
202+ def __torch_dispatch__ (self , func , types , args = (), kwargs = None ):
203+ if kwargs is None :
204+ kwargs = {}
205+
206+ if compile :
207+ func = torch .compile (func )
208+
209+ args = tree_map (_as_complex_tensor , args )
210+ kwargs = tree_map (_as_complex_tensor , kwargs )
211+
212+ return tree_map (_as_interleaved , func (* args , ** kwargs ))
0 commit comments