From 49e1d906cb6d4d2d6fa6326868efccdfea8178e2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Aug 2025 21:42:02 -0400 Subject: [PATCH 01/10] quantization done, need calibration Signed-off-by: Kyle Sayers --- examples/transform/spinquant_example.py | 4 +- .../modifiers/transform/spinquant/base.py | 49 ++++++++++++++++--- .../modifiers/transform/spinquant/mappings.py | 2 + 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py index 547d06041..fd8f81486 100644 --- a/examples/transform/spinquant_example.py +++ b/examples/transform/spinquant_example.py @@ -18,7 +18,7 @@ # * apply spinquant transforms to model to reduce quantization loss # * quantize the weights to 4 bit with group size 128 recipe = [ - SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"), + SpinQuantModifier(rotations=["R3"], transform_type="hadamard"), QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] @@ -35,6 +35,6 @@ print("==========================================\n\n") # Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR1R2-w4a16" +SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR3-w4a16" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 68095ab1b..95eeb0550 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -128,7 +128,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: config_groups["R2"] = self._create_r2_scheme(state.model) if SpinquantRotation.R3 in self.rotations: - config_groups["R3"] = self._create_r3_scheme() + config_groups["R3"] = self._create_r3_scheme(state.model) if SpinquantRotation.R4 in self.rotations: config_groups["R4"] = self._create_r4_scheme() @@ -235,12 +235,49 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: ], ) - def _create_r3_scheme(self) -> TransformScheme: - raise NotImplementedError( - "SpinQuant R3 and R4 rotations will be added in a future release" + def _create_r3_scheme(self, model: PreTrainedModel) -> TransformScheme: + config = model.config + + if hasattr(config, "head_dim"): + head_dim = config.head_dim + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + head_dim = config.hidden_size // config.num_attention_heads + else: + raise NotImplementedError() + + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + precision=self.precision, + head_dim=head_dim, + apply=[ + TransformArgs( + targets=[self.mappings.attn], + location="q_attn", + ), + TransformArgs( + targets=[self.mappings.attn], + location="k_cache", + ), + ], ) def _create_r4_scheme(self) -> TransformScheme: - raise NotImplementedError( - "SpinQuant R3 and R4 rotations will be added in a future release" + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + precision=self.precision, + apply=[ + TransformArgs( + targets=[*self.mappings.mlp_out], + location="input", + ), + TransformArgs( + targets=[*self.mappings.mlp_out], + location="weight_input", + inverse=True, + ), + ], ) diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py index 514d1f109..2d2bb3cba 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -29,6 +29,7 @@ class SpinQuantMapping(BaseModel): embedding: str + attn: str attn_q: str attn_k: str attn_v: str @@ -50,6 +51,7 @@ def cast_to_list(cls, value): _default_mappings = SpinQuantMapping( embedding="re:.*embed_tokens$", + attn="re:.*self_attn$", attn_q="re:.*q_proj$", attn_k="re:.*k_proj$", attn_v="re:.*v_proj$", From 8837af27d21fa5646981e0d34984d37e83fcc0b6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Aug 2025 23:02:11 -0400 Subject: [PATCH 02/10] attention and kv quantization Signed-off-by: Kyle Sayers --- examples/transform/spinquant_example.py | 44 ++++++++- .../modifiers/quantization/calibration.py | 93 +++++++++++++------ .../quantization/quantization/mixin.py | 48 ++++------ 3 files changed, 125 insertions(+), 60 deletions(-) diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py index fd8f81486..569e0544e 100644 --- a/examples/transform/spinquant_example.py +++ b/examples/transform/spinquant_example.py @@ -1,3 +1,4 @@ +from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot @@ -11,6 +12,43 @@ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) +# Select calibration dataset. +DATASET_ID = "mit-han-lab/pile-val-backup" +DATASET_SPLIT = "validation" + +# Select number of samples. 256 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 256 +MAX_SEQUENCE_LENGTH = 512 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + [{"role": "user", "content": example["text"]}], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + # NOTE: currently only fused rotations (R1 & R2) are available # Learned rotations and online rotations (R3 & R4) will be added # in a future release. @@ -19,11 +57,11 @@ # * quantize the weights to 4 bit with group size 128 recipe = [ SpinQuantModifier(rotations=["R3"], transform_type="hadamard"), - QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), + QuantizationModifier(targets=["LlamaAttention"], scheme="FP8", ignore=["lm_head"]), ] # Apply algorithms. -oneshot(model=model, recipe=recipe, pipeline="datafree") +oneshot(model=model, dataset=ds, recipe=recipe, pipeline="sequential") # Confirm generations of the quantized model look sane. print("\n\n") @@ -35,6 +73,6 @@ print("==========================================\n\n") # Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR3-w4a16" +SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR3" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 979a227d2..417808a77 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -4,6 +4,7 @@ from compressed_tensors.quantization import ( DynamicType, KVCacheScaleType, + QuantizationArgs, QuantizationScheme, QuantizationStatus, QuantizationStrategy, @@ -39,10 +40,7 @@ ] -def initialize_observer( - module: Module, - base_name: str, -): +def initialize_observer(module: Module, base_name: str): """ Initialize observer module and attach as submodule. The name of the observer is fetched from the quantization_args. @@ -53,33 +51,34 @@ def initialize_observer( :param base_name: str used to name the observer attribute """ - - arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" - quantization_scheme = getattr(module, "quantization_scheme", None) - if not quantization_scheme: - # no quantization scheme nothing to do + if base_name == "weights": + arg_name = "weight" + elif base_name == "output": + arg_name = "output_activations" + else: + # (input, q, k, v) + arg_name = "input_activations" + + quantization_args: Optional[QuantizationArgs] = getattr_chain( + module, f"quantization_scheme.{arg_name}", None + ) + if quantization_args is None or quantization_args.is_online(): return - quantization_args = getattr(quantization_scheme, arg_name, None) - # dont need observers for dynamic - if quantization_args is not None and quantization_args.dynamic in ( - False, - DynamicType.LOCAL, - ): - observer_kwargs = quantization_args.observer_kwargs or {} - observer = Observer.load_from_registry( - quantization_args.observer, - quantization_args=quantization_args, - averaging_constant=observer_kwargs.get( - "averaging_constant", DEFAULT_AVERAGING_CONSTANT - ), - # used by mse observer only, will be ignored by minmax observer - maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK), - patience=observer_kwargs.get("patience", DEFAULT_PATIENCE), - grid=observer_kwargs.get("grid", DEFAULT_GRID), - norm=observer_kwargs.get("norm", DEFAULT_NORM), - ) - module.register_module(f"{base_name}_observer", observer) + observer_kwargs = quantization_args.observer_kwargs or {} + observer = Observer.load_from_registry( + quantization_args.observer, + quantization_args=quantization_args, + averaging_constant=observer_kwargs.get( + "averaging_constant", DEFAULT_AVERAGING_CONSTANT + ), + # used by mse observer only, will be ignored by minmax observer + maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK), + patience=observer_kwargs.get("patience", DEFAULT_PATIENCE), + grid=observer_kwargs.get("grid", DEFAULT_GRID), + norm=observer_kwargs.get("norm", DEFAULT_NORM), + ) + module.register_module(f"{base_name}_observer", observer) def call_observer( @@ -218,6 +217,18 @@ def calibrate_input_hook(module: Module, args: Any): calibrate_activations(module, value=args, base_name="input") +def calibrate_query_hook(module: Module, query_states: torch.Tensor): + calibrate_activations(module, query_states, base_name="q") + + +def calibrate_key_hook(module: Module, key_states: torch.Tensor): + calibrate_activations(module, key_states, base_name="k") + + +def calibrate_value_hook(module: Module, value_states: torch.Tensor): + calibrate_activations(module, value_states, base_name="v") + + def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): """ Hook to calibrate output activations. @@ -238,6 +249,30 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): return output +# def register_calibrate_attn_hooks( +# modifier: "HooksMixin", attention_impl: "CompressedAttentionImpl" +# ) -> Set[RemovableHandle]: +# return { +# modifier.register_hook( +# attention_impl, partial(calibrate_input_hook, basename="q"), "query" +# ), +# modifier.register_hook( +# attention_impl, partial(calibrate_input_hook, basename="k"), "key" +# ), +# modifier.register_hook( +# attention_impl, partial(calibrate_input_hook, basename="v"), "value" +# ), +# } + + +# def initialize_attention_observers(module: Module): +# input_args = getattr_chain(module, "quantization_scheme.input_activations", None) +# if input_args is not None: +# initialize_observer(module, "q", input_args) +# initialize_observer(module, "k", input_args) +# initialize_observer(module, "v", input_args) + + def calibrate_kv_cache_input_hook( module: Module, args: Any, kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index d193d85a1..8f335d167 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -1,6 +1,8 @@ from typing import Any, Dict, List, Optional, Set, Union import torch +from compressed_tensors.modeling.attention import register_query_hook +from compressed_tensors.modeling.kvcache import register_key_hook, register_value_hook from compressed_tensors.quantization import ( DynamicType, QuantizationArgs, @@ -20,12 +22,12 @@ from llmcompressor.modifiers.quantization.calibration import ( apply_calibration_status, calibrate_input_hook, - calibrate_kv_cache_input_hook, - calibrate_kv_cache_output_hook, + calibrate_key_hook, calibrate_output_hook, + calibrate_query_hook, + calibrate_value_hook, freeze_module_quantization, initialize_observer, - initialize_quantized_kv_cache, reset_quantization_status, ) from llmcompressor.modifiers.utils.hooks import HooksMixin @@ -223,17 +225,18 @@ def _initialize_observers(self, module: torch.nn.Module): # input activations if input: - initialize_observer(module, base_name="input") + if not is_attention: + initialize_observer(module, base_name="input") + else: + if not scheme.kv_cache_only: + initialize_observer(module, base_name="q") + initialize_observer(module, base_name="k") + initialize_observer(module, base_name="v") # weight observers (used by `update_weight_zp_scale` or child modifier) if weight: initialize_observer(module, base_name="weight") - # kv_cache activations. Within `apply_quantization_config`, the config is - # modified to use attention output quantization if a kv_cache_scheme exists - if is_attention and output: - initialize_quantized_kv_cache(module) - # output activations elif output: initialize_observer(module, base_name="output") @@ -254,26 +257,15 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: # input activations if input: - hooks.add( - self.register_hook(module, calibrate_input_hook, "forward_pre") - ) - - # kv_cache activations. Within `apply_quantization_config`, the config is - # modified to use attention output quantization if a kv_cache_scheme exists - if is_attention and output: - hooks.add( - self.register_hook( - module, - calibrate_kv_cache_input_hook, - "forward_pre", - with_kwargs=True, - ) - ) - hooks.add( - self.register_hook( - module, calibrate_kv_cache_output_hook, "forward" + if not is_attention: + hooks.add( + self.register_hook(module, calibrate_input_hook, "forward_pre") ) - ) + else: + if not scheme.kv_cache_only: + hooks.add(register_query_hook(module, calibrate_query_hook)) + hooks.add(register_key_hook(module, calibrate_key_hook)) + hooks.add(register_value_hook(module, calibrate_value_hook)) # output activations elif output: From c965bdd7e03e538c9bf243620a28ad55351aadc7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Aug 2025 23:02:57 -0400 Subject: [PATCH 03/10] style Signed-off-by: Kyle Sayers --- .../modifiers/quantization/calibration.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 417808a77..13504e797 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -249,30 +249,6 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): return output -# def register_calibrate_attn_hooks( -# modifier: "HooksMixin", attention_impl: "CompressedAttentionImpl" -# ) -> Set[RemovableHandle]: -# return { -# modifier.register_hook( -# attention_impl, partial(calibrate_input_hook, basename="q"), "query" -# ), -# modifier.register_hook( -# attention_impl, partial(calibrate_input_hook, basename="k"), "key" -# ), -# modifier.register_hook( -# attention_impl, partial(calibrate_input_hook, basename="v"), "value" -# ), -# } - - -# def initialize_attention_observers(module: Module): -# input_args = getattr_chain(module, "quantization_scheme.input_activations", None) -# if input_args is not None: -# initialize_observer(module, "q", input_args) -# initialize_observer(module, "k", input_args) -# initialize_observer(module, "v", input_args) - - def calibrate_kv_cache_input_hook( module: Module, args: Any, kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: From 9415c62de2f41e792ca2806a6e60f2293a89d380 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Aug 2025 23:37:31 -0400 Subject: [PATCH 04/10] flatten activations before passing Signed-off-by: Kyle Sayers --- examples/transform/spinquant_example.py | 23 ++++++++++++++++--- .../modifiers/quantization/calibration.py | 3 +++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py index 569e0544e..b1cd02fe1 100644 --- a/examples/transform/spinquant_example.py +++ b/examples/transform/spinquant_example.py @@ -55,13 +55,30 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * apply spinquant transforms to model to reduce quantization loss # * quantize the weights to 4 bit with group size 128 +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, + QuantizationType, +) + +scheme = QuantizationScheme( + targets=["LlamaAttention"], + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=False, + ), +) + recipe = [ SpinQuantModifier(rotations=["R3"], transform_type="hadamard"), - QuantizationModifier(targets=["LlamaAttention"], scheme="FP8", ignore=["lm_head"]), + QuantizationModifier(config_groups={"attention": scheme}), ] # Apply algorithms. -oneshot(model=model, dataset=ds, recipe=recipe, pipeline="sequential") +oneshot(model=model, dataset=ds, recipe=recipe, pipeline="basic") # Confirm generations of the quantized model look sane. print("\n\n") @@ -73,6 +90,6 @@ def tokenize(sample): print("==========================================\n\n") # Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR3" +SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR3-FP8_asym-attn" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 13504e797..e5d1f8f81 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -218,14 +218,17 @@ def calibrate_input_hook(module: Module, args: Any): def calibrate_query_hook(module: Module, query_states: torch.Tensor): + query_states = query_states.flatten(0, -2) calibrate_activations(module, query_states, base_name="q") def calibrate_key_hook(module: Module, key_states: torch.Tensor): + key_states = key_states.flatten(0, -2) calibrate_activations(module, key_states, base_name="k") def calibrate_value_hook(module: Module, value_states: torch.Tensor): + value_states = value_states.flatten(0, -2) calibrate_activations(module, value_states, base_name="v") From cec5fee417c30f094b45587dfa2deeb7fa7e90b8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 21 Aug 2025 10:52:45 -0400 Subject: [PATCH 05/10] full spinquant Signed-off-by: Kyle Sayers --- examples/transform/spinquant_example.py | 28 ++++++++++++++----------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py index b1cd02fe1..643b47dd0 100644 --- a/examples/transform/spinquant_example.py +++ b/examples/transform/spinquant_example.py @@ -61,20 +61,24 @@ def tokenize(sample): QuantizationStrategy, QuantizationType, ) - -scheme = QuantizationScheme( - targets=["LlamaAttention"], - input_activations=QuantizationArgs( - num_bits=8, - type=QuantizationType.FLOAT, - strategy=QuantizationStrategy.TENSOR, - symmetric=False, +from compressed_tensors.quantization.quant_scheme import FP8 + +config_groups = { + "attention": QuantizationScheme( + targets=["LlamaAttention"], + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=False, + ), ), -) + "linear": QuantizationScheme(targets=["Linear"], **FP8), +} recipe = [ - SpinQuantModifier(rotations=["R3"], transform_type="hadamard"), - QuantizationModifier(config_groups={"attention": scheme}), + SpinQuantModifier(rotations=["R1", "R2", "R3", "R4"], transform_type="random-hadamard"), + QuantizationModifier(config_groups=config_groups), ] # Apply algorithms. @@ -90,6 +94,6 @@ def tokenize(sample): print("==========================================\n\n") # Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR3-FP8_asym-attn" +SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquant" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) From e6e66b073e0cad4749ad485eab56ed986895881e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 21 Aug 2025 19:47:13 -0400 Subject: [PATCH 06/10] fix typo Signed-off-by: Kyle Sayers --- examples/transform/quip_example.py | 2 +- examples/transform/spinquant_example.py | 31 ++++++++++--------- .../modifiers/quantization/calibration.py | 4 +-- .../quantization/quantization/mixin.py | 4 +-- 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py index 2c989b2d7..933754ea1 100644 --- a/examples/transform/quip_example.py +++ b/examples/transform/quip_example.py @@ -13,7 +13,7 @@ # Select model and load it. # NOTE: because the datafree pipeline is being used in this # example, you can use additional GPUs to support larger models -MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" +MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py index 643b47dd0..db394cf0b 100644 --- a/examples/transform/spinquant_example.py +++ b/examples/transform/spinquant_example.py @@ -7,7 +7,7 @@ from llmcompressor.utils import dispatch_for_generation # Select model and load it. -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) @@ -64,21 +64,22 @@ def tokenize(sample): from compressed_tensors.quantization.quant_scheme import FP8 config_groups = { - "attention": QuantizationScheme( - targets=["LlamaAttention"], - input_activations=QuantizationArgs( - num_bits=8, - type=QuantizationType.FLOAT, - strategy=QuantizationStrategy.TENSOR, - symmetric=False, - ), - ), + # "attention": QuantizationScheme( + # targets=["LlamaAttention"], + # input_activations=QuantizationArgs( + # num_bits=8, + # type=QuantizationType.FLOAT, + # strategy=QuantizationStrategy.TENSOR, + # symmetric=False, + # ), + # ), "linear": QuantizationScheme(targets=["Linear"], **FP8), } recipe = [ - SpinQuantModifier(rotations=["R1", "R2", "R3", "R4"], transform_type="random-hadamard"), - QuantizationModifier(config_groups=config_groups), + SpinQuantModifier(rotations=["R1"], transform_type="random-hadamard"), + #QuantizationModifier(config_groups=config_groups), + #QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. @@ -94,6 +95,6 @@ def tokenize(sample): print("==========================================\n\n") # Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquant" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) +# SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquant-R1R2R4-W4A16" +# model.save_pretrained(SAVE_DIR, save_compressed=True) +# tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index e5d1f8f81..3aa57eeb2 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -51,8 +51,8 @@ def initialize_observer(module: Module, base_name: str): :param base_name: str used to name the observer attribute """ - if base_name == "weights": - arg_name = "weight" + if base_name == "weight": + arg_name = "weights" elif base_name == "output": arg_name = "output_activations" else: diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 8f335d167..ff6ba0d9b 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -238,7 +238,7 @@ def _initialize_observers(self, module: torch.nn.Module): initialize_observer(module, base_name="weight") # output activations - elif output: + if output: initialize_observer(module, base_name="output") def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: @@ -268,7 +268,7 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: hooks.add(register_value_hook(module, calibrate_value_hook)) # output activations - elif output: + if output: hooks.add(self.register_hook(module, calibrate_output_hook, "forward")) return hooks From ba799886fa3b68fd4454a0a47ac2d0c96fd0aa52 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 25 Aug 2025 18:32:03 -0400 Subject: [PATCH 07/10] remove kv cache logic, make kv cache tests faster Signed-off-by: Kyle Sayers --- examples/transform/quip_example.py | 2 +- examples/transform/spinquant_example.py | 44 +--- .../modifiers/quantization/__init__.py | 1 - .../modifiers/quantization/cache.py | 208 ------------------ .../modifiers/quantization/calibration.py | 61 +---- .../transformers/kv_cache/test_kv_cache.py | 22 +- 6 files changed, 23 insertions(+), 315 deletions(-) delete mode 100644 src/llmcompressor/modifiers/quantization/cache.py diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py index 933754ea1..2c989b2d7 100644 --- a/examples/transform/quip_example.py +++ b/examples/transform/quip_example.py @@ -13,7 +13,7 @@ # Select model and load it. # NOTE: because the datafree pipeline is being used in this # example, you can use additional GPUs to support larger models -MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py index db394cf0b..3bb11f1a8 100644 --- a/examples/transform/spinquant_example.py +++ b/examples/transform/spinquant_example.py @@ -2,11 +2,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.modifiers.transform import SpinQuantModifier from llmcompressor.utils import dispatch_for_generation # Select model and load it. +# TODO: change back MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") @@ -49,41 +50,14 @@ def tokenize(sample): ) -# NOTE: currently only fused rotations (R1 & R2) are available -# Learned rotations and online rotations (R3 & R4) will be added -# in a future release. -# Configure the quantization algorithm to run. -# * apply spinquant transforms to model to reduce quantization loss -# * quantize the weights to 4 bit with group size 128 -from compressed_tensors.quantization import ( - QuantizationArgs, - QuantizationScheme, - QuantizationStrategy, - QuantizationType, -) -from compressed_tensors.quantization.quant_scheme import FP8 - -config_groups = { - # "attention": QuantizationScheme( - # targets=["LlamaAttention"], - # input_activations=QuantizationArgs( - # num_bits=8, - # type=QuantizationType.FLOAT, - # strategy=QuantizationStrategy.TENSOR, - # symmetric=False, - # ), - # ), - "linear": QuantizationScheme(targets=["Linear"], **FP8), -} - +# TODO recipe = [ - SpinQuantModifier(rotations=["R1"], transform_type="random-hadamard"), - #QuantizationModifier(config_groups=config_groups), - #QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), + SpinQuantModifier(rotations=["R3"], transform_type="random-hadamard"), + GPTQModifier(targets=["Linear"], scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. -oneshot(model=model, dataset=ds, recipe=recipe, pipeline="basic") +oneshot(model=model, dataset=ds, recipe=recipe) # Confirm generations of the quantized model look sane. print("\n\n") @@ -95,6 +69,6 @@ def tokenize(sample): print("==========================================\n\n") # Save to disk compressed. -# SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquant-R1R2R4-W4A16" -# model.save_pretrained(SAVE_DIR, save_compressed=True) -# tokenizer.save_pretrained(SAVE_DIR) +SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquant-W4A16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/quantization/__init__.py b/src/llmcompressor/modifiers/quantization/__init__.py index f1cdf596c..226869f39 100644 --- a/src/llmcompressor/modifiers/quantization/__init__.py +++ b/src/llmcompressor/modifiers/quantization/__init__.py @@ -1,5 +1,4 @@ # flake8: noqa -from .cache import * from .gptq import * from .quantization import * diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py deleted file mode 100644 index dd3640dda..000000000 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -from compressed_tensors.quantization.lifecycle import KVCacheScaleType -from compressed_tensors.quantization.quant_args import QuantizationArgs -from torch import Tensor -from transformers import DynamicCache - -from llmcompressor.observers import Observer - - -class QuantizedKVParameterCache(DynamicCache): - """ - Quantized KV cache used in the forward call based on HF's dynamic cache. - Quantization strategy (tensor, group, channel) set from Quantization arg's strategy - Singleton, so that the same cache gets reused in all forward call of self_attn. - Each time forward is called, .update() is called, and ._quantize(), ._dequantize() - gets called appropriately. - The size of tensor is - `[batch_size, num_heads, seq_len - residual_length, head_dim]`. - - - Triggered by adding kv_cache_scheme in the recipe. - - Example: - - ```python3 - recipe = ''' - quant_stage: - quant_modifiers: - QuantizationModifier: - kv_cache_scheme: - num_bits: 8 - type: float - strategy: tensor - dynamic: false - symmetric: true - ''' - - """ - - _instance = None - _initialized = False - - def __new__(cls, *args, **kwargs): - """Singleton""" - if cls._instance is None: - cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) - return cls._instance - - def __init__(self, quantization_args: QuantizationArgs): - if not self._initialized: - super().__init__() - - self.quantization_args = quantization_args - - self.k_observers: List[Observer] = [] - self.v_observers: List[Observer] = [] - - # each index corresponds to layer_idx of the attention layer - self.k_scales: List[Tensor] = [] - self.v_scales: List[Tensor] = [] - - self.k_zps: List[Tensor] = [] - self.v_zps: List[Tensor] = [] - - self._initialized = True - - def update( - self, - key_states: Tensor, - value_states: Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Get the k_scale and v_scale and output the - fakequant-ed key_states and value_states - """ - - if len(self.k_observers) <= layer_idx: - k_observer_name = self.quantization_args.observer - k_observer = Observer.load_from_registry( - k_observer_name, quantization_args=self.quantization_args - ) - v_observer_name = self.quantization_args.observer - v_observer = Observer.load_from_registry( - v_observer_name, quantization_args=self.quantization_args - ) - - # NOTE: User may ignore some layers in configuration, - # meaning len(self.k_observers) <= layer_idx-1 - # Must account for that case by padding list so that - # index of lists corresponds to layer_idx - _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) - _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) - - q_key_states = self._quantize( - key_states.contiguous(), KVCacheScaleType.KEY, layer_idx - ) - q_value_states = self._quantize( - value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx - ) - - qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx) - qdq_value_states = self._dequantize( - q_value_states, KVCacheScaleType.VALUE, layer_idx - ) - - keys_to_return, values_to_return = qdq_key_states, qdq_value_states - - return keys_to_return, values_to_return - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """ - Returns the sequence length of the cached states. - A layer index can be optionally passed. - """ - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and - # rely on `_seen_tokens` which is updated every "layer_idx" == 0, - # this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to - # verify attn_weight shape in some models - return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 - - def reset_states(self): - """reset the kv states (used in calibration)""" - self.key_cache: List[Tensor] = [] - self.value_cache: List[Tensor] = [] - # Used in `generate` to keep tally of how many tokens the cache has seen - self._seen_tokens = 0 - self._quantized_key_cache: List[Tensor] = [] - self._quantized_value_cache: List[Tensor] = [] - - def reset(self): - """ - Reset the instantiation, create new instance on init - """ - QuantizedKVParameterCache._instance = None - QuantizedKVParameterCache._initialized = False - - def _quantize(self, tensor, kv_type, layer_idx): - """Quantizes a key/value using a defined quantization method.""" - from compressed_tensors.quantization.lifecycle.forward import quantize - - if kv_type == KVCacheScaleType.KEY: # key type - observer = self.k_observers[layer_idx] - scales = self.k_scales - zps = self.k_zps - else: - assert kv_type == KVCacheScaleType.VALUE - observer = self.v_observers[layer_idx] - scales = self.v_scales - zps = self.v_zps - - scale, zp = observer(tensor) - _pad_and_append_at_idx_(scales, layer_idx, scale) - _pad_and_append_at_idx_(zps, layer_idx, zp) - - q_tensor = quantize( - x=tensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return q_tensor - - def _dequantize(self, qtensor, kv_type, layer_idx): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - from compressed_tensors.quantization.lifecycle.forward import dequantize - - if kv_type == KVCacheScaleType.KEY: - scale = self.k_scales[layer_idx] - zp = self.k_zps[layer_idx] - else: - assert kv_type == KVCacheScaleType.VALUE - scale = self.v_scales[layer_idx] - zp = self.v_zps[layer_idx] - - qdq_tensor = dequantize( - x_q=qtensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return qdq_tensor - - -# NOTE: Using _ suffix to denote l is modified in place -def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list: - """ - Append value val to list lst at index idx, right padding if necessary - Needed because user may ignore some layers in configuration, meaning - len(lst) <= idx-1 - - >>> _pad_and_append_at_idx_([0,1,2], 5, 5) - [0, 1, 2, None, None, 5] - >>> _pad_and_append_at_idx_([0,1,2], 3, 8) - [0, 1, 2, 8] - >>> _pad_and_append_at_idx_([0,1,2], 1, 5) - [0, 5, 2] - """ - num_to_pad = idx - len(lst) + 1 - if num_to_pad > 0: - lst += [None] * num_to_pad - lst[idx] = val - return lst diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 3aa57eeb2..824dbd7b6 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -1,21 +1,17 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import torch from compressed_tensors.quantization import ( DynamicType, - KVCacheScaleType, QuantizationArgs, - QuantizationScheme, QuantizationStatus, QuantizationStrategy, ) from compressed_tensors.quantization.lifecycle.forward import forward_quantize -from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme from compressed_tensors.utils import align_module_device, update_offload_parameter from loguru import logger from torch.nn import Module -from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain @@ -30,9 +26,6 @@ "update_weight_zp_scale", "calibrate_input_hook", "calibrate_output_hook", - "calibrate_kv_cache_input_hook", - "calibrate_kv_cache_output_hook", - "initialize_quantized_kv_cache", "freeze_module_quantization", "apply_calibration_status", "reset_quantization_status", @@ -252,53 +245,6 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): return output -def calibrate_kv_cache_input_hook( - module: Module, args: Any, kwargs: Dict[str, Any] -) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - """ - Hook to update inputs to attention layers when running - kv_cache quantization. Will update the passed in - kv_cache to singleton QuantizedKVParameterCache. - """ - kv_cache = getattr(module, "kv_cache") - kwargs["past_key_value"] = kv_cache - kwargs["use_cache"] = False - return args, kwargs - - -def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor): - """ - Hook to update k_scale and v_scale parameters when running kv_cache quantization. - """ - kv_cache = getattr(module, "kv_cache") - k_scale = kv_cache.k_scales[module.layer_idx] - v_scale = kv_cache.v_scales[module.layer_idx] - update_offload_parameter(module, KVCacheScaleType.KEY.value, k_scale) - update_offload_parameter(module, KVCacheScaleType.VALUE.value, v_scale) - - -def initialize_quantized_kv_cache(module: Module): - """ - Initialize a quantized kv_cache on a module (analogous to initializing an observer) - When a config specifying kv_cache quantization is applied to a model, the kv_cache - args are redefined as the output_activations targeting attention modules. - - This function should be called on attention modules with output_activations - """ - scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None) - existing_kv_cache = getattr(module, "kv_cache", None) - - if ( - scheme is None - or not is_kv_cache_quant_scheme(scheme) - or isinstance(existing_kv_cache, QuantizedKVParameterCache) - ): - return - - quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations) - setattr(module, "kv_cache", quantized_kv_cache) - - def apply_calibration_status(module: Module): scheme = getattr(module, "quantization_scheme", None) if not scheme: @@ -330,11 +276,6 @@ def freeze_module_quantization(module: Module): if hasattr(module, obs_name): delattr(module, obs_name) - # remove quantized kv_cache - kv_cache = getattr(module, "kv_cache", None) - if isinstance(kv_cache, QuantizedKVParameterCache): - delattr(module, "kv_cache") - module.quantization_status = QuantizationStatus.FROZEN diff --git a/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py b/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py index 7038c42d4..ff0ea2bda 100644 --- a/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py +++ b/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py @@ -3,7 +3,7 @@ import pytest from accelerate import init_empty_weights -from compressed_tensors.quantization import KVCacheScaleType, is_attention_module +from compressed_tensors.quantization import is_attention_module from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.utils.quantization_config import CompressedTensorsConfig @@ -14,7 +14,7 @@ NUM_CALIBRATION_SAMPLES = 16 MAX_SEQUENCE_LENGTH = 512 DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" +DATASET_SPLIT = f"train_sft[:{NUM_CALIBRATION_SAMPLES}]" MODEL_IDS = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", @@ -49,9 +49,11 @@ def _oneshot_fixture(tmp_path: Path): symmetric=symmetric, ) oneshot_args = dict( - dataset="open_platypus", recipe=recipe, - num_calibration_samples=16, + dataset="open_platypus", + splits={"calibration": f"train[:{NUM_CALIBRATION_SAMPLES}]"}, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + max_seq_length=MAX_SEQUENCE_LENGTH, ) for model_id in MODEL_IDS: oneshot_args["output_dir"] = os.path.join(tmp_path, model_id) @@ -161,8 +163,8 @@ def test_kv_cache_model_state_dict_attr(oneshot_fixture, tmp_path): for name, submodule in model.named_modules(): if is_attention_module(submodule): counts += 1 - assert hasattr(submodule, KVCacheScaleType.VALUE.value) - assert hasattr(submodule, KVCacheScaleType.KEY.value) + assert hasattr(submodule, "v_scale") + assert hasattr(submodule, "k_scale") assert counts > 0 @@ -200,8 +202,8 @@ def test_kv_cache_gptq_config_format(kv_cache_fixture, tmp_path): for name, submodule in model.named_modules(): if is_attention_module(submodule): counts += 1 - assert hasattr(submodule, KVCacheScaleType.VALUE.value) - assert hasattr(submodule, KVCacheScaleType.KEY.value) + assert hasattr(submodule, "v_scale") + assert hasattr(submodule, "k_scale") assert counts > 0 @@ -244,7 +246,7 @@ def test_kv_cache_gptq_model_state_dict_attr(kv_cache_fixture, tmp_path): for name, submodule in model.named_modules(): if is_attention_module(submodule): counts += 1 - assert hasattr(submodule, KVCacheScaleType.VALUE.value) - assert hasattr(submodule, KVCacheScaleType.KEY.value) + assert hasattr(submodule, "v_scale") + assert hasattr(submodule, "k_scale") assert counts > 0 From d26075c557b0476eef784f2bb21d4c47c0f8c76a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 25 Aug 2025 19:06:52 -0400 Subject: [PATCH 08/10] wip: multiple qconfig application Signed-off-by: Kyle Sayers --- examples/quantization_w4a16/llama3_example.py | 7 +++-- .../modifiers/quantization/calibration.py | 28 +++++++++++++++++++ .../quantization/quantization/mixin.py | 3 +- src/llmcompressor/modifiers/utils/hooks.py | 14 +++++++++- 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index b729be003..572e8f435 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -6,7 +6,7 @@ from llmcompressor.utils import dispatch_for_generation # Select model and load it. -model_id = "meta-llama/Meta-Llama-3-8B-Instruct" +model_id = "meta-llama/Llama-3.1-8B-Instruct" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -51,7 +51,10 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) +recipe = [ + GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), + GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), +] # Apply algorithms. oneshot( diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 824dbd7b6..52f74746b 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -20,11 +20,15 @@ DEFAULT_AVERAGING_CONSTANT = 0.01 DEFAULT_GRID = 100.0 DEFAULT_NORM = 2.4 +ALL_OBSERVER_BASE_NAMES = {"input", "weight", "output", "q", "k", "v"} __all__ = [ "initialize_observer", "update_weight_zp_scale", "calibrate_input_hook", + "calibrate_query_hook", + "calibrate_key_hook", + "calibrate_value_hook", "calibrate_output_hook", "freeze_module_quantization", "apply_calibration_status", @@ -44,6 +48,7 @@ def initialize_observer(module: Module, base_name: str): :param base_name: str used to name the observer attribute """ + # resolve arg name in scheme if base_name == "weight": arg_name = "weights" elif base_name == "output": @@ -279,7 +284,30 @@ def freeze_module_quantization(module: Module): module.quantization_status = QuantizationStatus.FROZEN +ALL_CALIBRATION_HOOKS = { + calibrate_input_hook, + calibrate_query_hook, + calibrate_key_hook, + calibrate_value_hook, + calibrate_output_hook, +} + + def reset_quantization_status(model: Module): + from llmcompressor.modifiers.utils.hooks import HooksMixin + for module in model.modules(): + # reset status if hasattr(module, "quantization_status"): delattr(module, "quantization_status") + + # reset observers + for base_name in ALL_OBSERVER_BASE_NAMES: + attr_name = f"{base_name}_observer" + if hasattr(module, attr_name): + delattr(module, attr_name) + + # remove hooks (note that removal is idempotent) + for handle_id, hook in module._forward_hooks.items(): + if hook in ALL_CALIBRATION_HOOKS: + HooksMixin.remove_hooks_by_id(set(handle_id)) diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index ff6ba0d9b..774308ce5 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -123,7 +123,8 @@ def initialize_quantization(self, model: torch.nn.Module): :param model: model to attach schemes and observers to """ - reset_quantization_status(model) # reset any previously applied qconfigs + # reset previous statuses, observers, and calibration hooks + reset_quantization_status(model) # apply scheme and status to model config = self.resolve_quantization_config() diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 98d5240e2..6a76bbcfe 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,6 +1,7 @@ import contextlib +from copy import deepcopy from functools import wraps -from typing import Any, Callable, ClassVar, Optional, Set, Union +from typing import Any, Callable, ClassVar, Dict, Optional, Set, Union import torch from loguru import logger @@ -39,6 +40,7 @@ class HooksMixin(BaseModel): # attached to global HooksMixin class _HOOKS_DISABLED: ClassVar[bool] = False _HOOKS_KEEP_ENABLED: ClassVar[Set[RemovableHandle]] = set() + _HOOKS_TO_MODIFIER: ClassVar[Dict[RemovableHandle, "HooksMixin"]] = dict() # attached to local subclasses _hooks: Set[RemovableHandle] = set() @@ -95,6 +97,7 @@ def wrapped_hook(*args, **kwargs): register_function = getattr(target, f"register_{hook_type}_hook") handle = register_function(wrapped_hook, **kwargs) self._hooks.add(handle) + self._HOOKS_TO_MODIFIER[handle] = self logger.debug(f"{self} added {handle}") return handle @@ -113,3 +116,12 @@ def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None): hook.remove() self._hooks -= handles + self._HOOKS_TO_MODIFIER -= handles + + @classmethod + def remove_hooks_by_id(cls, ids: Set[int]): + handles = deepcopy(cls._HOOKS_TO_MODIFIER) + for handle in handles: + if handle.id in ids: + modifier = cls._HOOKS_TO_MODIFIER[handle] + modifier.remove_hooks(set(handle)) From f27d518aa8de82571e169d6df8aa856efb6ce0cb Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 26 Aug 2025 14:02:00 -0400 Subject: [PATCH 09/10] fix typo Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/utils/hooks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 6a76bbcfe..f3c616491 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -116,7 +116,8 @@ def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None): hook.remove() self._hooks -= handles - self._HOOKS_TO_MODIFIER -= handles + for handle in handles: + self._HOOKS_TO_MODIFIER.pop(handle, None) @classmethod def remove_hooks_by_id(cls, ids: Set[int]): From e8b0099173a58ae4928140da772477f53f93c1ff Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 26 Aug 2025 18:03:15 -0400 Subject: [PATCH 10/10] remove unused tests, temporarily disable torchcodec for testing so nightly will run? Signed-off-by: Kyle Sayers --- setup.py | 2 +- .../modifiers/calibration/test_cache.py | 118 ------------------ .../modifiers/calibration/test_kv_cache.py | 94 -------------- 3 files changed, 1 insertion(+), 213 deletions(-) delete mode 100644 tests/llmcompressor/modifiers/calibration/test_cache.py delete mode 100644 tests/llmcompressor/modifiers/calibration/test_kv_cache.py diff --git a/setup.py b/setup.py index 42052e138..7b5ef2f2c 100644 --- a/setup.py +++ b/setup.py @@ -189,7 +189,7 @@ def localversion_func(version: ScmVersion) -> str: "torchvision", "librosa", "soundfile", - "torchcodec", + #"torchcodec", # linting, formatting, and type checking "black~=24.4.2", "isort~=5.13.2", diff --git a/tests/llmcompressor/modifiers/calibration/test_cache.py b/tests/llmcompressor/modifiers/calibration/test_cache.py deleted file mode 100644 index 9b03234cf..000000000 --- a/tests/llmcompressor/modifiers/calibration/test_cache.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs - -from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache -from llmcompressor.observers import Observer - - -def test_is_quantized_cache_singleton(): - """ - Check if quantized_cache is a singleton, used for - passing in QuantizedKVParameterCache to the forward call of - the model's self_attn - """ - - args = QuantizationArgs() - cache = QuantizedKVParameterCache(args) - observer = args.observer - observer = Observer.load_from_registry(observer, quantization_args=args) - - tensor = torch.tensor([1, 2, 3]) - cache.k_scales.append(tensor) - cache.k_observers.append(observer) - - same_cache = QuantizedKVParameterCache(args) - - assert len(cache.k_scales) == len(same_cache.k_scales) - assert torch.equal(cache.k_scales[0], same_cache.k_scales[0]) - - assert cache.k_observers == same_cache.k_observers - assert hex(id(cache.k_observers[0])) == hex(id(same_cache.k_observers[0])) - - cache.reset() - - -def test_update(): - num_bits = 8 - args = QuantizationArgs(num_bits=num_bits, symmetric=True) - cache = QuantizedKVParameterCache(args) - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - denom = (2 ** (num_bits) - 1) / 2 - expected_k_scale = torch.tensor([max_key_states_val / denom]) - expected_v_scale = torch.tensor([max_value_states_val / denom]) - - assert cache.k_scales[0] == expected_k_scale - assert cache.v_scales[0] == expected_v_scale - - # new attn layer - layer_idx = 1 - cache.update(key_states, value_states, layer_idx) - - assert len(cache.k_scales) == 2 - assert len(cache.v_scales) == 2 - - assert len(cache.k_observers) == 2 - assert len(cache.v_observers) == 2 - - cache.reset() - - -def test_cache_reset(): - num_bits = 8 - args = QuantizationArgs(num_bits=num_bits, symmetric=True) - cache = QuantizedKVParameterCache(args) - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - assert len(cache.k_scales) == 1 - assert len(cache.v_scales) == 1 - - assert len(cache.k_observers) == 1 - assert len(cache.v_observers) == 1 - - cache.reset() - - # new instance, different memory addr - different_cache = QuantizedKVParameterCache(args) - - assert len(different_cache.k_scales) == 0 - assert len(different_cache.v_scales) == 0 - - assert len(different_cache.k_observers) == 0 - assert len(different_cache.v_observers) == 0 - - assert hex(id(cache)) != hex(id(different_cache)) diff --git a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py b/tests/llmcompressor/modifiers/calibration/test_kv_cache.py deleted file mode 100644 index b22e7ec40..000000000 --- a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch -from compressed_tensors.quantization import ( - QuantizationConfig, - QuantizationStatus, - apply_quantization_config, - is_attention_module, -) -from transformers import AutoModelForCausalLM - -from llmcompressor.modifiers.quantization.calibration import ( - calibrate_kv_cache_input_hook, - calibrate_kv_cache_output_hook, - freeze_module_quantization, - initialize_quantized_kv_cache, -) - -config = { - "quant_method": "compressed-tensors", - "format": "fakequant", - "kv_cache_scheme": { - "num_bits": 8, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "config_groups": { - "group_1": { - "weights": { - "num_bits": 4, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "targets": ["Linear"], - }, - }, -} - - -def _prep_for_calibration(module: torch.nn.Module): - if is_attention_module(module): - module.register_forward_pre_hook( - calibrate_kv_cache_input_hook, with_kwargs=True - ) - module.register_forward_hook(calibrate_kv_cache_output_hook) - module.quantization_status = QuantizationStatus.CALIBRATION - - -@pytest.mark.parametrize("config", [config]) -def test_kv_cache_quantization(config): - sample = { - name: torch.ones((1, 32)).long() - for name in ["input_ids", "attention_mask", "labels"] - } - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", - torch_dtype="auto", - ) - model.eval() - - config = QuantizationConfig(**config) - config.quantization_status = QuantizationStatus.CALIBRATION - apply_quantization_config(model, config) - model.apply(initialize_quantized_kv_cache) - model.apply(_prep_for_calibration) - - with torch.no_grad(): - _ = model(**sample) - - model.apply(freeze_module_quantization) - - reloaded_config = QuantizationConfig.from_pretrained(model) - - assert ( - config.kv_cache_scheme.model_dump().keys() - == reloaded_config.kv_cache_scheme.model_dump().keys() - ) - assert list(config.kv_cache_scheme.model_dump().values()) == list( - reloaded_config.kv_cache_scheme.model_dump().values() - )