Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 84 additions & 7 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,14 @@ def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
# The scale is inverted
return data / scale.float()

def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
def dequant_simple(weight: Tensor, scale: Tensor, block_size: Sequence[int] | None = None) -> Tensor:
scale = scale.float()

if (weight_block_size := quant_config.get("weight_block_size")):
# TODO: make sure it's a list of integers
for i, size in enumerate(weight_block_size):
if block_size is not None:
for i, size in enumerate(block_size):
scale = scale.repeat_interleave(size, i)
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
scale = scale[tuple(slice(0, size) for size in weight.shape)]
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
scale = scale[tuple(slice(0, size) for size in weight.shape)]

return weight.float() * scale

Expand Down Expand Up @@ -333,6 +332,40 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)

return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T

def dequant_packed(w: Tensor, scale: Tensor, shape_tensor: Tensor, zero_point: Tensor | None, num_bits: int, group_size: int):
assert w.dtype == torch.int32
shape = tuple(shape_tensor.tolist())
assert len(shape) == 2
mask = (1 << num_bits) - 1

shifts = torch.arange(0, 32 - (num_bits - 1), num_bits, dtype=torch.int32)
if self.lazy:
shifts = LazyTorchTensor.from_eager(shifts)

if zero_point is None:
offset = 1 << (num_bits - 1)
else:
assert len(zero_point.shape) == 2
offset = (zero_point.unsqueeze(1) >> shifts.reshape(1, -1, 1)) & mask
offset = offset.reshape(-1, zero_point.shape[1])
# trim padding, and prepare for broadcast
# NOTE: the zero-point is packed along dim 0
offset = offset[:shape[0], :].unsqueeze(-1)

# extract values
# NOTE: the weights are packed along dim 1
unpacked = (w.unsqueeze(-1) >> shifts.reshape(1, 1, -1)) & mask
unpacked = unpacked.reshape(shape[0], -1)

# trim padding
unpacked = unpacked[:, :shape[1]]

# prepare for broadcast of the scale
unpacked = unpacked.reshape(shape[0], (unpacked.shape[-1] + group_size - 1) // group_size, group_size)
unpacked = unpacked - offset

return (unpacked * scale.unsqueeze(-1).float()).reshape(shape)

if quant_method == "bitnet":
for name in self.model_tensors.keys():
if name.endswith(".weight_scale"):
Expand All @@ -342,12 +375,13 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
tensors_to_remove.append(name)
elif quant_method == "fp8":
block_size = quant_config.get("weight_block_size")
for name in self.model_tensors.keys():
if name.endswith(".weight_scale_inv"):
weight_name = name.removesuffix("_scale_inv")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
tensors_to_remove.append(name)
elif quant_method == "gptq":
for name in self.model_tensors.keys():
Expand All @@ -371,6 +405,49 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
".scales",
)
]
elif quant_method == "compressed-tensors":
quant_format = quant_config["format"]
groups = quant_config["config_groups"]
if len(groups) > 1:
raise NotImplementedError("Can't handle multiple config groups for compressed-tensors yet")
weight_config = tuple(groups.values())[0]["weights"]

if quant_format == "float-quantized" or quant_format == "int-quantized" or quant_format == "naive-quantized":
block_size = weight_config.get("block_structure", None)
strategy = weight_config.get("strategy")
assert strategy == "channel" or strategy == "block"
assert weight_config.get("group_size") is None # didn't find a model using this yet
for name in self.model_tensors.keys():
if name.endswith(".weight_scale"):
weight_name = name.removesuffix("_scale")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), block_size)
tensors_to_remove.append(name)
elif quant_format == "pack-quantized":
assert weight_config.get("strategy") == "group"
assert weight_config.get("type", "int") == "int"
num_bits = weight_config.get("num_bits")
group_size = weight_config.get("group_size")
assert isinstance(num_bits, int)
assert isinstance(group_size, int)
for name in self.model_tensors.keys():
if name.endswith(".weight_packed"):
base_name = name.removesuffix("_packed")
w = self.model_tensors[name]
scale = self.model_tensors[base_name + "_scale"]
shape = self.model_tensors[base_name + "_shape"]
zero_point = self.model_tensors.get(base_name + "_zero_point", lambda: None)
new_tensors[base_name] = (
lambda w=w, scale=scale, shape=shape, zero_point=zero_point: dequant_packed(
w(), scale(), shape(), zero_point(), num_bits, group_size,
)
)
tensors_to_remove += [base_name + n for n in ("_packed", "_shape", "_scale")]
if (base_name + "_zero_point") in self.model_tensors:
tensors_to_remove.append(base_name + "_zero_point")
else:
raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
else:
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")

Expand Down
11 changes: 8 additions & 3 deletions gguf-py/gguf/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,18 @@ def wrapped_special_op(self, *args, **kwargs):
# NOTE: doing this from a metaclass is very convenient
# TODO: make this even more comprehensive
for binary_op in (
"lt", "le", "eq", "ne", "ge", "gt", "not"
"abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul",
"neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor",
"lt", "le", "eq", "ne", "ge", "gt",
"add", "and", "floordiv", "lshift", "mod", "mul", "matmul",
"or", "pow", "rshift", "sub", "truediv", "xor",
"iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
"radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
):
attr_name = f"__{binary_op}__"
# evaluation on the meta tensor is needed in case there's broadcasting
namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)

for unary_op in ("not", "abs", "invert", "neg", "pos"):
attr_name = f"__{unary_op}__"
# the result of these operators usually has the same shape and dtype as the input,
# so evaluation on the meta tensor can be skipped.
namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
Expand Down