Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
class ExpertParallel(ParallelStyle):
def __init__(self):
super().__init__()
self.input_splits = None
self.output_splits = None
self.input_shape = None
self.permuted_indices = None

# performing all-to-all dispatch on the input
def _token_dispatch(self, mod, inputs, device_mesh):
Expand Down Expand Up @@ -103,14 +99,14 @@ def _token_dispatch(self, mod, inputs, device_mesh):
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)
self.input_splits = input_splits.tolist()
self.output_splits = output_splits.tolist()
input_splits = input_splits.tolist()
output_splits = output_splits.tolist()

# perform all-to-all
routed_input = all_to_all_single_autograd(
routed_input,
self.output_splits,
self.input_splits,
output_splits,
input_splits,
device_mesh.get_group(),
)

Expand All @@ -127,15 +123,22 @@ def _token_dispatch(self, mod, inputs, device_mesh):
# of GroupedExperts, as it does not need padding.

(
self.input_shape,
input_shape,
routed_input,
self.permuted_indices,
permuted_indices,
num_tokens_per_expert_group,
) = _permute(
routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts
)

return routed_input, num_tokens_per_expert_group
return (
routed_input,
num_tokens_per_expert_group,
input_shape,
permuted_indices,
input_splits,
output_splits,
)

@staticmethod
def _partition_fn(name, mod, device_mesh):
Expand All @@ -145,15 +148,20 @@ def _partition_fn(name, mod, device_mesh):
mod.register_parameter(name, dist_param)

# performing all-to-all combine on the output
def _token_combine(self, mod, routed_output, device_mesh):
routed_output = _unpermute(
routed_output, self.input_shape, self.permuted_indices
)
def _token_combine(self, mod, mod_outputs, device_mesh):
(
routed_output,
input_shape,
permuted_indices,
input_splits,
output_splits,
) = mod_outputs
routed_output = _unpermute(routed_output, input_shape, permuted_indices)

routed_output = all_to_all_single_autograd(
routed_output,
self.input_splits,
self.output_splits,
input_splits,
output_splits,
device_mesh.get_group(),
)
return routed_output
Expand Down Expand Up @@ -204,9 +212,9 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh):
nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(1)])),
) # Column-wise sharding

def _token_combine(self, mod, routed_output, device_mesh):
def _token_combine(self, mod, mod_outputs, device_mesh):
# token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
return super()._token_combine(mod, routed_output, device_mesh["ep"])
return super()._token_combine(mod, mod_outputs, device_mesh["ep"])

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ local_batch_size = 4
seq_len = 4096
max_norm = 1.0 # grad norm clipping
steps = 1000
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: toml file config changed.


[parallelism]
data_parallel_replicate_degree = 1
Expand All @@ -65,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable=true
enable = true
components = ["loss"] # ["model", "loss"]

[quantize.linear.float8]
Expand Down
10 changes: 8 additions & 2 deletions torchtitan/models/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def forward(
self,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
input_shape,
permuted_indices,
input_splits,
output_splits,
Comment on lines +146 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These shouldn't be exposed to single-device model code. Plus, I don't think it will work if EP is not used.

If it's getting too hard, maybe we should use local_map / to_local to re-implement MoE.

) -> torch.Tensor:
if isinstance(self.w1, DTensor):
# Convert parameters from DTensors to plain Tensors, to work with
Expand All @@ -166,9 +170,11 @@ def forward(
run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm)
else:
run_experts_fn = _run_experts_grouped_mm
return run_experts_fn(w1, w2, w3, x, num_tokens_per_expert)
out = run_experts_fn(w1, w2, w3, x, num_tokens_per_expert)
else:
return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert)
out = _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert)

return (out, input_shape, permuted_indices, input_splits, output_splits)

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
Expand Down
Loading