Skip to content

Commit 2380ed3

Browse files
committed
compiler toolkit without TP
1 parent ad78ed8 commit 2380ed3

File tree

3 files changed

+24
-8
lines changed

3 files changed

+24
-8
lines changed

torchtitan/experiments/compiler_toolkit/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ Joint Graph based Training Prototype:
1212

1313
## DeepSeek v3
1414

15+
**SimpleFSDP + EP**
16+
```shell
17+
NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none
18+
```
19+
1520
**SimpleFSDP + TP + EP**
1621
```shell
1722
NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none
@@ -24,6 +29,11 @@ NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.tom
2429

2530
## llama3
2631

32+
**SimpleFSDP**
33+
```shell
34+
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=8
35+
```
36+
2737
**SimpleFSDP + TP**
2838
```shell
2939
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4

torchtitan/experiments/compiler_toolkit/common_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,17 @@ def disable_compile(job_config: JobConfig):
2525

2626

2727
def parallelize_inputs(world_mesh, args, kwargs):
28-
def to_dtensor(tensor):
29-
if isinstance(tensor, torch.Tensor):
30-
return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()])
31-
return tensor
28+
if "tp" in world_mesh.mesh_dim_names:
3229

33-
dt_args = tree_map(to_dtensor, args)
30+
def to_dtensor(tensor):
31+
if isinstance(tensor, torch.Tensor):
32+
return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()])
33+
return tensor
34+
35+
dt_args = tree_map(to_dtensor, args)
36+
else:
37+
# TODO: When there is no TP (SimpleFSDP only), it currently only supports plain tensor inputs
38+
dt_args = args
3439

3540
# TODO: When using flex_attention, BlockMask would show up in kwargs,
3641
# and it's unclear how to convert it to DTensor. If I use to_dtensor,

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
JointWithDescriptors,
1616
)
1717
from torch._guards import tracing, TracingContext
18-
from torch.distributed.tensor import DTensor
1918
from torchtitan.distributed import ParallelDims
2019
from torchtitan.tools.logging import logger
2120

@@ -158,8 +157,10 @@ def joint_graph_builder(
158157
joint_custom_pass: Optional custom pass to run on the joint graph
159158
"""
160159
assert isinstance(model_args, tuple)
161-
for arg in model_args:
162-
assert isinstance(arg, DTensor)
160+
161+
# TODO: Enable this when we have full-DTensorize inputs support of SimpleFSDP
162+
# for arg in model_args:
163+
# assert isinstance(arg, DTensor)
163164

164165
# get joint graph
165166
(

0 commit comments

Comments
 (0)