Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,10 @@ def on_end(self, state: State, event: Event, **kwargs):

self.ended_ = True

modules = list(state.model.modules())
for module in tqdm(modules, desc="Calibrating weights"):
for _, module in tqdm(
match_named_modules(state.model, self.targets, self.ignore),
desc="Calibrating weights",
):
update_weight_zp_scale(module)

QuantizationMixin.end_calibration(self, state.model)
Expand Down
1 change: 1 addition & 0 deletions src/llmcompressor/modifiers/awq/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class AWQMapping:
"Phi3ForCausalLM": _phi_mappings,
"Phi3VForCausalLM": _phi_mappings,
"Qwen2ForCausalLM": _default_mappings,
"Qwen2_5OmniThinkerForConditionalGeneration": _default_mappings,
"Qwen2MoeForCausalLM": _moe_default_mappings,
"Qwen3ForCausalLM": _default_mappings,
"Qwen3MoeForCausalLM": _moe_default_mappings,
Expand Down
8 changes: 6 additions & 2 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
align_module_device,
get_execution_device,
getattr_chain,
match_named_modules,
update_offload_parameter,
)
from loguru import logger
Expand Down Expand Up @@ -165,7 +166,10 @@ def on_initialize(self, state: State, **kwargs) -> bool:
QuantizationMixin.initialize_quantization(self, state.model)

# prepare module names
self._module_names = {m: name for name, m in state.model.named_modules()}
self._module_names = {
m: name
for name, m in match_named_modules(state.model, self.targets, self.ignore)
}

return True

Expand All @@ -178,7 +182,7 @@ def on_start(self, state: State, event: Event, **kwargs):

# register gptq hooks
added_hook = False
for module in state.model.modules():
for _, module in match_named_modules(state.model, self.targets, self.ignore):
if getattr_chain(module, "quantization_scheme.weights", None) is not None:
# HACK: previously, embeddings were not quantized because they were not
# accessible by the layer compressor. For now, we manually ignore it,
Expand Down
9 changes: 6 additions & 3 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tqdm
from compressed_tensors.utils import match_named_modules

from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
Expand Down Expand Up @@ -69,14 +70,16 @@ def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True
QuantizationMixin.start_calibration(self, state.model)

modules = list(state.model.modules())
named_modules = list(
match_named_modules(state.model, self.targets, self.ignore)
)
# TODO: this step can be combined with update_weight_zp_scale
# once update_fused_layer_weight_global_scales is removed
# and not required by vLLM
for module in tqdm.tqdm(modules):
for _, module in tqdm.tqdm(named_modules):
update_weight_global_scale(module)

for module in tqdm.tqdm(modules, desc="Calibrating weights"):
for _, module in tqdm.tqdm(named_modules, desc="Calibrating weights"):
update_fused_layer_weight_global_scales(module)
update_weight_zp_scale(module)

Expand Down
33 changes: 21 additions & 12 deletions src/llmcompressor/modifiers/quantization/quantization/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
is_preset_scheme,
preset_name_to_scheme,
)
from compressed_tensors.utils import match_named_modules
from pydantic import Field, PrivateAttr, field_validator
from torch.utils.hooks import RemovableHandle

Expand Down Expand Up @@ -116,41 +117,49 @@ def validate_scheme(

def initialize_quantization(self, model: torch.nn.Module):
"""
Attach quantization schemes and observers to modules in the model according to
Attach quantization schemes to modules in the model according to
the quantization config specified on this modifier

:param model: model to attach schemes and observers to
"""
reset_quantization_status(model) # reset any previously applied qconfigs

# apply scheme and status to model
config = self.resolve_quantization_config()

for _, module in match_named_modules(model, self.targets, self.ignore):
reset_quantization_status(module) # reset any previously applied qconfigs

apply_quantization_config(model, config)

# apply observers, disable quantization until calibration
model.apply(self._initialize_observers)
# TODO should we disable for entire model or just matching modules?
# disable quantization until calibration
model.apply(disable_quantization)

def start_calibration(self, model: torch.nn.Module):
"""
Register activation calibration hooks (including kv_cache quantization) and
enable quantization as we calibrate
Attach observers, register activation calibration hooks (including
kv_cache quantization) and enable quantization as we calibrate

:param model: model to prepare for calibration
"""
self._calibration_hooks = self._initialize_hooks(model)
model.apply(apply_calibration_status)
for _, module in match_named_modules(model, self.targets, self.ignore):
self._initialize_observers(module)
apply_calibration_status(module)

# TODO should we disable for entire model or just matching modules?
model.apply(enable_quantization) # quantize at the same time as calibrate

def end_calibration(self, model: torch.nn.Module):
"""
Remove calibration hooks and set the model status to frozen. Keep quantization
enabled for future operations
Remove calibration hooks and observers, and set the model status to frozen.
Keep quantization enabled for future operations

:param model: model to end calibration for
"""
self.remove_hooks(self._calibration_hooks)
model.apply(freeze_module_quantization) # remove observers
for _, module in match_named_modules(model, self.targets, self.ignore):
freeze_module_quantization(module) # remove observers

model.apply(enable_quantization) # keep quantization enabled

def has_config(self) -> bool:
Expand Down Expand Up @@ -240,7 +249,7 @@ def _initialize_observers(self, module: torch.nn.Module):

def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
hooks = set()
for module in model.modules():
for _, module in match_named_modules(model, self.targets, self.ignore):
if not hasattr(module, "quantization_scheme"):
continue

Expand Down