Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/components/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def build_cross_entropy_loss(job_config: JobConfig, **kwargs):
loss_fn = cross_entropy_loss
if job_config.compile.enable and "loss" in job_config.compile.components:
logger.info("Compiling the loss function with torch.compile")
loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend)
loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend, mode="light")
return loss_fn


Expand Down
1 change: 1 addition & 0 deletions torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def parallelize_llama(
model,
backend=get_compile_backend(backend),
fullgraph=True,
mode="light",
)

return model
2 changes: 1 addition & 1 deletion torchtitan/experiments/vlm/infra/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,5 @@ def build_token_imbalance_ce_loss(
loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg)
if job_config.compile.enable and "loss" in job_config.compile.components:
logger.info("Compiling the loss function with torch.compile")
loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend)
loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend, mode="light")
return loss_fn
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ description = "DeepSeek-V3 debug training"
print_config = false

[profiling]
enable_profiling = false
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig):
"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(
transformer_block, backend=compile_config.backend, fullgraph=True
transformer_block, backend=compile_config.backend, fullgraph=True, mode="light"
)
model.layers.register_module(layer_id, transformer_block)

Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dataset = "c4"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
data_parallel_shard_degree = 8
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_degree = 1
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,15 +546,15 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig):
moe,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
submod, backend=compile_config.backend, fullgraph=True, mode="light",
),
)
else:
setattr(
block,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
submod, backend=compile_config.backend, fullgraph=True, mode="light",
),
)

Expand All @@ -565,6 +565,7 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig):
transformer_block,
backend=compile_config.backend,
fullgraph=True,
mode="light",
)

model.layers.register_module(layer_id, transformer_block)
Expand Down
Loading