Skip to content

Commit 36e8769

Browse files
committed
Merge branch 'split_implements' of https://github.com/Krishn1412/ao into split_implements
2 parents 115f296 + f38baa5 commit 36e8769

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

test/test_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,54 @@ def fake_linear(f, types, args, kwargs):
389389
torch.ops.aten.linear.default(x, l.weight, l.bias)
390390
self.assertEqual(counter["calls"], 2, "Expected fake_linear to be called once via aten.linear")
391391

392+
def test_implements_and_torch_function_together(self):
393+
"""Ensure a function decorated with both @_implements and @_implements_torch_function works.
394+
"""
395+
counter = {"calls": 0}
396+
397+
class MyTensor(TorchAOBaseTensor):
398+
tensor_data_names = ["qdata"]
399+
tensor_attribute_names = ["attr", "device"]
400+
401+
def __new__(cls, qdata, attr="attr", device=None):
402+
shape = qdata.shape
403+
if device is None:
404+
device = qdata.device
405+
kwargs = {"device": device}
406+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
407+
408+
def __init__(self, qdata, attr="attr", device=None):
409+
self.qdata = qdata
410+
self.attr = attr
411+
412+
# Register the same implementation for both aten and torch-level function
413+
@MyTensor.implements(torch.ops.aten.linear.default)
414+
@MyTensor.implements_torch_function(F.linear)
415+
def fake_linear(f, types, args, kwargs):
416+
x, w, b = args
417+
w_plain = getattr(w, "qdata", w)
418+
b_plain = getattr(b, "qdata", b) if b is not None else None
419+
counter["calls"] += 1
420+
return torch.matmul(x, w_plain.T) + (b_plain if b_plain is not None else 0)
421+
422+
423+
l = torch.nn.Linear(2, 3)
424+
orig_w = l.weight.detach().clone()
425+
orig_b = l.bias.detach().clone() if l.bias is not None else None
426+
427+
p = torch.nn.Parameter(orig_w)
428+
p.data = MyTensor(orig_w, "attr", None)
429+
l.weight = p
430+
431+
x = torch.randn(4, 2)
432+
433+
# module path (F.linear)
434+
self.assertEqual(counter["calls"], 1, "Expected fake_linear to be called once via F.linear")
435+
436+
# aten path
437+
torch.ops.aten.linear.default(x, l.weight, l.bias)
438+
self.assertEqual(counter["calls"], 2, "Expected fake_linear to be called once via aten.linear")
439+
392440

393441
if __name__ == "__main__":
394442
unittest.main()

0 commit comments

Comments
 (0)