@@ -121,47 +121,53 @@ def __init__(self, qdata, attr, device=None):
121
121
# after copy_, the tensor values should match
122
122
self .assertEqual (lp_tensor .qdata [0 ], another_lp_tensor .qdata [0 ])
123
123
124
- def test_implements_and_torch_function_together (self ):
125
- """Ensure a function decorated with both @_implements and @_implements_torch_function works."""
126
-
127
- implements = TorchAOBaseTensor .implements
128
- implements_torch_function = TorchAOBaseTensor .implements_torch_function
129
-
130
- @implements ([torch .ops .aten .linear .default ])
131
- @implements_torch_function ([F .linear ])
132
- def fake_linear (f , types , args , kwargs ):
133
- x , w , b = args
134
- return torch .matmul (x , w .T ) + (b if b is not None else 0 )
135
-
136
- # make sure both got registered on TorchAOBaseTensor
137
- self .assertIn (
138
- torch .ops .aten .linear .default ,
139
- TorchAOBaseTensor ._ATEN_OP_TABLE [TorchAOBaseTensor ],
140
- )
141
- self .assertIn (F .linear , TorchAOBaseTensor ._TORCH_FN_TABLE [TorchAOBaseTensor ])
142
-
143
- # check they both point to the same function wrapper
144
- aten_wrapper = TorchAOBaseTensor ._ATEN_OP_TABLE [TorchAOBaseTensor ][
145
- torch .ops .aten .linear .default
146
- ]
147
- torchfn_wrapper = TorchAOBaseTensor ._TORCH_FN_TABLE [TorchAOBaseTensor ][F .linear ]
148
-
149
- # check if they wrap the same underlying function
150
- self .assertEqual (aten_wrapper .__wrapped__ , fake_linear )
151
- self .assertEqual (torchfn_wrapper .__wrapped__ , fake_linear )
152
-
153
- # run through the wrapper
154
- x = torch .randn (2 , 3 )
155
- w = torch .randn (4 , 3 )
156
- b = torch .randn (4 )
157
-
158
- out_aten = aten_wrapper (fake_linear , (TorchAOBaseTensor ,), (x , w , b ), {})
159
- out_torchfn = torchfn_wrapper (fake_linear , (TorchAOBaseTensor ,), (x , w , b ), {})
160
-
161
- expected = F .linear (x , w , b )
162
-
163
- self .assertTrue (torch .allclose (out_aten , expected , atol = 1e-6 ))
164
- self .assertTrue (torch .allclose (out_torchfn , expected , atol = 1e-6 ))
124
+ def test_implements_and_torch_function_together (self ):
125
+ """Ensure a function decorated with both @_implements and @_implements_torch_function works.
126
+ """
127
+ counter = {"calls" : 0 }
128
+
129
+ class MyTensor (TorchAOBaseTensor ):
130
+ tensor_data_names = ["qdata" ]
131
+ tensor_attribute_names = ["attr" , "device" ]
132
+
133
+ def __new__ (cls , qdata , attr = "attr" , device = None ):
134
+ shape = qdata .shape
135
+ if device is None :
136
+ device = qdata .device
137
+ kwargs = {"device" : device }
138
+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs )
139
+
140
+ def __init__ (self , qdata , attr = "attr" , device = None ):
141
+ self .qdata = qdata
142
+ self .attr = attr
143
+
144
+ # Register the same implementation for both aten and torch-level function
145
+ @MyTensor .implements (torch .ops .aten .linear .default )
146
+ @MyTensor .implements_torch_function (F .linear )
147
+ def fake_linear (f , types , args , kwargs ):
148
+ x , w , b = args
149
+ w_plain = getattr (w , "qdata" , w )
150
+ b_plain = getattr (b , "qdata" , b ) if b is not None else None
151
+ counter ["calls" ] += 1
152
+ return torch .matmul (x , w_plain .T ) + (b_plain if b_plain is not None else 0 )
153
+
154
+
155
+ l = torch .nn .Linear (2 , 3 )
156
+ orig_w = l .weight .detach ().clone ()
157
+ orig_b = l .bias .detach ().clone () if l .bias is not None else None
158
+
159
+ p = torch .nn .Parameter (orig_w )
160
+ p .data = MyTensor (orig_w , "attr" , None )
161
+ l .weight = p
162
+
163
+ x = torch .randn (4 , 2 )
164
+
165
+ # module path (F.linear)
166
+ self .assertEqual (counter ["calls" ], 1 , "Expected fake_linear to be called once via F.linear" )
167
+
168
+ # aten path
169
+ torch .ops .aten .linear .default (x , l .weight , l .bias )
170
+ self .assertEqual (counter ["calls" ], 2 , "Expected fake_linear to be called once via aten.linear" )
165
171
166
172
167
173
if __name__ == "__main__" :
0 commit comments