14
14
is_preset_scheme ,
15
15
preset_name_to_scheme ,
16
16
)
17
+ from compressed_tensors .utils import match_named_modules
17
18
from pydantic import Field , PrivateAttr , field_validator
18
19
from torch .utils .hooks import RemovableHandle
19
20
@@ -121,12 +122,15 @@ def initialize_quantization(self, model: torch.nn.Module):
121
122
122
123
:param model: model to attach schemes and observers to
123
124
"""
124
- reset_quantization_status (model ) # reset any previously applied qconfigs
125
-
126
125
# apply scheme and status to model
127
126
config = self .resolve_quantization_config ()
127
+
128
+ for _ , module in match_named_modules (model , self .targets , self .ignore ):
129
+ reset_quantization_status (module ) # reset any previously applied qconfigs
130
+
128
131
apply_quantization_config (model , config )
129
132
133
+ # TODO should we disable for entire model or just matching modules?
130
134
# disable quantization until calibration
131
135
model .apply (disable_quantization )
132
136
@@ -138,8 +142,11 @@ def start_calibration(self, model: torch.nn.Module):
138
142
:param model: model to prepare for calibration
139
143
"""
140
144
self ._calibration_hooks = self ._initialize_hooks (model )
141
- model .apply (self ._initialize_observers )
142
- model .apply (apply_calibration_status )
145
+ for _ , module in match_named_modules (model , self .targets , self .ignore ):
146
+ self ._initialize_observers (module )
147
+ apply_calibration_status (module )
148
+
149
+ # TODO should we disable for entire model or just matching modules?
143
150
model .apply (enable_quantization ) # quantize at the same time as calibrate
144
151
145
152
def end_calibration (self , model : torch .nn .Module ):
@@ -150,7 +157,9 @@ def end_calibration(self, model: torch.nn.Module):
150
157
:param model: model to end calibration for
151
158
"""
152
159
self .remove_hooks (self ._calibration_hooks )
153
- model .apply (freeze_module_quantization ) # remove observers
160
+ for _ , module in match_named_modules (model , self .targets , self .ignore ):
161
+ freeze_module_quantization (module ) # remove observers
162
+
154
163
model .apply (enable_quantization ) # keep quantization enabled
155
164
156
165
def has_config (self ) -> bool :
@@ -240,7 +249,7 @@ def _initialize_observers(self, module: torch.nn.Module):
240
249
241
250
def _initialize_hooks (self , model : torch .nn .Module ) -> Set [RemovableHandle ]:
242
251
hooks = set ()
243
- for module in model . modules ( ):
252
+ for _ , module in match_named_modules ( model , self . targets , self . ignore ):
244
253
if not hasattr (module , "quantization_scheme" ):
245
254
continue
246
255
0 commit comments