@@ -278,15 +278,14 @@ def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
278278 # The scale is inverted
279279 return data / scale .float ()
280280
281- def dequant_simple (weight : Tensor , scale : Tensor ) -> Tensor :
281+ def dequant_simple (weight : Tensor , scale : Tensor , block_size : Sequence [ int ] | None = None ) -> Tensor :
282282 scale = scale .float ()
283283
284- if (weight_block_size := quant_config .get ("weight_block_size" )):
285- # TODO: make sure it's a list of integers
286- for i , size in enumerate (weight_block_size ):
284+ if block_size is not None :
285+ for i , size in enumerate (block_size ):
287286 scale = scale .repeat_interleave (size , i )
288- # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
289- scale = scale [tuple (slice (0 , size ) for size in weight .shape )]
287+ # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
288+ scale = scale [tuple (slice (0 , size ) for size in weight .shape )]
290289
291290 return weight .float () * scale
292291
@@ -333,6 +332,40 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
333332
334333 return (scales [g_idx ].float () * (weight - zeros [g_idx ]).float ()).T
335334
335+ def dequant_packed (w : Tensor , scale : Tensor , shape_tensor : Tensor , zero_point : Tensor | None , num_bits : int , group_size : int ):
336+ assert w .dtype == torch .int32
337+ shape = tuple (shape_tensor .tolist ())
338+ assert len (shape ) == 2
339+ mask = (1 << num_bits ) - 1
340+
341+ shifts = torch .arange (0 , 32 - (num_bits - 1 ), num_bits , dtype = torch .int32 )
342+ if self .lazy :
343+ shifts = LazyTorchTensor .from_eager (shifts )
344+
345+ if zero_point is None :
346+ offset = 1 << (num_bits - 1 )
347+ else :
348+ assert len (zero_point .shape ) == 2
349+ offset = (zero_point .unsqueeze (1 ) >> shifts .reshape (1 , - 1 , 1 )) & mask
350+ offset = offset .reshape (- 1 , zero_point .shape [1 ])
351+ # trim padding, and prepare for broadcast
352+ # NOTE: the zero-point is packed along dim 0
353+ offset = offset [:shape [0 ], :].unsqueeze (- 1 )
354+
355+ # extract values
356+ # NOTE: the weights are packed along dim 1
357+ unpacked = (w .unsqueeze (- 1 ) >> shifts .reshape (1 , 1 , - 1 )) & mask
358+ unpacked = unpacked .reshape (shape [0 ], - 1 )
359+
360+ # trim padding
361+ unpacked = unpacked [:, :shape [1 ]]
362+
363+ # prepare for broadcast of the scale
364+ unpacked = unpacked .reshape (shape [0 ], (unpacked .shape [- 1 ] + group_size - 1 ) // group_size , group_size )
365+ unpacked = unpacked - offset
366+
367+ return (unpacked * scale .unsqueeze (- 1 ).float ()).reshape (shape )
368+
336369 if quant_method == "bitnet" :
337370 for name in self .model_tensors .keys ():
338371 if name .endswith (".weight_scale" ):
@@ -342,12 +375,13 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
342375 self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_bitnet (w (), s ())
343376 tensors_to_remove .append (name )
344377 elif quant_method == "fp8" :
378+ block_size = quant_config .get ("weight_block_size" )
345379 for name in self .model_tensors .keys ():
346380 if name .endswith (".weight_scale_inv" ):
347381 weight_name = name .removesuffix ("_scale_inv" )
348382 w = self .model_tensors [weight_name ]
349383 s = self .model_tensors [name ]
350- self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_simple (w (), s ())
384+ self .model_tensors [weight_name ] = lambda w = w , s = s , bs = block_size : dequant_simple (w (), s (), bs )
351385 tensors_to_remove .append (name )
352386 elif quant_method == "gptq" :
353387 for name in self .model_tensors .keys ():
@@ -371,6 +405,49 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
371405 ".scales" ,
372406 )
373407 ]
408+ elif quant_method == "compressed-tensors" :
409+ quant_format = quant_config ["format" ]
410+ groups = quant_config ["config_groups" ]
411+ if len (groups ) > 1 :
412+ raise NotImplementedError ("Can't handle multiple config groups for compressed-tensors yet" )
413+ weight_config = tuple (groups .values ())[0 ]["weights" ]
414+
415+ if quant_format == "float-quantized" or quant_format == "int-quantized" or quant_format == "naive-quantized" :
416+ block_size = weight_config .get ("block_structure" , None )
417+ strategy = weight_config .get ("strategy" )
418+ assert strategy == "channel" or strategy == "block"
419+ assert weight_config .get ("group_size" ) is None # didn't find a model using this yet
420+ for name in self .model_tensors .keys ():
421+ if name .endswith (".weight_scale" ):
422+ weight_name = name .removesuffix ("_scale" )
423+ w = self .model_tensors [weight_name ]
424+ s = self .model_tensors [name ]
425+ self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_simple (w (), s (), block_size )
426+ tensors_to_remove .append (name )
427+ elif quant_format == "pack-quantized" :
428+ assert weight_config .get ("strategy" ) == "group"
429+ assert weight_config .get ("type" , "int" ) == "int"
430+ num_bits = weight_config .get ("num_bits" )
431+ group_size = weight_config .get ("group_size" )
432+ assert isinstance (num_bits , int )
433+ assert isinstance (group_size , int )
434+ for name in self .model_tensors .keys ():
435+ if name .endswith (".weight_packed" ):
436+ base_name = name .removesuffix ("_packed" )
437+ w = self .model_tensors [name ]
438+ scale = self .model_tensors [base_name + "_scale" ]
439+ shape = self .model_tensors [base_name + "_shape" ]
440+ zero_point = self .model_tensors .get (base_name + "_zero_point" , lambda : None )
441+ new_tensors [base_name ] = (
442+ lambda w = w , scale = scale , shape = shape , zero_point = zero_point : dequant_packed (
443+ w (), scale (), shape (), zero_point (), num_bits , group_size ,
444+ )
445+ )
446+ tensors_to_remove += [base_name + n for n in ("_packed" , "_shape" , "_scale" )]
447+ if (base_name + "_zero_point" ) in self .model_tensors :
448+ tensors_to_remove .append (base_name + "_zero_point" )
449+ else :
450+ raise NotImplementedError (f"Quant format { quant_format !r} for method { quant_method !r} is not yet supported" )
374451 else :
375452 raise NotImplementedError (f"Quant method is not yet supported: { quant_method !r} " )
376453
0 commit comments