-
Notifications
You must be signed in to change notification settings - Fork 336
Split implements and implements_torch_function #2866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2866
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! there are some more things to do I think.
(1). we'll need to modify all callsites that's using implements
on some aten function to use implements_torch_function
instead
(2). need to make sure these two things can compose,
e.g.
@implements([torch.nn.functional.linear, aten.linear.default]) |
@implements(aten.linear)
@implements_torch_function(F.linear)
def _(...):
...
should work, please add a test in
Line 89 in f03a737
class TestTorchAOBaseTensor(unittest.TestCase): |
|
||
|
||
@implements([torch.nn.functional.linear, aten.linear.default]) | ||
@implements([aten.linear.default]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: btw, we don't need a list if it's a single op, i.e. we can do @implements(aten.linear.default)
test/test_utils.py
Outdated
self.assertIn( | ||
torch.ops.aten.linear.default, | ||
TorchAOBaseTensor._ATEN_OP_TABLE[TorchAOBaseTensor], | ||
) | ||
self.assertIn(F.linear, TorchAOBaseTensor._TORCH_FN_TABLE[TorchAOBaseTensor]) | ||
|
||
# check they both point to the same function wrapper | ||
aten_wrapper = TorchAOBaseTensor._ATEN_OP_TABLE[TorchAOBaseTensor][ | ||
torch.ops.aten.linear.default | ||
] | ||
torchfn_wrapper = TorchAOBaseTensor._TORCH_FN_TABLE[TorchAOBaseTensor][F.linear] | ||
|
||
# check if they wrap the same underlying function | ||
self.assertEqual(aten_wrapper.__wrapped__, fake_linear) | ||
self.assertEqual(torchfn_wrapper.__wrapped__, fake_linear) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are testing implementation details, I think it might be better to change these to test the user code instead, for example:
class MyTensor(TorchAOBaseTensor):
...
# something like this https://github.com/pytorch/ao/blob/87769675a3e12209c4c30cd4e8563de7099d9d21/test/test_utils.py#L239-L252
counter = 0
def fake_linear(f, types, args, kwargs):
x, w, b = args
counter += 1
return torch.matmul(x, w.T) + (b if b is not None else 0)
linear = torch.nn.Linear(...)
# swap the weight to MyTensor: https://github.com/pytorch/ao/blob/87769675a3e12209c4c30cd4e8563de7099d9d21/test/test_utils.py#L254-L256
linear.weight = MyTensor(...)
# trigger F.linear
linear(*example_inputs)
# check the value of counter to make sure the function called
# trigger aten.linear
with torch.inference_mode():
linear(*example_inputs)
# check the value of counter to make sure the function called
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great, thanks!
seems like there is a conflict with main, please rebase |
1) Added two registers _ATEN_OP_TABLE and _TORCH_FN_TABLE instead of one. 2) Split the decorator into two.
36e8769
to
115f296
Compare
test/test_utils.py
Outdated
orig_b = l.bias.detach().clone() if l.bias is not None else None | ||
|
||
p = torch.nn.Parameter(orig_w) | ||
p.data = MyTensor(orig_w, "attr", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: p.data
is not a recommended API I think, just initialize the parameter with MyTensor would be better:
l.weight = torch.nn.Parameter(MyTensor(orig_w, "attr", None))
please fix ruff errors as well |
Added two registers _ATEN_OP_TABLE and _TORCH_FN_TABLE instead of one.
Split the decorator into two.