@@ -21,69 +21,111 @@ def linear_impl(input, weight, bias=None):
2121
2222
2323@register_tracer_impl (F .conv1d , name = '_bias_addition_impl' )
24- def conv1d_impl (input , weight , ** kwargs ):
25- bias = getattr (kwargs , 'bias' , None )
24+ def conv1d_impl (input , weight , bias = None , stride = _single (1 ), padding = _single (0 ), dilation = _single (1 ), groups = 1 ):
2625 if bias is None :
27- return F .conv1d (input , weight , ** kwargs )
26+ return F .conv1d (input , weight , stride = stride , padding = padding , dilation = dilation , groups = groups )
2827 else :
29- new_kwargs = kwargs
30- new_kwargs ['bias' ] = None
31- return F .conv1d (input , weight , ** kwargs ) + bias .reshape ((- 1 , 1 ))
28+ return F .conv1d (input , weight , stride = stride , padding = padding , dilation = dilation , groups = groups ) + bias .reshape (
29+ (- 1 , 1 ))
3230
3331
3432@register_tracer_impl (F .conv2d , name = '_bias_addition_impl' )
35- def conv2d_impl (input , weight , ** kwargs ):
36- bias = getattr (kwargs , 'bias' , None )
33+ def conv2d_impl (input , weight , bias = None , stride = _pair (1 ), padding = _pair (0 ), dilation = _pair (1 ), groups = 1 ):
3734 if bias is None :
38- return F .conv2d (input , weight , ** kwargs )
35+ return F .conv2d (input , weight , stride = stride , padding = padding , dilation = dilation , groups = groups )
3936 else :
40- new_kwargs = kwargs
41- new_kwargs ['bias' ] = None
42- return F .conv2d (input , weight , ** kwargs ) + bias .reshape ((- 1 , 1 , 1 ))
37+ return F .conv2d (input , weight , stride = stride , padding = padding , dilation = dilation , groups = groups ) + bias .reshape (
38+ (- 1 , 1 , 1 ))
4339
4440
4541@register_tracer_impl (F .conv3d , name = '_bias_addition_impl' )
46- def conv3d_impl (input , weight , ** kwargs ):
47- bias = getattr (kwargs , 'bias' , None )
42+ def conv3d_impl (input , weight , bias = None , stride = _triple (1 ), padding = _triple (0 ), dilation = _triple (1 ), groups = 1 ):
4843 if bias is None :
49- return F .conv3d (input , weight , ** kwargs )
44+ return F .conv3d (input , weight , stride = stride , padding = padding , dilation = dilation , groups = groups )
5045 else :
51- new_kwargs = kwargs
52- new_kwargs ['bias' ] = None
53- return F .conv3d (input , weight , ** new_kwargs ) + bias .reshape ((- 1 , 1 , 1 , 1 ))
46+ return F .conv3d (input , weight , stride = stride , padding = padding , dilation = dilation , groups = groups ) + bias .reshape (
47+ (- 1 , 1 , 1 , 1 ))
5448
5549
5650@register_tracer_impl (F .conv_transpose1d , name = '_bias_addition_impl' )
57- def conv_transpose1d_impl (input , weight , ** kwargs ):
58- bias = getattr (kwargs , 'bias' , None )
51+ def conv_transpose1d_impl (input ,
52+ weight ,
53+ bias = None ,
54+ stride = _single (1 ),
55+ padding = _single (0 ),
56+ output_padding = _single (0 ),
57+ groups = 1 ,
58+ dilation = _single (1 )):
5959 if bias is None :
60- return F .conv_transpose1d (input , weight , ** kwargs )
60+ return F .conv_transpose1d (input ,
61+ weight ,
62+ stride = stride ,
63+ padding = padding ,
64+ output_padding = output_padding ,
65+ groups = groups ,
66+ dilation = dilation )
6167 else :
62- new_kwargs = kwargs
63- new_kwargs ['bias' ] = None
64- return F .conv_transpose1d (input , weight , ** new_kwargs ) + bias .reshape ((- 1 , 1 ))
68+ return F .conv_transpose1d (input ,
69+ weight ,
70+ stride = stride ,
71+ padding = padding ,
72+ output_padding = output_padding ,
73+ groups = groups ,
74+ dilation = dilation ) + bias .reshape ((- 1 , 1 ))
6575
6676
6777@register_tracer_impl (F .conv_transpose2d , name = '_bias_addition_impl' )
68- def conv_transpose2d_impl (input , weight , ** kwargs ):
69- bias = getattr (kwargs , 'bias' , None )
78+ def conv_transpose2d_impl (input ,
79+ weight ,
80+ bias = None ,
81+ stride = _pair (1 ),
82+ padding = _pair (0 ),
83+ output_padding = _pair (0 ),
84+ groups = 1 ,
85+ dilation = _pair (1 )):
7086 if bias is None :
71- return F .conv_transpose2d (input , weight , ** kwargs )
87+ return F .conv_transpose2d (input ,
88+ weight ,
89+ stride = stride ,
90+ padding = padding ,
91+ output_padding = output_padding ,
92+ groups = groups ,
93+ dilation = dilation )
7294 else :
73- new_kwargs = kwargs
74- new_kwargs ['bias' ] = None
75- return F .conv_transpose2d (input , weight , ** new_kwargs ) + bias .reshape ((- 1 , 1 , 1 ))
95+ return F .conv_transpose2d (input ,
96+ weight ,
97+ stride = stride ,
98+ padding = padding ,
99+ output_padding = output_padding ,
100+ groups = groups ,
101+ dilation = dilation ) + bias .reshape ((- 1 , 1 , 1 ))
76102
77103
78104@register_tracer_impl (F .conv_transpose3d , name = '_bias_addition_impl' )
79- def conv_transpose3d_impl (input , weight , ** kwargs ):
80- bias = getattr (kwargs , 'bias' , None )
105+ def conv_transpose3d_impl (input ,
106+ weight ,
107+ bias = None ,
108+ stride = _triple (1 ),
109+ padding = _triple (0 ),
110+ output_padding = _triple (0 ),
111+ groups = 1 ,
112+ dilation = _triple (1 )):
81113 if bias is None :
82- return F .conv_transpose3d (input , weight , ** kwargs )
114+ return F .conv_transpose3d (input ,
115+ weight ,
116+ stride = stride ,
117+ padding = padding ,
118+ output_padding = output_padding ,
119+ groups = groups ,
120+ dilation = dilation )
83121 else :
84- new_kwargs = kwargs
85- new_kwargs ['bias' ] = None
86- return F .conv_transpose3d (input , weight , ** new_kwargs ) + bias .reshape ((- 1 , 1 , 1 , 1 ))
122+ return F .conv_transpose3d (input ,
123+ weight ,
124+ stride = stride ,
125+ padding = padding ,
126+ output_padding = output_padding ,
127+ groups = groups ,
128+ dilation = dilation ) + bias .reshape ((- 1 , 1 , 1 , 1 ))
87129
88130
89131@register_tracer_impl (torch .addmm , name = '_bias_addition_impl' )
0 commit comments