Skip to content

Commit 2c13a68

Browse files
committed
Add post_dataloading_processing method to Trainer
We are adding more actions to convert the raw inputs and label. 1. The new CP can do the input/label/BlockMask sharding this in this method. 2. The experimental full dtensor model can simply override this method without changing too many Trainer code. This method is extracted from #1857 Makeing this a standalone PR allows us to continue the two projects above without one blocks another. ghstack-source-id: d1882a7 Pull-Request: #1985
1 parent 8659543 commit 2c13a68

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

torchtitan/train.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -410,26 +410,37 @@ def batch_generator(
410410

411411
yield input_dict, labels
412412

413-
def forward_backward_step(
414-
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
415-
) -> torch.Tensor:
416-
model_parts = self.model_parts
417-
parallel_dims = self.parallel_dims
418-
413+
def post_dataloading_processing(
414+
self, input_dict: dict[str, torch.Tensor], label: torch.Tensor
415+
) -> tuple[
416+
dict[str, torch.Tensor], torch.Tensor, dict[str, torch.Tensor], dict[str, Any]
417+
]:
418+
"""Post processing after data loading."""
419419
inputs = input_dict["input"]
420420
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
421421
# For arguments, like attention_masks, we have to put them in a separate
422422
# dict as extra_inputs are not forwarded to other stages in PP, but
423423
# extra_kwargs are.
424-
extra_kwargs = {}
424+
extra_kwargs: dict[str, Any] = {}
425425

426426
if getattr(self.model_args, "use_flex_attn", False):
427-
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
427+
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
428428
input_batch=inputs,
429429
tokenizer=self.tokenizer,
430430
extra_inputs=extra_inputs,
431431
)
432432

433+
return inputs, label, extra_inputs, extra_kwargs
434+
435+
def forward_backward_step(
436+
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
437+
) -> torch.Tensor:
438+
model_parts = self.model_parts
439+
parallel_dims = self.parallel_dims
440+
441+
inputs, label, extra_inputs, extra_kwargs = self.post_dataloading_processing(
442+
input_dict, labels
443+
)
433444
# apply context parallelism if cp is enabled
434445
# ensure CP handles the separate freqs_cis buffer for each pp stage
435446
optional_context_parallel_ctx = (

0 commit comments

Comments
 (0)