-
Notifications
You must be signed in to change notification settings - Fork 2
Add ops required for gradient checks #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ops required for gradient checks #11
Conversation
ce4729b to
4a0b409
Compare
| def register_binary_nonlinear(op: OpType) -> Callable: | ||
| def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: | ||
| a_r, a_i = split_complex_tensor(lhs) | ||
| a_r, a_i = split_complex_arg(lhs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is needed as any of lhs and rhs may be a ComplexTensor, but it isn't guaranteed which one.
| if alpha is not None: | ||
| return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs) | ||
| a_r, a_i = split_complex_tensor(lhs) | ||
| a_r, a_i = split_complex_arg(lhs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as #11 (comment).
eef577e to
d91eb5b
Compare
d91eb5b to
ff129e7
Compare
|
cc @amjames Ready for review -- each commit can be reviewed separately. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hameerabbasi Overall this looks good to me! I had a couple comments on de-duplicating some logic, and I didn't spend a lot of time looking at the actual numerical logic of the implemented functions (I'm trusting the tests to catch issues there).
I assume the stuff you need to push upstream includes the functions that are skipping __torch_dispatch__?
No description provided.