1
1
import torch
2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
- import torch .distributed as dist
5
4
from torch .distributed import ProcessGroup
6
5
import math
7
6
from typing import Callable , Optional , cast , Tuple
8
7
9
8
from . import mark_replicated
9
+ from clt .parallel import ops as dist_ops
10
10
11
11
12
12
class _ParallelLinear (nn .Module ):
@@ -32,14 +32,12 @@ def __init__(
32
32
super ().__init__ ()
33
33
self .process_group = process_group
34
34
35
- # Handle non-distributed case
36
- if process_group is None or not dist .is_initialized ():
37
- self .world_size = 1
38
- self .rank = 0
35
+ # Handle non-distributed case using new utility functions
36
+ self .world_size = dist_ops .get_world_size (process_group )
37
+ self .rank = dist_ops .get_rank (process_group )
38
+ # If world_size is 1, process_group should effectively be None for logic below
39
+ if self .world_size == 1 :
39
40
self .process_group = None
40
- else :
41
- self .world_size = dist .get_world_size (process_group )
42
- self .rank = dist .get_rank (process_group )
43
41
44
42
self .partition_dim = partition_dim
45
43
self .input_is_parallel = input_is_parallel
@@ -108,15 +106,16 @@ class _Gather(torch.autograd.Function):
108
106
109
107
@staticmethod
110
108
def forward (ctx , input_ : torch .Tensor , process_group : ProcessGroup , dim : int , full_dim_size : Optional [int ]):
111
- if process_group is None or not dist .is_initialized () or dist .get_world_size (process_group ) == 1 :
109
+ # Use new utility functions
110
+ if not dist_ops .is_dist_initialized_and_available () or dist_ops .get_world_size (process_group ) == 1 :
112
111
ctx .dim = dim
113
112
ctx .local_dim = input_ .size (dim )
114
113
ctx .full_dim_size = full_dim_size or input_ .size (dim )
115
114
ctx .process_group = None # Mark non-distributed case
116
115
return input_
117
116
118
- world_size = dist .get_world_size (process_group )
119
- rank = dist .get_rank (process_group )
117
+ world_size = dist_ops .get_world_size (process_group )
118
+ rank = dist_ops .get_rank (process_group )
120
119
121
120
ctx .dim = dim
122
121
ctx .local_dim = input_ .size (dim )
@@ -131,8 +130,8 @@ def forward(ctx, input_: torch.Tensor, process_group: ProcessGroup, dim: int, fu
131
130
# can track the dependency (no copy!).
132
131
gathered [rank ] = input_contig
133
132
134
- # Perform the collective.
135
- dist .all_gather (gathered , input_contig , group = process_group )
133
+ # Perform the collective using new utility function wrapper
134
+ dist_ops .all_gather (gathered , input_contig , group = process_group )
136
135
137
136
output = torch .cat (gathered , dim = dim )
138
137
@@ -150,10 +149,15 @@ def backward(ctx, *grad_outputs):
150
149
grad_output = grad_outputs [0 ]
151
150
152
151
# Non-distributed: gradient flows straight through.
153
- if ctx .process_group is None or not dist .is_initialized () or dist .get_world_size (ctx .process_group ) == 1 :
152
+ # Use new utility functions
153
+ if (
154
+ ctx .process_group is None
155
+ or not dist_ops .is_dist_initialized_and_available ()
156
+ or dist_ops .get_world_size (ctx .process_group ) == 1
157
+ ):
154
158
return grad_output , None , None , None
155
159
156
- rank = dist .get_rank (ctx .process_group )
160
+ rank = dist_ops .get_rank (ctx .process_group )
157
161
158
162
# Compute start/end indices for this rank's slice along the gather dim.
159
163
local_dim_padded = ctx .local_dim # Already accounts for padding in weight shape.
@@ -179,25 +183,28 @@ class _Reduce(torch.autograd.Function):
179
183
180
184
@staticmethod
181
185
def forward (ctx , input_ : torch .Tensor , process_group : Optional [ProcessGroup ]):
182
- if process_group is None or not dist .is_initialized () or dist .get_world_size (process_group ) == 1 :
186
+ # Use new utility functions
187
+ if not dist_ops .is_dist_initialized_and_available () or dist_ops .get_world_size (process_group ) == 1 :
183
188
ctx .process_group = None # Mark non-distributed case
184
189
return input_
185
190
186
191
ctx .process_group = process_group
187
192
input_contig = input_ .contiguous () # Ensure contiguous before collective
188
193
189
- # Perform the all-reduce with SUM operation.
190
- # The operation is in-place on input_contig if it's the same object for all_reduce's output internally,
191
- # or if all_reduce returns a new tensor, that's what we return.
192
- # For clarity, let's assume all_reduce modifies input_contig or we assign its result.
193
- dist .all_reduce (input_contig , op = dist .ReduceOp .SUM , group = process_group )
194
+ # Perform the all-reduce with SUM operation using new utility function wrapper.
195
+ dist_ops .all_reduce (input_contig , op = dist_ops .SUM , group = process_group )
194
196
# The tensor input_contig now holds the sum.
195
197
return input_contig
196
198
197
199
@staticmethod
198
200
def backward (ctx , grad_output : torch .Tensor ) -> Tuple [Optional [torch .Tensor ], Optional [torch .Tensor ]]:
199
201
# Non-distributed: gradient flows straight through.
200
- if ctx .process_group is None or not dist .is_initialized () or dist .get_world_size (ctx .process_group ) == 1 :
202
+ # Use new utility functions
203
+ if (
204
+ ctx .process_group is None
205
+ or not dist_ops .is_dist_initialized_and_available ()
206
+ or dist_ops .get_world_size (ctx .process_group ) == 1
207
+ ):
201
208
# Match the number of forward inputs in return for consistency
202
209
return grad_output .contiguous () if grad_output is not None else None , None
203
210
@@ -220,10 +227,11 @@ def _reduce(input_, process_group):
220
227
and broken optimisation. The caller can always divide afterwards if an average is
221
228
truly desired, but for the core TP math we need the raw sum.
222
229
"""
223
- if process_group is None or not dist .is_initialized ():
230
+ # Use new utility functions
231
+ if not dist_ops .is_dist_initialized_and_available ():
224
232
return input_ # No-op if not distributed
225
233
226
- world_size = dist .get_world_size (process_group )
234
+ world_size = dist_ops .get_world_size (process_group )
227
235
if world_size == 1 :
228
236
return input_
229
237
@@ -239,14 +247,15 @@ def _split(input_, process_group, dim=-1):
239
247
Assumes uniform padding, so each rank gets ceil(full_dim / world_size).
240
248
Handles truncation for ranks that would exceed the original full dimension.
241
249
"""
242
- if process_group is None or not dist .is_initialized ():
250
+ # Use new utility functions
251
+ if not dist_ops .is_dist_initialized_and_available ():
243
252
return input_ # No-op if not distributed
244
253
245
- world_size = dist .get_world_size (process_group )
254
+ world_size = dist_ops .get_world_size (process_group )
246
255
if world_size == 1 :
247
256
return input_
248
257
249
- rank = dist .get_rank (process_group )
258
+ rank = dist_ops .get_rank (process_group )
250
259
full_dim_size = input_ .size (dim )
251
260
252
261
# Calculate the size of each slice (using ceil for uniform distribution)
@@ -402,10 +411,11 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor:
402
411
403
412
# Add bias *after* reduction
404
413
if self .bias and self .bias_param is not None :
405
- # Cast bias_param for type checker; runtime None already guarded.
414
+ # The runtime check `self.bias_param is not None` is the primary guard.
415
+ # Casting `self.bias_param` to `torch.Tensor` helps the type checker.
406
416
reduced_output = reduced_output + cast (torch .Tensor , self .bias_param )
407
417
408
- return cast (torch .Tensor , reduced_output ) # Cast to ensure Tensor type
418
+ return cast (torch .Tensor , reduced_output )
409
419
410
420
411
421
# --------------------------- Public helper --------------------------- #
0 commit comments