Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -703,9 +703,12 @@ def decompress(self, model_path: str, model: Module):
with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
):
names_to_scheme = apply_quantization_config(
model, self.quantization_config
)
apply_quantization_config(model, self.quantization_config)
names_to_scheme: Set[QuantizationScheme] = {
name: getattr(module, "quantization_scheme")
for name, module in model.named_modules()
if getattr(module, "quantization_scheme", None) is not None
}
# Load activation scales/zp or any other quantization parameters
# Conditionally load the weight quantization parameters if we have a dense compressor
# Or if a sparsity compressor has already been applied
Expand Down
141 changes: 141 additions & 0 deletions src/compressed_tensors/modeling/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Callable, Optional

import torch
from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
forward_quantize,
)
from compressed_tensors.quantization.lifecycle.initialize import (
_initialize_scale_zero_point,
)
from compressed_tensors.utils import getattr_chain
from compressed_tensors.utils.internal import InternalModule
from torch.utils.hooks import RemovableHandle
from transformers import AttentionInterface, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS


__all__ = ["IMPL_ATTR", "QuantizedAttentionImpl"]


IMPL_ATTR = "impl"
_original_impl = "eager" # mutable


class QuantizedAttentionImpl(InternalModule):
def __init__(self, attn_module: torch.nn.Module):
super().__init__()
self.attn_module_container = [attn_module] # avoid circular reference
self._qparams_initialized = False

def forward(
self,
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
*args,
**kwargs,
):
# quantization
quant_args_attr = "quantization_scheme.input_activations"
quant_args = getattr_chain(module, quant_args_attr, None)
quant_enabled = getattr(module, "quantization_enabled", True)
if quant_args is not None and quant_enabled and self._qparams_initialized:
query = forward_quantize(module, query, "q", quant_args)

# original attention
return ALL_ATTENTION_FUNCTIONS[_original_impl](
module,
query,
key,
value,
*args,
**kwargs,
)

def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module):
assert module is self.attn_module_container[0]
scheme: Optional[QuantizationScheme] = getattr(
module, "quantization_scheme", None
)
quant_args: Optional[QuantizationArgs] = getattr(
scheme, "input_activations", None
)

if (
not self._qparams_initialized
and quant_args is not None
and not scheme.kv_cache_only
):
# TODO: use model.config.num_attention_heads to find query_size
assert quant_args.strategy == QuantizationStrategy.TENSOR
_initialize_scale_zero_point(module, "q", quant_args)
self._qparams_initialized = True


# ----- initialize ----- #


def ct_hooked_attention(module: torch.nn.Module, *args, **kwargs):
if hasattr(module, IMPL_ATTR):
return module.impl(module, *args, **kwargs)
else:
return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs)


def initialize_hooked_attention(
model: PreTrainedModel, module: torch.nn.Module, quantize: bool = True
):
if not hasattr(module, IMPL_ATTR):
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(module))
if model.config._attn_implementation != "ct_hooked_attention":
# assumes only one model at a time
global _original_impl
_original_impl = model.config._attn_implementation

AttentionInterface.register("ct_hooked_attention", ct_hooked_attention)
model.config._attn_implementation = "ct_hooked_attention"

impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
if quantize:
impl.initialize_qparams_once(model, module)

initialize_hooked_kv_cache(model, module, quantize=quantize)


# ----- hooks ----- #


def register_query_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
"""
Registers a forward pre-hook on `module.impl` that replaces the `query` argument
with `hook(mod, query)` (handles both positional and keyword forms).
"""
impl = getattr(module, IMPL_ATTR)

def _hook(impl: QuantizedAttentionImpl, args, kwargs):
bound = inspect.signature(impl.forward).bind(*args, **kwargs)
bound.arguments["query"] = hook(impl, bound.arguments["query"])

return bound.args, bound

return impl.register_forward_pre_hook(_hook, with_kwargs=True)
138 changes: 138 additions & 0 deletions src/compressed_tensors/modeling/kvcache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Callable, Optional, Tuple

import torch
import transformers
from compressed_tensors.quantization import QuantizationStrategy, forward_quantize
from compressed_tensors.quantization.lifecycle.initialize import (
_initialize_scale_zero_point,
)
from compressed_tensors.utils import getattr_chain
from compressed_tensors.utils.internal import InternalModule
from packaging import version
from torch import Tensor
from torch.utils.hooks import RemovableHandle
from transformers import Cache, PreTrainedModel


__all__ = ["KV_CACHE_ATTR", "QuantizedKVCache"]


KV_CACHE_ATTR = "kv_cache"


class QuantizedKVCache(InternalModule):
def __init__(self, attn_module: torch.nn.Module):
super().__init__()
self.attn_module_container = [attn_module] # avoid nn.Module circular reference
self.past_key_values: Optional[Cache] = None
self._qparams_initialized = False

def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
return self(*args, **kwargs)

def forward(
self,
key_states: Tensor,
value_states: Tensor,
*args,
**kwargs,
) -> Tuple[Tensor, Tensor]:
# quantization
module = self.attn_module_container[0]
quant_args_attr = "quantization_scheme.input_activations"
quant_args = getattr_chain(module, quant_args_attr, None)
quant_enabled = getattr(module, "quantization_enabled", True)
if quant_args is not None and quant_enabled and self._qparams_initialized:
key_states = forward_quantize(module, key_states, "k", quant_args)
value_states = forward_quantize(module, value_states, "v", quant_args)

# original cache
if self.past_key_values is not None:
ret = self.past_key_values.update(key_states, value_states, *args, **kwargs)
else:
ret = (key_states, value_states)

self.past_key_values = None
return ret

def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module):
assert module is self.attn_module_container[0]
scheme = getattr(module, "quantization_scheme", None)
quant_args = getattr(scheme, "input_activations", None)

if not self._qparams_initialized and quant_args is not None:
# TODO: use model.config.num_key_value_heads to find key_size, value_size
assert quant_args.strategy == QuantizationStrategy.TENSOR
_initialize_scale_zero_point(module, "k", quant_args)
_initialize_scale_zero_point(module, "v", quant_args)
self._qparams_initialized = True


# ----- initialize ----- #


def initialize_hooked_kv_cache(
model: PreTrainedModel, module: torch.nn.Module, quantize: bool = False
):
if not hasattr(module, KV_CACHE_ATTR):
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module))
module.register_forward_pre_hook(kv_cache_attention_hook, with_kwargs=True)

kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
if quantize:
kv_cache.initialize_qparams_once(model, module)


def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs):
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
_past_kv_name = (
"past_key_value"
if version.parse(transformers.__version__) <= version.parse("4.55.4")
else "past_key_values" # transformers#39956
)
kv_cache.past_key_values = kwargs.get(_past_kv_name, None)
kwargs[_past_kv_name] = kv_cache

return args, kwargs


# ----- hooks ----- #


def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)

def _hook(cache: QuantizedKVCache, args, kwargs):
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
bound.arguments["key_states"] = hook(cache, bound.arguments["key_states"])

return bound.args, bound

return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)


def register_value_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)

def _hook(cache: QuantizedKVCache, args, kwargs):
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
bound.arguments["value_states"] = hook(cache, bound.arguments["value_states"])

return bound.args, bound

return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
Loading