@@ -389,6 +389,54 @@ def fake_linear(f, types, args, kwargs):
389
389
torch .ops .aten .linear .default (x , l .weight , l .bias )
390
390
self .assertEqual (counter ["calls" ], 2 , "Expected fake_linear to be called once via aten.linear" )
391
391
392
+ def test_implements_and_torch_function_together (self ):
393
+ """Ensure a function decorated with both @_implements and @_implements_torch_function works.
394
+ """
395
+ counter = {"calls" : 0 }
396
+
397
+ class MyTensor (TorchAOBaseTensor ):
398
+ tensor_data_names = ["qdata" ]
399
+ tensor_attribute_names = ["attr" , "device" ]
400
+
401
+ def __new__ (cls , qdata , attr = "attr" , device = None ):
402
+ shape = qdata .shape
403
+ if device is None :
404
+ device = qdata .device
405
+ kwargs = {"device" : device }
406
+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs )
407
+
408
+ def __init__ (self , qdata , attr = "attr" , device = None ):
409
+ self .qdata = qdata
410
+ self .attr = attr
411
+
412
+ # Register the same implementation for both aten and torch-level function
413
+ @MyTensor .implements (torch .ops .aten .linear .default )
414
+ @MyTensor .implements_torch_function (F .linear )
415
+ def fake_linear (f , types , args , kwargs ):
416
+ x , w , b = args
417
+ w_plain = getattr (w , "qdata" , w )
418
+ b_plain = getattr (b , "qdata" , b ) if b is not None else None
419
+ counter ["calls" ] += 1
420
+ return torch .matmul (x , w_plain .T ) + (b_plain if b_plain is not None else 0 )
421
+
422
+
423
+ l = torch .nn .Linear (2 , 3 )
424
+ orig_w = l .weight .detach ().clone ()
425
+ orig_b = l .bias .detach ().clone () if l .bias is not None else None
426
+
427
+ p = torch .nn .Parameter (orig_w )
428
+ p .data = MyTensor (orig_w , "attr" , None )
429
+ l .weight = p
430
+
431
+ x = torch .randn (4 , 2 )
432
+
433
+ # module path (F.linear)
434
+ self .assertEqual (counter ["calls" ], 1 , "Expected fake_linear to be called once via F.linear" )
435
+
436
+ # aten path
437
+ torch .ops .aten .linear .default (x , l .weight , l .bias )
438
+ self .assertEqual (counter ["calls" ], 2 , "Expected fake_linear to be called once via aten.linear" )
439
+
392
440
393
441
if __name__ == "__main__" :
394
442
unittest .main ()
0 commit comments