diff --git a/torchtitan/train.py b/torchtitan/train.py index e38446a398..8183b8e6df 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -451,7 +451,7 @@ def forward_backward_step( with self.maybe_enable_amp: pred = model_parts[0](inputs, eos_id=self.tokenizer.eos_id) loss = self.loss_fn(pred, labels) - # need to free to before bwd to avoid peaking memory + # need to free pred before bwd to avoid peaking memory del pred loss.backward() @@ -471,7 +471,7 @@ def train_step( accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. - for microbatch in range(self.gradient_accumulation_steps): + for _microbatch in range(self.gradient_accumulation_steps): input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach())