Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 48 additions & 7 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,26 +410,67 @@ def batch_generator(

yield input_dict, labels

def forward_backward_step(
def post_dataloading_process(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims

) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
"""
Post-processing hook after data loading and before model forward pass.
This method processes the raw data from the dataloader and prepares it for
the model's forward pass. It separates the main input tensor from auxiliary
inputs and constructs additional keyword arguments (e.g., attention masks).
This method can be overridden in subclasses to customize data processing
for different training strategies (e.g., converting tensors to DTensors,
applying custom transformations, etc.).
Args:
input_dict: Dictionary containing tensors from the dataloader. Must
contain an "input" key with the main input tensor. May contain
additional keys for auxiliary inputs (e.g., position ids).
labels: Target labels for the batch.
Returns:
A tuple of (inputs, labels, extra_inputs, extra_kwargs) where:
- inputs: Main input tensor extracted from input_dict["input"].
- labels: Target labels (unchanged from input parameter).
- extra_inputs: Dict of auxiliary input tensors (all keys except
"input" from input_dict). These are passed to the model forward
but are NOT forwarded across pipeline parallel stages.
- extra_kwargs: Dict of additional keyword arguments for model forward.
These ARE forwarded across pipeline parallel stages. Contains
attention_masks if flex attention is enabled.
Note:
The distinction between extra_inputs and extra_kwargs is important for
pipeline parallelism: extra_kwargs are forwarded to all pipeline stages,
while extra_inputs are only available to the first stage.
"""
inputs = input_dict["input"]
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
# For arguments, like attention_masks, we have to put them in a separate
# dict as extra_inputs are not forwarded to other stages in PP, but
# extra_kwargs are.
extra_kwargs = {}
extra_kwargs: dict[str, Any] = {}

if getattr(self.model_args, "use_flex_attn", False):
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
extra_inputs=extra_inputs,
)

return inputs, labels, extra_inputs, extra_kwargs

def forward_backward_step(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims

inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
input_dict, labels
)
# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
optional_context_parallel_ctx = (
Expand Down
Loading