From 9e2c6d2121e2acf75ca877b1a57b9e7207dacb4b Mon Sep 17 00:00:00 2001 From: xinliu <1579664489@qq.com> Date: Sun, 14 Sep 2025 15:36:45 +0800 Subject: [PATCH] [Feat] Gradient sync is turned off during gradient accumulation. --- torchtitan/distributed/utils.py | 22 ++++++++++++++++++++++ torchtitan/train.py | 4 +++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 74d310dfc1..38aa30a942 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -448,3 +448,25 @@ def _clip_grad_norm_with_ep( torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach) return total_norm + + +@contextlib.contextmanager +def _no_grad_sync(model: torch.nn.Module): + model.set_requires_gradient_sync(False) + try: + yield + finally: + model.set_requires_gradient_sync(True) + + +@contextlib.contextmanager +def no_grad_sync(models: list[torch.nn.Module], enable: bool = False): + if not enable: + yield + return + + with contextlib.ExitStack() as stack: + for m in models: + ctx = _no_grad_sync(m) if hasattr(m, "set_requires_gradient_sync") else contextlib.nullcontext() + stack.enter_context(ctx) + yield \ No newline at end of file diff --git a/torchtitan/train.py b/torchtitan/train.py index 7c49000774..ee8dacb5fc 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -482,7 +482,9 @@ def train_step( # entire step will not be executed. for _microbatch in range(self.gradient_accumulation_steps): input_dict, labels = next(data_iterator) - loss = self.forward_backward_step(input_dict, labels) + no_grad_sync = _microbatch < self.gradient_accumulation_steps - 1 + with dist_utils.no_grad_sync(self.model_parts, no_grad_sync): + loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) grad_norm = dist_utils.clip_grad_norm_(