@@ -271,15 +271,15 @@ def relu6(x):
271
271
272
272
class LeakyReLU (Cell ):
273
273
274
- def __init__ (self , alpha = 0.2 ):
274
+ def __init__ (self , negative_slope = 0.01 ):
275
275
super (LeakyReLU , self ).__init__ ()
276
- self .leakyrelu = ms .nn .LeakyReLU (alpha = alpha )
276
+ self .leakyrelu = ms .nn .LeakyReLU (alpha = negative_slope )
277
277
278
278
def construct (self , x ):
279
279
return self .leakyrelu (x )
280
280
281
281
282
- def leaky_relu (x , alpha = 0.2 ):
282
+ def leaky_relu (x , negative_slope = 0.2 ):
283
283
"""
284
284
Compute the Leaky ReLU activation function.
285
285
@@ -294,9 +294,9 @@ def leaky_relu(x, alpha=0.2):
294
294
The activation value.
295
295
"""
296
296
297
- leaky_relu = LeakyReLU (alpha = alpha )
297
+ leaky_relu = ms . nn . LeakyReLU (alpha = negative_slope )
298
298
output = leaky_relu (x )
299
- return leaky_relu
299
+ return output
300
300
301
301
302
302
class Softplus (Cell ):
@@ -348,15 +348,15 @@ def sigmoid(x):
348
348
349
349
class Softmax (Cell ):
350
350
351
- def __init__ (self ):
351
+ def __init__ (self , axis = - 1 ):
352
352
super (Softmax , self ).__init__ ()
353
- self .softmax = P .Softmax ()
353
+ self .softmax = P .Softmax (axis )
354
354
355
355
def construct (self , x ):
356
356
return self .softmax (x )
357
357
358
358
359
- def softmax (logits , axis = None ):
359
+ def softmax (logits , axis = - 1 ):
360
360
"""
361
361
Computes softmax activations.
362
362
@@ -2392,3 +2392,22 @@ def __init__(
2392
2392
2393
2393
def construct (self , inputs ):
2394
2394
raise NotImplementedError
2395
+
2396
+ class PReLU (Cell ):
2397
+
2398
+ def __init__ (self , data_format ):
2399
+ super (PReLU , self ).__init__ ()
2400
+ self .data_format = data_format
2401
+
2402
+ def __call__ (self , input , weight ):
2403
+
2404
+ prelu = P .PReLU ()
2405
+ v = prelu (input , F .cast (weight , input .dtype ))
2406
+ return v
2407
+
2408
+
2409
+ def prelu (input , weight , data_format ):
2410
+
2411
+ prelu = P .PReLU ()
2412
+ v = prelu (input , F .cast (weight , input .dtype ))
2413
+ return v
0 commit comments