Skip to content

Commit d7d9b6c

Browse files
scoped quant status/config
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent d15139d commit d7d9b6c

File tree

3 files changed

+23
-13
lines changed

3 files changed

+23
-13
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,10 @@ def on_end(self, state: State, event: Event, **kwargs):
265265

266266
self.ended_ = True
267267

268-
modules = list(state.model.modules())
269-
for module in tqdm(modules, desc="Calibrating weights"):
268+
for _, module in tqdm(
269+
match_named_modules(state.model, self.targets, self.ignore),
270+
desc="Calibrating weights",
271+
):
270272
update_weight_zp_scale(module)
271273

272274
QuantizationMixin.end_calibration(self, state.model)

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import tqdm
2-
from compressed_tensors.utils import (
3-
match_named_modules,
4-
)
2+
from compressed_tensors.utils import match_named_modules
3+
54
from llmcompressor.core import Event, EventType, State
65
from llmcompressor.modifiers import Modifier
76
from llmcompressor.modifiers.quantization.calibration import (
@@ -77,10 +76,10 @@ def on_start(self, state: State, event: Event, **kwargs):
7776
# TODO: this step can be combined with update_weight_zp_scale
7877
# once update_fused_layer_weight_global_scales is removed
7978
# and not required by vLLM
80-
for name, module in tqdm.tqdm(named_modules):
79+
for _, module in tqdm.tqdm(named_modules):
8180
update_weight_global_scale(module)
8281

83-
for name, module in tqdm.tqdm(named_modules, desc="Calibrating weights"):
82+
for _, module in tqdm.tqdm(named_modules, desc="Calibrating weights"):
8483
update_fused_layer_weight_global_scales(module)
8584
update_weight_zp_scale(module)
8685

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
is_preset_scheme,
1515
preset_name_to_scheme,
1616
)
17+
from compressed_tensors.utils import match_named_modules
1718
from pydantic import Field, PrivateAttr, field_validator
1819
from torch.utils.hooks import RemovableHandle
1920

@@ -121,12 +122,15 @@ def initialize_quantization(self, model: torch.nn.Module):
121122
122123
:param model: model to attach schemes and observers to
123124
"""
124-
reset_quantization_status(model) # reset any previously applied qconfigs
125-
126125
# apply scheme and status to model
127126
config = self.resolve_quantization_config()
127+
128+
for _, module in match_named_modules(model, self.targets, self.ignore):
129+
reset_quantization_status(module) # reset any previously applied qconfigs
130+
128131
apply_quantization_config(model, config)
129132

133+
# TODO should we disable for entire model or just matching modules?
130134
# disable quantization until calibration
131135
model.apply(disable_quantization)
132136

@@ -138,8 +142,11 @@ def start_calibration(self, model: torch.nn.Module):
138142
:param model: model to prepare for calibration
139143
"""
140144
self._calibration_hooks = self._initialize_hooks(model)
141-
model.apply(self._initialize_observers)
142-
model.apply(apply_calibration_status)
145+
for _, module in match_named_modules(model, self.targets, self.ignore):
146+
self._initialize_observers(module)
147+
apply_calibration_status(module)
148+
149+
# TODO should we disable for entire model or just matching modules?
143150
model.apply(enable_quantization) # quantize at the same time as calibrate
144151

145152
def end_calibration(self, model: torch.nn.Module):
@@ -150,7 +157,9 @@ def end_calibration(self, model: torch.nn.Module):
150157
:param model: model to end calibration for
151158
"""
152159
self.remove_hooks(self._calibration_hooks)
153-
model.apply(freeze_module_quantization) # remove observers
160+
for _, module in match_named_modules(model, self.targets, self.ignore):
161+
freeze_module_quantization(module) # remove observers
162+
154163
model.apply(enable_quantization) # keep quantization enabled
155164

156165
def has_config(self) -> bool:
@@ -240,7 +249,7 @@ def _initialize_observers(self, module: torch.nn.Module):
240249

241250
def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
242251
hooks = set()
243-
for module in model.modules():
252+
for _, module in match_named_modules(model, self.targets, self.ignore):
244253
if not hasattr(module, "quantization_scheme"):
245254
continue
246255

0 commit comments

Comments
 (0)