@@ -39,6 +39,8 @@ def adabelief(
39
39
eps_root : float = 1e-16 ,
40
40
* ,
41
41
nesterov : bool = False ,
42
+ weight_decay : float = 1e-4 ,
43
+ mask : Optional [Union [Any , Callable [[base .Params ], Any ]]] = None ,
42
44
) -> base .GradientTransformationExtraArgs :
43
45
r"""The AdaBelief optimizer.
44
46
@@ -94,6 +96,16 @@ def adabelief(
94
96
improve numerical stability. If backpropagating gradients through the
95
97
gradient transformation (e.g. for meta-learning), this must be non-zero.
96
98
nesterov: Whether to use Nesterov momentum.
99
+ weight_decay: Strength of the weight decay regularization. Note that this
100
+ weight decay is multiplied with the learning rate. This is consistent
101
+ with other frameworks such as PyTorch, but different from
102
+ (Loshchilov et al, 2019) where the weight decay is only multiplied with
103
+ the "schedule multiplier", but not the base learning rate.
104
+ mask: A tree with same structure as (or a prefix of) the params PyTree,
105
+ or a Callable that returns such a pytree given the params/updates.
106
+ The leaves should be booleans, `True` for leaves/subtrees you want to
107
+ apply the weight decay to, and `False` for those you want to skip. Note
108
+ that the Adam gradient transformations are applied to all parameters.
97
109
98
110
Returns:
99
111
The corresponding :class:`optax.GradientTransformationExtraArgs`.
@@ -134,6 +146,7 @@ def adabelief(
134
146
eps_root = eps_root ,
135
147
nesterov = nesterov ,
136
148
),
149
+ transform .add_decayed_weights (weight_decay , mask ),
137
150
transform .scale_by_learning_rate (learning_rate ),
138
151
)
139
152
0 commit comments