|
| 1 | +from collections.abc import Callable |
| 2 | +from typing import Any |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch._ops import OpOverloadPacket |
| 6 | +from torch._refs import is_complex |
| 7 | +from torch.utils._pytree import tree_flatten, tree_unflatten |
| 8 | + |
| 9 | +from ..complex_tensor import ComplexTensor |
| 10 | + |
| 11 | +OpType = OpOverloadPacket |
| 12 | + |
| 13 | +TableType = dict[OpType, Callable] |
| 14 | +COMPLEX_OPS_TABLE: TableType = {} |
| 15 | + |
| 16 | +COMPLEX_TO_REAL = { |
| 17 | + torch.complex128: torch.float64, |
| 18 | + torch.complex64: torch.float32, |
| 19 | + torch.complex32: torch.float16, |
| 20 | +} |
| 21 | + |
| 22 | +PROMOTE_TYPES_CPU = { |
| 23 | + torch.float16: torch.float32, |
| 24 | + torch.bfloat16: torch.float32, |
| 25 | +} |
| 26 | + |
| 27 | + |
| 28 | +def promote_real_cpu_tensors( |
| 29 | + tensor: torch.Tensor, *tensors: torch.Tensor |
| 30 | +) -> tuple[torch.dtype, tuple[torch.Tensor, ...]]: |
| 31 | + out_dt = tensor.dtype |
| 32 | + for t in tensors: |
| 33 | + if isinstance(t, torch.Tensor): |
| 34 | + out_dt = torch.promote_types(out_dt, t.dtype) |
| 35 | + |
| 36 | + prom_dt = PROMOTE_TYPES_CPU.get(out_dt) |
| 37 | + if ( |
| 38 | + prom_dt is None |
| 39 | + or tensor.device.type != "cpu" |
| 40 | + or any(t.device.type != "cpu" for t in tensors if isinstance(t, torch.Tensor)) |
| 41 | + ): |
| 42 | + return out_dt, ( |
| 43 | + tensor.to(out_dt), |
| 44 | + *( |
| 45 | + t.to(out_dt) if isinstance(t, torch.Tensor) else torch.asarray(t, dtype=out_dt) |
| 46 | + for t in tensors |
| 47 | + ), |
| 48 | + ) |
| 49 | + |
| 50 | + return out_dt, ( |
| 51 | + tensor.to(prom_dt), |
| 52 | + *( |
| 53 | + t.to(prom_dt) if isinstance(t, torch.Tensor) else torch.asarray(t, dtype=prom_dt) |
| 54 | + for t in tensors |
| 55 | + ), |
| 56 | + ) |
| 57 | + |
| 58 | + |
| 59 | +def register_complex( |
| 60 | + op: OpType, |
| 61 | + func_impl: Callable | None = None, |
| 62 | +): |
| 63 | + """Decorator to register an implementation for some ops in some dispatch tables""" |
| 64 | + |
| 65 | + def inner(func): |
| 66 | + COMPLEX_OPS_TABLE[op] = func |
| 67 | + return func |
| 68 | + |
| 69 | + if func_impl is None: |
| 70 | + return inner |
| 71 | + return inner(func_impl) |
| 72 | + |
| 73 | + |
| 74 | +FORCE_TEST_LIST: list[OpType] = [] |
| 75 | + |
| 76 | + |
| 77 | +def register_force_test(op: OpType, *args, **kwargs): |
| 78 | + FORCE_TEST_LIST.append(op) |
| 79 | + return register_complex(op, *args, **kwargs) |
| 80 | + |
| 81 | + |
| 82 | +def lookup_complex(func, *args, **kwargs): |
| 83 | + return COMPLEX_OPS_TABLE.get(func, COMPLEX_OPS_TABLE.get(func.overloadpacket, None)) |
| 84 | + |
| 85 | + |
| 86 | +def split_complex_arg( |
| 87 | + arg: torch.Tensor | ComplexTensor | Any, |
| 88 | +) -> tuple[torch.Tensor, torch.Tensor] | tuple[Any, Any]: |
| 89 | + if isinstance(arg, ComplexTensor): |
| 90 | + return split_complex_tensor(arg) |
| 91 | + if isinstance(arg, torch.Tensor): |
| 92 | + if is_complex(arg): |
| 93 | + return arg.real, arg.imag |
| 94 | + return arg, torch.zeros_like(arg) |
| 95 | + if isinstance(arg, complex): |
| 96 | + return arg.real, arg.imag |
| 97 | + if isinstance(arg, float | torch.SymFloat): |
| 98 | + return arg, 0.0 |
| 99 | + if isinstance(arg, int | torch.SymInt): |
| 100 | + return arg, 0 |
| 101 | + if isinstance(arg, bool | torch.SymBool): |
| 102 | + return arg, False |
| 103 | + raise TypeError(f"Expected tensor or number got, {type(arg)}") |
| 104 | + |
| 105 | + |
| 106 | +def split_complex_tensor(complex_tensor: ComplexTensor) -> tuple[torch.Tensor, torch.Tensor]: |
| 107 | + return complex_tensor.re, complex_tensor.im |
| 108 | + |
| 109 | + |
| 110 | +def complex_to_real_dtype(dtype: torch.dtype) -> torch.dtype: |
| 111 | + return COMPLEX_TO_REAL.get(dtype, dtype) |
| 112 | + |
| 113 | + |
| 114 | +def register_error(op: OpType): |
| 115 | + msg = f"`aten.{str(op).split('.', 1)[0]}` not implemented for `{ComplexTensor.__name__}`." |
| 116 | + |
| 117 | + exc_type = ERROR_TYPES.get(op, NotImplementedError) |
| 118 | + |
| 119 | + def ordered_impl(*args, **kwargs): |
| 120 | + raise exc_type(msg) |
| 121 | + |
| 122 | + func_name = f"{str(op).split('.', 1)}_impl" |
| 123 | + ordered_impl.__name__ = func_name |
| 124 | + ordered_impl.__qualname__ = func_name |
| 125 | + |
| 126 | + return register_force_test(op, ordered_impl) |
| 127 | + |
| 128 | + |
| 129 | +ERROR_TYPES: dict[OpType, type[Exception]] = {} |
| 130 | + |
| 131 | + |
| 132 | +def register_binary_nonlinear(op: OpType) -> Callable: |
| 133 | + def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: |
| 134 | + a_r, a_i = split_complex_tensor(lhs) |
| 135 | + b_r, b_i = split_complex_arg(rhs) |
| 136 | + out_dt, (a_r, a_i, b_r, b_i) = promote_real_cpu_tensors(a_r, a_i, b_r, b_i) |
| 137 | + real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs) |
| 138 | + imag = op(a_r, b_i, *args, **kwargs) + op(a_i, b_r, *args, **kwargs) |
| 139 | + return ComplexTensor(real.to(out_dt), imag.to(out_dt)) |
| 140 | + |
| 141 | + func_name = f"{str(op).split('.', 1)}_impl" |
| 142 | + impl.__name__ = func_name |
| 143 | + impl.__qualname__ = func_name |
| 144 | + |
| 145 | + return register_complex(op, impl) |
| 146 | + |
| 147 | + |
| 148 | +def register_simple(op: OpType): |
| 149 | + def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: |
| 150 | + x, y = split_complex_tensor(self) |
| 151 | + u = op(x, *args, **kwargs) |
| 152 | + v = op(y, *args, **kwargs) |
| 153 | + u_flat, u_spec = tree_flatten(u) |
| 154 | + v_flat, v_spec = tree_flatten(v) |
| 155 | + assert u_spec == v_spec |
| 156 | + out_flat = [ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False)] |
| 157 | + return tree_unflatten(out_flat, u_spec) |
| 158 | + |
| 159 | + func_name = f"{str(op).split('.', 1)}_impl" |
| 160 | + impl.__name__ = func_name |
| 161 | + impl.__qualname__ = func_name |
| 162 | + |
| 163 | + return register_complex(op, impl) |
0 commit comments