Skip to content

Commit c9f331c

Browse files
Fix checkpoints for Gaudi
remove _orig_mod prefix from checkpoints from torch.compile trained model Signed-off-by: Jianhong-Zhang <[email protected]>
1 parent 0730062 commit c9f331c

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

src/instructlab/training/hpu_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import torch
23
from functools import lru_cache
34

@@ -47,3 +48,15 @@ def simple_bucket(length):
4748

4849
def bucket(length):
4950
return simple_bucket(length)
51+
52+
53+
def save_hpu_model(model, output_dir):
54+
from safetensors.torch import save_file
55+
56+
state_dict = model.state_dict()
57+
remove_prefix = "_orig_mod."
58+
clean_state_dict = {
59+
k[len(remove_prefix) :] if k.startswith(remove_prefix) else k: v
60+
for k, v in state_dict.items()
61+
}
62+
save_file(clean_state_dict, os.path.join(output_dir, "model.safetensors"))

src/instructlab/training/utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@
5050
QuantizeDataType,
5151
TrainingArgs,
5252
)
53-
from instructlab.training.hpu_utils import is_torch_hpu_available, bucket
53+
from instructlab.training.hpu_utils import (
54+
is_torch_hpu_available,
55+
bucket,
56+
save_hpu_model,
57+
)
5458

5559
logger = logging.getLogger("instructlab.training")
5660

@@ -1033,12 +1037,15 @@ def _get_state_dict_patched(model, unwrap=False):
10331037
model.module.unmerge_adapter()
10341038

10351039
if not is_lora:
1036-
accelerator.save_model(
1037-
model,
1038-
save_directory=output_dir,
1039-
max_shard_size="5GB",
1040-
safe_serialization=True,
1041-
)
1040+
if is_torch_hpu_available() and os.getenv("HPU_ENABLE_TORCH_COMPILE", False):
1041+
save_hpu_model(model, output_dir)
1042+
else:
1043+
accelerator.save_model(
1044+
model,
1045+
save_directory=output_dir,
1046+
max_shard_size="5GB",
1047+
safe_serialization=True,
1048+
)
10421049

10431050
if args.use_dolomite and convert_dolomite and accelerator.is_main_process:
10441051
# export doesnt like the directory to exist

0 commit comments

Comments
 (0)