Skip to content

Commit 90b0ac5

Browse files
committed
Fixed missing bias in bnb.matmul_4bit for inference; more tests.
1 parent dc96e9e commit 90b0ac5

File tree

4 files changed

+46
-12
lines changed

4 files changed

+46
-12
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,9 @@ def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bia
571571
warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
572572
return MatMul4Bit.apply(A, B, out, bias, quant_state)
573573
else:
574-
return F.gemv_4bit(A, B.t(), out, state=quant_state)
574+
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
575+
if bias is not None:
576+
out += bias
577+
return out
575578
else:
576579
return MatMul4Bit.apply(A, B, out, bias, quant_state)

bitsandbytes/functional.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,8 +1512,6 @@ def gemv_4bit(
15121512

15131513
return out
15141514

1515-
1516-
15171515
def igemm(
15181516
A: Tensor,
15191517
B: Tensor,

tests/test_functional.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2364,7 +2364,7 @@ def test_normal_map_tree():
23642364
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
23652365
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
23662366
def test_gemv_4bit(dtype, storage_type, double_quant, kind):
2367-
for dim in [128, 256, 512, 1024, 2048, 4096, 6144]:
2367+
for dim in [128, 256, 512, 1024]:
23682368
#for dim in [4*1024]:
23692369
#for dim in [1*128]:
23702370
errs1 = []
@@ -2525,3 +2525,31 @@ def test_managed():
25252525
# assert (A==17).sum().item() == n*n
25262526

25272527
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
2528+
2529+
2530+
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
2531+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
2532+
@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True'])
2533+
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
2534+
dims = 10
2535+
torch.random.manual_seed(np.random.randint(0, 412424242))
2536+
dims = torch.randint(0, 8192, size=(dims,)).tolist()
2537+
dims = [dim + (64-(dim % 64)) for dim in dims]
2538+
#for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
2539+
for dim in dims:
2540+
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda')
2541+
B = torch.eye(dim, dtype=dtype, device='cuda')
2542+
2543+
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
2544+
C3 = torch.matmul(A, B.t())
2545+
C2 = bnb.matmul_4bit(A, qB.t(), state)
2546+
A.requires_grad = True
2547+
C1 = bnb.matmul_4bit(A, qB.t(), state)
2548+
2549+
torch.testing.assert_close(A, C3)
2550+
torch.testing.assert_close(A, C1)
2551+
torch.testing.assert_close(A, C2)
2552+
#torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
2553+
#torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
2554+
2555+

tests/test_generation.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,25 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f
6565
return tokenizer.decode(outputs[0], skip_special_tokens=True)
6666

6767
models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7']
68-
dtypes = ['nf4', 'fp4', '16bit']
68+
dtypes = ['nf4', 'fp4']
6969
load_in_4bit = [True, False]
7070
values = list(product(models, dtypes))
7171
strfunc = lambda lst: [str(x) for x in lst]
7272
ids = ['_'.join(strfunc(x)) for x in values]
7373
@pytest.fixture(scope='session', params=values, ids=ids)
7474
def model_and_tokenizer(request):
7575
model, tokenizer = get_model_and_tokenizer(request.param)
76-
yield model, tokenizer
76+
yield request.param, model, tokenizer
7777
del model
7878

79+
@pytest.mark.parametrize("DQ", [True, False], ids=['DQ_True', 'DQ_False'])
7980
@pytest.mark.parametrize("inference_kernel", [True, False], ids=['inference_kernel_True', 'inference_kernel_False'])
80-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
81-
def test_pi(model_and_tokenizer, dtype, inference_kernel):
81+
#@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
82+
def test_pi(model_and_tokenizer, inference_kernel, DQ):
83+
print('')
84+
dtype = torch.float16
8285

83-
model, tokenizer = model_and_tokenizer
86+
fixture_config, model, tokenizer = model_and_tokenizer
8487

8588
generation_config = transformers.GenerationConfig(
8689
max_new_tokens=20,
@@ -94,16 +97,16 @@ def test_pi(model_and_tokenizer, dtype, inference_kernel):
9497
#text = 'Please write down the first 50 digits of pi.'
9598
#text = get_prompt_for_generation_eval(text)
9699
#text += ' Sure, here the first 50 digits of pi: 3.14159'
97-
n_cases = 3
100+
n_cases = 6
98101
text = '3.14159'
99102
if hasattr(model.config, 'quantization_config'):
100103
model.config.quantization_config.bnb_4bit_compute_dtype = dtype
104+
model.config.quantization_config.bnb_4bit_use_double_quant = DQ
101105

102106
if not inference_kernel:
103107
text = [text]*n_cases
104108
inputs = tokenizer(text, return_tensors="pt").to('cuda:0')
105109
x = inputs['input_ids']
106-
failure_count = 0
107110
outputs = []
108111
if inference_kernel:
109112
for i in range(n_cases):
@@ -116,10 +119,12 @@ def test_pi(model_and_tokenizer, dtype, inference_kernel):
116119

117120

118121
assert len(outputs) == n_cases
122+
failure_count = 0
119123
for i in range(n_cases):
120124
if not outputs[i][:len(str(math.pi))] == str(math.pi):
121125
failure_count += 1
122-
if failure_count > 1:
126+
failure_max = (2 if fixture_config[0] == 'huggyllama/llama-7b' else 4)
127+
if failure_count > failure_max:
123128
print(math.pi)
124129
for out in outputs:
125130
print(out)

0 commit comments

Comments
 (0)