Skip to content

Commit 70897e3

Browse files
committed
[compiler] Switch to the new dynamo export API.
Summary: Replacing the API usage while removing some dead code. Test Plan: ``` NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --model.flavor=debugmodel_flex_attn NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn ```
1 parent bb308da commit 70897e3

File tree

1 file changed

+4
-73
lines changed

1 file changed

+4
-73
lines changed

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 4 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Callable, Optional
99

1010
import torch
11-
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
11+
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
1212
from torch._functorch.aot_autograd import (
1313
aot_compile_joint_with_descriptors,
1414
aot_export_joint_with_descriptors,
@@ -33,88 +33,19 @@ def _clear_traced_params_buffers(
3333
setattr(traced_module, key, buffer)
3434

3535

36-
def _restore_state_dict(
37-
original_module: torch.nn.Module, traced_module: torch.fx.GraphModule
38-
) -> None:
39-
"""
40-
TODO: move this into torch.export
41-
Restores the state dict of the traced module to match the original module exactly.
42-
Preserves the original FQNs with dots, creating intermediate empty modules as needed.
43-
Ensures that the ordering of parameters/buffers matches the original module.
44-
"""
45-
# Build ID-based lookups for traced module params/buffers
46-
traced_params: dict[int, tuple[str, torch.nn.Parameter]] = {}
47-
for name, param in traced_module.named_parameters(remove_duplicate=False):
48-
traced_params[id(param)] = (name, param)
49-
50-
traced_buffers: dict[int, tuple[str, torch.Tensor]] = {}
51-
for name, buffer in traced_module.named_buffers(remove_duplicate=False):
52-
traced_buffers[id(buffer)] = (name, buffer)
53-
54-
# Build mapping from old names to new names for graph node updates
55-
name_mapping: dict[str, str] = {}
56-
57-
# Restore parameters in the order they appear in original module
58-
for orig_name, orig_param in original_module.named_parameters(
59-
remove_duplicate=False
60-
):
61-
if id(orig_param) in traced_params:
62-
# This param exists in traced module - restore it with original FQN
63-
traced_name, traced_param = traced_params[id(orig_param)]
64-
torch.fx.graph_module._assign_attr(traced_param, traced_module, orig_name)
65-
torch.fx.graph_module._del_attr(traced_module, traced_name)
66-
name_mapping[traced_name] = orig_name
67-
else:
68-
# This param doesn't exist in traced module - add it
69-
torch.fx.graph_module._assign_attr(orig_param, traced_module, orig_name)
70-
71-
# Restore buffers in the order they appear in original module
72-
for orig_name, orig_buffer in original_module.named_buffers(remove_duplicate=False):
73-
if id(orig_buffer) in traced_buffers:
74-
# This buffer exists in traced module - restore it with original FQN
75-
traced_name, traced_buffer = traced_buffers[id(orig_buffer)]
76-
torch.fx.graph_module._assign_attr(orig_buffer, traced_module, orig_name)
77-
name_mapping[traced_name] = orig_name
78-
torch.fx.graph_module._del_attr(traced_module, traced_name)
79-
else:
80-
# This buffer doesn't exist in traced module - add it
81-
torch.fx.graph_module._assign_attr(orig_buffer, traced_module, orig_name)
82-
83-
param_names = [v[0] for v in traced_params.values()]
84-
buffer_names = [v[0] for v in traced_buffers.values()]
85-
const_keys = set(param_names + buffer_names).difference(set(name_mapping.keys()))
86-
87-
_clear_traced_params_buffers(traced_module, const_keys)
88-
89-
# Update get_attr nodes in the graph to use the correct FQNs
90-
for node in traced_module.graph.nodes:
91-
if node.op == "get_attr" and node.target in name_mapping:
92-
node.target = name_mapping[node.target]
93-
94-
traced_module.recompile()
95-
96-
9736
def export_joint(
9837
model, args, kwargs=None
9938
) -> tuple[JointWithDescriptors, TracingContext]:
10039
if kwargs is None:
10140
kwargs = {}
10241
assert isinstance(args, tuple)
10342
assert isinstance(kwargs, dict)
104-
with torch._dynamo.config.patch(
105-
install_free_tensors=True
106-
), torch.fx.traceback.preserve_node_meta():
107-
# TODO: switch to use the official graph_capture API once it is ready
108-
gm = _dynamo_graph_capture_for_export(model)(*args, **kwargs)
109-
110-
# Restore the state dict to match the original module
111-
_restore_state_dict(model, gm)
112-
43+
with torch.fx.traceback.preserve_node_meta():
44+
gm = dynamo_graph_capture_for_export(model)(*args, **kwargs)
11345
logger.info("Dynamo gm:")
11446
logger.info(gm.print_readable(print_output=False))
11547

116-
fake_mode = gm.meta.get("fake_mode", None)
117-
tracing_context = TracingContext(fake_mode)
48+
tracing_context = gm.meta["tracing_context"]
11849

11950
with tracing(tracing_context):
12051
return (

0 commit comments

Comments
 (0)