@@ -341,47 +341,53 @@ def __init__(
341
341
)
342
342
self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
343
343
344
- def test_implements_and_torch_function_together (self ):
345
- """Ensure a function decorated with both @_implements and @_implements_torch_function works."""
346
-
347
- implements = TorchAOBaseTensor .implements
348
- implements_torch_function = TorchAOBaseTensor .implements_torch_function
349
-
350
- @implements ([torch .ops .aten .linear .default ])
351
- @implements_torch_function ([F .linear ])
352
- def fake_linear (f , types , args , kwargs ):
353
- x , w , b = args
354
- return torch .matmul (x , w .T ) + (b if b is not None else 0 )
355
-
356
- # make sure both got registered on TorchAOBaseTensor
357
- self .assertIn (
358
- torch .ops .aten .linear .default ,
359
- TorchAOBaseTensor ._ATEN_OP_TABLE [TorchAOBaseTensor ],
360
- )
361
- self .assertIn (F .linear , TorchAOBaseTensor ._TORCH_FN_TABLE [TorchAOBaseTensor ])
362
-
363
- # check they both point to the same function wrapper
364
- aten_wrapper = TorchAOBaseTensor ._ATEN_OP_TABLE [TorchAOBaseTensor ][
365
- torch .ops .aten .linear .default
366
- ]
367
- torchfn_wrapper = TorchAOBaseTensor ._TORCH_FN_TABLE [TorchAOBaseTensor ][F .linear ]
368
-
369
- # check if they wrap the same underlying function
370
- self .assertEqual (aten_wrapper .__wrapped__ , fake_linear )
371
- self .assertEqual (torchfn_wrapper .__wrapped__ , fake_linear )
372
-
373
- # run through the wrapper
374
- x = torch .randn (2 , 3 )
375
- w = torch .randn (4 , 3 )
376
- b = torch .randn (4 )
377
-
378
- out_aten = aten_wrapper (fake_linear , (TorchAOBaseTensor ,), (x , w , b ), {})
379
- out_torchfn = torchfn_wrapper (fake_linear , (TorchAOBaseTensor ,), (x , w , b ), {})
380
-
381
- expected = F .linear (x , w , b )
382
-
383
- self .assertTrue (torch .allclose (out_aten , expected , atol = 1e-6 ))
384
- self .assertTrue (torch .allclose (out_torchfn , expected , atol = 1e-6 ))
344
+ def test_implements_and_torch_function_together (self ):
345
+ """Ensure a function decorated with both @_implements and @_implements_torch_function works.
346
+ """
347
+ counter = {"calls" : 0 }
348
+
349
+ class MyTensor (TorchAOBaseTensor ):
350
+ tensor_data_names = ["qdata" ]
351
+ tensor_attribute_names = ["attr" , "device" ]
352
+
353
+ def __new__ (cls , qdata , attr = "attr" , device = None ):
354
+ shape = qdata .shape
355
+ if device is None :
356
+ device = qdata .device
357
+ kwargs = {"device" : device }
358
+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs )
359
+
360
+ def __init__ (self , qdata , attr = "attr" , device = None ):
361
+ self .qdata = qdata
362
+ self .attr = attr
363
+
364
+ # Register the same implementation for both aten and torch-level function
365
+ @MyTensor .implements (torch .ops .aten .linear .default )
366
+ @MyTensor .implements_torch_function (F .linear )
367
+ def fake_linear (f , types , args , kwargs ):
368
+ x , w , b = args
369
+ w_plain = getattr (w , "qdata" , w )
370
+ b_plain = getattr (b , "qdata" , b ) if b is not None else None
371
+ counter ["calls" ] += 1
372
+ return torch .matmul (x , w_plain .T ) + (b_plain if b_plain is not None else 0 )
373
+
374
+
375
+ l = torch .nn .Linear (2 , 3 )
376
+ orig_w = l .weight .detach ().clone ()
377
+ orig_b = l .bias .detach ().clone () if l .bias is not None else None
378
+
379
+ p = torch .nn .Parameter (orig_w )
380
+ p .data = MyTensor (orig_w , "attr" , None )
381
+ l .weight = p
382
+
383
+ x = torch .randn (4 , 2 )
384
+
385
+ # module path (F.linear)
386
+ self .assertEqual (counter ["calls" ], 1 , "Expected fake_linear to be called once via F.linear" )
387
+
388
+ # aten path
389
+ torch .ops .aten .linear .default (x , l .weight , l .bias )
390
+ self .assertEqual (counter ["calls" ], 2 , "Expected fake_linear to be called once via aten.linear" )
385
391
386
392
387
393
if __name__ == "__main__" :
0 commit comments