Skip to content

Commit c3f55b7

Browse files
ananthsubmollys
authored andcommitted
[sync] Fix bug in param_norm computation where some ranks might call collective and some might not (#852)
* sync with mr 3918 Signed-off-by: Ananth Subramaniam <[email protected]> * add unit tests Signed-off-by: Ananth Subramaniam <[email protected]> * rebase Signed-off-by: Ananth Subramaniam <[email protected]> --------- Signed-off-by: Ananth Subramaniam <[email protected]> Signed-off-by: mollys <[email protected]>
1 parent 621be95 commit c3f55b7

File tree

2 files changed

+676
-8
lines changed

2 files changed

+676
-8
lines changed

src/megatron/bridge/training/utils/train_utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,16 @@ def calc_params_l2_norm(
186186
False, # no per-parameter norm.
187187
)
188188
sharded_norm_2 = sharded_norm * sharded_norm
189-
# Sum over all DP groups, including CP since distributed optimizer state is
190-
# sharded jointly over DP+CP.
191-
torch.distributed.all_reduce(
192-
sharded_norm_2,
193-
op=torch.distributed.ReduceOp.SUM,
194-
group=parallel_state.get_data_parallel_group(with_context_parallel=True),
195-
)
196-
norm_2 += sharded_norm_2
189+
else:
190+
sharded_norm_2 = torch.zeros((1,), dtype=torch.float32, device="cuda")
191+
# Sum over all DP groups, including CP since distributed optimizer state is
192+
# sharded jointly over DP+CP.
193+
torch.distributed.all_reduce(
194+
sharded_norm_2,
195+
op=torch.distributed.ReduceOp.SUM,
196+
group=parallel_state.get_data_parallel_group(with_context_parallel=True),
197+
)
198+
norm_2 += sharded_norm_2
197199

198200
# Add norm contribution from expert layers in MoEs.
199201
if len(moe_params_data) > 0:

0 commit comments

Comments
 (0)