Skip to content

Commit 8696f03

Browse files
authored
Merge pull request #28 from curt-tigges/cleanup-stage-06
added distributed utilities layer
2 parents 2665913 + e79b391 commit 8696f03

File tree

8 files changed

+355
-64
lines changed

8 files changed

+355
-64
lines changed

clt/models/activations.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
22
from typing import Optional, Tuple, Dict, List
3-
import torch.distributed as dist
43
import logging
54
from clt.config import CLTConfig
65
from torch.distributed import ProcessGroup
6+
from clt.parallel import ops as dist_ops
77

88

99
class BatchTopK(torch.autograd.Function):
@@ -234,9 +234,7 @@ def _apply_batch_topk_helper(
234234
) -> Dict[int, torch.Tensor]:
235235
"""Helper to apply BatchTopK globally across concatenated layer pre-activations."""
236236

237-
world_size = 1
238-
if process_group is not None and dist.is_initialized():
239-
world_size = dist.get_world_size(process_group)
237+
world_size = dist_ops.get_world_size(process_group)
240238

241239
if not preactivations_dict:
242240
logger_helpers.warning(f"Rank {rank}: _apply_batch_topk_helper received empty preactivations_dict.")
@@ -310,9 +308,9 @@ def _apply_batch_topk_helper(
310308
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
311309
)
312310
mask.copy_(local_mask)
313-
dist.broadcast(mask, src=0, group=process_group)
311+
dist_ops.broadcast(mask, src=0, group=process_group)
314312
else:
315-
dist.broadcast(mask, src=0, group=process_group)
313+
dist_ops.broadcast(mask, src=0, group=process_group)
316314
else:
317315
mask = BatchTopK._compute_mask(
318316
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
@@ -340,9 +338,7 @@ def _apply_token_topk_helper(
340338
process_group: Optional[ProcessGroup],
341339
) -> Dict[int, torch.Tensor]:
342340
"""Helper to apply TokenTopK globally across concatenated layer pre-activations."""
343-
world_size = 1
344-
if process_group is not None and dist.is_initialized():
345-
world_size = dist.get_world_size(process_group)
341+
world_size = dist_ops.get_world_size(process_group)
346342

347343
if not preactivations_dict:
348344
logger_helpers.warning(f"Rank {rank}: _apply_token_topk_helper received empty preactivations_dict.")
@@ -418,9 +414,9 @@ def _apply_token_topk_helper(
418414
concatenated_preactivations_normalized,
419415
)
420416
mask.copy_(local_mask)
421-
dist.broadcast(mask, src=0, group=process_group)
417+
dist_ops.broadcast(mask, src=0, group=process_group)
422418
else:
423-
dist.broadcast(mask, src=0, group=process_group)
419+
dist_ops.broadcast(mask, src=0, group=process_group)
424420
else:
425421
mask = TokenTopK._compute_mask(
426422
concatenated_preactivations_original, k_val_float, concatenated_preactivations_normalized

clt/models/clt.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from typing import Dict, Optional, Union, Tuple, List
33
import logging
4-
import torch.distributed as dist
54

65
from clt.config import CLTConfig
76
from clt.models.base import BaseTranscoder
@@ -12,6 +11,7 @@
1211
from clt.models.theta import ThetaManager
1312

1413
from clt.activations.registry import get_activation_fn
14+
from clt.parallel import ops as dist_ops
1515

1616
from torch.distributed import ProcessGroup
1717

@@ -34,13 +34,8 @@ def __init__(
3434
):
3535
super().__init__(config)
3636
self.process_group = process_group
37-
if process_group is None or not dist.is_initialized():
38-
self.world_size = 1
39-
self.rank = 0
40-
self.process_group = None
41-
else:
42-
self.world_size = dist.get_world_size(process_group)
43-
self.rank = dist.get_rank(process_group)
37+
self.world_size = dist_ops.get_world_size(process_group)
38+
self.rank = dist_ops.get_rank(process_group)
4439

4540
self.dtype = self._resolve_dtype(config.clt_dtype)
4641
if device is not None:

clt/models/decoder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import torch.nn as nn
33
from typing import Dict, Optional
44
import logging
5-
import torch.distributed as dist
6-
from torch.distributed import ProcessGroup
75

86
from clt.config import CLTConfig
97
from clt.models.parallel import RowParallelLinear
8+
from clt.parallel import ops as dist_ops
9+
from torch.distributed import ProcessGroup
1010

1111
logger = logging.getLogger(__name__)
1212

@@ -31,12 +31,12 @@ def __init__(
3131
self.device = device
3232
self.dtype = dtype
3333

34-
if process_group is None or not dist.is_initialized():
34+
if process_group is None or not dist_ops.is_dist_initialized_and_available():
3535
self.world_size = 1
3636
self.rank = 0
3737
else:
38-
self.world_size = dist.get_world_size(process_group)
39-
self.rank = dist.get_rank(process_group)
38+
self.world_size = dist_ops.get_world_size(process_group)
39+
self.rank = dist_ops.get_rank(process_group)
4040

4141
self.decoders = nn.ModuleDict(
4242
{
@@ -175,8 +175,8 @@ def get_decoder_norms(self) -> torch.Tensor:
175175
f"Valid norms shape {valid_norms_sq.shape}, expected size {actual_local_dim}."
176176
)
177177

178-
if self.process_group is not None and dist.is_initialized():
179-
dist.all_reduce(local_norms_sq_accum, op=dist.ReduceOp.SUM, group=self.process_group)
178+
if self.process_group is not None and dist_ops.is_dist_initialized_and_available():
179+
dist_ops.all_reduce(local_norms_sq_accum, op=dist_ops.SUM, group=self.process_group)
180180

181181
full_decoder_norms[src_layer] = torch.sqrt(local_norms_sq_accum).to(self.dtype)
182182

clt/models/encoder.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import torch.nn as nn
33
from typing import Dict, List, Tuple, Optional
44
import logging
5-
import torch.distributed as dist
65

76
from clt.config import CLTConfig
87
from clt.models.parallel import ColumnParallelLinear
8+
from clt.parallel import ops as dist_ops
99
from torch.distributed import ProcessGroup
1010

1111
logger = logging.getLogger(__name__)
@@ -31,12 +31,8 @@ def __init__(
3131
self.device = device
3232
self.dtype = dtype
3333

34-
if process_group is None or not dist.is_initialized():
35-
self.world_size = 1
36-
self.rank = 0
37-
else:
38-
self.world_size = dist.get_world_size(process_group)
39-
self.rank = dist.get_rank(process_group)
34+
self.world_size = dist_ops.get_world_size(process_group)
35+
self.rank = dist_ops.get_rank(process_group)
4036

4137
self.encoders = nn.ModuleList(
4238
[

clt/models/parallel.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
import torch.distributed as dist
54
from torch.distributed import ProcessGroup
65
import math
76
from typing import Callable, Optional, cast, Tuple
87

98
from . import mark_replicated
9+
from clt.parallel import ops as dist_ops
1010

1111

1212
class _ParallelLinear(nn.Module):
@@ -32,14 +32,12 @@ def __init__(
3232
super().__init__()
3333
self.process_group = process_group
3434

35-
# Handle non-distributed case
36-
if process_group is None or not dist.is_initialized():
37-
self.world_size = 1
38-
self.rank = 0
35+
# Handle non-distributed case using new utility functions
36+
self.world_size = dist_ops.get_world_size(process_group)
37+
self.rank = dist_ops.get_rank(process_group)
38+
# If world_size is 1, process_group should effectively be None for logic below
39+
if self.world_size == 1:
3940
self.process_group = None
40-
else:
41-
self.world_size = dist.get_world_size(process_group)
42-
self.rank = dist.get_rank(process_group)
4341

4442
self.partition_dim = partition_dim
4543
self.input_is_parallel = input_is_parallel
@@ -108,15 +106,16 @@ class _Gather(torch.autograd.Function):
108106

109107
@staticmethod
110108
def forward(ctx, input_: torch.Tensor, process_group: ProcessGroup, dim: int, full_dim_size: Optional[int]):
111-
if process_group is None or not dist.is_initialized() or dist.get_world_size(process_group) == 1:
109+
# Use new utility functions
110+
if not dist_ops.is_dist_initialized_and_available() or dist_ops.get_world_size(process_group) == 1:
112111
ctx.dim = dim
113112
ctx.local_dim = input_.size(dim)
114113
ctx.full_dim_size = full_dim_size or input_.size(dim)
115114
ctx.process_group = None # Mark non-distributed case
116115
return input_
117116

118-
world_size = dist.get_world_size(process_group)
119-
rank = dist.get_rank(process_group)
117+
world_size = dist_ops.get_world_size(process_group)
118+
rank = dist_ops.get_rank(process_group)
120119

121120
ctx.dim = dim
122121
ctx.local_dim = input_.size(dim)
@@ -131,8 +130,8 @@ def forward(ctx, input_: torch.Tensor, process_group: ProcessGroup, dim: int, fu
131130
# can track the dependency (no copy!).
132131
gathered[rank] = input_contig
133132

134-
# Perform the collective.
135-
dist.all_gather(gathered, input_contig, group=process_group)
133+
# Perform the collective using new utility function wrapper
134+
dist_ops.all_gather(gathered, input_contig, group=process_group)
136135

137136
output = torch.cat(gathered, dim=dim)
138137

@@ -150,10 +149,15 @@ def backward(ctx, *grad_outputs):
150149
grad_output = grad_outputs[0]
151150

152151
# Non-distributed: gradient flows straight through.
153-
if ctx.process_group is None or not dist.is_initialized() or dist.get_world_size(ctx.process_group) == 1:
152+
# Use new utility functions
153+
if (
154+
ctx.process_group is None
155+
or not dist_ops.is_dist_initialized_and_available()
156+
or dist_ops.get_world_size(ctx.process_group) == 1
157+
):
154158
return grad_output, None, None, None
155159

156-
rank = dist.get_rank(ctx.process_group)
160+
rank = dist_ops.get_rank(ctx.process_group)
157161

158162
# Compute start/end indices for this rank's slice along the gather dim.
159163
local_dim_padded = ctx.local_dim # Already accounts for padding in weight shape.
@@ -179,25 +183,28 @@ class _Reduce(torch.autograd.Function):
179183

180184
@staticmethod
181185
def forward(ctx, input_: torch.Tensor, process_group: Optional[ProcessGroup]):
182-
if process_group is None or not dist.is_initialized() or dist.get_world_size(process_group) == 1:
186+
# Use new utility functions
187+
if not dist_ops.is_dist_initialized_and_available() or dist_ops.get_world_size(process_group) == 1:
183188
ctx.process_group = None # Mark non-distributed case
184189
return input_
185190

186191
ctx.process_group = process_group
187192
input_contig = input_.contiguous() # Ensure contiguous before collective
188193

189-
# Perform the all-reduce with SUM operation.
190-
# The operation is in-place on input_contig if it's the same object for all_reduce's output internally,
191-
# or if all_reduce returns a new tensor, that's what we return.
192-
# For clarity, let's assume all_reduce modifies input_contig or we assign its result.
193-
dist.all_reduce(input_contig, op=dist.ReduceOp.SUM, group=process_group)
194+
# Perform the all-reduce with SUM operation using new utility function wrapper.
195+
dist_ops.all_reduce(input_contig, op=dist_ops.SUM, group=process_group)
194196
# The tensor input_contig now holds the sum.
195197
return input_contig
196198

197199
@staticmethod
198200
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
199201
# Non-distributed: gradient flows straight through.
200-
if ctx.process_group is None or not dist.is_initialized() or dist.get_world_size(ctx.process_group) == 1:
202+
# Use new utility functions
203+
if (
204+
ctx.process_group is None
205+
or not dist_ops.is_dist_initialized_and_available()
206+
or dist_ops.get_world_size(ctx.process_group) == 1
207+
):
201208
# Match the number of forward inputs in return for consistency
202209
return grad_output.contiguous() if grad_output is not None else None, None
203210

@@ -220,10 +227,11 @@ def _reduce(input_, process_group):
220227
and broken optimisation. The caller can always divide afterwards if an average is
221228
truly desired, but for the core TP math we need the raw sum.
222229
"""
223-
if process_group is None or not dist.is_initialized():
230+
# Use new utility functions
231+
if not dist_ops.is_dist_initialized_and_available():
224232
return input_ # No-op if not distributed
225233

226-
world_size = dist.get_world_size(process_group)
234+
world_size = dist_ops.get_world_size(process_group)
227235
if world_size == 1:
228236
return input_
229237

@@ -239,14 +247,15 @@ def _split(input_, process_group, dim=-1):
239247
Assumes uniform padding, so each rank gets ceil(full_dim / world_size).
240248
Handles truncation for ranks that would exceed the original full dimension.
241249
"""
242-
if process_group is None or not dist.is_initialized():
250+
# Use new utility functions
251+
if not dist_ops.is_dist_initialized_and_available():
243252
return input_ # No-op if not distributed
244253

245-
world_size = dist.get_world_size(process_group)
254+
world_size = dist_ops.get_world_size(process_group)
246255
if world_size == 1:
247256
return input_
248257

249-
rank = dist.get_rank(process_group)
258+
rank = dist_ops.get_rank(process_group)
250259
full_dim_size = input_.size(dim)
251260

252261
# Calculate the size of each slice (using ceil for uniform distribution)
@@ -402,10 +411,11 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor:
402411

403412
# Add bias *after* reduction
404413
if self.bias and self.bias_param is not None:
405-
# Cast bias_param for type checker; runtime None already guarded.
414+
# The runtime check `self.bias_param is not None` is the primary guard.
415+
# Casting `self.bias_param` to `torch.Tensor` helps the type checker.
406416
reduced_output = reduced_output + cast(torch.Tensor, self.bias_param)
407417

408-
return cast(torch.Tensor, reduced_output) # Cast to ensure Tensor type
418+
return cast(torch.Tensor, reduced_output)
409419

410420

411421
# --------------------------- Public helper --------------------------- #

0 commit comments

Comments
 (0)