diff --git a/marlin/__init__.py b/marlin/__init__.py index b5b7758..c81bc32 100644 --- a/marlin/__init__.py +++ b/marlin/__init__.py @@ -102,7 +102,7 @@ def forward(self, A): def pack(self, linear, scales): """Pack a fake-quantized linear layer into this actual Marlin representation. @linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`) - @scales: corresponding quantization scales of shape `(infeatures, groups)` + @scales: corresponding quantization scales of shape `(outfeatures, groups)` """ if linear.weight.dtype != torch.half: raise ValueError('Only `torch.half` weights are supported.')