Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,361 changes: 1,361 additions & 0 deletions make_fx_graphs/graph_pp_rank_2_of_4.py

Large diffs are not rendered by default.

1,361 changes: 1,361 additions & 0 deletions make_fx_graphs/graph_pp_rank_3_of_4.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ set -ex
# use envs as local overwrites for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
# NGPU=${NGPU:-"8"}
NGPU=${NGPU:-"4"}
# export LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7}
export LOG_RANK=${LOG_RANK:-0,1,2,3}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}

Expand Down
4 changes: 2 additions & 2 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
deepseekv3_configs = {
"debugmodel": DeepSeekV3ModelArgs(
vocab_size=2000,
dim=256,
dim=32,
inter_dim=1024,
moe_inter_dim=256,
n_layers=6,
n_layers=16,
n_dense_layers=1,
n_heads=16,
moe_args=MoEArgs(
Expand Down
19 changes: 10 additions & 9 deletions torchtitan/models/deepseek_v3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ description = "DeepSeek-V3 debug training"
print_args = false

[profiling]
enable_profiling = false
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
profile_freq = 5
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

Expand Down Expand Up @@ -36,22 +36,23 @@ decay_type = "linear"
min_lr_factor = 0.0

[training]
local_batch_size = 8
seq_len = 2048
local_batch_size = 10
seq_len = 16
max_norm = 1.0 # grad norm clipping
steps = 10
steps = 6
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
# dataset = "c4"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
pipeline_parallel_schedule = "1F1B"
pipeline_parallel_degree = 2
expert_parallel_degree = 2
context_parallel_degree = 1
expert_parallel_degree = 1
pipeline_parallel_schedule = "DualPipeV"
expert_tensor_parallel_degree = 1

[checkpoint]
Expand All @@ -63,7 +64,7 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
mode = "none" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
Expand Down
195 changes: 185 additions & 10 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@
from typing import Any, Generator, Iterable, Optional

import torch
from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.protocols.train_spec as train_spec_module
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.pipelining.schedules import (
_Action,
_PipelineContext,
_PipelineScheduleRuntime,
_PipelineStageBase,
_wait_batch_p2p,
FORWARD,
OVERLAP_F_B,
)
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.dataloader import DataloaderExhaustedError
from torchtitan.components.ft import FTManager, maybe_semi_sync_training
Expand Down Expand Up @@ -431,6 +440,11 @@ def forward_backward_step(
else None
)

# register custom functions
assert isinstance(self.pp_schedule, _PipelineScheduleRuntime)
# self.pp_schedule.register_custom_function(FORWARD, forward_callback)
self.pp_schedule.register_custom_function(OVERLAP_F_B, overlap_callback)

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
with self.train_context(optional_context_parallel_ctx):
Expand Down Expand Up @@ -485,15 +499,17 @@ def train_step(
loss = self.forward_backward_step(input_dict, labels)
accumulated_losses.append(loss.detach())

grad_norm = dist_utils.clip_grad_norm_(
[p for m in self.model_parts for p in m.parameters()],
self.job_config.training.max_norm,
foreach=True,
pp_mesh=(
parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
),
ep_enabled=parallel_dims.ep_enabled,
)
# TODO: parameters are not DTensors which im not sure why
# grad_norm = dist_utils.clip_grad_norm_(
# [p for m in self.model_parts for p in m.parameters()],
# self.job_config.training.max_norm,
# foreach=True,
# pp_mesh=(
# parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
# ),
# ep_enabled=parallel_dims.ep_enabled,
# )
grad_norm = torch.tensor([0.0], device=self.device)
self.checkpointer.maybe_wait_for_staging()
self.optimizers.step()
self.lr_schedulers.step()
Expand Down Expand Up @@ -637,6 +653,165 @@ def close(self) -> None:
self.metrics_processor.close()


from torch.fx.experimental.proxy_tensor import make_fx


def overlap_callback(action: _Action, ctx: _PipelineContext):
"""Custom callback for OVERLAP_F_B computation that mimics the original implementation."""
schedule = ctx.schedule_ref
assert isinstance(schedule, _PipelineScheduleRuntime)
stage_index_to_stage: dict[int, _PipelineStageBase] = {
stage.stage_index: stage for stage in schedule._stages
}
assert action.sub_actions is not None
fwd_action = action.sub_actions[0]
bwd_action = action.sub_actions[1]

# Get stages
forward_stage_index = fwd_action.stage_index
forward_mb_index = fwd_action.microbatch_index
assert forward_mb_index is not None
backward_stage_index = bwd_action.stage_index
backward_stage = stage_index_to_stage[backward_stage_index]

# Forward setup
arg_mbs = ctx.arg_mbs
kwarg_mbs = ctx.kwarg_mbs
fwd_recv_ops = schedule.fwd_recv_ops
forward_stage = stage_index_to_stage[forward_stage_index]
forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage
forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage

# Backward setup
backward_is_next_stage_on_this_rank = (
backward_stage.stage_index + 1 in stage_index_to_stage
)
backward_is_prev_stage_on_this_rank = (
backward_stage.stage_index - 1 in stage_index_to_stage
)
backward_mb_index = bwd_action.microbatch_index
assert backward_mb_index is not None
bwd_recv_ops = schedule.bwd_recv_ops

# PP communication ========================================================

# Fwd receives
if (
not forward_stage.is_first
# no recv op expected for V-schedule special case (see [Note: V-schedule special case])
and not forward_is_prev_stage_on_this_rank
):
assert (
forward_stage_index,
forward_mb_index,
) in fwd_recv_ops, f"Computing {action=} before receiving input"
_wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index)))

# Bwd receives
if (
not backward_stage.is_last
# no recv op expected for V-schedule special case (see [Note: V-schedule special case])
and not backward_is_next_stage_on_this_rank
):
assert (
backward_stage_index,
backward_mb_index,
) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input"
_wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index)))

def fwd_bwd_joint_graph(forward_mb_index, arg_mb, kwargs_mb):
# forward_stage.forward_one_chunk(
# forward_mb_index,
# arg_mb,
# kwargs_mb,
# )
if forward_stage.is_first:
# First stage doesn't need to receive anything
composite_args = arg_mb
else:
# Receive activations for this chunk
# Activations only come in args form
composite_args = forward_stage._retrieve_recv_activations(forward_mb_index)

out = forward_stage.submod(*composite_args, **kwargs_mb)
output_grads = [
torch.ones_like(out)
]
grad_input = torch.autograd.grad(out, composite_args[0], output_grads)
return out

graph = make_fx(fwd_bwd_joint_graph)(
forward_mb_index, arg_mbs[forward_mb_index], {}
)
# Save graph code to file for inspection by PP rank
os.makedirs("make_fx_graphs", exist_ok=True)
# Get rank and world_size
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
current_step = getattr(ctx.schedule_ref, 'step', 'unknown')
filename = f"make_fx_graphs/graph_pp_rank_{rank}_of_{world_size}.py"
with open(filename, "w") as f:
f.write(graph.code)
print(f"PP Rank {rank}: Graph code saved to {filename}")

# PP computation ========================================================
def forward_backward_overlapped():
# Forward ========================================================
output = forward_stage.forward_one_chunk(
forward_mb_index,
arg_mbs[forward_mb_index], # type: ignore[index]
kwarg_mbs[forward_mb_index], # type: ignore[index]
)
schedule._maybe_compute_loss(
forward_stage, output, ctx.target_mbs, forward_mb_index
)

# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
# see [Note: V-schedule special case]
if forward_is_next_stage_on_this_rank:
stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input(
output, forward_mb_index
)

# Backward ========================================================
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
schedule.backward_counter[backward_stage_index] += 1
last_backward = (
schedule.backward_counter[backward_stage_index] == schedule._n_microbatches
)
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1
backward_stage.backward_one_chunk(
backward_mb_index,
loss=loss,
full_backward=True,
last_backward=last_backward,
)
if last_backward:
backward_stage.scale_grads(grad_scale_factor)
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
# see [Note: V-schedule special case]
if backward_is_prev_stage_on_this_rank:
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
backward_stage.get_local_bwd_output(backward_mb_index),
backward_mb_index,
)

# Compiled execution
# aot_backend = create_fw_bw_capture_backend()

# compiled_fn = torch.compile(
# forward_backward_overlapped,
# backend=aot_backend
# )
# compiled_fn()

# regular execution
# forward_backward_overlapped()


import fbvscode

fbvscode.attach_debugger()
if __name__ == "__main__":
init_logger()
config_manager = ConfigManager()
Expand Down
Loading