Skip to content

Commit 3623275

Browse files
authored
Merge pull request #30 from ilyalasy/bfloat_fix
Bfloat fix
2 parents 8696f03 + fdee725 commit 3623275

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

clt/activation_generation/generator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from collections import defaultdict
4242
import psutil
4343

44+
from clt.training.utils import torch_bfloat16_to_numpy_uint16
45+
4446
try:
4547
import GPUtil
4648
except ImportError:
@@ -764,13 +766,13 @@ def _write_chunk(
764766

765767
# Convert to numpy
766768
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{lid}_convert_numpy"):
767-
inp_np = inp_perm.to(self.torch_dtype).numpy()
768-
tgt_np = tgt_perm.to(self.torch_dtype).numpy()
769-
770-
# Handle bfloat16 conversion
771-
if h5py_dtype_str == "uint16" and inp_np.dtype == np.dtype("bfloat16"):
772-
inp_np = inp_np.view(np.uint16)
773-
tgt_np = tgt_np.view(np.uint16)
769+
# Handle bfloat16 conversion
770+
if h5py_dtype_str == "uint16":
771+
inp_np = torch_bfloat16_to_numpy_uint16(inp_perm)
772+
tgt_np = torch_bfloat16_to_numpy_uint16(tgt_perm)
773+
else:
774+
inp_np = inp_perm.to(self.torch_dtype).numpy()
775+
tgt_np = tgt_perm.to(self.torch_dtype).numpy()
774776

775777
# Store prepared data
776778
layer_data[lid] = (inp_np, tgt_np)

clt/training/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import datetime
22

3+
import numpy as np
4+
import torch
5+
36

47
# Helper function to format elapsed time
58
def _format_elapsed_time(seconds: float) -> str:
@@ -11,3 +14,7 @@ def _format_elapsed_time(seconds: float) -> str:
1114
return f"{td.days * 24 + hours:02d}:{minutes:02d}:{seconds:02d}"
1215
else:
1316
return f"{minutes:02d}:{seconds:02d}"
17+
18+
19+
def torch_bfloat16_to_numpy_uint16(x: torch.Tensor) -> np.ndarray:
20+
return np.frombuffer(x.float().numpy().tobytes(), dtype=np.uint16)[1::2].reshape(x.shape)

0 commit comments

Comments
 (0)