Skip to content

Commit 1cb3327

Browse files
committed
Refactor out core into aten and common to make room for prim.
1 parent 0c76924 commit 1cb3327

File tree

5 files changed

+224
-198
lines changed

5 files changed

+224
-198
lines changed

src/complex_tensor/complex_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def imag(self) -> torch.Tensor:
8080
def __torch_dispatch__(
8181
cls, func, types: tuple[type], args: tuple = (), kwargs: dict | None = None
8282
):
83-
from .ops.core import lookup_complex
83+
from .ops import lookup_complex
8484

8585
kwargs = {} if kwargs is None else kwargs
8686

src/complex_tensor/ops/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
__all__ = [
2+
"aten",
3+
"COMPLEX_OPS_TABLE",
4+
"FORCE_TEST_LIST",
5+
"lookup_complex",
6+
]
7+
8+
from . import aten
9+
from ._common import COMPLEX_OPS_TABLE, FORCE_TEST_LIST, lookup_complex

src/complex_tensor/ops/_common.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from collections.abc import Callable
2+
from typing import Any
3+
4+
import torch
5+
from torch._ops import OpOverloadPacket
6+
from torch._refs import is_complex
7+
from torch.utils._pytree import tree_flatten, tree_unflatten
8+
9+
from ..complex_tensor import ComplexTensor
10+
11+
OpType = OpOverloadPacket
12+
13+
TableType = dict[OpType, Callable]
14+
COMPLEX_OPS_TABLE: TableType = {}
15+
16+
COMPLEX_TO_REAL = {
17+
torch.complex128: torch.float64,
18+
torch.complex64: torch.float32,
19+
torch.complex32: torch.float16,
20+
}
21+
22+
PROMOTE_TYPES_CPU = {
23+
torch.float16: torch.float32,
24+
torch.bfloat16: torch.float32,
25+
}
26+
27+
28+
def promote_real_cpu_tensors(
29+
tensor: torch.Tensor, *tensors: torch.Tensor
30+
) -> tuple[torch.dtype, tuple[torch.Tensor, ...]]:
31+
out_dt = tensor.dtype
32+
for t in tensors:
33+
if isinstance(t, torch.Tensor):
34+
out_dt = torch.promote_types(out_dt, t.dtype)
35+
36+
prom_dt = PROMOTE_TYPES_CPU.get(out_dt)
37+
if (
38+
prom_dt is None
39+
or tensor.device.type != "cpu"
40+
or any(t.device.type != "cpu" for t in tensors if isinstance(t, torch.Tensor))
41+
):
42+
return out_dt, (
43+
tensor.to(out_dt),
44+
*(
45+
t.to(out_dt) if isinstance(t, torch.Tensor) else torch.asarray(t, dtype=out_dt)
46+
for t in tensors
47+
),
48+
)
49+
50+
return out_dt, (
51+
tensor.to(prom_dt),
52+
*(
53+
t.to(prom_dt) if isinstance(t, torch.Tensor) else torch.asarray(t, dtype=prom_dt)
54+
for t in tensors
55+
),
56+
)
57+
58+
59+
def register_complex(
60+
op: OpType,
61+
func_impl: Callable | None = None,
62+
):
63+
"""Decorator to register an implementation for some ops in some dispatch tables"""
64+
65+
def inner(func):
66+
COMPLEX_OPS_TABLE[op] = func
67+
return func
68+
69+
if func_impl is None:
70+
return inner
71+
return inner(func_impl)
72+
73+
74+
FORCE_TEST_LIST: list[OpType] = []
75+
76+
77+
def register_force_test(op: OpType, *args, **kwargs):
78+
FORCE_TEST_LIST.append(op)
79+
return register_complex(op, *args, **kwargs)
80+
81+
82+
def lookup_complex(func, *args, **kwargs):
83+
return COMPLEX_OPS_TABLE.get(func, COMPLEX_OPS_TABLE.get(func.overloadpacket, None))
84+
85+
86+
def split_complex_arg(
87+
arg: torch.Tensor | ComplexTensor | Any,
88+
) -> tuple[torch.Tensor, torch.Tensor] | tuple[Any, Any]:
89+
if isinstance(arg, ComplexTensor):
90+
return split_complex_tensor(arg)
91+
if isinstance(arg, torch.Tensor):
92+
if is_complex(arg):
93+
return arg.real, arg.imag
94+
return arg, torch.zeros_like(arg)
95+
if isinstance(arg, complex):
96+
return arg.real, arg.imag
97+
if isinstance(arg, float | torch.SymFloat):
98+
return arg, 0.0
99+
if isinstance(arg, int | torch.SymInt):
100+
return arg, 0
101+
if isinstance(arg, bool | torch.SymBool):
102+
return arg, False
103+
raise TypeError(f"Expected tensor or number got, {type(arg)}")
104+
105+
106+
def split_complex_tensor(complex_tensor: ComplexTensor) -> tuple[torch.Tensor, torch.Tensor]:
107+
return complex_tensor.re, complex_tensor.im
108+
109+
110+
def complex_to_real_dtype(dtype: torch.dtype) -> torch.dtype:
111+
return COMPLEX_TO_REAL.get(dtype, dtype)
112+
113+
114+
def register_error(op: OpType):
115+
msg = f"`aten.{str(op).split('.', 1)[0]}` not implemented for `{ComplexTensor.__name__}`."
116+
117+
exc_type = ERROR_TYPES.get(op, NotImplementedError)
118+
119+
def ordered_impl(*args, **kwargs):
120+
raise exc_type(msg)
121+
122+
func_name = f"{str(op).split('.', 1)}_impl"
123+
ordered_impl.__name__ = func_name
124+
ordered_impl.__qualname__ = func_name
125+
126+
return register_force_test(op, ordered_impl)
127+
128+
129+
ERROR_TYPES: dict[OpType, type[Exception]] = {}
130+
131+
132+
def register_binary_nonlinear(op: OpType) -> Callable:
133+
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
134+
a_r, a_i = split_complex_tensor(lhs)
135+
b_r, b_i = split_complex_arg(rhs)
136+
out_dt, (a_r, a_i, b_r, b_i) = promote_real_cpu_tensors(a_r, a_i, b_r, b_i)
137+
real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs)
138+
imag = op(a_r, b_i, *args, **kwargs) + op(a_i, b_r, *args, **kwargs)
139+
return ComplexTensor(real.to(out_dt), imag.to(out_dt))
140+
141+
func_name = f"{str(op).split('.', 1)}_impl"
142+
impl.__name__ = func_name
143+
impl.__qualname__ = func_name
144+
145+
return register_complex(op, impl)
146+
147+
148+
def register_simple(op: OpType):
149+
def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
150+
x, y = split_complex_tensor(self)
151+
u = op(x, *args, **kwargs)
152+
v = op(y, *args, **kwargs)
153+
u_flat, u_spec = tree_flatten(u)
154+
v_flat, v_spec = tree_flatten(v)
155+
assert u_spec == v_spec
156+
out_flat = [ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False)]
157+
return tree_unflatten(out_flat, u_spec)
158+
159+
func_name = f"{str(op).split('.', 1)}_impl"
160+
impl.__name__ = func_name
161+
impl.__qualname__ = func_name
162+
163+
return register_complex(op, impl)

0 commit comments

Comments
 (0)