Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,25 @@ def _clip_grad_norm_with_ep(
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)

return total_norm


@contextlib.contextmanager
def _no_grad_sync(model: torch.nn.Module):
model.set_requires_gradient_sync(False)
try:
yield
finally:
model.set_requires_gradient_sync(True)


@contextlib.contextmanager
def no_grad_sync(models: list[torch.nn.Module], enable: bool = False):
if not enable:
yield
return

with contextlib.ExitStack() as stack:
for m in models:
ctx = _no_grad_sync(m) if hasattr(m, "set_requires_gradient_sync") else contextlib.nullcontext()
stack.enter_context(ctx)
yield
4 changes: 3 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ def train_step(
# entire step will not be executed.
for _microbatch in range(self.gradient_accumulation_steps):
input_dict, labels = next(data_iterator)
loss = self.forward_backward_step(input_dict, labels)
no_grad_sync = _microbatch < self.gradient_accumulation_steps - 1
with dist_utils.no_grad_sync(self.model_parts, no_grad_sync):
loss = self.forward_backward_step(input_dict, labels)
accumulated_losses.append(loss.detach())

grad_norm = dist_utils.clip_grad_norm_(
Expand Down