Skip to content

Commit 6e6c6eb

Browse files
authored
Merge pull request #35 from aws/release-1.4.0
Sagemaker Hyperpod Recipes Release 1.4.0
2 parents 8e4e29c + d77a6f0 commit 6e6c6eb

File tree

6 files changed

+372
-1
lines changed

6 files changed

+372
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Amazon SageMaker HyperPod recipes include built-in support for:
1414
- Automated distributed checkpointing
1515
- Distributed optimizer
1616
- Accelerators: NVIDIA H100 (ml.p5), NVIDIA A100 (ml.p4), and AWS Trainium (ml.trn1)
17-
- Fine-tuning: Full, QLoRA, LoRA
17+
- Fine-tuning: Full, QLoRA, LoRA, DPO
1818
- AWS Instances: ml.p5.48xlarge, ml.p4d.24xlarge, and ml.trn1.32xlarge instance families
1919
- Supported Models: DeepSeek R1, DeepSeek R1 Distill Llama, DeepSeek R1 Distill Qwen, Llama, Mistral, Mixtral models
2020
- Model Evaluation: [Tensorboard](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.tensorboard.html#module-lightning.pytorch.loggers.tensorboard), [MLflow](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.mlflow.html), [Wandb](https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.WandbLogger.html) - feel free to add any key word arguments to the Logger classes by using their associated kwargs config

launcher/nemo/stages.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,9 @@ def _make_launch_docker_container_text(self):
299299
if OmegaConf.select(self.cfg, "recipes.model.model_type", default=None) == "deepseek_r1":
300300
transformers_upgrade_cmd = "pip install transformers==4.48.2"
301301
post_launch_commands.append(transformers_upgrade_cmd)
302+
if OmegaConf.select(self.cfg, "recipes.model.model_type", default=None) == "llama_v4":
303+
transformers_upgrade_cmd = "pip install transformers==4.51.1"
304+
post_launch_commands.append(transformers_upgrade_cmd)
302305

303306
launch_docker_container_text.append(f' "{image}" sleep infinity')
304307
launch_docker_container_text.append("")
@@ -421,6 +424,10 @@ def _make_train_script_text(self, stage_cfg_path=None, port=41000) -> str:
421424
transformers_upgrade_cmd = "pip install transformers==4.48.2"
422425
script_text.append("")
423426
script_text.append(transformers_upgrade_cmd)
427+
if OmegaConf.select(self.cfg, "recipes.model.model_type", default=None) == "llama_v4":
428+
transformers_upgrade_cmd = "pip install transformers==4.51.1"
429+
script_text.append("")
430+
script_text.append(transformers_upgrade_cmd)
424431

425432
script_text.append("")
426433
script_text.append(self._make_custom_call_string(stage_cfg_path))
@@ -757,6 +764,9 @@ def update_stage_specific_k8s_values(self, values_template):
757764
if OmegaConf.select(self.cfg, "recipes.model.model_type", default=False) == "deepseek_r1":
758765
transformers_upgrade_cmd = "pip install transformers==4.48.2"
759766
values_template.trainingConfig.pre_script.append(transformers_upgrade_cmd)
767+
if OmegaConf.select(self.cfg, "recipes.model.model_type", default=None) == "llama_v4":
768+
transformers_upgrade_cmd = "pip install transformers==4.51.1"
769+
values_template.trainingConfig.pre_script.append(transformers_upgrade_cmd)
760770

761771
return values_template
762772

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/bin/bash
2+
3+
# Original Copyright (c), NVIDIA CORPORATION. Modifications © Amazon.com
4+
5+
#Users should setup their cluster type in /recipes_collection/config.yaml
6+
7+
SAGEMAKER_TRAINING_LAUNCHER_DIR=${SAGEMAKER_TRAINING_LAUNCHER_DIR:-"$(pwd)"}
8+
9+
HF_MODEL_NAME_OR_PATH="${HF_MODEL_NAME_OR_PATH}" # HuggingFace pretrained model name or path
10+
HF_ACCESS_TOKEN="${HF_ACCESS_TOKEN}" # Optional HuggingFace access token
11+
12+
TRAIN_DIR="${TRAIN_DIR}" # Location of training dataset
13+
VAL_DIR="${VAL_DIR}" # Location of validation dataset
14+
15+
EXP_DIR="${EXP_DIR}" # Location to save experiment info including logging, checkpoints, ect
16+
17+
18+
HYDRA_FULL_ERROR=1 python3 "${SAGEMAKER_TRAINING_LAUNCHER_DIR}/main.py" \
19+
recipes=fine-tuning/llama/hf_llama3_8b_seq8k_gpu_dpo \
20+
base_results_dir="${SAGEMAKER_TRAINING_LAUNCHER_DIR}/results" \
21+
recipes.run.name="hf-llama3-8b-dpo" \
22+
recipes.exp_manager.exp_dir="$EXP_DIR" \
23+
recipes.trainer.num_nodes=1 \
24+
recipes.model.train_batch_size=2 \
25+
recipes.model.data.train_dir="$TRAIN_DIR" \
26+
recipes.model.data.val_dir="$VAL_DIR" \
27+
recipes.model.hf_model_name_or_path="$HF_MODEL_NAME_OR_PATH" \
28+
recipes.model.hf_access_token="$HF_ACCESS_TOKEN" \
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/bin/bash
2+
3+
# Original Copyright (c), NVIDIA CORPORATION. Modifications © Amazon.com
4+
5+
#Users should setup their cluster type in /recipes_collection/config.yaml
6+
7+
SAGEMAKER_TRAINING_LAUNCHER_DIR=${SAGEMAKER_TRAINING_LAUNCHER_DIR:-"$(pwd)"}
8+
9+
HF_MODEL_NAME_OR_PATH="${HF_MODEL_NAME_OR_PATH}" # HuggingFace pretrained model name or path
10+
HF_ACCESS_TOKEN="${HF_ACCESS_TOKEN}" # Optional HuggingFace access token
11+
12+
TRAIN_DIR="${TRAIN_DIR}" # Location of training dataset
13+
VAL_DIR="${VAL_DIR}" # Location of validation dataset
14+
15+
EXP_DIR="${EXP_DIR}" # Location to save experiment info including logging, checkpoints, etc.
16+
17+
18+
HYDRA_FULL_ERROR=1 python3 "${SAGEMAKER_TRAINING_LAUNCHER_DIR}/main.py" \
19+
recipes=fine-tuning/llama/hf_llama4_17b_16e_seq8k_gpu_lora_text_to_text \
20+
base_results_dir="${SAGEMAKER_TRAINING_LAUNCHER_DIR}/results" \
21+
recipes.run.name="hf-llama-4-17b-16e-lora" \
22+
recipes.exp_manager.exp_dir="$EXP_DIR" \
23+
recipes.trainer.num_nodes=1 \
24+
recipes.model.train_batch_size=1 \
25+
recipes.model.data.train_dir="$TRAIN_DIR" \
26+
recipes.model.data.val_dir="$VAL_DIR" \
27+
recipes.model.hf_model_name_or_path="$HF_MODEL_NAME_OR_PATH" \
28+
recipes.model.hf_access_token="$HF_ACCESS_TOKEN" \
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Original Copyright (c), NVIDIA CORPORATION. Modifications © Amazon.com
2+
3+
# Basic run information configs
4+
run:
5+
name: llama-8b
6+
results_dir: ${base_results_dir}/${.name}
7+
time_limit: "6-00:00:00"
8+
model_type: hf # huggingface for our recipes
9+
10+
# Basic pytorch lightning trainer config
11+
trainer:
12+
devices: 8
13+
num_nodes: 1
14+
accelerator: gpu
15+
precision: bf16
16+
max_steps: 50
17+
log_every_n_steps: 1
18+
val_check_interval: 1
19+
limit_val_batches: 0 # Number of batches per each validation run, set to 0 to disable validation.
20+
21+
# Basic pytorch lightning experiment config
22+
# Config for checkpoint/tensorboard etc
23+
exp_manager:
24+
exp_dir: null
25+
name: experiment
26+
# experiment loggers
27+
create_tensorboard_logger: False
28+
summary_writer_kwargs: {"save_dir" : "${recipes.exp_manager.exp_dir}/tensorboard"}
29+
create_mlflow_logger: False
30+
mlflow_logger_kwargs: {"tracking_uri" : "${recipes.exp_manager.exp_dir}/mlflow"}
31+
create_wandb_logger: False
32+
wandb_logger_kwargs: {"save_dir" : "${recipes.exp_manager.exp_dir}"} # wandb creates a wandb folder by default
33+
create_checkpoint_callback: True
34+
# Configs to save checkpoint with a fixed interval
35+
# Note: These config will not work with auto checkpoint mode
36+
checkpoint_callback_params:
37+
# Set save_top_k = 0 to disable sharded checkpointing
38+
save_top_k: 0
39+
every_n_train_steps: 10
40+
monitor: "step"
41+
mode: "max"
42+
save_last: False
43+
checkpoint_dir: ${recipes.exp_manager.exp_dir}/checkpoints/
44+
resume_from_checkpoint: null
45+
# Enable auto_checkpoint to automatically calculate the checkpoint interval and resume from checkpoint
46+
auto_checkpoint:
47+
enabled: False
48+
export_full_model:
49+
# Set every_n_train_steps = 0 to disable full checkpointing
50+
every_n_train_steps: 0
51+
save_last: True
52+
53+
################# Predefined configs ##########################
54+
use_smp_model: True # Enable sagemaker model parallelism
55+
distributed_backend: nccl
56+
57+
58+
# Model training configs
59+
model:
60+
model_type: llama_v3
61+
# Base configs
62+
train_batch_size: 2
63+
val_batch_size: 1
64+
seed: 12345
65+
grad_clip: 1.0
66+
log_reduced_training_loss: True
67+
68+
# Memory saving / distributed training configs
69+
tensor_model_parallel_degree: 1
70+
expert_model_parallel_degree: 1
71+
context_parallel_degree: 1
72+
moe: False
73+
activation_checkpointing: True
74+
activation_loading_horizon: 2
75+
delayed_param: True
76+
offload_activations: False
77+
78+
# FSDP Configs
79+
sharding_strategy: hybrid_shard
80+
forward_prefetch: True
81+
shard_degree: 8
82+
backward_fetch_policy: backward_pre
83+
auto_wrap_policy: transformer_auto_wrap_policy
84+
limit_all_gathers: true
85+
use_orig_param: False
86+
87+
# FP8 config
88+
fp8: False
89+
90+
# Model architecture
91+
max_context_width: 8192
92+
max_position_embeddings: ${.max_context_width}
93+
num_hidden_layers: 32
94+
hidden_size: 4096
95+
num_attention_heads: 32
96+
intermediate_size: 14336
97+
initializer_range: 0.02
98+
layernorm_epsilon: 1e-5
99+
vocab_size: 128256
100+
num_key_value_heads: 8
101+
use_flash_attention: True
102+
rope_theta: 500000.0
103+
104+
# rope scaling for llama3
105+
rope_scaling:
106+
rope_type: llama3
107+
factor: 8.0
108+
high_freq_factor: 4.0
109+
low_freq_factor: 1.0
110+
original_max_position_embeddings: 8192
111+
112+
# Finetuning config
113+
do_finetune: True
114+
# The path to resume from, needs to be HF compatible
115+
hf_model_name_or_path: null
116+
hf_access_token: null
117+
# PEFT config
118+
peft:
119+
peft_type: null # lora
120+
# DPO config
121+
dpo:
122+
enabled: True
123+
beta: 0.1
124+
label_smoothing: 0.0
125+
126+
precision: ${recipes.trainer.precision}
127+
################# End of Predefined configs ##########################
128+
129+
# Learning rate and optimizer configs
130+
lr_decay_iters: ${recipes.trainer.max_steps}
131+
# Optimizer
132+
optim:
133+
name: adamw
134+
lr: 1e-6
135+
weight_decay: 0.01
136+
betas:
137+
- 0.9
138+
- 0.98
139+
sched:
140+
name: CosineAnnealing
141+
warmup_steps: 0
142+
constant_steps: 0
143+
min_lr: 1e-7
144+
145+
# Data configs
146+
data:
147+
train_dir: null
148+
val_dir: null
149+
dataset_type: hf
150+
use_synthetic_data: False
151+
152+
# Profiling configs
153+
# Viztracer profiling options
154+
viztracer:
155+
enabled: false

0 commit comments

Comments
 (0)