-
Notifications
You must be signed in to change notification settings - Fork 1
Add ops required for gradient checks #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
ce4729b
to
4a0b409
Compare
@@ -131,7 +132,7 @@ def ordered_impl(*args, **kwargs): | |||
|
|||
def register_binary_nonlinear(op: OpType) -> Callable: | |||
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: | |||
a_r, a_i = split_complex_tensor(lhs) | |||
a_r, a_i = split_complex_arg(lhs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is needed as any of lhs
and rhs
may be a ComplexTensor
, but it isn't guaranteed which one.
@@ -33,7 +33,7 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens | |||
alpha = kwargs.pop("alpha", None) | |||
if alpha is not None: | |||
return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs) | |||
a_r, a_i = split_complex_tensor(lhs) | |||
a_r, a_i = split_complex_arg(lhs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as #11 (comment).
eef577e
to
d91eb5b
Compare
d91eb5b
to
ff129e7
Compare
cc @amjames Ready for review -- each commit can be reviewed separately. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hameerabbasi Overall this looks good to me! I had a couple comments on de-duplicating some logic, and I didn't spend a lot of time looking at the actual numerical logic of the implemented functions (I'm trusting the tests to catch issues there).
I assume the stuff you need to push upstream includes the functions that are skipping __torch_dispatch__
?
@@ -152,10 +152,17 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens | |||
|
|||
|
|||
def register_simple(op: OpType): | |||
def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: | |||
def impl(self: ComplexTensor, *args, dtype=None, **kwargs) -> ComplexTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a type to this, like the rest of the args?
if dtype is None: | ||
u = op(x, *args, **kwargs) | ||
v = op(y, *args, **kwargs) | ||
elif dtype in COMPLEX_TO_REAL: | ||
dtype = COMPLEX_TO_REAL[dtype] | ||
u = op(x, *args, dtype=dtype, **kwargs) | ||
v = op(y, *args, dtype=dtype, **kwargs) | ||
else: | ||
raise RuntimeError("Non-complex `dtype` specified, please write custom impl.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be boiled down a bit by adding dtype
to kwargs
if the dtype is non-None and in COMPLEX_TO_REAL
?
def full_like_impl(input: ComplexTensor, fill_value: complex, *args, **kwargs) -> ComplexTensor: | ||
def full_like_impl( | ||
input: ComplexTensor, fill_value: complex, *args, **kwargs | ||
) -> torch.Tensor | ComplexTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hameerabbasi Doesn't ComplexTensor
inherit from torch.Tensor
? That would make this redundant.
if dtype is None: | ||
ret_r = torch.full_like(input_r, fv_r, *args, **kwargs) | ||
ret_i = torch.full_like(input_i, fv_i, *args, **kwargs) | ||
return ComplexTensor(ret_r, ret_i) | ||
|
||
if dtype not in COMPLEX_TO_REAL: | ||
return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs) | ||
|
||
dtype = COMPLEX_TO_REAL[dtype] | ||
ret_r = torch.full_like(input_r, fv_r, *args, dtype=dtype, **kwargs) | ||
ret_i = torch.full_like(input_i, fv_i, *args, dtype=dtype, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar ask about adding dtype
to kwargs
here, and in a number of other places.
No description provided.