Skip to content

Commit 1c07c0c

Browse files
authored
convert : handle compressed-tensors quant method (#17069)
* convert : handle compressed-tensors quant method * convert : handle int-quantized models * convert : handle naive-quantized models * gguf-py : __pos__ is also unary * convert : fix flake8 lint * convert : use F32 for dequant of pack-quantized tensors
1 parent cb1adf8 commit 1c07c0c

File tree

2 files changed

+92
-10
lines changed

2 files changed

+92
-10
lines changed

convert_hf_to_gguf.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -278,15 +278,14 @@ def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
278278
# The scale is inverted
279279
return data / scale.float()
280280

281-
def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
281+
def dequant_simple(weight: Tensor, scale: Tensor, block_size: Sequence[int] | None = None) -> Tensor:
282282
scale = scale.float()
283283

284-
if (weight_block_size := quant_config.get("weight_block_size")):
285-
# TODO: make sure it's a list of integers
286-
for i, size in enumerate(weight_block_size):
284+
if block_size is not None:
285+
for i, size in enumerate(block_size):
287286
scale = scale.repeat_interleave(size, i)
288-
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
289-
scale = scale[tuple(slice(0, size) for size in weight.shape)]
287+
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
288+
scale = scale[tuple(slice(0, size) for size in weight.shape)]
290289

291290
return weight.float() * scale
292291

@@ -333,6 +332,40 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
333332

334333
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
335334

335+
def dequant_packed(w: Tensor, scale: Tensor, shape_tensor: Tensor, zero_point: Tensor | None, num_bits: int, group_size: int):
336+
assert w.dtype == torch.int32
337+
shape = tuple(shape_tensor.tolist())
338+
assert len(shape) == 2
339+
mask = (1 << num_bits) - 1
340+
341+
shifts = torch.arange(0, 32 - (num_bits - 1), num_bits, dtype=torch.int32)
342+
if self.lazy:
343+
shifts = LazyTorchTensor.from_eager(shifts)
344+
345+
if zero_point is None:
346+
offset = 1 << (num_bits - 1)
347+
else:
348+
assert len(zero_point.shape) == 2
349+
offset = (zero_point.unsqueeze(1) >> shifts.reshape(1, -1, 1)) & mask
350+
offset = offset.reshape(-1, zero_point.shape[1])
351+
# trim padding, and prepare for broadcast
352+
# NOTE: the zero-point is packed along dim 0
353+
offset = offset[:shape[0], :].unsqueeze(-1)
354+
355+
# extract values
356+
# NOTE: the weights are packed along dim 1
357+
unpacked = (w.unsqueeze(-1) >> shifts.reshape(1, 1, -1)) & mask
358+
unpacked = unpacked.reshape(shape[0], -1)
359+
360+
# trim padding
361+
unpacked = unpacked[:, :shape[1]]
362+
363+
# prepare for broadcast of the scale
364+
unpacked = unpacked.reshape(shape[0], (unpacked.shape[-1] + group_size - 1) // group_size, group_size)
365+
unpacked = unpacked - offset
366+
367+
return (unpacked * scale.unsqueeze(-1).float()).reshape(shape)
368+
336369
if quant_method == "bitnet":
337370
for name in self.model_tensors.keys():
338371
if name.endswith(".weight_scale"):
@@ -342,12 +375,13 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
342375
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
343376
tensors_to_remove.append(name)
344377
elif quant_method == "fp8":
378+
block_size = quant_config.get("weight_block_size")
345379
for name in self.model_tensors.keys():
346380
if name.endswith(".weight_scale_inv"):
347381
weight_name = name.removesuffix("_scale_inv")
348382
w = self.model_tensors[weight_name]
349383
s = self.model_tensors[name]
350-
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
384+
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
351385
tensors_to_remove.append(name)
352386
elif quant_method == "gptq":
353387
for name in self.model_tensors.keys():
@@ -371,6 +405,49 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
371405
".scales",
372406
)
373407
]
408+
elif quant_method == "compressed-tensors":
409+
quant_format = quant_config["format"]
410+
groups = quant_config["config_groups"]
411+
if len(groups) > 1:
412+
raise NotImplementedError("Can't handle multiple config groups for compressed-tensors yet")
413+
weight_config = tuple(groups.values())[0]["weights"]
414+
415+
if quant_format == "float-quantized" or quant_format == "int-quantized" or quant_format == "naive-quantized":
416+
block_size = weight_config.get("block_structure", None)
417+
strategy = weight_config.get("strategy")
418+
assert strategy == "channel" or strategy == "block"
419+
assert weight_config.get("group_size") is None # didn't find a model using this yet
420+
for name in self.model_tensors.keys():
421+
if name.endswith(".weight_scale"):
422+
weight_name = name.removesuffix("_scale")
423+
w = self.model_tensors[weight_name]
424+
s = self.model_tensors[name]
425+
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), block_size)
426+
tensors_to_remove.append(name)
427+
elif quant_format == "pack-quantized":
428+
assert weight_config.get("strategy") == "group"
429+
assert weight_config.get("type", "int") == "int"
430+
num_bits = weight_config.get("num_bits")
431+
group_size = weight_config.get("group_size")
432+
assert isinstance(num_bits, int)
433+
assert isinstance(group_size, int)
434+
for name in self.model_tensors.keys():
435+
if name.endswith(".weight_packed"):
436+
base_name = name.removesuffix("_packed")
437+
w = self.model_tensors[name]
438+
scale = self.model_tensors[base_name + "_scale"]
439+
shape = self.model_tensors[base_name + "_shape"]
440+
zero_point = self.model_tensors.get(base_name + "_zero_point", lambda: None)
441+
new_tensors[base_name] = (
442+
lambda w=w, scale=scale, shape=shape, zero_point=zero_point: dequant_packed(
443+
w(), scale(), shape(), zero_point(), num_bits, group_size,
444+
)
445+
)
446+
tensors_to_remove += [base_name + n for n in ("_packed", "_shape", "_scale")]
447+
if (base_name + "_zero_point") in self.model_tensors:
448+
tensors_to_remove.append(base_name + "_zero_point")
449+
else:
450+
raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
374451
else:
375452
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
376453

gguf-py/gguf/lazy.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,18 @@ def wrapped_special_op(self, *args, **kwargs):
4848
# NOTE: doing this from a metaclass is very convenient
4949
# TODO: make this even more comprehensive
5050
for binary_op in (
51-
"lt", "le", "eq", "ne", "ge", "gt", "not"
52-
"abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul",
53-
"neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor",
51+
"lt", "le", "eq", "ne", "ge", "gt",
52+
"add", "and", "floordiv", "lshift", "mod", "mul", "matmul",
53+
"or", "pow", "rshift", "sub", "truediv", "xor",
5454
"iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
5555
"radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
5656
):
5757
attr_name = f"__{binary_op}__"
58+
# evaluation on the meta tensor is needed in case there's broadcasting
59+
namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
60+
61+
for unary_op in ("not", "abs", "invert", "neg", "pos"):
62+
attr_name = f"__{unary_op}__"
5863
# the result of these operators usually has the same shape and dtype as the input,
5964
# so evaluation on the meta tensor can be skipped.
6065
namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)

0 commit comments

Comments
 (0)