Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 115 additions & 110 deletions graph_net/torch/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,118 @@
torch._dynamo.config.allow_rnn = True


class GraphExtractor:
def __init__(
self, name, dynamic, mut_graph_codes=None, placeholder_auto_rename=False
):
self.subgraph_counter = 0
self.name = name
self.dynamic = dynamic
self.mut_graph_codes = mut_graph_codes
self.placeholder_auto_rename = placeholder_auto_rename
self.workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
if not self.workspace_path:
raise EnvironmentError(
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
)

def move_files(self, source_dir, target_dir):
os.makedirs(target_dir, exist_ok=True)
for item in os.listdir(source_dir):
source_path = os.path.join(source_dir, item)
if os.path.isfile(source_path):
target_path = os.path.join(target_dir, item)
shutil.move(source_path, target_path)

def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
# 1. Get model path
model_path = os.path.join(self.workspace_path, self.name)
os.makedirs(model_path, exist_ok=True)

if self.subgraph_counter == 0:
subgraph_path = model_path
else:
if self.subgraph_counter == 1:
subgraph_0_path = os.path.join(model_path, f"subgraph_0")
self.move_files(model_path, subgraph_0_path)

subgraph_path = os.path.join(
model_path, f"subgraph_{self.subgraph_counter}"
)
os.makedirs(subgraph_path, exist_ok=True)

self.subgraph_counter += 1

# 2. Get full params
params = {}
input_idx = 0
unique_id = 0

def try_rename_placeholder(node):
assert node.op == "placeholder"
if not self.placeholder_auto_rename:
return
nonlocal unique_id
node.target = f"v{unique_id}"
unique_id += 1
node.name = f"v{unique_id}"
unique_id += 1

for node in gm.graph.nodes:
if node.op == "placeholder":
try_rename_placeholder(node)
input = sample_inputs[input_idx]
if isinstance(input, torch.SymInt):
input = torch.tensor(4)
params[node.target] = input
input_idx += 1

if node.op == "call_function" and hasattr(node.target, "__name__"):
if node.target.__name__ in [
"_enter_autocast",
"_exit_autocast",
]:
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)

assert input_idx == len(sample_inputs)
if self.mut_graph_codes is not None:
assert isinstance(self.mut_graph_codes, list)
self.mut_graph_codes.append(gm.code)
# 3. Generate and save model code
base_code = gm.code
# gm.graph.print_tabular()
write_code = utils.apply_templates(base_code)
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
fp.write(write_code)

# 4. Save metadata
metadata = {
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1,
"dynamic": bool(self.dynamic),
"model_name": self.name,
}
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
json.dump(metadata, f, indent=4)

# 5. Save tensor metadata
# Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors)
converted = utils.convert_state_and_inputs(params, [])
utils.save_converted_to_text(converted, file_path=subgraph_path)
utils.save_constraints_text(
converted,
file_path=os.path.join(subgraph_path, "input_tensor_constraints.py"),
)

print(
f"Graph and tensors for '{self.name}' extracted successfully to: {model_path}"
)

return gm.forward


def extract(name, dynamic=True, mut_graph_codes=None, placeholder_auto_rename=False):
"""
Extract computation graphs from PyTorch nn.Module.
Expand Down Expand Up @@ -83,118 +195,11 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):

def wrapper(model: torch.nn.Module):
assert isinstance(model, torch.nn.Module), f"{type(model)=}"

class GraphExtractor:
def __init__(self):
self.subgraph_counter = 0
self.workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
if not self.workspace_path:
raise EnvironmentError(
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
)

def move_files(self, source_dir, target_dir):
os.makedirs(target_dir, exist_ok=True)
for item in os.listdir(source_dir):
source_path = os.path.join(source_dir, item)
if os.path.isfile(source_path):
target_path = os.path.join(target_dir, item)
shutil.move(source_path, target_path)

def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
# 1. Get model path
model_path = os.path.join(self.workspace_path, name)
os.makedirs(model_path, exist_ok=True)

if self.subgraph_counter == 0:
subgraph_path = model_path
else:
if self.subgraph_counter == 1:
subgraph_0_path = os.path.join(model_path, f"subgraph_0")
self.move_files(model_path, subgraph_0_path)

subgraph_path = os.path.join(
model_path, f"subgraph_{self.subgraph_counter}"
)
os.makedirs(subgraph_path, exist_ok=True)

self.subgraph_counter += 1

# 2. Get full params
params = {}
input_idx = 0
unique_id = 0

def try_rename_placeholder(node):
assert node.op == "placeholder"
if not placeholder_auto_rename:
return
nonlocal unique_id
node.target = f"v{unique_id}"
unique_id += 1
node.name = f"v{unique_id}"
unique_id += 1

for node in gm.graph.nodes:
if node.op == "placeholder":
try_rename_placeholder(node)
input = sample_inputs[input_idx]
if isinstance(input, torch.SymInt):
input = torch.tensor(4)
params[node.target] = input
input_idx += 1

if node.op == "call_function" and hasattr(node.target, "__name__"):
if node.target.__name__ in [
"_enter_autocast",
"_exit_autocast",
]:
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)

assert input_idx == len(sample_inputs)
if mut_graph_codes is not None:
assert isinstance(mut_graph_codes, list)
mut_graph_codes.append(gm.code)
# 3. Generate and save model code
base_code = gm.code
# gm.graph.print_tabular()
write_code = utils.apply_templates(base_code)
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
fp.write(write_code)

# 4. Save metadata
metadata = {
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1,
"dynamic": bool(dynamic),
"model_name": name,
}
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
json.dump(metadata, f, indent=4)

# 5. Save tensor metadata
# Adapt to different input structures (e.g., single tensor vs. dict/tuple of tensors)
converted = utils.convert_state_and_inputs(params, [])
utils.save_converted_to_text(converted, file_path=subgraph_path)
utils.save_constraints_text(
converted,
file_path=os.path.join(
subgraph_path, "input_tensor_constraints.py"
),
)

print(
f"Graph and tensors for '{name}' extracted successfully to: {model_path}"
)

return gm.forward

extractor = GraphExtractor()
extractor = GraphExtractor(
name, dynamic, mut_graph_codes, placeholder_auto_rename
)
# return torch.compile(backend=extractor, dynamic=dynamic)
compiled_model = torch.compile(model, backend=extractor, dynamic=dynamic)

return compiled_model

def decorator(module_class):
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
96465c75b19e7adcc618c9f4415d42728480afb129905e98b6d8a5129471eb3a
4e259e7a200600da941370398d0d5936d3196756cd0e810e9416769d76e77c99
Loading
Loading