Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 9 additions & 74 deletions torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Callable, Optional

import torch
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import (
aot_compile_joint_with_descriptors,
aot_export_joint_with_descriptors,
Expand All @@ -33,88 +33,23 @@ def _clear_traced_params_buffers(
setattr(traced_module, key, buffer)


def _restore_state_dict(
original_module: torch.nn.Module, traced_module: torch.fx.GraphModule
) -> None:
"""
TODO: move this into torch.export
Restores the state dict of the traced module to match the original module exactly.
Preserves the original FQNs with dots, creating intermediate empty modules as needed.
Ensures that the ordering of parameters/buffers matches the original module.
"""
# Build ID-based lookups for traced module params/buffers
traced_params: dict[int, tuple[str, torch.nn.Parameter]] = {}
for name, param in traced_module.named_parameters(remove_duplicate=False):
traced_params[id(param)] = (name, param)

traced_buffers: dict[int, tuple[str, torch.Tensor]] = {}
for name, buffer in traced_module.named_buffers(remove_duplicate=False):
traced_buffers[id(buffer)] = (name, buffer)

# Build mapping from old names to new names for graph node updates
name_mapping: dict[str, str] = {}

# Restore parameters in the order they appear in original module
for orig_name, orig_param in original_module.named_parameters(
remove_duplicate=False
):
if id(orig_param) in traced_params:
# This param exists in traced module - restore it with original FQN
traced_name, traced_param = traced_params[id(orig_param)]
torch.fx.graph_module._assign_attr(traced_param, traced_module, orig_name)
torch.fx.graph_module._del_attr(traced_module, traced_name)
name_mapping[traced_name] = orig_name
else:
# This param doesn't exist in traced module - add it
torch.fx.graph_module._assign_attr(orig_param, traced_module, orig_name)

# Restore buffers in the order they appear in original module
for orig_name, orig_buffer in original_module.named_buffers(remove_duplicate=False):
if id(orig_buffer) in traced_buffers:
# This buffer exists in traced module - restore it with original FQN
traced_name, traced_buffer = traced_buffers[id(orig_buffer)]
torch.fx.graph_module._assign_attr(orig_buffer, traced_module, orig_name)
name_mapping[traced_name] = orig_name
torch.fx.graph_module._del_attr(traced_module, traced_name)
else:
# This buffer doesn't exist in traced module - add it
torch.fx.graph_module._assign_attr(orig_buffer, traced_module, orig_name)

param_names = [v[0] for v in traced_params.values()]
buffer_names = [v[0] for v in traced_buffers.values()]
const_keys = set(param_names + buffer_names).difference(set(name_mapping.keys()))

_clear_traced_params_buffers(traced_module, const_keys)

# Update get_attr nodes in the graph to use the correct FQNs
for node in traced_module.graph.nodes:
if node.op == "get_attr" and node.target in name_mapping:
node.target = name_mapping[node.target]

traced_module.recompile()


def export_joint(
model, args, kwargs=None
) -> tuple[JointWithDescriptors, TracingContext]:
if kwargs is None:
kwargs = {}
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
with torch._dynamo.config.patch(
install_free_tensors=True
), torch.fx.traceback.preserve_node_meta():
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(*args, **kwargs)

# Restore the state dict to match the original module
_restore_state_dict(model, gm)

with (
# TODO Investigate error on MOE model with use_grouped_mm=False.
# For repro, see: https://gist.github.com/zhxchen17/d794ff58236243d9faddf713b9fc6a61
torch._dynamo.config.patch(fake_tensor_cache_enabled=False),
torch.fx.traceback.preserve_node_meta(),
):
gm = dynamo_graph_capture_for_export(model)(*args, **kwargs)
logger.info("Dynamo gm:")
logger.info(gm.print_readable(print_output=False))

fake_mode = gm.meta.get("fake_mode", None)
tracing_context = TracingContext(fake_mode)
tracing_context = gm.meta["tracing_context"]

with tracing(tracing_context):
return (
Expand Down