Skip to content
57 changes: 36 additions & 21 deletions mmf/models/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,25 +103,13 @@ def build(self):

def get_optimizer_parameters(self, config):
lr = config.optimizer.params.lr

trunk_param_set = set()
param_list = []
parameters = []
head_configs = self.config.get("heads", [])

for name, module in self.named_children():
# Heads can have different learning rates. This is handled here
if name == "heads":
# Parameters in the head which have a separate learning
# rate, are added as a separate param group
for head_config, head in zip(head_configs, self.heads):
parameters, param_list = self.set_lr_for_parameters(
config=head_config,
module_name="{} head".format(head_config.get("type", "MLP")),
base_lr=lr,
module=head,
parameters=parameters,
param_list=param_list,
)
elif name == "encoders":

if name == "encoders":
for key in module:
for modality in self.config.modalities:
if key == modality.key:
Expand All @@ -134,29 +122,56 @@ def get_optimizer_parameters(self, config):
parameters=parameters,
param_list=param_list,
)
else:
if name != "heads":
# For other modules in trunk, add to same param group
param_list += list(module.named_parameters())

trunk_param_set.update(list(module.parameters()))
head_configs = self.config.get("heads", [])
# Heads can have different learning rates. This is handled here
if len(head_configs) > 0:
# Parameters in the head which have a separate learning
# rate, are added as a separate param group
for head_config, head in zip(head_configs, self.heads):
parameters, param_list = self.set_lr_for_parameters(
config=head_config,
module_name="{} head".format(head_config.get("type", "MLP")),
base_lr=lr,
module=head,
parameters=parameters,
param_list=param_list,
excluded_params=trunk_param_set,
)
parameters += get_bert_configured_parameters(param_list)

return parameters

def set_lr_for_parameters(
self, config, module_name, base_lr, module, parameters, param_list
self,
config,
module_name,
base_lr,
module,
parameters,
param_list,
excluded_params=None,
):
lr_multiplier = config.get("lr_multiplier", 1.0)
module_param = list(module.named_parameters())
if excluded_params is not None:
module_param = [
tup for tup in module_param if tup[1] not in excluded_params
]
if lr_multiplier != 1.0:
logger.info(
f"Setting learning rate of {module_name} to be {base_lr} * {lr_multiplier}."
) # noqa
parameters += get_bert_configured_parameters(
module, base_lr * lr_multiplier
module_param, base_lr * lr_multiplier
)
else:
# Parameters for the modules with same learning rate as
# trunk, add to same param group
param_list += list(module.named_parameters())
param_list += module_param
return parameters, param_list

def build_encoders(self):
Expand Down