Skip to content

Commit fddd9eb

Browse files
authored
[SimpleFSDP] Add typing to simple_fsdp.py (#2001)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #2002 * __->__ #2001 Add typing, credit to Claude.
1 parent 02990b0 commit fddd9eb

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from collections.abc import Sequence
7+
from collections.abc import Generator, Sequence
88
from contextlib import contextmanager
99
from dataclasses import dataclass
1010

@@ -27,7 +27,7 @@
2727

2828

2929
@contextmanager
30-
def disable_active_parametrization():
30+
def disable_active_parametrization() -> Generator[None, None, None]:
3131
global _active_parametrization
3232
try:
3333
_active_parametrization = False
@@ -180,27 +180,27 @@ def _register_parametrization(
180180
class ReplicateComputation(torch.nn.Module):
181181
def __init__(
182182
self,
183-
device_mesh,
184-
param_sharding,
185-
mode,
186-
mp_policy,
187-
reduction_divide_factor,
188-
):
183+
device_mesh: DeviceMesh,
184+
param_sharding: tuple[Placement, ...],
185+
mode: str,
186+
mp_policy: MixedPrecisionPolicy | None,
187+
reduction_divide_factor: float | None,
188+
) -> None:
189189
super().__init__()
190190
self.device_mesh = device_mesh
191191
self.param_sharding = param_sharding
192192
self.mode = mode
193-
self.compute_placements = [Replicate()] * self.device_mesh.ndim
194-
self.grad_placements = [
193+
self.compute_placements: list[Placement] = [Replicate()] * self.device_mesh.ndim
194+
self.grad_placements: list[Placement] = [
195195
_ScaledPartial(
196196
reduction_divide_factor=reduction_divide_factor,
197197
)
198198
if reduction_divide_factor is not None
199199
else Partial(reduce_op="avg")
200200
] * self.device_mesh.ndim
201201
mp_policy = mp_policy or MixedPrecisionPolicy()
202-
self.param_dtype = mp_policy.param_dtype
203-
self.reduce_dtype = mp_policy.reduce_dtype
202+
self.param_dtype: torch.dtype | None = mp_policy.param_dtype
203+
self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype
204204

205205
def replicate_compute(self, x: DTensor) -> torch.Tensor:
206206
# data parallel runtime replicate parameters and do local compute
@@ -274,7 +274,8 @@ def data_parallel(
274274
mp_policy: MixedPrecisionPolicy | None = None,
275275
shard_dim: int = 0,
276276
reduction_divide_factor: float | None = None,
277-
):
277+
) -> nn.Module:
278+
param_sharding: tuple[Placement, ...]
278279
if mode == "replicate":
279280
param_sharding = (Replicate(),)
280281
elif mode == "fully_shard":

0 commit comments

Comments
 (0)