Skip to content

Commit 8db2f16

Browse files
committed
misc
1 parent 0ac427f commit 8db2f16

File tree

2 files changed

+10
-23
lines changed

2 files changed

+10
-23
lines changed

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,10 @@ def parallelize_deepseekv3(
7070
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
7171
apply_moe_ep_tp(
7272
model,
73-
tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None,
74-
ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None,
75-
etp_mesh=parallel_dims.get_mesh("etp")
76-
if parallel_dims.etp_enabled
77-
else None,
78-
ep_etp_mesh=(
79-
parallel_dims.get_mesh("ep_etp")
80-
if parallel_dims.tp_enabled
81-
and parallel_dims.ep_enabled
82-
and parallel_dims.etp_enabled
83-
else None
84-
),
73+
tp_mesh=parallel_dims.get_mesh("tp"),
74+
ep_mesh=parallel_dims.get_mesh("ep"),
75+
etp_mesh=parallel_dims.get_mesh("etp"),
76+
ep_etp_mesh=parallel_dims.get_mesh(["ep", "etp"]),
8577
)
8678

8779
if job_config.activation_checkpoint.mode != "none":
@@ -139,13 +131,13 @@ def parallelize_deepseekv3(
139131
assert edp_mesh is not None
140132
assert hasattr(transformer_block, "moe")
141133
if (
142-
dp_mod_ep_mesh.size() * parallel_dims.ep
134+
edp_mesh.size() * parallel_dims.ep
143135
> transformer_block.moe.experts.num_experts
144136
):
145137
experts_shard_dim = 1
146138

147139
# when EP is enable, the routed experts' gradient reduction is done over
148-
# dp_mod_ep_mesh instead of whole dp_mesh.
140+
# edp_mesh instead of whole dp_mesh.
149141
# we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh
150142
# to be consistent with data.
151143
# TODO (ruisizhang123): update the logic following the link below instead

torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,10 @@ def parallelize_qwen3(
9494
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
9595
apply_moe_ep_tp(
9696
model,
97-
tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None,
98-
ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None,
99-
ep_etp_mesh=(
100-
parallel_dims.get_mesh("ep_etp")
101-
if parallel_dims.tp_enabled
102-
and parallel_dims.ep_enabled
103-
and parallel_dims.etp_enabled
104-
else None
105-
),
97+
tp_mesh=parallel_dims.get_mesh("tp"),
98+
ep_mesh=parallel_dims.get_mesh("ep"),
99+
etp_mesh=parallel_dims.get_mesh("etp"),
100+
ep_etp_mesh=parallel_dims.get_mesh(["ep", "etp"]),
106101
)
107102

108103
if job_config.activation_checkpoint.mode != "none":

0 commit comments

Comments
 (0)