Skip to content

Commit 405a6a3

Browse files
committed
Add weight_decay and mask arguments to adabelief optimizer.
1 parent 852ff6e commit 405a6a3

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

optax/_src/alias.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def adabelief(
3939
eps_root: float = 1e-16,
4040
*,
4141
nesterov: bool = False,
42+
weight_decay: float = 1e-4,
43+
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
4244
) -> base.GradientTransformationExtraArgs:
4345
r"""The AdaBelief optimizer.
4446
@@ -94,6 +96,16 @@ def adabelief(
9496
improve numerical stability. If backpropagating gradients through the
9597
gradient transformation (e.g. for meta-learning), this must be non-zero.
9698
nesterov: Whether to use Nesterov momentum.
99+
weight_decay: Strength of the weight decay regularization. Note that this
100+
weight decay is multiplied with the learning rate. This is consistent
101+
with other frameworks such as PyTorch, but different from
102+
(Loshchilov et al, 2019) where the weight decay is only multiplied with
103+
the "schedule multiplier", but not the base learning rate.
104+
mask: A tree with same structure as (or a prefix of) the params PyTree,
105+
or a Callable that returns such a pytree given the params/updates.
106+
The leaves should be booleans, `True` for leaves/subtrees you want to
107+
apply the weight decay to, and `False` for those you want to skip. Note
108+
that the Adam gradient transformations are applied to all parameters.
97109
98110
Returns:
99111
The corresponding :class:`optax.GradientTransformationExtraArgs`.
@@ -134,6 +146,7 @@ def adabelief(
134146
eps_root=eps_root,
135147
nesterov=nesterov,
136148
),
149+
transform.add_decayed_weights(weight_decay, mask),
137150
transform.scale_by_learning_rate(learning_rate),
138151
)
139152

0 commit comments

Comments
 (0)