Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions clt/models/activations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
from typing import Optional, Tuple, Dict, List
import torch.distributed as dist
import logging
from clt.config import CLTConfig
from torch.distributed import ProcessGroup
from clt.parallel import ops as dist_ops


class BatchTopK(torch.autograd.Function):
Expand Down Expand Up @@ -234,9 +234,7 @@ def _apply_batch_topk_helper(
) -> Dict[int, torch.Tensor]:
"""Helper to apply BatchTopK globally across concatenated layer pre-activations."""

world_size = 1
if process_group is not None and dist.is_initialized():
world_size = dist.get_world_size(process_group)
world_size = dist_ops.get_world_size(process_group)

if not preactivations_dict:
logger_helpers.warning(f"Rank {rank}: _apply_batch_topk_helper received empty preactivations_dict.")
Expand Down Expand Up @@ -310,9 +308,9 @@ def _apply_batch_topk_helper(
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
)
mask.copy_(local_mask)
dist.broadcast(mask, src=0, group=process_group)
dist_ops.broadcast(mask, src=0, group=process_group)
else:
dist.broadcast(mask, src=0, group=process_group)
dist_ops.broadcast(mask, src=0, group=process_group)
else:
mask = BatchTopK._compute_mask(
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
Expand Down Expand Up @@ -340,9 +338,7 @@ def _apply_token_topk_helper(
process_group: Optional[ProcessGroup],
) -> Dict[int, torch.Tensor]:
"""Helper to apply TokenTopK globally across concatenated layer pre-activations."""
world_size = 1
if process_group is not None and dist.is_initialized():
world_size = dist.get_world_size(process_group)
world_size = dist_ops.get_world_size(process_group)

if not preactivations_dict:
logger_helpers.warning(f"Rank {rank}: _apply_token_topk_helper received empty preactivations_dict.")
Expand Down Expand Up @@ -418,9 +414,9 @@ def _apply_token_topk_helper(
concatenated_preactivations_normalized,
)
mask.copy_(local_mask)
dist.broadcast(mask, src=0, group=process_group)
dist_ops.broadcast(mask, src=0, group=process_group)
else:
dist.broadcast(mask, src=0, group=process_group)
dist_ops.broadcast(mask, src=0, group=process_group)
else:
mask = TokenTopK._compute_mask(
concatenated_preactivations_original, k_val_float, concatenated_preactivations_normalized
Expand Down
11 changes: 3 additions & 8 deletions clt/models/clt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from typing import Dict, Optional, Union, Tuple, List
import logging
import torch.distributed as dist

from clt.config import CLTConfig
from clt.models.base import BaseTranscoder
Expand All @@ -12,6 +11,7 @@
from clt.models.theta import ThetaManager

from clt.activations.registry import get_activation_fn
from clt.parallel import ops as dist_ops

from torch.distributed import ProcessGroup

Expand All @@ -34,13 +34,8 @@ def __init__(
):
super().__init__(config)
self.process_group = process_group
if process_group is None or not dist.is_initialized():
self.world_size = 1
self.rank = 0
self.process_group = None
else:
self.world_size = dist.get_world_size(process_group)
self.rank = dist.get_rank(process_group)
self.world_size = dist_ops.get_world_size(process_group)
self.rank = dist_ops.get_rank(process_group)

self.dtype = self._resolve_dtype(config.clt_dtype)
if device is not None:
Expand Down
14 changes: 7 additions & 7 deletions clt/models/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import torch.nn as nn
from typing import Dict, Optional
import logging
import torch.distributed as dist
from torch.distributed import ProcessGroup

from clt.config import CLTConfig
from clt.models.parallel import RowParallelLinear
from clt.parallel import ops as dist_ops
from torch.distributed import ProcessGroup

logger = logging.getLogger(__name__)

Expand All @@ -31,12 +31,12 @@ def __init__(
self.device = device
self.dtype = dtype

if process_group is None or not dist.is_initialized():
if process_group is None or not dist_ops.is_dist_initialized_and_available():
self.world_size = 1
self.rank = 0
else:
self.world_size = dist.get_world_size(process_group)
self.rank = dist.get_rank(process_group)
self.world_size = dist_ops.get_world_size(process_group)
self.rank = dist_ops.get_rank(process_group)

self.decoders = nn.ModuleDict(
{
Expand Down Expand Up @@ -175,8 +175,8 @@ def get_decoder_norms(self) -> torch.Tensor:
f"Valid norms shape {valid_norms_sq.shape}, expected size {actual_local_dim}."
)

if self.process_group is not None and dist.is_initialized():
dist.all_reduce(local_norms_sq_accum, op=dist.ReduceOp.SUM, group=self.process_group)
if self.process_group is not None and dist_ops.is_dist_initialized_and_available():
dist_ops.all_reduce(local_norms_sq_accum, op=dist_ops.SUM, group=self.process_group)

full_decoder_norms[src_layer] = torch.sqrt(local_norms_sq_accum).to(self.dtype)

Expand Down
10 changes: 3 additions & 7 deletions clt/models/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import torch.nn as nn
from typing import Dict, List, Tuple, Optional
import logging
import torch.distributed as dist

from clt.config import CLTConfig
from clt.models.parallel import ColumnParallelLinear
from clt.parallel import ops as dist_ops
from torch.distributed import ProcessGroup

logger = logging.getLogger(__name__)
Expand All @@ -31,12 +31,8 @@ def __init__(
self.device = device
self.dtype = dtype

if process_group is None or not dist.is_initialized():
self.world_size = 1
self.rank = 0
else:
self.world_size = dist.get_world_size(process_group)
self.rank = dist.get_rank(process_group)
self.world_size = dist_ops.get_world_size(process_group)
self.rank = dist_ops.get_rank(process_group)

self.encoders = nn.ModuleList(
[
Expand Down
68 changes: 39 additions & 29 deletions clt/models/parallel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.distributed import ProcessGroup
import math
from typing import Callable, Optional, cast, Tuple

from . import mark_replicated
from clt.parallel import ops as dist_ops


class _ParallelLinear(nn.Module):
Expand All @@ -32,14 +32,12 @@ def __init__(
super().__init__()
self.process_group = process_group

# Handle non-distributed case
if process_group is None or not dist.is_initialized():
self.world_size = 1
self.rank = 0
# Handle non-distributed case using new utility functions
self.world_size = dist_ops.get_world_size(process_group)
self.rank = dist_ops.get_rank(process_group)
# If world_size is 1, process_group should effectively be None for logic below
if self.world_size == 1:
self.process_group = None
else:
self.world_size = dist.get_world_size(process_group)
self.rank = dist.get_rank(process_group)

self.partition_dim = partition_dim
self.input_is_parallel = input_is_parallel
Expand Down Expand Up @@ -108,15 +106,16 @@ class _Gather(torch.autograd.Function):

@staticmethod
def forward(ctx, input_: torch.Tensor, process_group: ProcessGroup, dim: int, full_dim_size: Optional[int]):
if process_group is None or not dist.is_initialized() or dist.get_world_size(process_group) == 1:
# Use new utility functions
if not dist_ops.is_dist_initialized_and_available() or dist_ops.get_world_size(process_group) == 1:
ctx.dim = dim
ctx.local_dim = input_.size(dim)
ctx.full_dim_size = full_dim_size or input_.size(dim)
ctx.process_group = None # Mark non-distributed case
return input_

world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
world_size = dist_ops.get_world_size(process_group)
rank = dist_ops.get_rank(process_group)

ctx.dim = dim
ctx.local_dim = input_.size(dim)
Expand All @@ -131,8 +130,8 @@ def forward(ctx, input_: torch.Tensor, process_group: ProcessGroup, dim: int, fu
# can track the dependency (no copy!).
gathered[rank] = input_contig

# Perform the collective.
dist.all_gather(gathered, input_contig, group=process_group)
# Perform the collective using new utility function wrapper
dist_ops.all_gather(gathered, input_contig, group=process_group)

output = torch.cat(gathered, dim=dim)

Expand All @@ -150,10 +149,15 @@ def backward(ctx, *grad_outputs):
grad_output = grad_outputs[0]

# Non-distributed: gradient flows straight through.
if ctx.process_group is None or not dist.is_initialized() or dist.get_world_size(ctx.process_group) == 1:
# Use new utility functions
if (
ctx.process_group is None
or not dist_ops.is_dist_initialized_and_available()
or dist_ops.get_world_size(ctx.process_group) == 1
):
return grad_output, None, None, None

rank = dist.get_rank(ctx.process_group)
rank = dist_ops.get_rank(ctx.process_group)

# Compute start/end indices for this rank's slice along the gather dim.
local_dim_padded = ctx.local_dim # Already accounts for padding in weight shape.
Expand All @@ -179,25 +183,28 @@ class _Reduce(torch.autograd.Function):

@staticmethod
def forward(ctx, input_: torch.Tensor, process_group: Optional[ProcessGroup]):
if process_group is None or not dist.is_initialized() or dist.get_world_size(process_group) == 1:
# Use new utility functions
if not dist_ops.is_dist_initialized_and_available() or dist_ops.get_world_size(process_group) == 1:
ctx.process_group = None # Mark non-distributed case
return input_

ctx.process_group = process_group
input_contig = input_.contiguous() # Ensure contiguous before collective

# Perform the all-reduce with SUM operation.
# The operation is in-place on input_contig if it's the same object for all_reduce's output internally,
# or if all_reduce returns a new tensor, that's what we return.
# For clarity, let's assume all_reduce modifies input_contig or we assign its result.
dist.all_reduce(input_contig, op=dist.ReduceOp.SUM, group=process_group)
# Perform the all-reduce with SUM operation using new utility function wrapper.
dist_ops.all_reduce(input_contig, op=dist_ops.SUM, group=process_group)
# The tensor input_contig now holds the sum.
return input_contig

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
# Non-distributed: gradient flows straight through.
if ctx.process_group is None or not dist.is_initialized() or dist.get_world_size(ctx.process_group) == 1:
# Use new utility functions
if (
ctx.process_group is None
or not dist_ops.is_dist_initialized_and_available()
or dist_ops.get_world_size(ctx.process_group) == 1
):
# Match the number of forward inputs in return for consistency
return grad_output.contiguous() if grad_output is not None else None, None

Expand All @@ -220,10 +227,11 @@ def _reduce(input_, process_group):
and broken optimisation. The caller can always divide afterwards if an average is
truly desired, but for the core TP math we need the raw sum.
"""
if process_group is None or not dist.is_initialized():
# Use new utility functions
if not dist_ops.is_dist_initialized_and_available():
return input_ # No-op if not distributed

world_size = dist.get_world_size(process_group)
world_size = dist_ops.get_world_size(process_group)
if world_size == 1:
return input_

Expand All @@ -239,14 +247,15 @@ def _split(input_, process_group, dim=-1):
Assumes uniform padding, so each rank gets ceil(full_dim / world_size).
Handles truncation for ranks that would exceed the original full dimension.
"""
if process_group is None or not dist.is_initialized():
# Use new utility functions
if not dist_ops.is_dist_initialized_and_available():
return input_ # No-op if not distributed

world_size = dist.get_world_size(process_group)
world_size = dist_ops.get_world_size(process_group)
if world_size == 1:
return input_

rank = dist.get_rank(process_group)
rank = dist_ops.get_rank(process_group)
full_dim_size = input_.size(dim)

# Calculate the size of each slice (using ceil for uniform distribution)
Expand Down Expand Up @@ -402,10 +411,11 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor:

# Add bias *after* reduction
if self.bias and self.bias_param is not None:
# Cast bias_param for type checker; runtime None already guarded.
# The runtime check `self.bias_param is not None` is the primary guard.
# Casting `self.bias_param` to `torch.Tensor` helps the type checker.
reduced_output = reduced_output + cast(torch.Tensor, self.bias_param)

return cast(torch.Tensor, reduced_output) # Cast to ensure Tensor type
return cast(torch.Tensor, reduced_output)


# --------------------------- Public helper --------------------------- #
Expand Down
Loading
Loading