Skip to content

Commit 8228c08

Browse files
authored
Remove the unused compiled_autograd option (#1939)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #1857 * __->__ #1939 TorchTitan doesn't need compiled_autograd, which is meant to support compiled DDP, but TorchTitan will adopt fully_shard-based replicate. Let's remove it.
1 parent 7f126cb commit 8228c08

File tree

11 files changed

+7
-38
lines changed

11 files changed

+7
-38
lines changed

scripts/estimate/estimation.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ def estimate_memory(job_config: JobConfig):
3333
# Get the world size
3434
world_size = int(os.environ["WORLD_SIZE"])
3535

36-
if job_config.compile.enable or job_config.parallelism.enable_compiled_autograd:
36+
if job_config.compile.enable:
3737
logger.info("Compile mode is not supported yet. Switching to eager mode.")
3838
job_config.compile.enable = False
39-
job_config.parallelism.enable_compiled_autograd = False
4039

4140
# init fake pg
4241
store = FakeStore()
@@ -80,10 +79,7 @@ def estimate_memory(job_config: JobConfig):
8079
loss_parallel_enabled = (
8180
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
8281
)
83-
train_context = dist_utils.get_train_context(
84-
loss_parallel_enabled,
85-
job_config.parallelism.enable_compiled_autograd,
86-
)
82+
train_context = dist_utils.get_train_context(loss_parallel_enabled)
8783

8884
# build model (using meta init)
8985
model_args = train_spec.model_args[job_config.model.flavor]

torchtitan/config/job_config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,6 @@ class Parallelism:
286286
1 means disabled.
287287
"""
288288

289-
enable_compiled_autograd: bool = False
290-
"""Enable CompiledAutograd to compile the backward."""
291-
292289
data_parallel_shard_degree: int = -1
293290
"""
294291
The `data_parallel_shard_degree` argument specifies the degree of data

torchtitan/distributed/utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,20 +193,13 @@ def create_context_parallel_ctx(
193193
)
194194

195195

196-
def get_train_context(
197-
enable_loss_parallel: bool, enable_compiled_autograd: bool
198-
) -> Generator[None, None, None]:
196+
def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]:
199197
@contextlib.contextmanager
200198
def context(cp_context: Generator[None, None, None] | None = None):
201199
with contextlib.ExitStack() as stack:
202200
if enable_loss_parallel:
203201
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
204202

205-
if enable_compiled_autograd:
206-
stack.enter_context(
207-
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
208-
)
209-
210203
if cp_context:
211204
stack.enter_context(cp_context)
212205

torchtitan/experiments/forge/engine.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,7 @@ def __init__(self, job_config: ForgeJobConfig):
233233
loss_parallel_enabled = (
234234
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
235235
)
236-
self.train_context = dist_utils.get_train_context(
237-
loss_parallel_enabled,
238-
parallelism_config.enable_compiled_autograd,
239-
)
236+
self.train_context = dist_utils.get_train_context(loss_parallel_enabled)
240237
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
241238
parallel_dims,
242239
job_config.training.mixed_precision_param,

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
torch._higher_order_ops.flex_attention,
4646
}
4747

48+
4849
# Adapted from llama4/infra/parallelize.py
4950
def parallelize_gptoss(
5051
model: nn.Module,
@@ -168,7 +169,6 @@ def parallelize_gptoss(
168169
model,
169170
dp_mesh,
170171
enable_compile=model_compile_enabled,
171-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
172172
)
173173

174174
return model

torchtitan/experiments/vlm/infra/parallelize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def parallelize_vlm(
107107
model,
108108
world_mesh,
109109
enable_compile=job_config.compile.enable,
110-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
111110
)
112111

113112
return model

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ def parallelize_deepseekv3(
171171
model,
172172
dp_mesh,
173173
enable_compile=model_compile_enabled,
174-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
175174
)
176175

177176
return model

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ def parallelize_llama(
143143
model,
144144
world_mesh,
145145
enable_compile=model_compile_enabled,
146-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
147146
)
148147

149148
return model
@@ -322,15 +321,9 @@ def apply_ddp(
322321
model: nn.Module,
323322
dp_mesh: DeviceMesh,
324323
enable_compile: bool,
325-
enable_compiled_autograd: bool,
326324
):
327325
if enable_compile:
328-
if enable_compiled_autograd:
329-
torch._dynamo.config.optimize_ddp = (
330-
"python_reducer_without_compiled_forward"
331-
)
332-
else:
333-
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
326+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
334327

335328
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
336329

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ def parallelize_llama(
178178
model,
179179
dp_mesh,
180180
enable_compile=model_compile_enabled,
181-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
182181
)
183182

184183
return model

torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ def parallelize_qwen3(
170170
model,
171171
world_mesh,
172172
enable_compile=model_compile_enabled,
173-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
174173
)
175174

176175
# Enable weight tying after applying parallelisms

0 commit comments

Comments
 (0)