File tree Expand file tree Collapse file tree 2 files changed +16
-7
lines changed Expand file tree Collapse file tree 2 files changed +16
-7
lines changed Original file line number Diff line number Diff line change 41
41
from collections import defaultdict
42
42
import psutil
43
43
44
+ from clt .training .utils import torch_bfloat16_to_numpy_uint16
45
+
44
46
try :
45
47
import GPUtil
46
48
except ImportError :
@@ -764,13 +766,13 @@ def _write_chunk(
764
766
765
767
# Convert to numpy
766
768
with self ._conditional_measure (f"chunk_{ chunk_idx } _layer_{ lid } _convert_numpy" ):
767
- inp_np = inp_perm . to ( self . torch_dtype ). numpy ()
768
- tgt_np = tgt_perm . to ( self . torch_dtype ). numpy ()
769
-
770
- # Handle bfloat16 conversion
771
- if h5py_dtype_str == "uint16" and inp_np . dtype == np . dtype ( "bfloat16" ) :
772
- inp_np = inp_np . view ( np . uint16 )
773
- tgt_np = tgt_np . view ( np . uint16 )
769
+ # Handle bfloat16 conversion
770
+ if h5py_dtype_str == "uint16" :
771
+ inp_np = torch_bfloat16_to_numpy_uint16 ( inp_perm )
772
+ tgt_np = torch_bfloat16_to_numpy_uint16 ( tgt_perm )
773
+ else :
774
+ inp_np = inp_perm . to ( self . torch_dtype ). numpy ( )
775
+ tgt_np = tgt_perm . to ( self . torch_dtype ). numpy ( )
774
776
775
777
# Store prepared data
776
778
layer_data [lid ] = (inp_np , tgt_np )
Original file line number Diff line number Diff line change 1
1
import datetime
2
2
3
+ import numpy as np
4
+ import torch
5
+
3
6
4
7
# Helper function to format elapsed time
5
8
def _format_elapsed_time (seconds : float ) -> str :
@@ -11,3 +14,7 @@ def _format_elapsed_time(seconds: float) -> str:
11
14
return f"{ td .days * 24 + hours :02d} :{ minutes :02d} :{ seconds :02d} "
12
15
else :
13
16
return f"{ minutes :02d} :{ seconds :02d} "
17
+
18
+
19
+ def torch_bfloat16_to_numpy_uint16 (x : torch .Tensor ) -> np .ndarray :
20
+ return np .frombuffer (x .float ().numpy ().tobytes (), dtype = np .uint16 )[1 ::2 ].reshape (x .shape )
You can’t perform that action at this time.
0 commit comments