Skip to content

Commit 704ac5f

Browse files
committed
fix vmap models
1 parent c1bc381 commit 704ac5f

File tree

85 files changed

+30591
-38977
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+30591
-38977
lines changed

graph_net/torch/extractor.py

Lines changed: 115 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,118 @@
1212
torch._dynamo.config.allow_rnn = True
1313

1414

15+
class GraphExtractor:
16+
def __init__(
17+
self, name, dynamic, mut_graph_codes=None, placeholder_auto_rename=False
18+
):
19+
self.subgraph_counter = 0
20+
self.name = name
21+
self.dynamic = dynamic
22+
self.mut_graph_codes = mut_graph_codes
23+
self.placeholder_auto_rename = placeholder_auto_rename
24+
self.workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
25+
if not self.workspace_path:
26+
raise EnvironmentError(
27+
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
28+
)
29+
30+
def move_files(self, source_dir, target_dir):
31+
os.makedirs(target_dir, exist_ok=True)
32+
for item in os.listdir(source_dir):
33+
source_path = os.path.join(source_dir, item)
34+
if os.path.isfile(source_path):
35+
target_path = os.path.join(target_dir, item)
36+
shutil.move(source_path, target_path)
37+
38+
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
39+
# 1. Get model path
40+
model_path = os.path.join(self.workspace_path, self.name)
41+
os.makedirs(model_path, exist_ok=True)
42+
43+
if self.subgraph_counter == 0:
44+
subgraph_path = model_path
45+
else:
46+
if self.subgraph_counter == 1:
47+
subgraph_0_path = os.path.join(model_path, f"subgraph_0")
48+
self.move_files(model_path, subgraph_0_path)
49+
50+
subgraph_path = os.path.join(
51+
model_path, f"subgraph_{self.subgraph_counter}"
52+
)
53+
os.makedirs(subgraph_path, exist_ok=True)
54+
55+
self.subgraph_counter += 1
56+
57+
# 2. Get full params
58+
params = {}
59+
input_idx = 0
60+
unique_id = 0
61+
62+
def try_rename_placeholder(node):
63+
assert node.op == "placeholder"
64+
if not self.placeholder_auto_rename:
65+
return
66+
nonlocal unique_id
67+
node.target = f"v{unique_id}"
68+
unique_id += 1
69+
node.name = f"v{unique_id}"
70+
unique_id += 1
71+
72+
for node in gm.graph.nodes:
73+
if node.op == "placeholder":
74+
try_rename_placeholder(node)
75+
input = sample_inputs[input_idx]
76+
if isinstance(input, torch.SymInt):
77+
input = torch.tensor(4)
78+
params[node.target] = input
79+
input_idx += 1
80+
81+
if node.op == "call_function" and hasattr(node.target, "__name__"):
82+
if node.target.__name__ in [
83+
"_enter_autocast",
84+
"_exit_autocast",
85+
]:
86+
node.replace_all_uses_with(node.args[0])
87+
gm.graph.erase_node(node)
88+
89+
assert input_idx == len(sample_inputs)
90+
if self.mut_graph_codes is not None:
91+
assert isinstance(self.mut_graph_codes, list)
92+
self.mut_graph_codes.append(gm.code)
93+
# 3. Generate and save model code
94+
base_code = gm.code
95+
# gm.graph.print_tabular()
96+
write_code = utils.apply_templates(base_code)
97+
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
98+
fp.write(write_code)
99+
100+
# 4. Save metadata
101+
metadata = {
102+
"framework": "torch",
103+
"num_devices_required": 1,
104+
"num_nodes_required": 1,
105+
"dynamic": bool(self.dynamic),
106+
"model_name": self.name,
107+
}
108+
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
109+
json.dump(metadata, f, indent=4)
110+
111+
# 5. Save tensor metadata
112+
# Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors)
113+
converted = utils.convert_state_and_inputs(params, [])
114+
utils.save_converted_to_text(converted, file_path=subgraph_path)
115+
utils.save_constraints_text(
116+
converted,
117+
file_path=os.path.join(subgraph_path, "input_tensor_constraints.py"),
118+
)
119+
120+
print(
121+
f"Graph and tensors for '{self.name}' extracted successfully to: {model_path}"
122+
)
123+
124+
return gm.forward
125+
126+
15127
def extract(name, dynamic=True, mut_graph_codes=None, placeholder_auto_rename=False):
16128
"""
17129
Extract computation graphs from PyTorch nn.Module.
@@ -83,118 +195,11 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
83195

84196
def wrapper(model: torch.nn.Module):
85197
assert isinstance(model, torch.nn.Module), f"{type(model)=}"
86-
87-
class GraphExtractor:
88-
def __init__(self):
89-
self.subgraph_counter = 0
90-
self.workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
91-
if not self.workspace_path:
92-
raise EnvironmentError(
93-
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
94-
)
95-
96-
def move_files(self, source_dir, target_dir):
97-
os.makedirs(target_dir, exist_ok=True)
98-
for item in os.listdir(source_dir):
99-
source_path = os.path.join(source_dir, item)
100-
if os.path.isfile(source_path):
101-
target_path = os.path.join(target_dir, item)
102-
shutil.move(source_path, target_path)
103-
104-
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
105-
# 1. Get model path
106-
model_path = os.path.join(self.workspace_path, name)
107-
os.makedirs(model_path, exist_ok=True)
108-
109-
if self.subgraph_counter == 0:
110-
subgraph_path = model_path
111-
else:
112-
if self.subgraph_counter == 1:
113-
subgraph_0_path = os.path.join(model_path, f"subgraph_0")
114-
self.move_files(model_path, subgraph_0_path)
115-
116-
subgraph_path = os.path.join(
117-
model_path, f"subgraph_{self.subgraph_counter}"
118-
)
119-
os.makedirs(subgraph_path, exist_ok=True)
120-
121-
self.subgraph_counter += 1
122-
123-
# 2. Get full params
124-
params = {}
125-
input_idx = 0
126-
unique_id = 0
127-
128-
def try_rename_placeholder(node):
129-
assert node.op == "placeholder"
130-
if not placeholder_auto_rename:
131-
return
132-
nonlocal unique_id
133-
node.target = f"v{unique_id}"
134-
unique_id += 1
135-
node.name = f"v{unique_id}"
136-
unique_id += 1
137-
138-
for node in gm.graph.nodes:
139-
if node.op == "placeholder":
140-
try_rename_placeholder(node)
141-
input = sample_inputs[input_idx]
142-
if isinstance(input, torch.SymInt):
143-
input = torch.tensor(4)
144-
params[node.target] = input
145-
input_idx += 1
146-
147-
if node.op == "call_function" and hasattr(node.target, "__name__"):
148-
if node.target.__name__ in [
149-
"_enter_autocast",
150-
"_exit_autocast",
151-
]:
152-
node.replace_all_uses_with(node.args[0])
153-
gm.graph.erase_node(node)
154-
155-
assert input_idx == len(sample_inputs)
156-
if mut_graph_codes is not None:
157-
assert isinstance(mut_graph_codes, list)
158-
mut_graph_codes.append(gm.code)
159-
# 3. Generate and save model code
160-
base_code = gm.code
161-
# gm.graph.print_tabular()
162-
write_code = utils.apply_templates(base_code)
163-
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
164-
fp.write(write_code)
165-
166-
# 4. Save metadata
167-
metadata = {
168-
"framework": "torch",
169-
"num_devices_required": 1,
170-
"num_nodes_required": 1,
171-
"dynamic": bool(dynamic),
172-
"model_name": name,
173-
}
174-
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
175-
json.dump(metadata, f, indent=4)
176-
177-
# 5. Save tensor metadata
178-
# Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors)
179-
converted = utils.convert_state_and_inputs(params, [])
180-
utils.save_converted_to_text(converted, file_path=subgraph_path)
181-
utils.save_constraints_text(
182-
converted,
183-
file_path=os.path.join(
184-
subgraph_path, "input_tensor_constraints.py"
185-
),
186-
)
187-
188-
print(
189-
f"Graph and tensors for '{name}' extracted successfully to: {model_path}"
190-
)
191-
192-
return gm.forward
193-
194-
extractor = GraphExtractor()
198+
extractor = GraphExtractor(
199+
name, dynamic, mut_graph_codes, placeholder_auto_rename
200+
)
195201
# return torch.compile(backend=extractor, dynamic=dynamic)
196202
compiled_model = torch.compile(model, backend=extractor, dynamic=dynamic)
197-
198203
return compiled_model
199204

200205
def decorator(module_class):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
96465c75b19e7adcc618c9f4415d42728480afb129905e98b6d8a5129471eb3a
1+
4e259e7a200600da941370398d0d5936d3196756cd0e810e9416769d76e77c99

0 commit comments

Comments
 (0)