Skip to content

Commit a5c4027

Browse files
committed
add manual bucketing pass
1 parent d0e2545 commit a5c4027

File tree

7 files changed

+56
-15
lines changed

7 files changed

+56
-15
lines changed

torchtitan/experiments/simple_fsdp/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,14 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing
5151

5252
2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance**
5353
- "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend.
54-
55-
56-
users can specify the pass (e.g., "aot_eager_autobucketing") via addtional configs:
57-
58-
```bash
59-
--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_autobucketing"
60-
```
54+
```bash
55+
--compile.backend "aot_eager" --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_autobucketing"
56+
```
57+
3. manual optimization: perform manual bucketing & reordering with user FQN inputs.
58+
- "aot_eager_manualbucketing": perform manual bucketing at aten fx-level, and perform code execution with aot_eager backend.
59+
```bash
60+
--compile.backend "aot_eager" --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_manualbucketing" --compile.manual_bucketed_modules "tok_embeddings,layers.[0-5],norm+output"
61+
```
6162

6263
### Citation
6364

torchtitan/experiments/simple_fsdp/backend.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
from typing import Any, Union
88

99
import torch
10+
from torchtitan.config import JobConfig
1011

11-
12-
def get_compile_backend(backend_name: str) -> Union[str, callable]:
12+
def get_compile_backend(job_config: JobConfig) -> Union[str, callable]:
1313
# return the compile backends used in SimpleFSDP training
1414
# Step1: check if backend_name is inside available torch.compile backends
1515
# Step2: check if the backend_name has been registered as a customized backend
16+
backend_name = job_config.compile.model_backend_override or job_config.compile.backend
17+
1618
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
1719
if backend_name in available_torch_backend:
1820
return backend_name
@@ -41,6 +43,32 @@ def aten_autobucketing_reordering_pass(
4143
bw_compiler=aten_autobucketing_reordering_pass,
4244
keep_inference_input_mutations=True,
4345
)
46+
elif backend_name == "aot_eager_manualbucketing":
47+
# Perform manual optimization in aten fx-level and execute code in aot_eager backend
48+
# The manualbucketing logic is here:
49+
bucketing_modules = job_config.compile.manual_bucketed_modules
50+
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
51+
from torch._inductor.fx_passes.overlap_manual_scheduling import (
52+
manual_overlap_bucketing,
53+
)
54+
from functools import partial
55+
56+
torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True
57+
torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False
58+
torch._inductor.config.allow_buffer_reuse = False
59+
manual_overlap_bucketing = partial(manual_overlap_bucketing, module_bucket_plans=job_config.compile.manual_bucketed_modules)
60+
61+
def aten_manualbucketing_reordering_pass(
62+
gm: torch.fx.GraphModule, example_inputs: Any
63+
) -> torch.fx.GraphModule:
64+
manual_overlap_bucketing(gm)
65+
return gm
66+
67+
backend = aot_autograd_backend(
68+
fw_compiler=aten_manualbucketing_reordering_pass,
69+
bw_compiler=aten_manualbucketing_reordering_pass,
70+
keep_inference_input_mutations=True,
71+
)
4472
else:
4573
raise AssertionError(f"Unsupported customized backend: {backend_name}")
4674

torchtitan/experiments/simple_fsdp/job_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
@dataclass
1111
class Compile:
1212
model_backend_override: str | None = None
13-
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
13+
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """
1414

15+
manual_bucketed_modules: list[str] = field(default_factory=list)
16+
"""Which modules should be bucketed together based on user specifications in manual optimization """
1517

1618
@dataclass
1719
class JobConfig:

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,9 @@ def parallelize_llama(
125125

126126
if job_config.compile.enable and "model" in job_config.compile.components:
127127
torch._inductor.config.reorder_for_peak_memory = False
128-
backend = (
129-
job_config.compile.model_backend_override or job_config.compile.backend
130-
)
131128
model = torch.compile(
132129
model,
133-
backend=get_compile_backend(backend),
130+
backend=get_compile_backend(job_config),
134131
fullgraph=True,
135132
)
136133

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,6 @@ def data_parallel(
342342

343343
# apply regional ac (with fsdp_policy) if no global ac is to be applied
344344
regional_ac = ac_mode == "none"
345-
346345
for mod in modules:
347346
params_dict = dict(mod.named_parameters(recurse=False))
348347
# we shouldn't apply data parallel to the modules that are already

torchtitan/models/llama3/model/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def __init__(self, model_args: TransformerModelArgs):
405405
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
406406
self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
407407
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
408+
self.total_list = []
408409

409410
def init_weights(
410411
self,
@@ -498,10 +499,13 @@ def forward(
498499
499500
"""
500501
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
502+
self.total_list = []
501503
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
504+
self.total_list.append(h)
502505

503506
for layer in self.layers.values():
504507
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
508+
self.total_list.append(h)
505509

506510
h = self.norm(h) if self.norm else h
507511
output = self.output(h) if self.output else h

torchtitan/train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,9 @@ def batch_generator(
412412

413413
yield input_dict, labels
414414

415+
def custom_hash_fn(self, tensor):
416+
return tensor.norm(p=2, dtype=torch.float64)
417+
415418
def forward_backward_step(
416419
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
417420
) -> torch.Tensor:
@@ -488,6 +491,9 @@ def forward_backward_step(
488491
pred = model_parts[0](inputs, **extra_inputs, **extra_args)
489492
loss = self.loss_fn(pred, labels)
490493
# need to free pred before bwd to avoid peaking memory
494+
for res in model_parts[0].total_list:
495+
print("[FWD] pred results", self.custom_hash_fn(res))
496+
print("[FWD] pred results", self.custom_hash_fn(pred))
491497
del pred
492498
loss.backward()
493499

@@ -521,6 +527,10 @@ def train_step(
521527
),
522528
ep_enabled=parallel_dims.ep_enabled,
523529
)
530+
531+
for m in self.model_parts:
532+
for p_name, p in m.named_parameters():
533+
print("[BWD] grad", self.custom_hash_fn(p).to_local())
524534
self.checkpointer.maybe_wait_for_staging()
525535
self.optimizers.step()
526536
self.lr_schedulers.step()

0 commit comments

Comments
 (0)