Skip to content

Conversation

Krishn1412
Copy link

  1. Added two registers _ATEN_OP_TABLE and _TORCH_FN_TABLE instead of one.

  2. Split the decorator into two.

Copy link

pytorch-bot bot commented Aug 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2866

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 24, 2025
@andrewor14 andrewor14 requested a review from liangel-02 August 25, 2025 18:58
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! there are some more things to do I think.

(1). we'll need to modify all callsites that's using implements on some aten function to use implements_torch_function instead
(2). need to make sure these two things can compose,
e.g.

@implements([torch.nn.functional.linear, aten.linear.default])
is using both, so:

@implements(aten.linear)
@implements_torch_function(F.linear)
def _(...):
     ...

should work, please add a test in

class TestTorchAOBaseTensor(unittest.TestCase):
to make sure



@implements([torch.nn.functional.linear, aten.linear.default])
@implements([aten.linear.default])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: btw, we don't need a list if it's a single op, i.e. we can do @implements(aten.linear.default)

Comment on lines 137 to 151
self.assertIn(
torch.ops.aten.linear.default,
TorchAOBaseTensor._ATEN_OP_TABLE[TorchAOBaseTensor],
)
self.assertIn(F.linear, TorchAOBaseTensor._TORCH_FN_TABLE[TorchAOBaseTensor])

# check they both point to the same function wrapper
aten_wrapper = TorchAOBaseTensor._ATEN_OP_TABLE[TorchAOBaseTensor][
torch.ops.aten.linear.default
]
torchfn_wrapper = TorchAOBaseTensor._TORCH_FN_TABLE[TorchAOBaseTensor][F.linear]

# check if they wrap the same underlying function
self.assertEqual(aten_wrapper.__wrapped__, fake_linear)
self.assertEqual(torchfn_wrapper.__wrapped__, fake_linear)
Copy link
Contributor

@jerryzh168 jerryzh168 Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are testing implementation details, I think it might be better to change these to test the user code instead, for example:

class MyTensor(TorchAOBaseTensor):
    ...
    # something like this https://github.com/pytorch/ao/blob/87769675a3e12209c4c30cd4e8563de7099d9d21/test/test_utils.py#L239-L252
    
counter = 0

def fake_linear(f, types, args, kwargs):
       x, w, b = args
       counter += 1
       return torch.matmul(x, w.T) + (b if b is not None else 0)

linear = torch.nn.Linear(...)
# swap the weight to MyTensor: https://github.com/pytorch/ao/blob/87769675a3e12209c4c30cd4e8563de7099d9d21/test/test_utils.py#L254-L256
linear.weight = MyTensor(...)

# trigger F.linear
linear(*example_inputs)

# check the value of counter to make sure the function called       

# trigger aten.linear
with torch.inference_mode():
    linear(*example_inputs)

# check the value of counter to make sure the function called       

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great, thanks!

@jerryzh168
Copy link
Contributor

seems like there is a conflict with main, please rebase

1) Added two registers _ATEN_OP_TABLE and _TORCH_FN_TABLE instead of one.

2) Split the decorator into two.
Updated the places where the decorators are called.

Added test to check the condition when both decorators are wrapped.
@jerryzh168 jerryzh168 requested a review from bdhirsh September 9, 2025 20:29
@jerryzh168 jerryzh168 added the topic: for developers Use this tag if this PR is mainly developer facing label Sep 9, 2025
orig_b = l.bias.detach().clone() if l.bias is not None else None

p = torch.nn.Parameter(orig_w)
p.data = MyTensor(orig_w, "attr", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: p.data is not a recommended API I think, just initialize the parameter with MyTensor would be better:

l.weight = torch.nn.Parameter(MyTensor(orig_w, "attr", None))

@jerryzh168
Copy link
Contributor

please fix ruff errors as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: for developers Use this tag if this PR is mainly developer facing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants