Skip to content

Conversation

carlosgmartin
Copy link
Contributor

Fixes #1290.

@carlosgmartin carlosgmartin force-pushed the adabelief_weight_decay branch from 405a6a3 to e8de17b Compare May 1, 2025 22:40
@rdyro
Copy link
Collaborator

rdyro commented May 5, 2025

Could we move towards calling the weight decay mask, weight_decay_mask instead of mask? We should probably move to this convention since mask is somewhat ambiguous in a high level optimizer interface.

If adabelief needs weight decay then perhaps we can think of adding it to all the main optimizers in alias.py? Wdyt? @carlosgmartin

@carlosgmartin carlosgmartin force-pushed the adabelief_weight_decay branch from e8de17b to b314331 Compare May 6, 2025 19:40
@carlosgmartin
Copy link
Contributor Author

@rdyro I've changed the argument's name from mark to weight_decay_mask.

I'll leave the changing the other optimizers' argument names to a subsequent PR, to keep this one self-contained.

@carlosgmartin
Copy link
Contributor Author

@rdyro Does this look good?

@rdyro
Copy link
Collaborator

rdyro commented Jun 6, 2025

I'm not entirely sure about this change, the original adabelief paper explicitly discusses, but does not use weight decay.

The problem for optax is that weight decay is NOT scaled by the learning rate, so the user has two options for adding weight decay to an existing optimizer:

  • reimplement the optimizer chain to insert the weight decay before scale_by_learning_rate in the chain
  • chain the pre-made optimizers (e.g., adabelief) with another chain of (weight_decay, scale_by_learning_rate)

It'd be great if we can solve this problem more systematically to not have to add extra weight decay arguments to every popular optimizer.

Perhaps we can introduce another keyword argument to the add_decayed_weights which takes in the learning rate (schedule)? @carlosgmartin

For a systematic fix, I'd prefer to remove the additional weight_decay keyword argument from pre-made optimizers (but we should keep the ones that explicitly include them (e.g., adamw) and ones to which we added the weight decay kwarg for backward compatibility).

@carlosgmartin
Copy link
Contributor Author

What does @vroulet think?

@vroulet
Copy link
Collaborator

vroulet commented Jun 17, 2025

The repository of the original author seems to have some weight decay https://github.com/juntang-zhuang/Adabelief-Optimizer/tree/update_0.2.0. So having a weight decay implementation makes sense.

I agree with Robert that the current duplications of weight_decay arguments are pretty bad (in particular the documentation is quite heavy, it would be best to have a "see_also" for people to know how to add weight decay). I like the idea of maybe adding a keyword argument to add_weight_decay (it may lead to a relatively large factorization though).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add optional weight decay to AdaBelief optimizer
3 participants