Skip to content

Commit 49e1d90

Browse files
committed
quantization done, need calibration
Signed-off-by: Kyle Sayers <[email protected]>
1 parent a7700c5 commit 49e1d90

File tree

3 files changed

+47
-8
lines changed

3 files changed

+47
-8
lines changed

examples/transform/spinquant_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# * apply spinquant transforms to model to reduce quantization loss
1919
# * quantize the weights to 4 bit with group size 128
2020
recipe = [
21-
SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"),
21+
SpinQuantModifier(rotations=["R3"], transform_type="hadamard"),
2222
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
2323
]
2424

@@ -35,6 +35,6 @@
3535
print("==========================================\n\n")
3636

3737
# Save to disk compressed.
38-
SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR1R2-w4a16"
38+
SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR3-w4a16"
3939
model.save_pretrained(SAVE_DIR, save_compressed=True)
4040
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
128128
config_groups["R2"] = self._create_r2_scheme(state.model)
129129

130130
if SpinquantRotation.R3 in self.rotations:
131-
config_groups["R3"] = self._create_r3_scheme()
131+
config_groups["R3"] = self._create_r3_scheme(state.model)
132132

133133
if SpinquantRotation.R4 in self.rotations:
134134
config_groups["R4"] = self._create_r4_scheme()
@@ -235,12 +235,49 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
235235
],
236236
)
237237

238-
def _create_r3_scheme(self) -> TransformScheme:
239-
raise NotImplementedError(
240-
"SpinQuant R3 and R4 rotations will be added in a future release"
238+
def _create_r3_scheme(self, model: PreTrainedModel) -> TransformScheme:
239+
config = model.config
240+
241+
if hasattr(config, "head_dim"):
242+
head_dim = config.head_dim
243+
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
244+
head_dim = config.hidden_size // config.num_attention_heads
245+
else:
246+
raise NotImplementedError()
247+
248+
return TransformScheme(
249+
type=self.transform_type,
250+
randomize=self.randomize,
251+
requires_grad=self.learnable,
252+
precision=self.precision,
253+
head_dim=head_dim,
254+
apply=[
255+
TransformArgs(
256+
targets=[self.mappings.attn],
257+
location="q_attn",
258+
),
259+
TransformArgs(
260+
targets=[self.mappings.attn],
261+
location="k_cache",
262+
),
263+
],
241264
)
242265

243266
def _create_r4_scheme(self) -> TransformScheme:
244-
raise NotImplementedError(
245-
"SpinQuant R3 and R4 rotations will be added in a future release"
267+
return TransformScheme(
268+
type=self.transform_type,
269+
randomize=self.randomize,
270+
requires_grad=self.learnable,
271+
precision=self.precision,
272+
apply=[
273+
TransformArgs(
274+
targets=[*self.mappings.mlp_out],
275+
location="input",
276+
),
277+
TransformArgs(
278+
targets=[*self.mappings.mlp_out],
279+
location="weight_input",
280+
inverse=True,
281+
),
282+
],
246283
)

src/llmcompressor/modifiers/transform/spinquant/mappings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class SpinQuantMapping(BaseModel):
2929

3030
embedding: str
3131

32+
attn: str
3233
attn_q: str
3334
attn_k: str
3435
attn_v: str
@@ -50,6 +51,7 @@ def cast_to_list(cls, value):
5051

5152
_default_mappings = SpinQuantMapping(
5253
embedding="re:.*embed_tokens$",
54+
attn="re:.*self_attn$",
5355
attn_q="re:.*q_proj$",
5456
attn_k="re:.*k_proj$",
5557
attn_v="re:.*v_proj$",

0 commit comments

Comments
 (0)