Skip to content

Commit fdee725

Browse files
committed
bugfix
1 parent 7948dae commit fdee725

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

clt/activation_generation/generator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from collections import defaultdict
4242
import psutil
4343

44-
from clt.training.utils import to_bfloat16
44+
from clt.training.utils import torch_bfloat16_to_numpy_uint16
4545

4646
try:
4747
import GPUtil
@@ -768,9 +768,8 @@ def _write_chunk(
768768
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{lid}_convert_numpy"):
769769
# Handle bfloat16 conversion
770770
if h5py_dtype_str == "uint16":
771-
inp_np = to_bfloat16(inp_perm.to(torch.float32).numpy())
772-
tgt_np = to_bfloat16(tgt_perm.to(torch.float32).numpy())
773-
771+
inp_np = torch_bfloat16_to_numpy_uint16(inp_perm)
772+
tgt_np = torch_bfloat16_to_numpy_uint16(tgt_perm)
774773
else:
775774
inp_np = inp_perm.to(self.torch_dtype).numpy()
776775
tgt_np = tgt_perm.to(self.torch_dtype).numpy()

clt/training/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22

33
import numpy as np
4+
import torch
45

56

67
# Helper function to format elapsed time
@@ -15,5 +16,5 @@ def _format_elapsed_time(seconds: float) -> str:
1516
return f"{minutes:02d}:{seconds:02d}"
1617

1718

18-
def to_bfloat16(x: np.ndarray) -> np.ndarray:
19-
return np.frombuffer(np.array(x, dtype=np.float32).tobytes()[::2], dtype=np.uint16)
19+
def torch_bfloat16_to_numpy_uint16(x: torch.Tensor) -> np.ndarray:
20+
return np.frombuffer(x.float().numpy().tobytes(), dtype=np.uint16)[1::2].reshape(x.shape)

0 commit comments

Comments
 (0)