1
1
import operator
2
+ import warnings
3
+
2
4
import torch
3
5
import bitsandbytes .functional as F
4
6
@@ -184,6 +186,7 @@ class MatmulLtState:
184
186
idx = None
185
187
is_training = True
186
188
has_fp16_weights = True
189
+ memory_efficient_backward = False
187
190
use_pool = False
188
191
formatB = F .get_special_format_str ()
189
192
@@ -209,31 +212,29 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
209
212
ctx .B = B
210
213
ctx .bias = bias
211
214
if A .shape [- 1 ] == B .shape [0 ]:
212
- return torch .empty (A .shape [:- 1 ]+ B .shape [1 :], dtype = torch . float16 , device = A .device )
215
+ return torch .empty (A .shape [:- 1 ]+ B .shape [1 :], dtype = A . dtype , device = A .device )
213
216
else :
214
- return torch .empty (A .shape [:- 1 ]+ B .shape [:1 ], dtype = torch . float16 , device = A .device )
217
+ return torch .empty (A .shape [:- 1 ]+ B .shape [:1 ], dtype = A . dtype , device = A .device )
215
218
216
219
# 1. Quantize A
217
220
# 2. Quantize B
218
221
# 3. Matmul
219
222
# 4. Mixed-precision decomposition matmul
220
223
# 5. Save state
221
- requires_gradA = A .requires_grad
222
- requires_gradB = B .requires_grad
223
- requires_gradBias = bias is not None and bias .requires_grad
224
224
formatB = state .formatB
225
225
input_shape = A .shape
226
226
if state .outlier_pool is None :
227
227
state .outlier_pool = GlobalOutlierPooler .get_instance ()
228
- assert (
229
- A .dtype == torch .float16
230
- ), f"The input data type needs to be fp16 but { A .dtype } was found!"
228
+
229
+ # Cast A to fp16
230
+ if A .dtype != torch .float16 :
231
+ warnings .warn (f"MatMul8bitLt: inputs will be cast from { A .dtype } to float16 during quantization" )
231
232
232
233
# 1. Quantize A
233
234
if len (A .shape ) == 3 :
234
235
A = A .view (- 1 , A .shape [- 1 ]).contiguous ()
235
236
CA , CAt , SCA , SCAt , coo_tensorA = F .double_quant (
236
- A , threshold = state .threshold
237
+ A . to ( torch . float16 ) , threshold = state .threshold
237
238
)
238
239
239
240
if state .threshold > 0.0 and coo_tensorA is not None :
@@ -269,7 +270,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
269
270
state .SCB ,
270
271
state .SCBt ,
271
272
coo_tensorB ,
272
- ) = F .double_quant (B )
273
+ ) = F .double_quant (B . to ( torch . float16 ) )
273
274
state .CxB , state .SB = F .transform (CB , to_order = formatB )
274
275
else :
275
276
has_grad = False
@@ -290,7 +291,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
290
291
(outliers * state .SCB .view (- 1 , 1 ) / 127.0 )
291
292
.t ()
292
293
.contiguous ()
293
- .half ( )
294
+ .to ( A . dtype )
294
295
)
295
296
CA [:, state .idx .long ()] = 0
296
297
CAt [:, state .idx .long ()] = 0
@@ -307,7 +308,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
307
308
C32A , SA = F .transform (CA , "col32" )
308
309
out32 , Sout32 = F .igemmlt (C32A , state .CxB , SA , state .SB )
309
310
# we apply the fused bias here
310
- output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = bias )
311
+
312
+ if bias is None or bias .dtype == torch .float16 :
313
+ output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = bias )
314
+ output = output .to (A .dtype )
315
+ else : # apply bias separately
316
+ output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB , bias = None )
317
+ output = output .to (A .dtype ).add_ (bias )
311
318
312
319
# 4. Mixed-precision decomposition matmul
313
320
if coo_tensorA is not None and subA is not None :
@@ -318,42 +325,43 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
318
325
319
326
ctx .formatB = formatB
320
327
ctx .grad_shape = input_shape
321
- ctx .req_grads = [ requires_gradA , requires_gradB , requires_gradBias ]
328
+ ctx .dtype_A , ctx . dtype_B , ctx . dtype_bias = A . dtype , B . dtype , None if bias is None else bias . dtype
322
329
323
- if requires_gradA or requires_gradB :
330
+ if any ( ctx . needs_input_grad [: 2 ]) :
324
331
ctx .tensors = (CAt , subA )
325
332
ctx .tensor_states = (SCAt , state .idx )
326
333
else :
327
334
ctx .tensors = [None , None ]
328
335
ctx .tensor_states = (None , None )
329
336
ctx .save_for_backward (None , None )
330
337
338
+
331
339
clone_func = torch .clone if len (output_shape ) == 3 else lambda x : x
332
- #clone_func = torch.clone
333
340
return clone_func (output .view (output_shape ))
334
341
335
342
@staticmethod
336
343
def backward (ctx , grad_output ):
337
344
if ctx .is_empty :
338
345
bias_grad = (None if ctx .bias is None else torch .zeros_like (ctx .bias ))
339
346
return torch .zeros_like (ctx .A ), torch .zeros_like (ctx .B ), None , bias_grad , None
340
- req_gradA , req_gradB , req_gradBias = ctx .req_grads
347
+ req_gradA , req_gradB , _ , req_gradBias , _ = ctx .needs_input_grad
341
348
CAt , subA = ctx .tensors
342
349
SCAt , idx = ctx .tensor_states
343
350
formatB = ctx .formatB
344
351
state = ctx .state
345
- assert (
346
- state .has_fp16_weights
347
- ), "Backprop only supported for fp16 weights."
352
+ grad_A = grad_B = grad_bias = None
353
+
354
+ if req_gradBias :
355
+ # compute grad_bias first before changing grad_output dtype
356
+ grad_bias = grad_output .sum (0 , dtype = ctx .dtype_bias )
348
357
358
+ # Cast grad_output to fp16
349
359
if len (grad_output .shape ) == 3 :
350
- grad_output = grad_output .view (
360
+ grad_output = grad_output .reshape (
351
361
- 1 , grad_output .shape [- 1 ]
352
362
).contiguous ()
353
363
354
- grad_A = grad_B = grad_bias = None
355
-
356
- Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = F .double_quant (grad_output )
364
+ Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = F .double_quant (grad_output .to (torch .float16 ))
357
365
if req_gradB :
358
366
CxAt , SAt = F .transform (CAt , formatB , transpose = True )
359
367
C32grad , Sgrad = F .transform (Cgradt , "col32" , transpose = True )
@@ -363,16 +371,20 @@ def backward(ctx, grad_output):
363
371
grad_B [:, idx ] += torch .matmul (grad_output .t (), subA )
364
372
365
373
if req_gradA :
366
- C32grad , Sgrad = F .transform (Cgrad , "col32" )
367
- if state .CxBt is None :
368
- state .CxBt , state .SBt = F .transform (
369
- state .CBt , to_order = formatB , transpose = True
370
- )
371
- gradA32 , SgradA32 = F .igemmlt (C32grad , state .CxBt , Sgrad , state .SBt )
372
- grad_A = F .mm_dequant (gradA32 , SgradA32 , SCgrad , state .SCBt ).view (ctx .grad_shape )
374
+ if state .CBt is not None :
375
+ C32grad , Sgrad = F .transform (Cgrad , "col32" )
376
+ if state .CxBt is None :
377
+ state .CxBt , state .SBt = F .transform (
378
+ state .CBt , to_order = formatB , transpose = True
379
+ )
380
+ gradA32 , SgradA32 = F .igemmlt (C32grad , state .CxBt , Sgrad , state .SBt )
381
+ grad_A = F .mm_dequant (gradA32 , SgradA32 , SCgrad , state .SCBt ).view (ctx .grad_shape ).to (ctx .dtype_A )
373
382
374
- if req_gradBias :
375
- grad_bias = grad_output .sum (0 )
383
+ elif state .CB is not None :
384
+ CB = state .CB .to (ctx .dtype_A , copy = True ).mul_ (state .SCB .unsqueeze (1 ).mul (1. / 127.0 ))
385
+ grad_A = torch .matmul (grad_output , CB ).view (ctx .grad_shape ).to (ctx .dtype_A )
386
+ else :
387
+ raise Exception ('State must contain either CBt or CB matrix for backward' )
376
388
377
389
return grad_A , grad_B , None , grad_bias , None
378
390
0 commit comments