Skip to content

Commit 439f2b0

Browse files
authored
Merge pull request #33 from dbaranchuk/memory-efficient-backward
Memory efficient backward
2 parents 9b5f2ed + 76ce9aa commit 439f2b0

File tree

4 files changed

+100
-52
lines changed

4 files changed

+100
-52
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import operator
2+
import warnings
3+
24
import torch
35
import bitsandbytes.functional as F
46

@@ -184,6 +186,7 @@ class MatmulLtState:
184186
idx = None
185187
is_training = True
186188
has_fp16_weights = True
189+
memory_efficient_backward = False
187190
use_pool = False
188191
formatB = F.get_special_format_str()
189192

@@ -209,31 +212,29 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
209212
ctx.B = B
210213
ctx.bias = bias
211214
if A.shape[-1] == B.shape[0]:
212-
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
215+
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
213216
else:
214-
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
217+
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
215218

216219
# 1. Quantize A
217220
# 2. Quantize B
218221
# 3. Matmul
219222
# 4. Mixed-precision decomposition matmul
220223
# 5. Save state
221-
requires_gradA = A.requires_grad
222-
requires_gradB = B.requires_grad
223-
requires_gradBias = bias is not None and bias.requires_grad
224224
formatB = state.formatB
225225
input_shape = A.shape
226226
if state.outlier_pool is None:
227227
state.outlier_pool = GlobalOutlierPooler.get_instance()
228-
assert (
229-
A.dtype == torch.float16
230-
), f"The input data type needs to be fp16 but {A.dtype} was found!"
228+
229+
# Cast A to fp16
230+
if A.dtype != torch.float16:
231+
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
231232

232233
# 1. Quantize A
233234
if len(A.shape) == 3:
234235
A = A.view(-1, A.shape[-1]).contiguous()
235236
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
236-
A, threshold=state.threshold
237+
A.to(torch.float16), threshold=state.threshold
237238
)
238239

239240
if state.threshold > 0.0 and coo_tensorA is not None:
@@ -269,7 +270,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
269270
state.SCB,
270271
state.SCBt,
271272
coo_tensorB,
272-
) = F.double_quant(B)
273+
) = F.double_quant(B.to(torch.float16))
273274
state.CxB, state.SB = F.transform(CB, to_order=formatB)
274275
else:
275276
has_grad = False
@@ -290,7 +291,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
290291
(outliers * state.SCB.view(-1, 1) / 127.0)
291292
.t()
292293
.contiguous()
293-
.half()
294+
.to(A.dtype)
294295
)
295296
CA[:, state.idx.long()] = 0
296297
CAt[:, state.idx.long()] = 0
@@ -307,7 +308,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
307308
C32A, SA = F.transform(CA, "col32")
308309
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
309310
# we apply the fused bias here
310-
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
311+
312+
if bias is None or bias.dtype == torch.float16:
313+
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
314+
output = output.to(A.dtype)
315+
else: # apply bias separately
316+
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
317+
output = output.to(A.dtype).add_(bias)
311318

312319
# 4. Mixed-precision decomposition matmul
313320
if coo_tensorA is not None and subA is not None:
@@ -318,42 +325,43 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
318325

319326
ctx.formatB = formatB
320327
ctx.grad_shape = input_shape
321-
ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias]
328+
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
322329

323-
if requires_gradA or requires_gradB:
330+
if any(ctx.needs_input_grad[:2]):
324331
ctx.tensors = (CAt, subA)
325332
ctx.tensor_states = (SCAt, state.idx)
326333
else:
327334
ctx.tensors = [None, None]
328335
ctx.tensor_states = (None, None)
329336
ctx.save_for_backward(None, None)
330337

338+
331339
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
332-
#clone_func = torch.clone
333340
return clone_func(output.view(output_shape))
334341

335342
@staticmethod
336343
def backward(ctx, grad_output):
337344
if ctx.is_empty:
338345
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
339346
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
340-
req_gradA, req_gradB, req_gradBias = ctx.req_grads
347+
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
341348
CAt, subA = ctx.tensors
342349
SCAt, idx = ctx.tensor_states
343350
formatB = ctx.formatB
344351
state = ctx.state
345-
assert (
346-
state.has_fp16_weights
347-
), "Backprop only supported for fp16 weights."
352+
grad_A = grad_B = grad_bias = None
353+
354+
if req_gradBias:
355+
# compute grad_bias first before changing grad_output dtype
356+
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
348357

358+
# Cast grad_output to fp16
349359
if len(grad_output.shape) == 3:
350-
grad_output = grad_output.view(
360+
grad_output = grad_output.reshape(
351361
-1, grad_output.shape[-1]
352362
).contiguous()
353363

354-
grad_A = grad_B = grad_bias = None
355-
356-
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
364+
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
357365
if req_gradB:
358366
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
359367
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
@@ -363,16 +371,20 @@ def backward(ctx, grad_output):
363371
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
364372

365373
if req_gradA:
366-
C32grad, Sgrad = F.transform(Cgrad, "col32")
367-
if state.CxBt is None:
368-
state.CxBt, state.SBt = F.transform(
369-
state.CBt, to_order=formatB, transpose=True
370-
)
371-
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
372-
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
374+
if state.CBt is not None:
375+
C32grad, Sgrad = F.transform(Cgrad, "col32")
376+
if state.CxBt is None:
377+
state.CxBt, state.SBt = F.transform(
378+
state.CBt, to_order=formatB, transpose=True
379+
)
380+
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
381+
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
373382

374-
if req_gradBias:
375-
grad_bias = grad_output.sum(0)
383+
elif state.CB is not None:
384+
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
385+
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
386+
else:
387+
raise Exception('State must contain either CBt or CB matrix for backward')
376388

377389
return grad_A, grad_B, None, grad_bias, None
378390

bitsandbytes/nn/modules.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def __init__(
221221
output_features,
222222
bias=True,
223223
has_fp16_weights=True,
224+
memory_efficient_backward=False,
224225
threshold=0.0,
225226
index=None,
226227
):
@@ -232,10 +233,13 @@ def __init__(
232233

233234
self.state.threshold = threshold
234235
self.state.has_fp16_weights = has_fp16_weights
236+
self.state.memory_efficient_backward = memory_efficient_backward
235237
if threshold > 0.0 and not has_fp16_weights:
236238
self.state.use_pool = True
237239

238-
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
240+
self.weight = Int8Params(
241+
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
242+
)
239243

240244
def init_8bit_state(self):
241245
self.state.CB = self.weight.CB
@@ -255,11 +259,16 @@ def forward(self, x):
255259

256260
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
257261

258-
if not self.state.has_fp16_weights and self.state.CB is not None:
259-
# we converted 8-bit row major to turing/ampere format in the first inference pass
260-
# we no longer need the row-major weight
261-
del self.state.CB
262-
self.weight.data = self.state.CxB
262+
if not self.state.has_fp16_weights:
263+
if not self.state.memory_efficient_backward and self.state.CB is not None:
264+
# we converted 8-bit row major to turing/ampere format in the first inference pass
265+
# we no longer need the row-major weight
266+
del self.state.CB
267+
self.weight.data = self.state.CxB
268+
elif self.state.memory_efficient_backward and self.state.CxB is not None:
269+
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
270+
# Thus, we delete CxB from the state.
271+
del self.state.CxB
263272

264273
return out
265274

tests/test_autograd.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
253253

254254
transpose = [(False, True), (False, False)]
255255
str_transpose = ["NT", "NN"]
256-
dtype = [torch.float16]
256+
dtype = [torch.float16, torch.bfloat16, torch.float32]
257257
has_fp16_weights = [True, False]
258258
has_bias = [True, False]
259259
values = list(
@@ -354,7 +354,7 @@ def test_matmullt(
354354
state.SCB,
355355
SCBt,
356356
coo_tensorB,
357-
) = bnb.functional.double_quant(B2)
357+
) = bnb.functional.double_quant(B2.to(torch.float16))
358358
B2 = state.CB
359359

360360
if not transpose[0] and transpose[1]:
@@ -367,11 +367,14 @@ def test_matmullt(
367367
if has_bias:
368368
out_torch += bias
369369

370+
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
371+
370372
n = out_bnb.numel()
371373
err = torch.abs(out_bnb - out_torch).mean().item()
372374
# print(f'abs error {err:.4f}')
375+
373376
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
374-
assert (idx == 0).sum().item() <= n * 0.0175
377+
assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021)
375378
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
376379
assert (idx == 0).sum().item() <= n * 0.001
377380

tests/test_modules.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ def __init__(self, initial_data):
1414

1515

1616
class MLP8bit(torch.nn.Module):
17-
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
17+
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
1818
super(MLP8bit, self).__init__()
1919
self.fc1 = bnb.nn.Linear8bitLt(
20-
dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
20+
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
21+
threshold=threshold
2122
)
2223
self.fc2 = bnb.nn.Linear8bitLt(
23-
dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
24+
dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
25+
threshold=threshold
2426
)
2527

2628
def forward(self, x):
@@ -451,9 +453,12 @@ def test_linear8bitlt_accumulated_gradient():
451453

452454

453455
@pytest.mark.parametrize("threshold", values, ids=names)
454-
def test_linear8bitlt_no_fp16_weights(threshold):
456+
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
457+
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
455458
l1 = (
456-
bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
459+
bnb.nn.Linear8bitLt(
460+
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
461+
)
457462
.cuda()
458463
.half()
459464
)
@@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
513518
assert mlp.fc2.weight.dtype == torch.int8
514519

515520
mlp = (
516-
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
521+
MLP8bit(
522+
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
523+
)
517524
.half()
518525
.to("cuda")
519526
)
@@ -531,11 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
531538
assert mlp.fc1.weight.device.type == "cuda"
532539
assert mlp.fc2.weight.device.type == "cuda"
533540

534-
mlp = (
535-
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
536-
.to(torch.float16)
537-
.to("cuda")
538-
)
541+
mlp = MLP8bit(
542+
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
543+
)
544+
w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization,
545+
mlp = mlp.cuda().half() # and this line triggers quantization
539546

540547
for i in range(100):
541548
b1 = torch.randn(16, 8, 32, device="cuda").half()
@@ -545,11 +552,28 @@ def test_linear8bitlt_no_fp16_weights(threshold):
545552
assert mlp.fc1.state.idx is not None
546553
if threshold > 0:
547554
assert mlp.fc2.state.idx is not None
555+
548556
assert mlp.fc1.weight.dtype == torch.int8
549557
assert mlp.fc2.weight.dtype == torch.int8
550558
assert mlp.fc1.weight.device.type == "cuda"
551559
assert mlp.fc2.weight.device.type == "cuda"
552560

561+
if memory_efficient_backward:
562+
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
563+
o1 = mlp(b1)
564+
assert o1.dtype == torch.float16
565+
assert o1.requires_grad
566+
grad_proj = torch.randn_like(o1)
567+
568+
mlp.zero_grad()
569+
(o1 * grad_proj).sum().backward()
570+
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
571+
scale = grad_ref.abs().mean()
572+
573+
torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
574+
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
575+
assert (idx == 0).sum().item() <= b1.numel() * 0.005
576+
553577

554578
def test_linear8bitlt_fp32_bias():
555579
# casts model to fp16 -> int8 automatically

0 commit comments

Comments
 (0)