Skip to content

Commit 115f296

Browse files
committed
Updating test to check the user code
1 parent 67f9273 commit 115f296

File tree

1 file changed

+47
-41
lines changed

1 file changed

+47
-41
lines changed

test/test_utils.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -341,47 +341,53 @@ def __init__(
341341
)
342342
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
343343

344-
def test_implements_and_torch_function_together(self):
345-
"""Ensure a function decorated with both @_implements and @_implements_torch_function works."""
346-
347-
implements = TorchAOBaseTensor.implements
348-
implements_torch_function = TorchAOBaseTensor.implements_torch_function
349-
350-
@implements([torch.ops.aten.linear.default])
351-
@implements_torch_function([F.linear])
352-
def fake_linear(f, types, args, kwargs):
353-
x, w, b = args
354-
return torch.matmul(x, w.T) + (b if b is not None else 0)
355-
356-
# make sure both got registered on TorchAOBaseTensor
357-
self.assertIn(
358-
torch.ops.aten.linear.default,
359-
TorchAOBaseTensor._ATEN_OP_TABLE[TorchAOBaseTensor],
360-
)
361-
self.assertIn(F.linear, TorchAOBaseTensor._TORCH_FN_TABLE[TorchAOBaseTensor])
362-
363-
# check they both point to the same function wrapper
364-
aten_wrapper = TorchAOBaseTensor._ATEN_OP_TABLE[TorchAOBaseTensor][
365-
torch.ops.aten.linear.default
366-
]
367-
torchfn_wrapper = TorchAOBaseTensor._TORCH_FN_TABLE[TorchAOBaseTensor][F.linear]
368-
369-
# check if they wrap the same underlying function
370-
self.assertEqual(aten_wrapper.__wrapped__, fake_linear)
371-
self.assertEqual(torchfn_wrapper.__wrapped__, fake_linear)
372-
373-
# run through the wrapper
374-
x = torch.randn(2, 3)
375-
w = torch.randn(4, 3)
376-
b = torch.randn(4)
377-
378-
out_aten = aten_wrapper(fake_linear, (TorchAOBaseTensor,), (x, w, b), {})
379-
out_torchfn = torchfn_wrapper(fake_linear, (TorchAOBaseTensor,), (x, w, b), {})
380-
381-
expected = F.linear(x, w, b)
382-
383-
self.assertTrue(torch.allclose(out_aten, expected, atol=1e-6))
384-
self.assertTrue(torch.allclose(out_torchfn, expected, atol=1e-6))
344+
def test_implements_and_torch_function_together(self):
345+
"""Ensure a function decorated with both @_implements and @_implements_torch_function works.
346+
"""
347+
counter = {"calls": 0}
348+
349+
class MyTensor(TorchAOBaseTensor):
350+
tensor_data_names = ["qdata"]
351+
tensor_attribute_names = ["attr", "device"]
352+
353+
def __new__(cls, qdata, attr="attr", device=None):
354+
shape = qdata.shape
355+
if device is None:
356+
device = qdata.device
357+
kwargs = {"device": device}
358+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
359+
360+
def __init__(self, qdata, attr="attr", device=None):
361+
self.qdata = qdata
362+
self.attr = attr
363+
364+
# Register the same implementation for both aten and torch-level function
365+
@MyTensor.implements(torch.ops.aten.linear.default)
366+
@MyTensor.implements_torch_function(F.linear)
367+
def fake_linear(f, types, args, kwargs):
368+
x, w, b = args
369+
w_plain = getattr(w, "qdata", w)
370+
b_plain = getattr(b, "qdata", b) if b is not None else None
371+
counter["calls"] += 1
372+
return torch.matmul(x, w_plain.T) + (b_plain if b_plain is not None else 0)
373+
374+
375+
l = torch.nn.Linear(2, 3)
376+
orig_w = l.weight.detach().clone()
377+
orig_b = l.bias.detach().clone() if l.bias is not None else None
378+
379+
p = torch.nn.Parameter(orig_w)
380+
p.data = MyTensor(orig_w, "attr", None)
381+
l.weight = p
382+
383+
x = torch.randn(4, 2)
384+
385+
# module path (F.linear)
386+
self.assertEqual(counter["calls"], 1, "Expected fake_linear to be called once via F.linear")
387+
388+
# aten path
389+
torch.ops.aten.linear.default(x, l.weight, l.bias)
390+
self.assertEqual(counter["calls"], 2, "Expected fake_linear to be called once via aten.linear")
385391

386392

387393
if __name__ == "__main__":

0 commit comments

Comments
 (0)