|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -from collections.abc import Sequence |
| 7 | +from collections.abc import Generator, Sequence |
8 | 8 | from contextlib import contextmanager |
9 | 9 | from dataclasses import dataclass |
10 | 10 |
|
|
27 | 27 |
|
28 | 28 |
|
29 | 29 | @contextmanager |
30 | | -def disable_active_parametrization(): |
| 30 | +def disable_active_parametrization() -> Generator[None, None, None]: |
31 | 31 | global _active_parametrization |
32 | 32 | try: |
33 | 33 | _active_parametrization = False |
@@ -180,27 +180,27 @@ def _register_parametrization( |
180 | 180 | class ReplicateComputation(torch.nn.Module): |
181 | 181 | def __init__( |
182 | 182 | self, |
183 | | - device_mesh, |
184 | | - param_sharding, |
185 | | - mode, |
186 | | - mp_policy, |
187 | | - reduction_divide_factor, |
188 | | - ): |
| 183 | + device_mesh: DeviceMesh, |
| 184 | + param_sharding: tuple[Placement, ...], |
| 185 | + mode: str, |
| 186 | + mp_policy: MixedPrecisionPolicy | None, |
| 187 | + reduction_divide_factor: float | None, |
| 188 | + ) -> None: |
189 | 189 | super().__init__() |
190 | 190 | self.device_mesh = device_mesh |
191 | 191 | self.param_sharding = param_sharding |
192 | 192 | self.mode = mode |
193 | | - self.compute_placements = [Replicate()] * self.device_mesh.ndim |
194 | | - self.grad_placements = [ |
| 193 | + self.compute_placements: list[Placement] = [Replicate()] * self.device_mesh.ndim |
| 194 | + self.grad_placements: list[Placement] = [ |
195 | 195 | _ScaledPartial( |
196 | 196 | reduction_divide_factor=reduction_divide_factor, |
197 | 197 | ) |
198 | 198 | if reduction_divide_factor is not None |
199 | 199 | else Partial(reduce_op="avg") |
200 | 200 | ] * self.device_mesh.ndim |
201 | 201 | mp_policy = mp_policy or MixedPrecisionPolicy() |
202 | | - self.param_dtype = mp_policy.param_dtype |
203 | | - self.reduce_dtype = mp_policy.reduce_dtype |
| 202 | + self.param_dtype: torch.dtype | None = mp_policy.param_dtype |
| 203 | + self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype |
204 | 204 |
|
205 | 205 | def replicate_compute(self, x: DTensor) -> torch.Tensor: |
206 | 206 | # data parallel runtime replicate parameters and do local compute |
@@ -274,7 +274,8 @@ def data_parallel( |
274 | 274 | mp_policy: MixedPrecisionPolicy | None = None, |
275 | 275 | shard_dim: int = 0, |
276 | 276 | reduction_divide_factor: float | None = None, |
277 | | -): |
| 277 | +) -> nn.Module: |
| 278 | + param_sharding: tuple[Placement, ...] |
278 | 279 | if mode == "replicate": |
279 | 280 | param_sharding = (Replicate(),) |
280 | 281 | elif mode == "fully_shard": |
|
0 commit comments