Skip to content

Commit 8698ab8

Browse files
authored
Merge pull request #11 from hameerabbasi/gradient-checks
Add ops required for gradient checks
2 parents 45dffd6 + def0693 commit 8698ab8

File tree

8 files changed

+590
-156
lines changed

8 files changed

+590
-156
lines changed

src/complex_tensor/ops/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
__all__ = [
22
"aten",
3+
"prims",
4+
"_c10d_functional",
35
"COMPLEX_OPS_TABLE",
46
"FORCE_TEST_LIST",
57
"lookup_complex",
68
]
79

8-
from . import aten
10+
from . import _c10d_functional, aten, prims
911
from ._common import COMPLEX_OPS_TABLE, FORCE_TEST_LIST, lookup_complex
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
3+
from ._common import (
4+
register_force_test,
5+
register_simple,
6+
)
7+
8+
_c10d_functional = torch.ops._c10d_functional
9+
10+
# TODO (hameerabbasi): Not being tested
11+
broadcast_impl = register_force_test(
12+
_c10d_functional.broadcast, register_simple(_c10d_functional.broadcast)
13+
)
14+
15+
# TODO (hameerabbasi): Not being tested
16+
broadcast__impl = register_force_test(
17+
_c10d_functional.broadcast_, register_simple(_c10d_functional.broadcast_)
18+
)

src/complex_tensor/ops/_common.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import torch
55
from torch._ops import OpOverloadPacket
66
from torch._refs import is_complex
7-
from torch.utils._pytree import tree_flatten, tree_unflatten
7+
from torch.utils._python_dispatch import TorchDispatchMode
8+
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
89

910
from ..complex_tensor import ComplexTensor
1011

@@ -63,6 +64,11 @@ def register_complex(
6364
"""Decorator to register an implementation for some ops in some dispatch tables"""
6465

6566
def inner(func):
67+
if COMPLEX_OPS_TABLE.get(op, func) is not func:
68+
raise RuntimeError(
69+
"Attempted to register multiple functions for "
70+
f"{op._qualified_op_name.replace('::', '.')}"
71+
)
6672
COMPLEX_OPS_TABLE[op] = func
6773
return func
6874

@@ -131,7 +137,7 @@ def ordered_impl(*args, **kwargs):
131137

132138
def register_binary_nonlinear(op: OpType) -> Callable:
133139
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
134-
a_r, a_i = split_complex_tensor(lhs)
140+
a_r, a_i = split_complex_arg(lhs)
135141
b_r, b_i = split_complex_arg(rhs)
136142
out_dt, (a_r, a_i, b_r, b_i) = promote_real_cpu_tensors(a_r, a_i, b_r, b_i)
137143
real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs)
@@ -146,10 +152,19 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens
146152

147153

148154
def register_simple(op: OpType):
149-
def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
155+
def impl(
156+
self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs
157+
) -> ComplexTensor:
150158
x, y = split_complex_tensor(self)
159+
if dtype is not None and dtype not in COMPLEX_TO_REAL:
160+
raise RuntimeError("Non-complex `dtype` specified, please write custom impl.")
161+
162+
if dtype in COMPLEX_TO_REAL:
163+
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
164+
151165
u = op(x, *args, **kwargs)
152166
v = op(y, *args, **kwargs)
167+
153168
u_flat, u_spec = tree_flatten(u)
154169
v_flat, v_spec = tree_flatten(v)
155170
assert u_spec == v_spec
@@ -161,3 +176,37 @@ def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
161176
impl.__qualname__ = func_name
162177

163178
return register_complex(op, impl)
179+
180+
181+
def _as_complex_tensor(arg: torch.Tensor | Any) -> torch.Tensor | ComplexTensor | Any:
182+
if (
183+
not isinstance(arg, ComplexTensor)
184+
and isinstance(arg, torch.Tensor)
185+
and arg.dtype in COMPLEX_TO_REAL
186+
):
187+
return ComplexTensor.from_interleaved(arg)
188+
return arg
189+
190+
191+
def _as_interleaved(arg: ComplexTensor | Any) -> torch.Tensor | Any:
192+
if isinstance(arg, ComplexTensor):
193+
return arg.as_interleaved()
194+
return arg
195+
196+
197+
class ComplexDispatchMode(TorchDispatchMode):
198+
def __init__(self, _dispatch_key=None, *, _compile=False):
199+
super().__init__(_dispatch_key)
200+
self._compile = _compile
201+
202+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
203+
if kwargs is None:
204+
kwargs = {}
205+
206+
if compile:
207+
func = torch.compile(func)
208+
209+
args = tree_map(_as_complex_tensor, args)
210+
kwargs = tree_map(_as_complex_tensor, kwargs)
211+
212+
return tree_map(_as_interleaved, func(*args, **kwargs))

0 commit comments

Comments
 (0)