| 
 | 1 | +# Original Copyright (c), NVIDIA CORPORATION. Modifications © Amazon.com  | 
 | 2 | + | 
 | 3 | +# Basic run information configs  | 
 | 4 | +run:  | 
 | 5 | +  name: gpt-oss-120b  | 
 | 6 | +  results_dir: ${base_results_dir}/${.name}  | 
 | 7 | +  time_limit: "6-00:00:00"  | 
 | 8 | +  model_type: hf # huggingface for our recipes  | 
 | 9 | + | 
 | 10 | +# Basic pytorch lightning trainer config  | 
 | 11 | +trainer:  | 
 | 12 | +  devices: 8  | 
 | 13 | +  num_nodes: 1  | 
 | 14 | +  accelerator: gpu  | 
 | 15 | +  precision: bf16  | 
 | 16 | +  max_steps: 50  | 
 | 17 | +  log_every_n_steps: 1  | 
 | 18 | +  val_check_interval: 1  | 
 | 19 | +  limit_val_batches: 0 # Number of batches per each validation run, set to 0 to disable validation.  | 
 | 20 | + | 
 | 21 | +# Basic pytorch lightning experiment config  | 
 | 22 | +# Config for checkpoint/tensorboard etc  | 
 | 23 | +exp_manager:  | 
 | 24 | +  exp_dir: null  | 
 | 25 | +  name: experiment  | 
 | 26 | +  # experiment loggers  | 
 | 27 | +  create_tensorboard_logger: False  | 
 | 28 | +  summary_writer_kwargs: {"save_dir" : "${recipes.exp_manager.exp_dir}/tensorboard"}  | 
 | 29 | +  create_mlflow_logger: False  | 
 | 30 | +  mlflow_logger_kwargs: {"tracking_uri" : "${recipes.exp_manager.exp_dir}/mlflow"}  | 
 | 31 | +  create_wandb_logger: False  | 
 | 32 | +  wandb_logger_kwargs: {"save_dir" : "${recipes.exp_manager.exp_dir}"} # wandb creates a wandb folder by default  | 
 | 33 | +  create_checkpoint_callback: True  | 
 | 34 | +  # Configs to save checkpoint with a fixed interval  | 
 | 35 | +  # Note: These config will not work with auto checkpoint mode  | 
 | 36 | +  checkpoint_callback_params:  | 
 | 37 | +    # Set save_top_k = 0 to disable sharded checkpointing  | 
 | 38 | +    save_top_k: 0  | 
 | 39 | +    every_n_train_steps: 10  | 
 | 40 | +    monitor: "step"  | 
 | 41 | +    mode: "max"  | 
 | 42 | +    save_last: False  | 
 | 43 | +  checkpoint_dir: ${recipes.exp_manager.exp_dir}/checkpoints/  | 
 | 44 | +  resume_from_checkpoint: null  | 
 | 45 | +  # Enable auto_checkpoint to automatically calculate the checkpoint interval and resume from checkpoint  | 
 | 46 | +  auto_checkpoint:  | 
 | 47 | +    enabled: False  | 
 | 48 | +  export_full_model:  | 
 | 49 | +    # Set every_n_train_steps = 0 to disable full checkpointing  | 
 | 50 | +    every_n_train_steps: 0  | 
 | 51 | +    save_last: True  | 
 | 52 | + | 
 | 53 | +################# Predefined configs ##########################  | 
 | 54 | +use_smp_model: False # Enable sagemaker model parallelism  | 
 | 55 | +distributed_backend: nccl  | 
 | 56 | + | 
 | 57 | +# Model training configs  | 
 | 58 | +model:  | 
 | 59 | +  model_type: gpt_oss  | 
 | 60 | +  # Base configs  | 
 | 61 | +  train_batch_size: 1 # Batch sizes > 1 are not currently supported  | 
 | 62 | +  val_batch_size: 1  | 
 | 63 | +  seed: 12345  | 
 | 64 | +  grad_clip: 1.0  | 
 | 65 | +  log_reduced_training_loss: True  | 
 | 66 | + | 
 | 67 | +  # Memory saving / distributed training configs  | 
 | 68 | +  tensor_model_parallel_degree: 1  | 
 | 69 | +  expert_model_parallel_degree: 1  | 
 | 70 | +  context_parallel_degree: 1  | 
 | 71 | +  moe: False  | 
 | 72 | +  activation_checkpointing: True  | 
 | 73 | +  activation_loading_horizon: 2  | 
 | 74 | +  delayed_param: True  | 
 | 75 | +  offload_activations: False  | 
 | 76 | + | 
 | 77 | +  # FSDP Configs  | 
 | 78 | +  sharding_strategy: hybrid_shard  | 
 | 79 | +  forward_prefetch: True  | 
 | 80 | +  shard_degree: 8  | 
 | 81 | +  backward_fetch_policy: backward_pre  | 
 | 82 | +  auto_wrap_policy: transformer_auto_wrap_policy  | 
 | 83 | +  limit_all_gathers: true  | 
 | 84 | +  use_orig_param: False  | 
 | 85 | + | 
 | 86 | +  # FP8 config  | 
 | 87 | +  fp8: False  | 
 | 88 | +  fp8_amax_history_len: 1024  | 
 | 89 | +  fp8_amax_compute_algo: max  | 
 | 90 | + | 
 | 91 | +  # Model architecture  | 
 | 92 | +  max_context_width: 4096  | 
 | 93 | +  max_position_embeddings: ${.max_context_width} # 131072  | 
 | 94 | +  num_hidden_layers: 36  | 
 | 95 | +  hidden_size: 2880  | 
 | 96 | +  num_attention_heads: 64  | 
 | 97 | +  intermediate_size: 2880  | 
 | 98 | +  initializer_range: 0.02  | 
 | 99 | +  layernorm_epsilon: 1e-5  | 
 | 100 | +  vocab_size: 201088  | 
 | 101 | +  num_key_value_heads: 8  | 
 | 102 | +  rms_norm_eps: 1e-05  | 
 | 103 | +  use_flash_attention: False # Use the gpt-oss-patch container for kernels-community/vllm-flash-attn3  | 
 | 104 | +  sliding_window: 128  | 
 | 105 | +  use_sliding_window: True  | 
 | 106 | +  num_experts_per_tok: 4  | 
 | 107 | +  num_local_experts: 128  | 
 | 108 | +  moe_load_balancing: 'sinkhorn'  | 
 | 109 | +  global_token_shuffle: True  | 
 | 110 | +  moe_all_to_all_dispatcher: False  | 
 | 111 | +  rope_theta: 150000.0  | 
 | 112 | +  tie_word_embeddings: False  | 
 | 113 | + | 
 | 114 | +  # Finetuning config  | 
 | 115 | +  do_finetune: True  | 
 | 116 | +  # The path to resume from, needs to be HF compatible  | 
 | 117 | +  hf_model_name_or_path: null  | 
 | 118 | +  hf_access_token: null  | 
 | 119 | +  # PEFT config  | 
 | 120 | +  peft:  | 
 | 121 | +    peft_type: lora  | 
 | 122 | +    rank: 16  | 
 | 123 | +    alpha: 32  | 
 | 124 | +    dropout: 0.1  | 
 | 125 | +    target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]  | 
 | 126 | + | 
 | 127 | +  precision: ${recipes.trainer.precision}  | 
 | 128 | +  ################# End of Predefined configs ##########################  | 
 | 129 | + | 
 | 130 | +  # Learning rate and optimizer configs  | 
 | 131 | +  lr_decay_iters: ${recipes.trainer.max_steps}  | 
 | 132 | +  # Optimizer  | 
 | 133 | +  optim:  | 
 | 134 | +    name: adamw  | 
 | 135 | +    lr: 2e-4  | 
 | 136 | +    weight_decay: 0.01  | 
 | 137 | +    betas:  | 
 | 138 | +    - 0.9  | 
 | 139 | +    - 0.95  | 
 | 140 | +    sched:  | 
 | 141 | +      name: CosineAnnealing  | 
 | 142 | +      warmup_steps: 0  | 
 | 143 | +      constant_steps: 0  | 
 | 144 | +      min_lr: 2e-6  | 
 | 145 | + | 
 | 146 | +  # Data configs  | 
 | 147 | +  data:  | 
 | 148 | +    train_dir: null  | 
 | 149 | +    val_dir: null  | 
 | 150 | +    dataset_type: hf  | 
 | 151 | +    use_synthetic_data: False  | 
 | 152 | + | 
 | 153 | +  # Profiling configs  | 
 | 154 | +  # Viztracer profiling options  | 
 | 155 | +  viztracer:  | 
 | 156 | +    enabled: false  | 
0 commit comments