Skip to content

Commit e37f83f

Browse files
authored
[Full DTensor][Reland] Add full_dtensor flag (#2013)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2013 When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. **This is a reland PR of #2002. The previous one was broken during rebase.**
1 parent fddd9eb commit e37f83f

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(
185185
mode: str,
186186
mp_policy: MixedPrecisionPolicy | None,
187187
reduction_divide_factor: float | None,
188+
full_dtensor: bool = False,
188189
) -> None:
189190
super().__init__()
190191
self.device_mesh = device_mesh
@@ -201,6 +202,7 @@ def __init__(
201202
mp_policy = mp_policy or MixedPrecisionPolicy()
202203
self.param_dtype: torch.dtype | None = mp_policy.param_dtype
203204
self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype
205+
self.full_dtensor = full_dtensor
204206

205207
def replicate_compute(self, x: DTensor) -> torch.Tensor:
206208
# data parallel runtime replicate parameters and do local compute
@@ -210,6 +212,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor:
210212
non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim
211213
assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported"
212214
if non_dp_mesh_dims > 0:
215+
if self.full_dtensor:
216+
raise NotImplementedError(
217+
"full_dtensor not implemented for nD parallelisms"
218+
)
213219
dp_mesh = self.device_mesh
214220
# re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather
215221
sharded_local_tensor = x.to_local()
@@ -245,7 +251,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor:
245251
placements=self.compute_placements,
246252
forward_dtype=self.param_dtype,
247253
backward_dtype=self.reduce_dtype,
248-
).to_local(grad_placements=self.grad_placements)
254+
)
255+
256+
if not self.full_dtensor:
257+
output = output.to_local(grad_placements=self.grad_placements)
249258
else:
250259
raise AssertionError(
251260
f"Unsupported replicate compute on placement {x._spec.placements} for DTensor {x}"
@@ -274,6 +283,7 @@ def data_parallel(
274283
mp_policy: MixedPrecisionPolicy | None = None,
275284
shard_dim: int = 0,
276285
reduction_divide_factor: float | None = None,
286+
full_dtensor: bool = False,
277287
) -> nn.Module:
278288
param_sharding: tuple[Placement, ...]
279289
if mode == "replicate":
@@ -333,6 +343,7 @@ def data_parallel(
333343
mode,
334344
mp_policy=mp_policy,
335345
reduction_divide_factor=reduction_divide_factor,
346+
full_dtensor=full_dtensor,
336347
),
337348
)
338349
return model

0 commit comments

Comments
 (0)