Skip to content

Commit 157d30d

Browse files
authored
Add post_dataloading_processing method to Trainer (#1985)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2002 * #2001 * #1995 * __->__ #1985 We are adding more actions to convert the raw inputs and label. 1. The new CP can do the input/label/BlockMask sharding this in this method. 2. The experimental full dtensor model can simply override this method without changing too many Trainer code. This method is extracted from #1857 Makeing this a standalone PR allows us to continue the two projects above without one blocks another.
1 parent 35bffe9 commit 157d30d

File tree

1 file changed

+48
-7
lines changed

1 file changed

+48
-7
lines changed

torchtitan/train.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -410,26 +410,67 @@ def batch_generator(
410410

411411
yield input_dict, labels
412412

413-
def forward_backward_step(
413+
def post_dataloading_process(
414414
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
415-
) -> torch.Tensor:
416-
model_parts = self.model_parts
417-
parallel_dims = self.parallel_dims
418-
415+
) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
416+
"""
417+
Post-processing hook after data loading and before model forward pass.
418+
419+
This method processes the raw data from the dataloader and prepares it for
420+
the model's forward pass. It separates the main input tensor from auxiliary
421+
inputs and constructs additional keyword arguments (e.g., attention masks).
422+
423+
This method can be overridden in subclasses to customize data processing
424+
for different training strategies (e.g., converting tensors to DTensors,
425+
applying custom transformations, etc.).
426+
427+
Args:
428+
input_dict: Dictionary containing tensors from the dataloader. Must
429+
contain an "input" key with the main input tensor. May contain
430+
additional keys for auxiliary inputs (e.g., position ids).
431+
labels: Target labels for the batch.
432+
433+
Returns:
434+
A tuple of (inputs, labels, extra_inputs, extra_kwargs) where:
435+
- inputs: Main input tensor extracted from input_dict["input"].
436+
- labels: Target labels (unchanged from input parameter).
437+
- extra_inputs: Dict of auxiliary input tensors (all keys except
438+
"input" from input_dict). These are passed to the model forward
439+
but are NOT forwarded across pipeline parallel stages.
440+
- extra_kwargs: Dict of additional keyword arguments for model forward.
441+
These ARE forwarded across pipeline parallel stages. Contains
442+
attention_masks if flex attention is enabled.
443+
444+
Note:
445+
The distinction between extra_inputs and extra_kwargs is important for
446+
pipeline parallelism: extra_kwargs are forwarded to all pipeline stages,
447+
while extra_inputs are only available to the first stage.
448+
"""
419449
inputs = input_dict["input"]
420450
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
421451
# For arguments, like attention_masks, we have to put them in a separate
422452
# dict as extra_inputs are not forwarded to other stages in PP, but
423453
# extra_kwargs are.
424-
extra_kwargs = {}
454+
extra_kwargs: dict[str, Any] = {}
425455

426456
if getattr(self.model_args, "use_flex_attn", False):
427-
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
457+
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
428458
input_batch=inputs,
429459
tokenizer=self.tokenizer,
430460
extra_inputs=extra_inputs,
431461
)
432462

463+
return inputs, labels, extra_inputs, extra_kwargs
464+
465+
def forward_backward_step(
466+
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
467+
) -> torch.Tensor:
468+
model_parts = self.model_parts
469+
parallel_dims = self.parallel_dims
470+
471+
inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
472+
input_dict, labels
473+
)
433474
# apply context parallelism if cp is enabled
434475
# ensure CP handles the separate freqs_cis buffer for each pp stage
435476
optional_context_parallel_ctx = (

0 commit comments

Comments
 (0)