Skip to content

Conversation

hameerabbasi
Copy link
Collaborator

No description provided.

@@ -131,7 +132,7 @@ def ordered_impl(*args, **kwargs):

def register_binary_nonlinear(op: OpType) -> Callable:
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
a_r, a_i = split_complex_tensor(lhs)
a_r, a_i = split_complex_arg(lhs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed as any of lhs and rhs may be a ComplexTensor, but it isn't guaranteed which one.

@@ -33,7 +33,7 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens
alpha = kwargs.pop("alpha", None)
if alpha is not None:
return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs)
a_r, a_i = split_complex_tensor(lhs)
a_r, a_i = split_complex_arg(lhs)
Copy link
Collaborator Author

@hameerabbasi hameerabbasi Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as #11 (comment).

@hameerabbasi hameerabbasi changed the title Add gradient checks Add ops required for gradient checks Sep 9, 2025
@hameerabbasi hameerabbasi requested a review from amjames September 9, 2025 07:26
@hameerabbasi
Copy link
Collaborator Author

cc @amjames Ready for review -- each commit can be reviewed separately.

Copy link

@benjaminglass1 benjaminglass1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hameerabbasi Overall this looks good to me! I had a couple comments on de-duplicating some logic, and I didn't spend a lot of time looking at the actual numerical logic of the implemented functions (I'm trusting the tests to catch issues there).

I assume the stuff you need to push upstream includes the functions that are skipping __torch_dispatch__?

@@ -152,10 +152,17 @@ def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTens


def register_simple(op: OpType):
def impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
def impl(self: ComplexTensor, *args, dtype=None, **kwargs) -> ComplexTensor:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a type to this, like the rest of the args?

Comment on lines +157 to +165
if dtype is None:
u = op(x, *args, **kwargs)
v = op(y, *args, **kwargs)
elif dtype in COMPLEX_TO_REAL:
dtype = COMPLEX_TO_REAL[dtype]
u = op(x, *args, dtype=dtype, **kwargs)
v = op(y, *args, dtype=dtype, **kwargs)
else:
raise RuntimeError("Non-complex `dtype` specified, please write custom impl.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be boiled down a bit by adding dtype to kwargs if the dtype is non-None and in COMPLEX_TO_REAL?

def full_like_impl(input: ComplexTensor, fill_value: complex, *args, **kwargs) -> ComplexTensor:
def full_like_impl(
input: ComplexTensor, fill_value: complex, *args, **kwargs
) -> torch.Tensor | ComplexTensor:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hameerabbasi Doesn't ComplexTensor inherit from torch.Tensor? That would make this redundant.

Comment on lines +533 to +543
if dtype is None:
ret_r = torch.full_like(input_r, fv_r, *args, **kwargs)
ret_i = torch.full_like(input_i, fv_i, *args, **kwargs)
return ComplexTensor(ret_r, ret_i)

if dtype not in COMPLEX_TO_REAL:
return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs)

dtype = COMPLEX_TO_REAL[dtype]
ret_r = torch.full_like(input_r, fv_r, *args, dtype=dtype, **kwargs)
ret_i = torch.full_like(input_i, fv_i, *args, dtype=dtype, **kwargs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar ask about adding dtype to kwargs here, and in a number of other places.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants