88from typing import Callable , Optional
99
1010import torch
11- from torch ._dynamo .functional_export import _dynamo_graph_capture_for_export
11+ from torch ._dynamo .functional_export import dynamo_graph_capture_for_export
1212from torch ._functorch .aot_autograd import (
1313 aot_compile_joint_with_descriptors ,
1414 aot_export_joint_with_descriptors ,
@@ -33,88 +33,19 @@ def _clear_traced_params_buffers(
3333 setattr (traced_module , key , buffer )
3434
3535
36- def _restore_state_dict (
37- original_module : torch .nn .Module , traced_module : torch .fx .GraphModule
38- ) -> None :
39- """
40- TODO: move this into torch.export
41- Restores the state dict of the traced module to match the original module exactly.
42- Preserves the original FQNs with dots, creating intermediate empty modules as needed.
43- Ensures that the ordering of parameters/buffers matches the original module.
44- """
45- # Build ID-based lookups for traced module params/buffers
46- traced_params : dict [int , tuple [str , torch .nn .Parameter ]] = {}
47- for name , param in traced_module .named_parameters (remove_duplicate = False ):
48- traced_params [id (param )] = (name , param )
49-
50- traced_buffers : dict [int , tuple [str , torch .Tensor ]] = {}
51- for name , buffer in traced_module .named_buffers (remove_duplicate = False ):
52- traced_buffers [id (buffer )] = (name , buffer )
53-
54- # Build mapping from old names to new names for graph node updates
55- name_mapping : dict [str , str ] = {}
56-
57- # Restore parameters in the order they appear in original module
58- for orig_name , orig_param in original_module .named_parameters (
59- remove_duplicate = False
60- ):
61- if id (orig_param ) in traced_params :
62- # This param exists in traced module - restore it with original FQN
63- traced_name , traced_param = traced_params [id (orig_param )]
64- torch .fx .graph_module ._assign_attr (traced_param , traced_module , orig_name )
65- torch .fx .graph_module ._del_attr (traced_module , traced_name )
66- name_mapping [traced_name ] = orig_name
67- else :
68- # This param doesn't exist in traced module - add it
69- torch .fx .graph_module ._assign_attr (orig_param , traced_module , orig_name )
70-
71- # Restore buffers in the order they appear in original module
72- for orig_name , orig_buffer in original_module .named_buffers (remove_duplicate = False ):
73- if id (orig_buffer ) in traced_buffers :
74- # This buffer exists in traced module - restore it with original FQN
75- traced_name , traced_buffer = traced_buffers [id (orig_buffer )]
76- torch .fx .graph_module ._assign_attr (orig_buffer , traced_module , orig_name )
77- name_mapping [traced_name ] = orig_name
78- torch .fx .graph_module ._del_attr (traced_module , traced_name )
79- else :
80- # This buffer doesn't exist in traced module - add it
81- torch .fx .graph_module ._assign_attr (orig_buffer , traced_module , orig_name )
82-
83- param_names = [v [0 ] for v in traced_params .values ()]
84- buffer_names = [v [0 ] for v in traced_buffers .values ()]
85- const_keys = set (param_names + buffer_names ).difference (set (name_mapping .keys ()))
86-
87- _clear_traced_params_buffers (traced_module , const_keys )
88-
89- # Update get_attr nodes in the graph to use the correct FQNs
90- for node in traced_module .graph .nodes :
91- if node .op == "get_attr" and node .target in name_mapping :
92- node .target = name_mapping [node .target ]
93-
94- traced_module .recompile ()
95-
96-
9736def export_joint (
9837 model , args , kwargs = None
9938) -> tuple [JointWithDescriptors , TracingContext ]:
10039 if kwargs is None :
10140 kwargs = {}
10241 assert isinstance (args , tuple )
10342 assert isinstance (kwargs , dict )
104- with torch ._dynamo .config .patch (
105- install_free_tensors = True
106- ), torch .fx .traceback .preserve_node_meta ():
107- # TODO: switch to use the official graph_capture API once it is ready
108- gm = _dynamo_graph_capture_for_export (model )(* args , ** kwargs )
109-
110- # Restore the state dict to match the original module
111- _restore_state_dict (model , gm )
112-
43+ with torch .fx .traceback .preserve_node_meta ():
44+ gm = dynamo_graph_capture_for_export (model )(* args , ** kwargs )
11345 logger .info ("Dynamo gm:" )
11446 logger .info (gm .print_readable (print_output = False ))
11547
116- fake_mode = gm .meta .get ("fake_mode" , None )
117- tracing_context = TracingContext (fake_mode )
48+ tracing_context = gm .meta ["tracing_context" ]
11849
11950 with tracing (tracing_context ):
12051 return (
0 commit comments