Skip to content

Conversation

rakkit
Copy link
Contributor

@rakkit rakkit commented Sep 1, 2025

we can set DEBUG_FORCE_LOAD_BALANCED=1 to force each experts get same amount of tokens.

reprodue: DEBUG_FORCE_LOAD_BALANCED=1 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --compile.enable

here is test on 8layers, 8 activate and 64 total experts. Green one is vanilla one and purple one is with force load balance
image

image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 1, 2025
@tianyu-l tianyu-l linked an issue Sep 6, 2025 that may be closed by this pull request
Returns LongTensor of shape (n_tokens, top_k)."""
i = torch.arange(n_tokens, device=device)[:, None] # [N,1]
k = torch.arange(top_k, device=device)[None, :] # [1,K]
return ((i * top_k + k) % num_experts).long() # [N,K]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm it's not clear to me why this gives balanced routing. E.g. with N=10, k=3 I get i * top_k + k being

[[ 0,  1,  2],
[ 0,  2,  4],
[ 0,  3,  6],
[ 0,  4,  8],
[ 0,  5, 10],
[ 0,  6, 12],
[ 0,  7, 14],
[ 0,  8, 16],
[ 0,  9, 18],
[ 0, 10, 20]]

so Expert 0 will be used by every single token.
Even if you trim the first column, it's not clear why each expert get the same amount of tokens -- can you prove?

A naive round-robin I had in mind is
token 0 -> e0, e1, e2
token 1 -> e3, e4, e5
token 2 -> e6, e7, e8
...

Returns LongTensor of shape (n_tokens, top_k)."""
i = torch.arange(n_tokens, device=device)[:, None] # [N,1]
k = torch.arange(top_k, device=device)[None, :] # [1,K]
return ((i * top_k + k) % num_experts).long() # [N,K]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm it's not clear to me why this gives balanced routing. E.g. with N=10, k=3 I get i * top_k + k being

[[ 0,  1,  2],
[ 0,  2,  4],
[ 0,  3,  6],
[ 0,  4,  8],
[ 0,  5, 10],
[ 0,  6, 12],
[ 0,  7, 14],
[ 0,  8, 16],
[ 0,  9, 18],
[ 0, 10, 20]]

so Expert 0 will be used by every single token.
Even if you trim the first column, it's not clear why each expert get the same amount of tokens -- can you prove?

A naive round-robin I had in mind is
token 0 -> e0, e1, e2
token 1 -> e3, e4, e5
token 2 -> e6, e7, e8
...

Copy link
Contributor

@garrett361 garrett361 Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l I'm a little confused about your example, I think this is doing what you want.

i * top_k + k is like the flattened linear index of a row-major (n_tokens, top_k)-shaped matrix. I.e. it is just torch.arange(n_tokens*top_k).reshape(n_tokens, top_k), and then modding by num_experts loops over experts.

import torch

device = "cuda"
top_k, num_tokens, num_experts = 3, 256, 128
i = torch.arange(num_tokens, device=device)[:, None]
k = torch.arange(top_k, device=device)[None, :]

exp_idxs = ((i * top_k + k) % num_experts).long()


exp_idxs_alt = (torch.arange(num_tokens * top_k, device=device).reshape(
    num_tokens, top_k
) % num_experts)
torch.testing.assert_close(exp_idxs, exp_idxs_alt)
print(f"{exp_idxs=}")

Output:

exp_idxs=tensor([[ 0,  1,  2],
                 [ 3,  4,  5],
                 [ 6,  7,  8],
                 [ 9, 10, 11],
                 [12, 13, 14],
                 [15, 16, 17],
                 [18, 19, 20],
                 [21, 22, 23],
                 [24, 25, 26],
                 [27, 28, 29],
                 [30, 31, 32],

[... snip ...]

Writing this with arange does seem simpler, though.

EDIT: fixed missing % arange in exp_idxs_alt and top_k * num_tokens > num_experts so that the test is less trivial.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@garrett361 oh my bad, I was thinking & trying i * k + k instead of i * top_k + k. Sorry for the confusion lol.

Comment on lines +250 to +253
selected_experts_indices = self.uniform_indices(
x.size(0), self.top_k, self.num_experts, x.device
)
top_scores = scores.gather(dim=1, index=selected_experts_indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge these into one instance method
top_scores, selected_experts_indices = _debug_force_load_balanced_routing(scores)

@@ -188,6 +189,19 @@ def __init__(
self.score_func = score_func
self.route_norm = route_norm
self.route_scale = route_scale
self.debug_force_load_balanced = bool(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be consistent with the rest torchtitan config:

  1. make it part of MoEArgs
  2. pass in into the Router constructor https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/moe.py#L325
  3. make it part of job_config.training https://github.com/pytorch/torchtitan/blob/main/torchtitan/config/job_config.py#L172, e.g. call it _debug_moe_force_load_balance
  4. update the model args from job config in https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/deepseek_v3/model/args.py#L89 for both DSV3 and Llama 4

3 and 4 are optional to me, but we can add if you think it's useful to conveniently enable this feature.

If possible we should add underscore _ before every name we come up with, as it's only for debugging and not meant to be exposed as API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

create fake balanced routing in MoE / EP for infra development
3 participants