File tree Expand file tree Collapse file tree 2 files changed +27
-7
lines changed Expand file tree Collapse file tree 2 files changed +27
-7
lines changed Original file line number Diff line number Diff line change
1
+ import os
1
2
import torch
2
3
from functools import lru_cache
3
4
@@ -47,3 +48,15 @@ def simple_bucket(length):
47
48
48
49
def bucket (length ):
49
50
return simple_bucket (length )
51
+
52
+
53
+ def save_hpu_model (model , output_dir ):
54
+ from safetensors .torch import save_file
55
+
56
+ state_dict = model .state_dict ()
57
+ remove_prefix = "_orig_mod."
58
+ clean_state_dict = {
59
+ k [len (remove_prefix ) :] if k .startswith (remove_prefix ) else k : v
60
+ for k , v in state_dict .items ()
61
+ }
62
+ save_file (clean_state_dict , os .path .join (output_dir , "model.safetensors" ))
Original file line number Diff line number Diff line change 50
50
QuantizeDataType ,
51
51
TrainingArgs ,
52
52
)
53
- from instructlab .training .hpu_utils import is_torch_hpu_available , bucket
53
+ from instructlab .training .hpu_utils import (
54
+ is_torch_hpu_available ,
55
+ bucket ,
56
+ save_hpu_model ,
57
+ )
54
58
55
59
logger = logging .getLogger ("instructlab.training" )
56
60
@@ -1033,12 +1037,15 @@ def _get_state_dict_patched(model, unwrap=False):
1033
1037
model .module .unmerge_adapter ()
1034
1038
1035
1039
if not is_lora :
1036
- accelerator .save_model (
1037
- model ,
1038
- save_directory = output_dir ,
1039
- max_shard_size = "5GB" ,
1040
- safe_serialization = True ,
1041
- )
1040
+ if is_torch_hpu_available () and os .getenv ("HPU_ENABLE_TORCH_COMPILE" , False ):
1041
+ save_hpu_model (model , output_dir )
1042
+ else :
1043
+ accelerator .save_model (
1044
+ model ,
1045
+ save_directory = output_dir ,
1046
+ max_shard_size = "5GB" ,
1047
+ safe_serialization = True ,
1048
+ )
1042
1049
1043
1050
if args .use_dolomite and convert_dolomite and accelerator .is_main_process :
1044
1051
# export doesnt like the directory to exist
You can’t perform that action at this time.
0 commit comments