-
Notifications
You must be signed in to change notification settings - Fork 515
Fake balanced routing in MoE #1670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Returns LongTensor of shape (n_tokens, top_k).""" | ||
i = torch.arange(n_tokens, device=device)[:, None] # [N,1] | ||
k = torch.arange(top_k, device=device)[None, :] # [1,K] | ||
return ((i * top_k + k) % num_experts).long() # [N,K] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm it's not clear to me why this gives balanced routing. E.g. with N=10, k=3 I get i * top_k + k
being
[[ 0, 1, 2],
[ 0, 2, 4],
[ 0, 3, 6],
[ 0, 4, 8],
[ 0, 5, 10],
[ 0, 6, 12],
[ 0, 7, 14],
[ 0, 8, 16],
[ 0, 9, 18],
[ 0, 10, 20]]
so Expert 0 will be used by every single token.
Even if you trim the first column, it's not clear why each expert get the same amount of tokens -- can you prove?
A naive round-robin I had in mind is
token 0 -> e0, e1, e2
token 1 -> e3, e4, e5
token 2 -> e6, e7, e8
...
Returns LongTensor of shape (n_tokens, top_k).""" | ||
i = torch.arange(n_tokens, device=device)[:, None] # [N,1] | ||
k = torch.arange(top_k, device=device)[None, :] # [1,K] | ||
return ((i * top_k + k) % num_experts).long() # [N,K] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm it's not clear to me why this gives balanced routing. E.g. with N=10, k=3 I get i * top_k + k
being
[[ 0, 1, 2],
[ 0, 2, 4],
[ 0, 3, 6],
[ 0, 4, 8],
[ 0, 5, 10],
[ 0, 6, 12],
[ 0, 7, 14],
[ 0, 8, 16],
[ 0, 9, 18],
[ 0, 10, 20]]
so Expert 0 will be used by every single token.
Even if you trim the first column, it's not clear why each expert get the same amount of tokens -- can you prove?
A naive round-robin I had in mind is
token 0 -> e0, e1, e2
token 1 -> e3, e4, e5
token 2 -> e6, e7, e8
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l I'm a little confused about your example, I think this is doing what you want.
i * top_k + k
is like the flattened linear index of a row-major (n_tokens
, top_k
)-shaped matrix. I.e. it is just torch.arange(n_tokens*top_k).reshape(n_tokens, top_k)
, and then modding by num_experts
loops over experts.
import torch
device = "cuda"
top_k, num_tokens, num_experts = 3, 256, 128
i = torch.arange(num_tokens, device=device)[:, None]
k = torch.arange(top_k, device=device)[None, :]
exp_idxs = ((i * top_k + k) % num_experts).long()
exp_idxs_alt = (torch.arange(num_tokens * top_k, device=device).reshape(
num_tokens, top_k
) % num_experts)
torch.testing.assert_close(exp_idxs, exp_idxs_alt)
print(f"{exp_idxs=}")
Output:
exp_idxs=tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17],
[18, 19, 20],
[21, 22, 23],
[24, 25, 26],
[27, 28, 29],
[30, 31, 32],
[... snip ...]
Writing this with arange
does seem simpler, though.
EDIT: fixed missing % arange
in exp_idxs_alt
and top_k * num_tokens > num_experts
so that the test is less trivial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@garrett361 oh my bad, I was thinking & trying i * k + k
instead of i * top_k + k
. Sorry for the confusion lol.
selected_experts_indices = self.uniform_indices( | ||
x.size(0), self.top_k, self.num_experts, x.device | ||
) | ||
top_scores = scores.gather(dim=1, index=selected_experts_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's merge these into one instance method
top_scores, selected_experts_indices = _debug_force_load_balanced_routing(scores)
@@ -188,6 +189,19 @@ def __init__( | |||
self.score_func = score_func | |||
self.route_norm = route_norm | |||
self.route_scale = route_scale | |||
self.debug_force_load_balanced = bool( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be consistent with the rest torchtitan config:
- make it part of
MoEArgs
- pass in into the Router constructor https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/moe.py#L325
- make it part of
job_config.training
https://github.com/pytorch/torchtitan/blob/main/torchtitan/config/job_config.py#L172, e.g. call it_debug_moe_force_load_balance
- update the model args from job config in https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/deepseek_v3/model/args.py#L89 for both DSV3 and Llama 4
3 and 4 are optional to me, but we can add if you think it's useful to conveniently enable this feature.
If possible we should add underscore _
before every name we come up with, as it's only for debugging and not meant to be exposed as API.
we can set
DEBUG_FORCE_LOAD_BALANCED=1
to force each experts get same amount of tokens.reprodue:
DEBUG_FORCE_LOAD_BALANCED=1 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --compile.enable
here is test on 8layers, 8 activate and 64 total experts. Green one is vanilla one and purple one is with force load balance
