diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py new file mode 100644 index 00000000..1d86e78b --- /dev/null +++ b/src/compressed_tensors/modeling/attention.py @@ -0,0 +1,143 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Optional + +import torch +from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, + forward_quantize, +) +from compressed_tensors.quantization.lifecycle.initialize import ( + _initialize_scale_zero_point, +) +from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule +from torch.utils.hooks import RemovableHandle +from transformers import AttentionInterface, PreTrainedModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + +__all__ = ["IMPL_ATTR", "QuantizedAttentionImpl"] + + +IMPL_ATTR = "impl" +_original_impl = "eager" # mutable + + +class QuantizedAttentionImpl(InternalModule): + def __init__(self, attn_module: torch.nn.Module): + super().__init__() + self.attn_module_container = [attn_module] # avoid circular reference + self._qparams_initialized = False + + def forward( + self, + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + *args, + **kwargs, + ): + # quantization + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) + if quant_args is not None and quant_enabled and self._qparams_initialized: + query = forward_quantize(module, query, "q", quant_args) + + # original attention + return ALL_ATTENTION_FUNCTIONS[_original_impl]( + module, + query, + key, + value, + *args, + **kwargs, + ) + + def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module): + assert module is self.attn_module_container[0] + scheme: Optional[QuantizationScheme] = getattr( + module, "quantization_scheme", None + ) + quant_args: Optional[QuantizationArgs] = getattr( + scheme, "input_activations", None + ) + + if ( + not self._qparams_initialized + and quant_args is not None + and not scheme.kv_cache_only + ): + # TODO: use model.config.num_attention_heads to find query_size + assert quant_args.strategy == QuantizationStrategy.TENSOR + _initialize_scale_zero_point(module, "q", quant_args) + self._qparams_initialized = True + + +# ----- initialize ----- # + + +def ct_hooked_attention(module: torch.nn.Module, *args, **kwargs): + if hasattr(module, IMPL_ATTR): + return module.impl(module, *args, **kwargs) + else: + return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) + + +def initialize_hooked_attention( + model: PreTrainedModel, module: torch.nn.Module, quantize: bool = True +): + if not hasattr(module, IMPL_ATTR): + module.register_module(IMPL_ATTR, QuantizedAttentionImpl(module)) + if model.config._attn_implementation != "ct_hooked_attention": + # assumes only one model at a time + global _original_impl + _original_impl = model.config._attn_implementation + + AttentionInterface.register("ct_hooked_attention", ct_hooked_attention) + model.config._attn_implementation = "ct_hooked_attention" + + impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR) + if quantize: + impl.initialize_qparams_once(model, module) + + initialize_hooked_kv_cache(model, module, quantize=quantize) + + +# ----- hooks ----- # + + +def register_query_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: + """ + Registers a forward pre-hook on `module.impl` that replaces the `query` argument + with `hook(mod, query)` (handles both positional and keyword forms). + """ + impl = getattr(module, IMPL_ATTR) + + def _hook(impl: QuantizedAttentionImpl, args, kwargs): + bound = inspect.signature(module.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["query"]) + if value is not None: + bound.arguments["query"] = value + + return bound.args, bound.kwargs + + return impl.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py new file mode 100644 index 00000000..f26fca08 --- /dev/null +++ b/src/compressed_tensors/modeling/kvcache.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Optional, Tuple + +import torch +import transformers +from compressed_tensors.quantization import QuantizationStrategy, forward_quantize +from compressed_tensors.quantization.lifecycle.initialize import ( + _initialize_scale_zero_point, +) +from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule +from packaging import version +from torch import Tensor +from torch.utils.hooks import RemovableHandle +from transformers import Cache, PreTrainedModel + + +__all__ = ["KV_CACHE_ATTR", "QuantizedKVCache"] + + +KV_CACHE_ATTR = "kv_cache" + + +class QuantizedKVCache(InternalModule): + def __init__(self, attn_module: torch.nn.Module): + super().__init__() + self.attn_module_container = [attn_module] # avoid nn.Module circular reference + self.past_key_values: Optional[Cache] = None + self._qparams_initialized = False + + def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: + return self(*args, **kwargs) + + def forward( + self, + key_states: Tensor, + value_states: Tensor, + *args, + **kwargs, + ) -> Tuple[Tensor, Tensor]: + # quantization + module = self.attn_module_container[0] + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) + if quant_args is not None and quant_enabled and self._qparams_initialized: + key_states = forward_quantize(module, key_states, "k", quant_args) + value_states = forward_quantize(module, value_states, "v", quant_args) + + # original cache + if self.past_key_values is not None: + ret = self.past_key_values.update(key_states, value_states, *args, **kwargs) + else: + ret = (key_states, value_states) + + self.past_key_values = None + return ret + + def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module): + assert module is self.attn_module_container[0] + scheme = getattr(module, "quantization_scheme", None) + quant_args = getattr(scheme, "input_activations", None) + + if not self._qparams_initialized and quant_args is not None: + # TODO: use model.config.num_key_value_heads to find key_size, value_size + assert quant_args.strategy == QuantizationStrategy.TENSOR + _initialize_scale_zero_point(module, "k", quant_args) + _initialize_scale_zero_point(module, "v", quant_args) + self._qparams_initialized = True + + +# ----- initialize ----- # + + +def initialize_hooked_kv_cache( + model: PreTrainedModel, module: torch.nn.Module, quantize: bool = False +): + if not hasattr(module, KV_CACHE_ATTR): + module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module)) + module.register_forward_pre_hook(kv_cache_attention_hook, with_kwargs=True) + + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + if quantize: + kv_cache.initialize_qparams_once(model, module) + + +def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs): + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + _past_kv_name = ( + "past_key_value" + if version.parse(transformers.__version__) <= version.parse("4.55.4") + else "past_key_values" # transformers#39956 + ) + kv_cache.past_key_values = kwargs.get(_past_kv_name, None) + kwargs[_past_kv_name] = kv_cache + + return args, kwargs + + +# ----- hooks ----- # + + +def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["key_states"]) + if value is not None: + bound.arguments["key_states"] = value + + return bound.args, bound.kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) + + +def register_value_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["value_states"]) + if value is not None: + bound.arguments["value_states"] = value + + return bound.args, bound.kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index ded9fee9..445e47be 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -13,11 +13,8 @@ # limitations under the License. import logging -from collections import OrderedDict from copy import deepcopy -from typing import Dict, Iterable, List, Optional -from typing import OrderedDict as OrderedDictType -from typing import Union +from typing import Any, Dict, Iterable, List, Optional, Union import torch from compressed_tensors.config import CompressionFormat @@ -26,24 +23,29 @@ ) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, + is_attention_module, ) -from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import ( QuantizationConfig, QuantizationStatus, ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( - KV_CACHE_TARGETS, + ATTN_TARGETS, infer_quantization_status, - is_kv_cache_quant_scheme, ) -from compressed_tensors.utils.helpers import deprecated, replace_module -from compressed_tensors.utils.match import match_named_modules, match_targets -from compressed_tensors.utils.offload import update_parameter_data -from compressed_tensors.utils.safetensors_load import get_safetensors_folder +from compressed_tensors.utils import ( + deprecated, + get_safetensors_folder, + is_narrow_match, + match_named_modules, + match_targets, + replace_module, + update_parameter_data, +) from safetensors import safe_open from torch.nn import Module +from transformers import PreTrainedModel __all__ = [ @@ -114,7 +116,13 @@ def load_pretrained_quantization_parameters( def apply_quantization_config( +<<<<<<< HEAD + model: PreTrainedModel, + config: Union[QuantizationConfig, None], + run_compressed: bool = False, +======= model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False +>>>>>>> origin ): """ Initializes the model for quantization in-place based on the given config. @@ -126,6 +134,7 @@ def apply_quantization_config( decompressed fully on load """ from compressed_tensors.linear.compressed_linear import CompressedLinear + from compressed_tensors.modeling.attention import initialize_hooked_attention config = deepcopy(config) if config is None: # see PR #180 @@ -134,41 +143,57 @@ def apply_quantization_config( # preprocess to support kv cache scheme config = process_quantization_config(config) - # build mapping of targets to schemes for easier matching - # use ordered dict to preserve target ordering in config - target_to_scheme = OrderedDict() - for scheme in config.config_groups.values(): - for target in scheme.targets: - target_to_scheme[target] = scheme - # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in match_named_modules( - model, target_to_scheme, config.ignore, warn_on_fail=True - ): - # mark modules to be quantized by adding - # quant scheme to the matching layers - matched_targets = match_targets(name, submodule, target_to_scheme) - scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) - # target matched - add layer and scheme to target list - submodule.quantization_scheme = scheme - - # replace with run compressed if applicable - # FUTURE: move this to model compressor - if isinstance(submodule, torch.nn.Linear) and run_compressed: - format = config.format - if format != CompressionFormat.dense.value: - if isinstance(submodule, torch.nn.Linear): - # TODO: expand to more module types - compressed_linear = CompressedLinear.from_linear( - submodule, - quantization_scheme=scheme, - quantization_format=format, - ) - replace_module(model, name, compressed_linear) - - # apply current quantization status across all targeted layers + for scheme in config.config_groups.values(): + for name, submodule in match_named_modules( + model, scheme.targets, config.ignore or [], warn_on_fail=True + ): + # attach scheme to module (with merging) + attach_scheme(submodule, scheme) + + # replace with run compressed if applicable + # FUTURE: move this to model compressor + if isinstance(submodule, torch.nn.Linear): + if run_compressed: + format = config.format + if format != CompressionFormat.dense.value: + # TODO: expand to more module types + compressed_linear = CompressedLinear.from_linear( + submodule, + quantization_scheme=scheme, + quantization_format=format, + ) + replace_module(model, name, compressed_linear) + + # attention quantization and/or kv cache quantization + if is_attention_module(submodule): + if is_narrow_match(model, scheme.targets, name): + # unlike linear, do qparam initialization here (once) + initialize_hooked_attention(model, submodule, quantize=True) + else: + # do not quantize attention unless specifically targeted + delattr(submodule, "quantization_scheme") + + # apply current quantization status across all targeted linear/embedding layers apply_quantization_status(model, config.quantization_status) + # attach config for serialization + attach_config(model, config) + + +def attach_scheme(module: Module, scheme: QuantizationScheme) -> QuantizationScheme: + if existing_scheme := getattr(module, "quantization_scheme", None): + scheme = scheme.merge(existing_scheme) + + setattr(module, "quantization_scheme", scheme) + return scheme + + +def attach_config(model: PreTrainedModel, config: QuantizationConfig): + if existing_config := getattr(model, "quantization_config", None): + config = config.merge(existing_config) + setattr(model, "quantization_config", config) + def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: """ @@ -183,9 +208,7 @@ def process_quantization_config(config: QuantizationConfig) -> QuantizationConfi return config -def process_kv_cache_config( - config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS -) -> QuantizationConfig: +def process_kv_cache_config(config: QuantizationConfig) -> QuantizationConfig: """ Reformulate the `config.kv_cache` as a `config_group` and add it to the set of existing `config.groups` @@ -193,16 +216,14 @@ def process_kv_cache_config( :param config: the QuantizationConfig :return: the QuantizationConfig with additional "kv_cache" group """ - if targets == KV_CACHE_TARGETS: - _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}") + _LOGGER.info(f"KV cache targets set to default value of: {ATTN_TARGETS}") - kv_cache_dict = config.kv_cache_scheme.model_dump() - kv_cache_scheme = QuantizationScheme( - output_activations=QuantizationArgs(**kv_cache_dict), - targets=targets, + scheme = QuantizationScheme( + targets=ATTN_TARGETS, + input_activations=config.kv_cache_scheme, + kv_cache_only=True, ) - kv_cache_group = dict(kv_cache=kv_cache_scheme) - config.config_groups.update(kv_cache_group) + config.config_groups.update({"kv_cache": scheme}) return config @@ -263,14 +284,6 @@ def find_name_or_class_matches( return match_targets(name, module, targets) -def _infer_status(model: Module) -> Optional[QuantizationStatus]: - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - def _load_quant_args_from_mapping( base_name: str, module_name: str, module: Module, mapping: Dict ): @@ -313,67 +326,3 @@ def _load_quant_args_from_mapping( state_dict_zp = f.get_tensor(f"{module_name}.{zp_name}") update_parameter_data(module, state_dict_zp, zp_name) - - -def _scheme_from_targets( - target_to_scheme: OrderedDictType[str, QuantizationScheme], - targets: List[str], - name: str, -) -> QuantizationScheme: - if len(targets) == 1: - # if `targets` iterable contains a single element - # use it as the key - return target_to_scheme[targets[0]] - - # otherwise, we need to merge QuantizationSchemes corresponding - # to multiple targets. This is most likely because `name` module - # is being target both as an ordinary quantization target, as well - # as kv cache quantization target - schemes_to_merge = [target_to_scheme[target] for target in targets] - return _merge_schemes(schemes_to_merge, name) - - -def _merge_schemes( - schemes_to_merge: List[QuantizationScheme], name: str -) -> QuantizationScheme: - kv_cache_quantization_scheme = [ - scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme) - ] - if not kv_cache_quantization_scheme: - # if the schemes_to_merge do not contain any - # kv cache QuantizationScheme - # return the first scheme (the prioritized one, - # since the order of schemes_to_merge matters) - return schemes_to_merge[0] - else: - # fetch the kv cache QuantizationScheme and the highest - # priority non-kv cache QuantizationScheme and merge them - kv_cache_quantization_scheme = kv_cache_quantization_scheme[0] - quantization_scheme = [ - scheme - for scheme in schemes_to_merge - if not is_kv_cache_quant_scheme(scheme) - ][0] - schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme] - merged_scheme = {} - for scheme in schemes_to_merge: - scheme_dict = { - k: v for k, v in scheme.model_dump().items() if v is not None - } - # when merging multiple schemes, the final target will be - # the `name` argument - hence erase the original targets - del scheme_dict["targets"] - # make sure that schemes do not "clash" with each other - overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys()) - if overlapping_keys: - raise ValueError( - f"The module: {name} is being modified by two clashing " - f"quantization schemes, that jointly try to override " - f"properties: {overlapping_keys}. Fix the quantization config " - "so that it is not ambiguous." - ) - merged_scheme.update(scheme_dict) - - merged_scheme.update(targets=[name]) - - return QuantizationScheme(**merged_scheme) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index c9430e9e..9ee25d72 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -82,53 +82,41 @@ def initialize_module_for_quantization( # no scheme passed and layer not targeted for quantization - skip return - if is_attention_module(module): - # quantized actions based on calltime status - _initialize_attn_scales(module) + if not isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): + return - else: - if scheme.input_activations is not None: - _initialize_scale_zero_point( - module, - "input", - scheme.input_activations, - force_zero_point=force_zero_point, - scale_dtype=scale_dtype, - ) - - if scheme.weights is not None: - if hasattr(module, "weight"): - weight_shape = None - if isinstance(module, torch.nn.Linear): - weight_shape = module.weight.shape - _initialize_scale_zero_point( - module, - "weight", - scheme.weights, - weight_shape=weight_shape, - force_zero_point=force_zero_point, - scale_dtype=scale_dtype, - ) - else: - _LOGGER.warning( - f"module type {type(module)} targeted for weight quantization but " - "has no attribute weight, skipping weight quantization " - f"for {type(module)}" - ) + if scheme.input_activations is not None: + _initialize_scale_zero_point( + module, + "input", + scheme.input_activations, + force_zero_point=force_zero_point, + scale_dtype=scale_dtype, + ) - if scheme.output_activations is not None: - if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point( - module, "output", scheme.output_activations, scale_dtype=scale_dtype - ) + if scheme.weights is not None: + weight_shape = module.weight.shape + _initialize_scale_zero_point( + module, + "weight", + scheme.weights, + weight_shape=weight_shape, + force_zero_point=force_zero_point, + scale_dtype=scale_dtype, + ) + + if scheme.output_activations is not None: + _initialize_scale_zero_point( + module, "output", scheme.output_activations, scale_dtype=scale_dtype + ) - module.quantization_scheme = scheme - module.quantization_status = QuantizationStatus.INITIALIZED + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED - with disable_hf_hook(module): - # wrap forward call of module to perform - # quantized actions based on calltime status - wrap_module_forward_quantized(module, scheme) + with disable_hf_hook(module): + # wrap forward call of module to perform + # quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) def is_attention_module(module: Module): @@ -213,7 +201,9 @@ def _initialize_scale_zero_point( expected_shape = 1 # 3. Identify quantization scale and zp dtype - scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype + scale_dtype = ( + scale_dtype if scale_dtype is not None else next(module.parameters()).dtype + ) if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9e88353..34bdb829 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -356,6 +356,9 @@ def pytorch_dtype(self) -> torch.dtype: else: raise ValueError(f"Invalid quantization type {self.type}") + def is_online(self) -> bool: + return self.dynamic is True + @deprecated("QuantizationArgs.observer") def get_observer(self) -> str: return self.observer diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 42df3a33..ecdc8ae2 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -26,7 +26,7 @@ module_type, parse_out_kv_cache_args, ) -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from torch.nn import Module @@ -58,13 +58,6 @@ class QuantizationStatus(str, Enum): FROZEN = "frozen" COMPRESSED = "compressed" - @classmethod - def lifecycle_order(cls) -> List["QuantizationStatus"]: - """ - :return: list of correct quantization lifecycle order - """ - return - def __ge__(self, other): if other is None: return True @@ -102,7 +95,7 @@ def __le__(self, other): ] DEFAULT_QUANTIZATION_METHOD = "compressed-tensors" -DEFAULT_QUANTIZATION_FORMAT = "fakequant" +DEFAULT_QUANTIZATION_FORMAT = "fakequant" # TODO: remove class QuantizationConfig(BaseModel): @@ -163,92 +156,68 @@ def to_dict(self): # for compatibility with HFQuantizer return self.model_dump() - @staticmethod - def from_pretrained( - model: Module, format: Optional[str] = None - ) -> Optional["QuantizationConfig"]: - """ - Converts a model into its associated QuantizationConfig based on the - QuantizationScheme attached to each quantized module + def merge(self, other: "QuantizationConfig") -> "QuantizationConfig": + def merge_field(field_name: str, value_a: Any, value_b: Any) -> Any: + if field_name == "config_groups": + return value_a | value_b - :param model: model to calculate quantization scheme of - :return: filled out QuantizationScheme for the input model - """ - quant_scheme_to_layers = [] - quantization_status = None - ignore = {} - quantization_type_names = set() - for name, submodule in model.named_modules(): - layer_type = module_type(submodule) - if not is_module_quantized(submodule): - if layer_type not in ignore: - ignore[layer_type] = [] - ignore[layer_type].append(name) - else: - quantization_status = submodule.quantization_status - scheme = submodule.quantization_scheme - quantization_type_names.add(layer_type) - - match_found = False - for existing_scheme in quant_scheme_to_layers: - if scheme == existing_scheme: - match_found = True - break - if not match_found: - quant_scheme_to_layers.append(scheme) - - if len(quant_scheme_to_layers) == 0: # No quantized layers - return None - - # kv-cache only, no weight/activation quantization - if ( - len(quantization_type_names) == 1 - and "attention" in list(quantization_type_names)[0].lower() - ): - quantization_type_names.add("Linear") - - # clean up ignore list, we can leave out layers types if none of the - # instances are quantized - consolidated_ignore = [] - for layer_type, ignore_names in ignore.items(): - if layer_type in quantization_type_names: - # specific layers of a quantized type are ignored - consolidated_ignore += ignore_names - # else we leave it off the ignore list, doesn't fall under any of the - # existing quantization schemes so it won't be quantized - - kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args( - quant_scheme_to_layers - ) - kv_cache_scheme = ( - kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args + if field_name == "ignore": + if value_a is not None and value_b is None: + return value_a + + if value_a is None and value_b is not None: + return value_b + + if set(value_a) == set(value_b): + return value_a + + raise NotImplementedError( + "Cannot merge quantization configs with non-identical ignore lists " + "Please modify your config to resolve this ambiguity." + f"\n{self}\n{other}" + ) + + if value_a is not None and value_b is None: + return value_a + + if value_a is None and value_b is not None: + return value_b + + if value_a == value_b: + return value_a + + raise ValueError( + "The following fields have overlapping targets and conflicting values " + f"for {field_name}. Please modify your config to resolve this " + f"ambiguity.\n{self}\n{other}" + ) + + dict_a = self.model_dump() + dict_b = other.model_dump() + + assert dict_a.keys() == dict_b.keys() + return self.model_validate( + {key: merge_field(key, dict_a[key], dict_b[key]) for key in dict_a.keys()} ) - config_groups = {} - for idx, scheme in enumerate(quant_scheme_to_layers): - group_name = "group_" + str(idx) - config_groups[group_name] = scheme + @classmethod + def from_pretrained( + cls, model: Module, format: Optional[str] = None + ) -> "QuantizationConfig": + default_config = QuantizationConfig(config_groups={}) + config = getattr(model, "quantization_config", default_config) - if format is None: - if quantization_status == QuantizationStatus.COMPRESSED: - format = CompressionFormat.int_quantized.value - else: - format = CompressionFormat.dense.value - elif isinstance(format, list): + # silently override format + if isinstance(format, list): format = ( CompressionFormat.mixed_precision.value if len(format) > 1 else format[0] ) - - return QuantizationConfig( - config_groups=config_groups, - quantization_status=quantization_status, - kv_cache_scheme=kv_cache_scheme, - global_compression_ratio=None, - format=format, - ignore=consolidated_ignore, - ) + if format is None: + format = CompressionFormat.dense.value + config.format = format + return config def requires_calibration_data(self): if self.kv_cache_scheme is not None: diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index a9c8b45a..d908d670 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -14,7 +14,7 @@ import warnings from copy import deepcopy -from typing import List, Optional +from typing import Any, List, Optional from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import ( @@ -44,6 +44,7 @@ class QuantizationScheme(BaseModel): :param input_activations: quantization config for layer inputs :param output_activations: quantization config for layer outputs :param format: CompressionFormat for the layer + TODO """ targets: List[str] @@ -51,6 +52,7 @@ class QuantizationScheme(BaseModel): input_activations: Optional[QuantizationArgs] = None output_activations: Optional[QuantizationArgs] = None format: Optional[str] = None + kv_cache_only: Optional[bool] = None @model_validator(mode="after") def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": @@ -91,6 +93,44 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": return model + def merge(self, other: "QuantizationScheme") -> "QuantizationScheme": + def merge_field(field_name: str, value_a: Any, value_b: Any) -> Any: + if field_name == "targets": + return value_a + value_b + + if field_name == "kv_cache_only": + # nones defer to other value + if value_a is None: + return value_b + if value_b is None: + return value_a + + # kv_cache_only=True overrides + return not ((not value_a) or (not value_b)) + + if value_a is not None and value_b is None: + return value_a + + if value_a is None and value_b is not None: + return value_b + + if value_a == value_b: + return value_a + + raise ValueError( + "The following fields have overlapping targets and conflicting values " + f"for {field_name}. Please modify your config to resolve this " + f"ambiguity.\n{self}\n{other}" + ) + + dict_a = self.model_dump() + dict_b = other.model_dump() + + assert dict_a.keys() == dict_b.keys() + return self.model_validate( + {key: merge_field(key, dict_a[key], dict_b[key]) for key in dict_a.keys()} + ) + model_config = ConfigDict(extra="forbid") diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 73d54519..b610c71f 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -39,7 +39,7 @@ "get_torch_bit_depth", "can_quantize", "parse_out_kv_cache_args", - "KV_CACHE_TARGETS", + "ATTN_TARGETS", "is_kv_cache_quant_scheme", "iter_named_leaf_modules", "iter_named_quantizable_modules", @@ -50,9 +50,8 @@ "is_fp4", ] -# target the self_attn layer -# QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale -KV_CACHE_TARGETS = ["re:.*self_attn$"] +# note that this is a "narrow match", see quantization/apply.py +ATTN_TARGETS = ["re:.*self_attn$"] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -409,17 +408,17 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: """ Check whether the QuantizationScheme targets the kv cache. It does if all the following criteria are met: - - the scheme targets either exactly match the KV_CACHE_TARGETS - or the match KV_CACHE_TARGETS regex pattern + - the scheme targets either exactly match the ATTN_TARGETS + or the match ATTN_TARGETS regex pattern - the scheme quantizes output_activations (we want to quantize the - outputs from the KV_CACHE_TARGETS, as their correspond to the + outputs from the ATTN_TARGETS, as their correspond to the keys and values that are to be saved in the cache) :param scheme: The QuantizationScheme to investigate :return: boolean flag """ for target in scheme.targets: - if target in KV_CACHE_TARGETS: + if target in ATTN_TARGETS: return True return False diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index a7744709..9843a254 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -18,6 +18,14 @@ import torch import torch.nn.utils.parametrize as P +from compressed_tensors.modeling.attention import ( + initialize_hooked_attention, + register_query_hook, +) +from compressed_tensors.modeling.kvcache import ( + initialize_hooked_kv_cache, + register_key_hook, +) from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( TransformArgs, @@ -36,6 +44,7 @@ from compressed_tensors.utils.internal import InternalModule from torch import Tensor from torch.nn import Module, Parameter +from transformers import PreTrainedModel __all__ = ["TransformFactory", "TransformBase"] @@ -84,7 +93,7 @@ def create_transform(self, module: Module, args: TransformArgs) -> "TransformBas """ raise NotImplementedError() - def apply_to_model(self, model: Module): + def apply_to_model(self, model: PreTrainedModel): """ Create transforms and apply them to the model @@ -92,11 +101,13 @@ def apply_to_model(self, model: Module): """ for arg in self.scheme.apply: for _, module in match_named_modules(model, arg.targets, arg.ignore): - self._apply_to_module(module, arg) + self._apply_to_module(module, arg, model) self._update_tied_weights() - def _apply_to_module(self, module: Module, args: TransformArgs): + def _apply_to_module( + self, module: Module, args: TransformArgs, model: PreTrainedModel + ): """ Create transforms and apply them to the module @@ -154,9 +165,25 @@ def output_hook(_, _input, output): module.register_forward_hook(output_hook) + elif args.location == TransformLocation.Q_ATTN: + initialize_hooked_attention(model, module, quantize=False) + + def query_hook(_, query_states): + return transform(query_states) + + register_query_hook(module, query_hook) + # other locations such as q_attn and k_attn have not been implemented + elif args.location == TransformLocation.K_CACHE: + initialize_hooked_kv_cache(model, module, quantize=False) + + def key_hook(_, key_states): + return transform(key_states) + + register_key_hook(module, key_hook) + else: - raise NotImplementedError() + assert False def _update_tied_weights(self): """ diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index 42611967..1e0baa7f 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -22,8 +22,12 @@ apply_transform_weight, get_transform_size, ) -from compressed_tensors.utils import get_execution_device, get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict +from compressed_tensors.utils.offload import ( + get_execution_device, + get_offloaded_device, + has_offloaded_params, +) from torch import Tensor, device, dtype from torch.nn import Module, Parameter @@ -51,12 +55,17 @@ def create_transform(self, module: Module, args: TransformArgs): :param module: parent module that transform will be applied to :param args: defines how the transform will be applied to the module """ - assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) dtype = self.scheme.precision - device = get_offloaded_device(module) exec_device = get_execution_device(module) + # if the parent is offloaded, then weight will be placed in the weights_map + # if the parent is not offloaded, then the weight will stay on the exec device + if has_offloaded_params(module): + device = get_offloaded_device(module) + else: + device = exec_device + factory_kwargs = {"construct_device": exec_device} weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs) perm = self.perms[weight] if self.scheme.randomize else None diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index a7112e76..54bf7dcc 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -53,7 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs): assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) dtype = self.scheme.precision - device = get_offloaded_device(module) + exec_device = get_execution_device(module) + + # if the parent is offloaded, then weight will be placed in the weights_map + # if the parent is not offloaded, then the weight will stay on the exec device + if has_offloaded_params(module): + device = get_offloaded_device(module) + else: + device = exec_device weight = self.weights[size, dtype, device] if args.inverse: diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index d3f46957..75c81649 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -45,6 +45,12 @@ class TransformLocation(str, Enum): K_CACHE = "k_cache" Q_ATTN = "q_attn" + def is_online(self) -> bool: + return self not in ( + TransformLocation.WEIGHT_INPUT, + TransformLocation.WEIGHT_OUTPUT, + ) + class TransformArgs(BaseModel, use_enum_values=True): """ @@ -70,9 +76,6 @@ def wrap_singleton(cls, value): return value def is_online(self) -> bool: - return self.location not in ( - TransformLocation.WEIGHT_INPUT, - TransformLocation.WEIGHT_OUTPUT, - ) + return TransformLocation(self.location).is_online() model_config = ConfigDict(extra="forbid") diff --git a/src/compressed_tensors/transform/utils/matrix.py b/src/compressed_tensors/transform/utils/matrix.py index 92072857..0414e3f6 100644 --- a/src/compressed_tensors/transform/utils/matrix.py +++ b/src/compressed_tensors/transform/utils/matrix.py @@ -34,6 +34,8 @@ def get_transform_size( :param head_dim: size of head when transform is applied to mha :return: size of matrix """ + size = None + if isinstance(module, torch.nn.Linear): if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): size = module.in_features @@ -44,11 +46,13 @@ def get_transform_size( size = module.num_embeddings else: size = module.embedding_dim - else: - raise NotImplementedError(f"Transforms on {type(module)} are not supported") + elif head_dim is None: + raise NotImplementedError( + f"Transforms on {type(module)} are not supported without head_dim" + ) if head_dim is not None: - if size % head_dim != 0: + if size is not None and size % head_dim != 0: raise ValueError( f"{head_dim} must divide {size} for {type(module)} at {location}" ) @@ -105,11 +109,11 @@ def apply_transform_weight( assert transform_weight.shape[0] == transform_weight.shape[1] - if module_type == torch.nn.Linear: - if location == TransformLocation.INPUT: - return _multihead_matmul(value, transform_weight) + if TransformLocation(location).is_online(): + return _multihead_matmul(value, transform_weight) - elif location == TransformLocation.WEIGHT_INPUT: + if module_type == torch.nn.Linear: + if location == TransformLocation.WEIGHT_INPUT: # equivalent to (transform_weight @ value.T).T return _multihead_matmul(value, transform_weight.T) @@ -117,26 +121,14 @@ def apply_transform_weight( # equivalent to (value.T @ transform_weight).T return _multihead_matmul(transform_weight.T, value) - elif location == TransformLocation.OUTPUT: - return _multihead_matmul(value, transform_weight) - # similar derivation to torch.nn.Linear, but `y = (x W)` elif module_type == torch.nn.Embedding: - if location == TransformLocation.INPUT: - return _multihead_matmul(value, transform_weight) - - elif location == TransformLocation.WEIGHT_INPUT: - return _multihead_matmul( - transform_weight, - value, - ) + if location == TransformLocation.WEIGHT_INPUT: + return _multihead_matmul(transform_weight, value) elif location == TransformLocation.WEIGHT_OUTPUT: return _multihead_matmul(value, transform_weight) - elif location == TransformLocation.OUTPUT: - return _multihead_matmul(value, transform_weight) - raise NotImplementedError( f"Applying transforms to {module_type} {location} is not supported" ) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 11e2a2a1..3d423981 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -30,6 +30,7 @@ "match_targets", "match_modules_set", "is_match", + "is_narrow_match", ] @@ -128,7 +129,6 @@ def match_targets( :param module: the module to match :param targets: the target strings, potentially containing "re:" prefixes :return: the targets that match the given name and module - Outputs are ordered by type: exact name match, regex name match, class name match """ targets = targets or [] @@ -305,3 +305,13 @@ def _match_class(module: torch.nn.Module, target: str) -> bool: ) for cls in module.__class__.__mro__ ) + + +def is_narrow_match(model: torch.nn.Module, targets: Iterable[str], name: str) -> bool: + module = model.get_submodule(name) + parent_name = name.rsplit(".", 1)[0] + parent = model.get_submodule(parent_name) + + return is_match(name, module, targets) and not is_match( + parent_name, parent, targets + ) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index ea7cd01d..82b5084c 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -127,11 +127,16 @@ def get_offloaded_device(module: torch.nn.Module) -> torch.device: :param module: module to check :return: device module is offloaded to onto after forward pass """ - if has_offloaded_params(module): - first_key = list(module._hf_hook.weights_map.keys())[0] - prefix_dataset = module._hf_hook.weights_map.dataset - return prefix_dataset[first_key].device - return next(module.parameters()).device + for submodule in module.modules(): + name, param = next(submodule.named_parameters(recurse=False), (None, None)) + if has_offloaded_params(submodule) and name is not None: + return cast_to_device(submodule._hf_hook.weights_map[name].device) + + if param is not None: + return param.device + + warnings.warn(f"Unable to get offload device of {module}, falling back to CPU") + return torch.device("cpu") @check_accelerate(fallback=None) diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index dc48870b..9680d40e 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -399,7 +399,8 @@ def _get_combined_config(s_config, q_config): ) def test_compress_model(model_stub, q_format, s_config, tmpdir): model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32) - compressor = ModelCompressor.from_pretrained_model(model, s_config, [q_format]) + qformats = None if q_format is None else [q_format] # FUTURE: remove nullability + compressor = ModelCompressor.from_pretrained_model(model, s_config, qformats) # compress model by eagerly compressing state dict true_compressed = dict(compressor.compress(model)) @@ -446,8 +447,9 @@ def test_compress_model_meta(model_stub, q_format, s_config): cpu_model = AutoModelForCausalLM.from_pretrained( model_stub, torch_dtype=torch.float32 ) + qformats = None if q_format is None else [q_format] # FUTURE: remove nullability reference_compressor = ModelCompressor.from_pretrained_model( - cpu_model, s_config, [q_format] + cpu_model, s_config, qformats ) # Only stores dtype because meta model does not store values expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()} @@ -463,7 +465,7 @@ def test_compress_model_meta(model_stub, q_format, s_config): module.to_empty(device="meta") # Compress in-place on meta model - compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, [q_format]) + compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, qformats) compressor.compress_model(meta_model) # Compare keys and dtypes diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index d5fd6c2c..52b301ed 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -57,53 +57,6 @@ def llama_stories_model(): ) -def test_target_prioritization(mock_frozen): - # tests that the config_groups are applied in the correct order - # of priority, where exact layer name > regex > module name - config = { - "quant_method": "compressed-tensors", - "format": "fakequant", - "config_groups": { - "group_1": { - "weights": { - "num_bits": 8, - }, - "targets": ["Linear"], - }, - "group_2": { - "weights": { - "num_bits": 4, - }, - "targets": ["re:.*down_proj"], - }, - "group_3": { - "weights": { - "num_bits": 2, - }, - "targets": ["model.layers.0.mlp.down_proj"], - }, - }, - } - - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", torch_dtype="auto" - ) - model.eval() - - config = QuantizationConfig(**config) - config.quantization_status = QuantizationStatus.CALIBRATION - apply_quantization_config(model, config) - mock_frozen(model) - - for name, module in model.named_modules(): - if name == "model.layers.0.mlp.down_proj": - assert module.quantization_scheme.weights.num_bits == 2 - elif re.match(".*down_proj", name): - assert module.quantization_scheme.weights.num_bits == 4 - elif isinstance(module, torch.nn.Linear): - assert module.quantization_scheme.weights.num_bits == 8 - - def test_apply_quantization_config_tinyllama(): quant_config = get_sample_tinyllama_quant_config(status="calibration") model = get_tinyllama_model() @@ -174,13 +127,16 @@ def test_serialize_config_tinyllama(): serialized_config = QuantizationConfig.from_pretrained(model) assert len(serialized_config.config_groups) == 2 - assert serialized_config.config_groups["group_0"].targets == ["Embedding"] - assert serialized_config.config_groups["group_0"].input_activations is None assert serialized_config.config_groups["group_1"].targets == ["Linear"] assert serialized_config.config_groups["group_1"].input_activations is not None + assert serialized_config.config_groups["group_2"].targets == ["Embedding"] + assert serialized_config.config_groups["group_2"].input_activations is None assert serialized_config.format == CompressionFormat.dense.value assert serialized_config.quant_method == DEFAULT_QUANTIZATION_METHOD - assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"] + assert serialized_config.ignore == [ + "LlamaRotaryEmbedding", + "model.layers.1.mlp.down_proj", + ] if serialized_config.global_compression_ratio is not None: assert serialized_config.global_compression_ratio > 1.0 assert serialized_config.global_compression_ratio < 8.0 diff --git a/tests/test_quantization/test_quant_config.py b/tests/test_quantization/test_quant_config.py index c3830a02..cbe07183 100644 --- a/tests/test_quantization/test_quant_config.py +++ b/tests/test_quantization/test_quant_config.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +from compressed_tensors import CompressionFormat from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_FORMAT, DEFAULT_QUANTIZATION_METHOD,