@@ -409,12 +409,10 @@ def batch_generator(
409409
410410 yield input_dict , labels
411411
412- def forward_backward_step (
412+ def post_dataloader_step (
413413 self , input_dict : dict [str , torch .Tensor ], labels : torch .Tensor
414- ) -> torch .Tensor :
415- model_parts = self .model_parts
416- parallel_dims = self .parallel_dims
417-
414+ ) -> tuple [torch .Tensor , torch .Tensor , dict [str , Any ], dict [str , Any ],]:
415+ """Post processing of the batch and label after being loaded from the dataloader."""
418416 inputs = input_dict ["input" ]
419417 extra_inputs = {k : v for k , v in input_dict .items () if k != "input" }
420418 # For arguments, like attention_masks, we have to put them in a separate
@@ -423,32 +421,53 @@ def forward_backward_step(
423421 extra_kwargs = {}
424422
425423 if getattr (self .model_args , "use_flex_attn" , False ):
426- extra_kwargs ["attention_masks" ] = model_parts [0 ].get_attention_masks (
424+ extra_kwargs ["attention_masks" ] = self . model_parts [0 ].get_attention_masks (
427425 input_batch = inputs ,
428426 tokenizer = self .tokenizer ,
429427 extra_inputs = extra_inputs ,
430428 )
429+ else :
430+ extra_kwargs ["attention_masks" ] = None
431431
432- # apply context parallelism if cp is enabled
433- # ensure CP handles the separate freqs_cis buffer for each pp stage
434- optional_context_parallel_ctx = (
435- dist_utils .create_context_parallel_ctx (
436- cp_mesh = parallel_dims .world_mesh ["cp" ],
437- cp_buffers = [inputs , labels ] + [m .freqs_cis for m in model_parts ],
438- cp_seq_dims = [1 , 1 ] + [0 for _ in model_parts ],
439- cp_no_restore_buffers = {inputs , labels },
440- cp_rotate_method = self .job_config .parallelism .context_parallel_rotate_method ,
441- )
442- if parallel_dims .cp_enabled
432+ # Get the order sensitive buffers
433+ order_sensitive_buffers = self .model_parts [0 ].get_order_sensitive_buffers (
434+ inputs .size (0 ), inputs .size (1 )
435+ )
436+ cp_mesh = (
437+ self .parallel_dims .world_mesh ["cp" ]
438+ if self .parallel_dims .cp_enabled
443439 else None
444440 )
441+ if cp_mesh :
442+ (
443+ inputs ,
444+ labels ,
445+ extra_kwargs ["attention_masks" ],
446+ * order_sensitive_buffers ,
447+ ) = dist_utils .cp_shard (
448+ cp_mesh ,
449+ inputs ,
450+ labels ,
451+ extra_kwargs ["attention_masks" ],
452+ * order_sensitive_buffers ,
453+ )
454+ extra_kwargs .update (order_sensitive_buffers [0 ])
455+ return inputs , labels , extra_inputs , extra_kwargs
456+
457+ def forward_backward_step (
458+ self , input_dict : dict [str , torch .Tensor ], labels : torch .Tensor
459+ ) -> torch .Tensor :
460+ model_parts = self .model_parts
461+ parallel_dims = self .parallel_dims
462+
463+ inputs , labels , extra_inputs , extra_kwargs = self .post_dataloader_step (
464+ input_dict , labels
465+ )
445466
446467 if parallel_dims .pp_enabled :
447468 # Pipeline Parallel forward / backward inside step() call
448- with self .train_context (optional_context_parallel_ctx ):
449- targets , losses = (
450- (labels , []) if self .pp_has_last_stage else (None , None )
451- )
469+ targets , losses = (labels , []) if self .pp_has_last_stage else (None , None )
470+ with self .train_context ():
452471 if self .pp_has_first_stage :
453472 self .pp_schedule .step (
454473 inputs ,
@@ -478,7 +497,7 @@ def forward_backward_step(
478497 )
479498 else :
480499 # Non-PP forward / backward
481- with self .train_context (optional_context_parallel_ctx ):
500+ with self .train_context ():
482501 assert len (model_parts ) == 1
483502 with self .maybe_enable_amp :
484503 pred = model_parts [0 ](inputs , ** extra_inputs , ** extra_kwargs )
0 commit comments