Skip to content

Commit 7948dae

Browse files
committed
fix bfloat16 as it's not supported by numpy directly
1 parent 8696f03 commit 7948dae

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

clt/activation_generation/generator.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from collections import defaultdict
4242
import psutil
4343

44+
from clt.training.utils import to_bfloat16
45+
4446
try:
4547
import GPUtil
4648
except ImportError:
@@ -764,13 +766,14 @@ def _write_chunk(
764766

765767
# Convert to numpy
766768
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{lid}_convert_numpy"):
767-
inp_np = inp_perm.to(self.torch_dtype).numpy()
768-
tgt_np = tgt_perm.to(self.torch_dtype).numpy()
769-
770-
# Handle bfloat16 conversion
771-
if h5py_dtype_str == "uint16" and inp_np.dtype == np.dtype("bfloat16"):
772-
inp_np = inp_np.view(np.uint16)
773-
tgt_np = tgt_np.view(np.uint16)
769+
# Handle bfloat16 conversion
770+
if h5py_dtype_str == "uint16":
771+
inp_np = to_bfloat16(inp_perm.to(torch.float32).numpy())
772+
tgt_np = to_bfloat16(tgt_perm.to(torch.float32).numpy())
773+
774+
else:
775+
inp_np = inp_perm.to(self.torch_dtype).numpy()
776+
tgt_np = tgt_perm.to(self.torch_dtype).numpy()
774777

775778
# Store prepared data
776779
layer_data[lid] = (inp_np, tgt_np)

clt/training/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import datetime
22

3+
import numpy as np
4+
35

46
# Helper function to format elapsed time
57
def _format_elapsed_time(seconds: float) -> str:
@@ -11,3 +13,7 @@ def _format_elapsed_time(seconds: float) -> str:
1113
return f"{td.days * 24 + hours:02d}:{minutes:02d}:{seconds:02d}"
1214
else:
1315
return f"{minutes:02d}:{seconds:02d}"
16+
17+
18+
def to_bfloat16(x: np.ndarray) -> np.ndarray:
19+
return np.frombuffer(np.array(x, dtype=np.float32).tobytes()[::2], dtype=np.uint16)

0 commit comments

Comments
 (0)