File tree Expand file tree Collapse file tree 3 files changed +24
-8
lines changed
torchtitan/experiments/compiler_toolkit Expand file tree Collapse file tree 3 files changed +24
-8
lines changed Original file line number Diff line number Diff line change @@ -12,6 +12,11 @@ Joint Graph based Training Prototype:
1212
1313## DeepSeek v3
1414
15+ ** SimpleFSDP + EP**
16+ ``` shell
17+ NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none
18+ ```
19+
1520** SimpleFSDP + TP + EP**
1621``` shell
1722NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none
@@ -24,6 +29,11 @@ NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.tom
2429
2530## llama3
2631
32+ ** SimpleFSDP**
33+ ``` shell
34+ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=8
35+ ```
36+
2737** SimpleFSDP + TP**
2838``` shell
2939NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4
Original file line number Diff line number Diff line change @@ -25,12 +25,17 @@ def disable_compile(job_config: JobConfig):
2525
2626
2727def parallelize_inputs (world_mesh , args , kwargs ):
28- def to_dtensor (tensor ):
29- if isinstance (tensor , torch .Tensor ):
30- return DTensor .from_local (tensor , world_mesh ["tp" ], [Replicate ()])
31- return tensor
28+ if "tp" in world_mesh .mesh_dim_names :
3229
33- dt_args = tree_map (to_dtensor , args )
30+ def to_dtensor (tensor ):
31+ if isinstance (tensor , torch .Tensor ):
32+ return DTensor .from_local (tensor , world_mesh ["tp" ], [Replicate ()])
33+ return tensor
34+
35+ dt_args = tree_map (to_dtensor , args )
36+ else :
37+ # TODO: When there is no TP (SimpleFSDP only), it currently only supports plain tensor inputs
38+ dt_args = args
3439
3540 # TODO: When using flex_attention, BlockMask would show up in kwargs,
3641 # and it's unclear how to convert it to DTensor. If I use to_dtensor,
Original file line number Diff line number Diff line change 1515 JointWithDescriptors ,
1616)
1717from torch ._guards import tracing , TracingContext
18- from torch .distributed .tensor import DTensor
1918from torchtitan .distributed import ParallelDims
2019from torchtitan .tools .logging import logger
2120
@@ -158,8 +157,10 @@ def joint_graph_builder(
158157 joint_custom_pass: Optional custom pass to run on the joint graph
159158 """
160159 assert isinstance (model_args , tuple )
161- for arg in model_args :
162- assert isinstance (arg , DTensor )
160+
161+ # TODO: Enable this when we have full-DTensorize inputs support of SimpleFSDP
162+ # for arg in model_args:
163+ # assert isinstance(arg, DTensor)
163164
164165 # get joint graph
165166 (
You can’t perform that action at this time.
0 commit comments