Skip to content

Commit d709480

Browse files
committed
[CP] test load-balance on llama3-8B
ghstack-source-id: dbfb796 Pull Request resolved: #1897
1 parent 3114f2f commit d709480

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

torchtitan/distributed/utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,7 @@ def _clip_grad_norm_with_ep(
434434
if math.isinf(norm_type):
435435
total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm)
436436
else:
437-
total_norm = (
438-
ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
439-
)
437+
total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
440438
total_norm **= 1.0 / norm_type
441439

442440
if pp_mesh is not None:
@@ -462,9 +460,23 @@ def cp_shard(
462460
order_sensitive_buffers_seq_dims: dict[str, int],
463461
):
464462
from torch.distributed.tensor.experimental._attention import _context_parallel_shard
463+
from torch.distributed.tensor.experimental._load_balancer import (
464+
_HeadTailLoadBalancer,
465+
_PTRRLoadBalancer,
466+
)
465467
from torch.nn.attention.flex_attention import BlockMask
466468

467469
load_balancer = None
470+
"""
471+
seq_length = inputs.shape[1]
472+
load_balancer = _HeadTailLoadBalancer(
473+
seq_length, cp_mesh.size(0), cp_mesh.device_type
474+
)
475+
476+
assert isinstance(attention_masks, BlockMask)
477+
load_balancer = _PTRRLoadBalancer(attention_masks, cp_mesh.size(0))
478+
"""
479+
468480
inputs, labels = _context_parallel_shard(
469481
mesh=cp_mesh,
470482
buffers=(inputs, labels),

0 commit comments

Comments
 (0)