Skip to content

Commit f38baa5

Browse files
committed
Updating test to check the user code
1 parent 3bd975f commit f38baa5

File tree

1 file changed

+47
-41
lines changed

1 file changed

+47
-41
lines changed

test/test_utils.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -121,47 +121,53 @@ def __init__(self, qdata, attr, device=None):
121121
# after copy_, the tensor values should match
122122
self.assertEqual(lp_tensor.qdata[0], another_lp_tensor.qdata[0])
123123

124-
def test_implements_and_torch_function_together(self):
125-
"""Ensure a function decorated with both @_implements and @_implements_torch_function works."""
126-
127-
implements = TorchAOBaseTensor.implements
128-
implements_torch_function = TorchAOBaseTensor.implements_torch_function
129-
130-
@implements([torch.ops.aten.linear.default])
131-
@implements_torch_function([F.linear])
132-
def fake_linear(f, types, args, kwargs):
133-
x, w, b = args
134-
return torch.matmul(x, w.T) + (b if b is not None else 0)
135-
136-
# make sure both got registered on TorchAOBaseTensor
137-
self.assertIn(
138-
torch.ops.aten.linear.default,
139-
TorchAOBaseTensor._ATEN_OP_TABLE[TorchAOBaseTensor],
140-
)
141-
self.assertIn(F.linear, TorchAOBaseTensor._TORCH_FN_TABLE[TorchAOBaseTensor])
142-
143-
# check they both point to the same function wrapper
144-
aten_wrapper = TorchAOBaseTensor._ATEN_OP_TABLE[TorchAOBaseTensor][
145-
torch.ops.aten.linear.default
146-
]
147-
torchfn_wrapper = TorchAOBaseTensor._TORCH_FN_TABLE[TorchAOBaseTensor][F.linear]
148-
149-
# check if they wrap the same underlying function
150-
self.assertEqual(aten_wrapper.__wrapped__, fake_linear)
151-
self.assertEqual(torchfn_wrapper.__wrapped__, fake_linear)
152-
153-
# run through the wrapper
154-
x = torch.randn(2, 3)
155-
w = torch.randn(4, 3)
156-
b = torch.randn(4)
157-
158-
out_aten = aten_wrapper(fake_linear, (TorchAOBaseTensor,), (x, w, b), {})
159-
out_torchfn = torchfn_wrapper(fake_linear, (TorchAOBaseTensor,), (x, w, b), {})
160-
161-
expected = F.linear(x, w, b)
162-
163-
self.assertTrue(torch.allclose(out_aten, expected, atol=1e-6))
164-
self.assertTrue(torch.allclose(out_torchfn, expected, atol=1e-6))
124+
def test_implements_and_torch_function_together(self):
125+
"""Ensure a function decorated with both @_implements and @_implements_torch_function works.
126+
"""
127+
counter = {"calls": 0}
128+
129+
class MyTensor(TorchAOBaseTensor):
130+
tensor_data_names = ["qdata"]
131+
tensor_attribute_names = ["attr", "device"]
132+
133+
def __new__(cls, qdata, attr="attr", device=None):
134+
shape = qdata.shape
135+
if device is None:
136+
device = qdata.device
137+
kwargs = {"device": device}
138+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
139+
140+
def __init__(self, qdata, attr="attr", device=None):
141+
self.qdata = qdata
142+
self.attr = attr
143+
144+
# Register the same implementation for both aten and torch-level function
145+
@MyTensor.implements(torch.ops.aten.linear.default)
146+
@MyTensor.implements_torch_function(F.linear)
147+
def fake_linear(f, types, args, kwargs):
148+
x, w, b = args
149+
w_plain = getattr(w, "qdata", w)
150+
b_plain = getattr(b, "qdata", b) if b is not None else None
151+
counter["calls"] += 1
152+
return torch.matmul(x, w_plain.T) + (b_plain if b_plain is not None else 0)
153+
154+
155+
l = torch.nn.Linear(2, 3)
156+
orig_w = l.weight.detach().clone()
157+
orig_b = l.bias.detach().clone() if l.bias is not None else None
158+
159+
p = torch.nn.Parameter(orig_w)
160+
p.data = MyTensor(orig_w, "attr", None)
161+
l.weight = p
162+
163+
x = torch.randn(4, 2)
164+
165+
# module path (F.linear)
166+
self.assertEqual(counter["calls"], 1, "Expected fake_linear to be called once via F.linear")
167+
168+
# aten path
169+
torch.ops.aten.linear.default(x, l.weight, l.bias)
170+
self.assertEqual(counter["calls"], 2, "Expected fake_linear to be called once via aten.linear")
165171

166172

167173
if __name__ == "__main__":

0 commit comments

Comments
 (0)