File tree Expand file tree Collapse file tree 2 files changed +16
-7
lines changed Expand file tree Collapse file tree 2 files changed +16
-7
lines changed Original file line number Diff line number Diff line change 41
41
from collections import defaultdict
42
42
import psutil
43
43
44
+ from clt .training .utils import to_bfloat16
45
+
44
46
try :
45
47
import GPUtil
46
48
except ImportError :
@@ -764,13 +766,14 @@ def _write_chunk(
764
766
765
767
# Convert to numpy
766
768
with self ._conditional_measure (f"chunk_{ chunk_idx } _layer_{ lid } _convert_numpy" ):
767
- inp_np = inp_perm .to (self .torch_dtype ).numpy ()
768
- tgt_np = tgt_perm .to (self .torch_dtype ).numpy ()
769
-
770
- # Handle bfloat16 conversion
771
- if h5py_dtype_str == "uint16" and inp_np .dtype == np .dtype ("bfloat16" ):
772
- inp_np = inp_np .view (np .uint16 )
773
- tgt_np = tgt_np .view (np .uint16 )
769
+ # Handle bfloat16 conversion
770
+ if h5py_dtype_str == "uint16" :
771
+ inp_np = to_bfloat16 (inp_perm .to (torch .float32 ).numpy ())
772
+ tgt_np = to_bfloat16 (tgt_perm .to (torch .float32 ).numpy ())
773
+
774
+ else :
775
+ inp_np = inp_perm .to (self .torch_dtype ).numpy ()
776
+ tgt_np = tgt_perm .to (self .torch_dtype ).numpy ()
774
777
775
778
# Store prepared data
776
779
layer_data [lid ] = (inp_np , tgt_np )
Original file line number Diff line number Diff line change 1
1
import datetime
2
2
3
+ import numpy as np
4
+
3
5
4
6
# Helper function to format elapsed time
5
7
def _format_elapsed_time (seconds : float ) -> str :
@@ -11,3 +13,7 @@ def _format_elapsed_time(seconds: float) -> str:
11
13
return f"{ td .days * 24 + hours :02d} :{ minutes :02d} :{ seconds :02d} "
12
14
else :
13
15
return f"{ minutes :02d} :{ seconds :02d} "
16
+
17
+
18
+ def to_bfloat16 (x : np .ndarray ) -> np .ndarray :
19
+ return np .frombuffer (np .array (x , dtype = np .float32 ).tobytes ()[::2 ], dtype = np .uint16 )
You can’t perform that action at this time.
0 commit comments