File tree Expand file tree Collapse file tree 1 file changed +15
-3
lines changed Expand file tree Collapse file tree 1 file changed +15
-3
lines changed Original file line number Diff line number Diff line change @@ -434,9 +434,7 @@ def _clip_grad_norm_with_ep(
434434 if math .isinf (norm_type ):
435435 total_norm = torch .maximum (ep_grads_total_norm , non_ep_grads_total_norm )
436436 else :
437- total_norm = (
438- ep_grads_total_norm ** norm_type + non_ep_grads_total_norm ** norm_type
439- )
437+ total_norm = ep_grads_total_norm ** norm_type + non_ep_grads_total_norm ** norm_type
440438 total_norm **= 1.0 / norm_type
441439
442440 if pp_mesh is not None :
@@ -462,9 +460,23 @@ def cp_shard(
462460 order_sensitive_buffers_seq_dims : dict [str , int ],
463461):
464462 from torch .distributed .tensor .experimental ._attention import _context_parallel_shard
463+ from torch .distributed .tensor .experimental ._load_balancer import (
464+ _HeadTailLoadBalancer ,
465+ _PTRRLoadBalancer ,
466+ )
465467 from torch .nn .attention .flex_attention import BlockMask
466468
467469 load_balancer = None
470+ """
471+ seq_length = inputs.shape[1]
472+ load_balancer = _HeadTailLoadBalancer(
473+ seq_length, cp_mesh.size(0), cp_mesh.device_type
474+ )
475+
476+ assert isinstance(attention_masks, BlockMask)
477+ load_balancer = _PTRRLoadBalancer(attention_masks, cp_mesh.size(0))
478+ """
479+
468480 inputs , labels = _context_parallel_shard (
469481 mesh = cp_mesh ,
470482 buffers = (inputs , labels ),
You can’t perform that action at this time.
0 commit comments