@@ -55,6 +55,20 @@ class OpStat:
55
55
count : int = 0
56
56
57
57
58
+ def resolve_native_multi_head_attention (* args , ** kwargs ):
59
+ query , key , value = args [0 ], args [1 ], args [2 ]
60
+ seq_len , batch_size , embed_dim = query .shape
61
+ attn_output = torch .empty (
62
+ (seq_len , batch_size , embed_dim ), dtype = query .dtype , device = "meta"
63
+ )
64
+
65
+ # seq_len_k = key.shape[0]
66
+ # num_heads = args[4]
67
+ # attn_output_weights = torch.empty((batch_size, num_heads, seq_len, seq_len_k),
68
+ # dtype=query.dtype, device='meta')
69
+ return attn_output # , attn_output_weights
70
+
71
+
58
72
def resolve_get_attr (gm : torch .fx .GraphModule , node : torch .fx .Node ):
59
73
attr_itr = node .target .split ("." )
60
74
val = gm
@@ -65,13 +79,13 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
65
79
66
80
67
81
def collect_op_stats (model , input_dict ):
68
- # FX symbolic trace
69
82
try :
83
+ # FX symbolic trace
70
84
traced = torch .fx .symbolic_trace (model )
71
85
# print(traced.graph)
72
86
except Exception :
73
87
print ("Failed to FX symbolic trace" )
74
- return None
88
+ return False , None
75
89
76
90
# Use meta tensors as input to avoid actually running the model
77
91
meta_input_dict = {}
@@ -80,8 +94,9 @@ def collect_op_stats(model, input_dict):
80
94
torch .empty_like (x , device = "meta" ) if isinstance (x , torch .Tensor ) else x
81
95
)
82
96
83
- node_outputs = {}
97
+ is_complete = True
84
98
op_stats = {}
99
+ node_outputs = {}
85
100
for node in traced .graph .nodes :
86
101
op_name = None
87
102
dtype = None
@@ -99,31 +114,35 @@ def collect_op_stats(model, input_dict):
99
114
lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
100
115
)
101
116
102
- if node .op == "call_module" :
103
- # classname of module
104
- submod = traced .get_submodule (node .target )
105
- op_name = submod .__class__ .__name__
106
- op_func = submod
107
- elif node .op == "call_function" :
108
- op_name = node .target .__name__
109
- op_func = node .target
110
- elif node .op == "call_method" :
111
- op_name = node .target
112
- self_obj = (
113
- node_outputs [node .args [0 ].name ]
114
- if isinstance (node .args [0 ], torch .fx .Node )
115
- else node .args [0 ]
116
- )
117
- op_func = getattr (self_obj , node .target )
118
- node_args = node_args [1 :]
119
-
120
117
try :
121
- out = op_func (* node_args , ** node_kwargs )
118
+ if node .op == "call_module" :
119
+ # classname of module
120
+ submod = traced .get_submodule (node .target )
121
+ op_name = submod .__class__ .__name__
122
+ op_func = submod
123
+ elif node .op == "call_function" :
124
+ op_name = node .target .__name__
125
+ op_func = node .target
126
+ elif node .op == "call_method" :
127
+ op_name = node .target
128
+ self_obj = (
129
+ node_outputs [node .args [0 ].name ]
130
+ if isinstance (node .args [0 ], torch .fx .Node )
131
+ else node .args [0 ]
132
+ )
133
+ op_func = getattr (self_obj , node .target )
134
+ node_args = node_args [1 :]
135
+
136
+ if op_name == "_native_multi_head_attention" :
137
+ out = resolve_native_multi_head_attention (* node_args , ** node_kwargs )
138
+ else :
139
+ out = op_func (* node_args , ** node_kwargs )
122
140
node_outputs [node .name ] = out
123
141
dtype = out .dtype if isinstance (out , torch .Tensor ) else None
124
142
except Exception :
125
143
print (f"dtype inference failed: node.op={ node .op } , op_name={ op_name } " )
126
144
node_outputs [node .name ] = None
145
+ is_complete = False
127
146
elif node .op == "get_attr" :
128
147
op_name = node .op
129
148
out = resolve_get_attr (traced , node )
@@ -149,11 +168,16 @@ def collect_op_stats(model, input_dict):
149
168
else :
150
169
op_stats [op_name ].dtype .add (dtype_str )
151
170
op_stats [op_name ].count = op_stats [op_name ].count + 1
152
- return op_stats
171
+ return is_complete , op_stats
153
172
154
173
155
174
def collect_model_stats (model_path , device , log_prompt ):
156
- print (f"Collect information for { model_path } " )
175
+ if not hasattr (collect_model_stats , "_counter" ):
176
+ collect_model_stats ._counter = 0
177
+ else :
178
+ collect_model_stats ._counter += 1
179
+ print (f"[{ collect_model_stats ._counter } ] Collect information for { model_path } " )
180
+
157
181
model_class = load_class_from_file (
158
182
os .path .join (model_path , "model.py" ), "GraphModule"
159
183
)
@@ -164,7 +188,7 @@ def collect_model_stats(model_path, device, log_prompt):
164
188
num_inputs = 0
165
189
num_outputs = 0
166
190
dtypes = set ()
167
- op_stats = collect_op_stats (model , input_dict )
191
+ is_complete , op_stats = collect_op_stats (model , input_dict )
168
192
if op_stats is not None :
169
193
for op_name , stat in op_stats .items ():
170
194
if op_name == "placeholder" :
@@ -192,7 +216,7 @@ def collect_model_stats(model_path, device, log_prompt):
192
216
dtypes_str = "[" + "," .join (dtypes ) + "]"
193
217
param_dtypes_str = "[" + "," .join (param_dtypes ) + "]"
194
218
print (
195
- f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } " ,
219
+ f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } is_complete: { is_complete } " ,
196
220
file = sys .stderr ,
197
221
flush = True ,
198
222
)
0 commit comments