File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change 41
41
from collections import defaultdict
42
42
import psutil
43
43
44
- from clt .training .utils import to_bfloat16
44
+ from clt .training .utils import torch_bfloat16_to_numpy_uint16
45
45
46
46
try :
47
47
import GPUtil
@@ -768,9 +768,8 @@ def _write_chunk(
768
768
with self ._conditional_measure (f"chunk_{ chunk_idx } _layer_{ lid } _convert_numpy" ):
769
769
# Handle bfloat16 conversion
770
770
if h5py_dtype_str == "uint16" :
771
- inp_np = to_bfloat16 (inp_perm .to (torch .float32 ).numpy ())
772
- tgt_np = to_bfloat16 (tgt_perm .to (torch .float32 ).numpy ())
773
-
771
+ inp_np = torch_bfloat16_to_numpy_uint16 (inp_perm )
772
+ tgt_np = torch_bfloat16_to_numpy_uint16 (tgt_perm )
774
773
else :
775
774
inp_np = inp_perm .to (self .torch_dtype ).numpy ()
776
775
tgt_np = tgt_perm .to (self .torch_dtype ).numpy ()
Original file line number Diff line number Diff line change 1
1
import datetime
2
2
3
3
import numpy as np
4
+ import torch
4
5
5
6
6
7
# Helper function to format elapsed time
@@ -15,5 +16,5 @@ def _format_elapsed_time(seconds: float) -> str:
15
16
return f"{ minutes :02d} :{ seconds :02d} "
16
17
17
18
18
- def to_bfloat16 (x : np . ndarray ) -> np .ndarray :
19
- return np .frombuffer (np . array ( x , dtype = np . float32 ).tobytes ()[:: 2 ] , dtype = np .uint16 )
19
+ def torch_bfloat16_to_numpy_uint16 (x : torch . Tensor ) -> np .ndarray :
20
+ return np .frombuffer (x . float (). numpy ( ).tobytes (), dtype = np .uint16 )[ 1 :: 2 ]. reshape ( x . shape )
You can’t perform that action at this time.
0 commit comments