File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -451,7 +451,7 @@ def forward_backward_step(
451
451
with self .maybe_enable_amp :
452
452
pred = model_parts [0 ](inputs , eos_id = self .tokenizer .eos_id )
453
453
loss = self .loss_fn (pred , labels )
454
- # need to free to before bwd to avoid peaking memory
454
+ # need to free pred before bwd to avoid peaking memory
455
455
del pred
456
456
loss .backward ()
457
457
@@ -471,7 +471,7 @@ def train_step(
471
471
accumulated_losses = []
472
472
# If data runs out during gradient accumulation, that
473
473
# entire step will not be executed.
474
- for microbatch in range (self .gradient_accumulation_steps ):
474
+ for _microbatch in range (self .gradient_accumulation_steps ):
475
475
input_dict , labels = next (data_iterator )
476
476
loss = self .forward_backward_step (input_dict , labels )
477
477
accumulated_losses .append (loss .detach ())
You can’t perform that action at this time.
0 commit comments