Skip to content

Commit 29eb910

Browse files
[simplefsdp] add reshard after forward option (#1961)
As titled, this pr adds zero2-style FSDP sharding option to SimpleFSDP. It can be enabled with `--parallelism.simple_fsdp_reshard_after_forward "never"` config. As seen, with `--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --parallelism.simple_fsdp_reshard_after_forward "always"`, there is AG in bwd for re-gather parameters (Trace [link](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-29-20-02-34_rank0_trace.json)) <img width="1289" height="234" alt="Screenshot 2025-10-29 at 8 03 05 PM" src="https://github.com/user-attachments/assets/a6e5b736-9d1f-44a2-aa35-af7b315fd24a" /> with `--parallelism.simple_fsdp_reshard_after_forward "never"`, there is no AG in bwd (Trace [link](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-29-20-04-49_rank0_trace.json)) <img width="870" height="215" alt="Screenshot 2025-10-29 at 8 05 07 PM" src="https://github.com/user-attachments/assets/022f1a02-fe45-45d6-ba04-6bcaf2e9d64f" />
1 parent a3e170c commit 29eb910

File tree

4 files changed

+43
-4
lines changed

4 files changed

+43
-4
lines changed

torchtitan/experiments/simple_fsdp/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ This folder includes an experimental frontend implementation for [SimpleFSDP: Si
1515
#### Training Llama3 models
1616

1717
```bash
18-
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name simple_fsdp.llama3 --compile.enable
18+
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name simple_fsdp.llama3 --compile.enable --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config
1919
```
2020

2121
#### Training DeepSeek_v3 models
2222

2323
```bash
24-
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable
24+
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable --activation_checkpoint.mode "none" --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config
2525
```
2626

2727
### Composability Support
@@ -56,7 +56,7 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing
5656
users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs:
5757

5858
```bash
59-
--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_autobucketing"
59+
--compile.model_backend_override "aot_eager_autobucketing"
6060
```
6161

6262
### Citation

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,20 @@ def parallelize_deepseekv3(
9191
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
9292
)
9393

94+
match job_config.parallelism.fsdp_reshard_after_forward:
95+
case "always":
96+
reshard_after_forward = True
97+
case "never":
98+
reshard_after_forward = False
99+
case "default":
100+
# For PP, by default do not reshard after forward to avoid per-microbatch
101+
# all-gathers, which can be expensive and non-overlapped
102+
reshard_after_forward = not parallel_dims.pp_enabled
103+
case _:
104+
raise ValueError(
105+
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
106+
)
107+
94108
# apply data parallel
95109
dp_mesh: DeviceMesh | None = None
96110
if (
@@ -143,6 +157,7 @@ def parallelize_deepseekv3(
143157
dp_mode,
144158
ac_mode=job_config.activation_checkpoint.mode,
145159
mp_policy=mp_policy,
160+
reshard_after_forward=reshard_after_forward,
146161
shard_dim=experts_shard_dim,
147162
reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
148163
)
@@ -153,6 +168,7 @@ def parallelize_deepseekv3(
153168
dp_mode,
154169
ac_mode=job_config.activation_checkpoint.mode,
155170
mp_policy=mp_policy,
171+
reshard_after_forward=reshard_after_forward,
156172
)
157173

158174
logger.info(

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,27 @@ def parallelize_llama(
112112
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
113113
)
114114

115+
match job_config.parallelism.fsdp_reshard_after_forward:
116+
case "always":
117+
reshard_after_forward = True
118+
case "never":
119+
reshard_after_forward = False
120+
case "default":
121+
# For PP, by default do not reshard after forward to avoid per-microbatch
122+
# all-gathers, which can be expensive and non-overlapped
123+
reshard_after_forward = not parallel_dims.pp_enabled
124+
case _:
125+
raise ValueError(
126+
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
127+
)
128+
115129
model = data_parallel(
116130
model,
117131
parallel_dims.world_mesh[tuple(dp_mesh_dim_names)],
118132
mode=dp_mode,
119133
ac_mode=job_config.activation_checkpoint.mode,
120134
mp_policy=mp_policy,
135+
reshard_after_forward=reshard_after_forward,
121136
)
122137
logger.info(
123138
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(
210210
mode,
211211
regional_ac,
212212
mp_policy,
213+
reshard_after_forward,
213214
reduction_divide_factor,
214215
):
215216
super().__init__()
@@ -228,6 +229,7 @@ def __init__(
228229
mp_policy = mp_policy or MixedPrecisionPolicy()
229230
self.param_dtype = mp_policy.param_dtype
230231
self.reduce_dtype = mp_policy.reduce_dtype
232+
self.reshard_after_forward = reshard_after_forward
231233

232234
def replicate_compute(self, x: DTensor) -> torch.Tensor:
233235
# data parallel runtime replicate parameters and do local compute
@@ -290,7 +292,11 @@ def forward(self, x: DTensor) -> torch.Tensor:
290292
if not _active_parametrization:
291293
return x
292294

293-
if self.regional_ac and self.mode in ("fully_shard", "hybrid_shard"):
295+
if (
296+
self.regional_ac
297+
and self.mode in ("fully_shard", "hybrid_shard")
298+
and self.reshard_after_forward
299+
):
294300
# apply checkpointing to implement reshard_after_forward
295301
output = checkpoint(
296302
self.replicate_compute,
@@ -310,6 +316,7 @@ def data_parallel(
310316
mode: str = "replicate",
311317
ac_mode: str = "none",
312318
mp_policy: MixedPrecisionPolicy | None = None,
319+
reshard_after_forward: bool = True,
313320
shard_dim: int = 0,
314321
reduction_divide_factor: float | None = None,
315322
):
@@ -374,6 +381,7 @@ def data_parallel(
374381
mode,
375382
regional_ac,
376383
mp_policy=mp_policy,
384+
reshard_after_forward=reshard_after_forward,
377385
reduction_divide_factor=reduction_divide_factor,
378386
),
379387
)

0 commit comments

Comments
 (0)