17
17
from instructlab .training .config import (
18
18
DataProcessArgs ,
19
19
DistributedBackend ,
20
+ LoraOptions ,
20
21
TorchrunArgs ,
21
22
TrainingArgs ,
22
23
)
23
24
from instructlab .training .main_ds import run_training
24
25
25
26
MINIMAL_TRAINING_ARGS = {
26
27
"max_seq_len" : 140 , # this config fits nicely on 4xL40s and may need modification for other setups
27
- "max_batch_len" : 15000 ,
28
+ "max_batch_len" : 5000 ,
28
29
"num_epochs" : 1 ,
29
- "effective_batch_size" : 3840 ,
30
+ "effective_batch_size" : 128 ,
30
31
"save_samples" : 0 ,
31
32
"learning_rate" : 1e-4 ,
32
33
"warmup_steps" : 1 ,
50
51
RUNNER_CPUS_EXPECTED = 4
51
52
52
53
# Number of samples to randomly sample from the processed dataset for faster training
53
- NUM_SAMPLES_TO_KEEP = 5000
54
+ NUM_SAMPLES_TO_KEEP = 2500
54
55
55
56
56
57
@pytest .fixture (scope = "module" )
@@ -231,25 +232,36 @@ def cached_training_data(
231
232
@pytest .mark .parametrize (
232
233
"dist_backend" , [DistributedBackend .FSDP , DistributedBackend .DEEPSPEED ]
233
234
)
234
- @pytest .mark .parametrize ("cpu_offload" , [True , False ])
235
+ @pytest .mark .parametrize ("cpu_offload" , [False , True ])
236
+ @pytest .mark .parametrize ("lora_rank" , [0 ])
237
+ @pytest .mark .parametrize ("use_liger" , [False , True ])
235
238
def test_training_feature_matrix (
236
239
cached_test_model : pathlib .Path ,
237
240
cached_training_data : pathlib .Path ,
238
241
checkpoint_dir : pathlib .Path ,
239
242
prepared_data_dir : pathlib .Path ,
243
+ use_liger : bool ,
244
+ lora_rank : int ,
240
245
cpu_offload : bool ,
241
246
dist_backend : DistributedBackend ,
242
247
) -> None :
248
+ torch_args = TorchrunArgs (** DEFAULT_TORCHRUN_ARGS )
243
249
train_args = TrainingArgs (
244
250
model_path = str (cached_test_model ),
245
251
data_path = str (cached_training_data ),
246
252
data_output_dir = str (prepared_data_dir ),
247
253
ckpt_output_dir = str (checkpoint_dir ),
254
+ lora = LoraOptions (rank = lora_rank ),
255
+ use_liger = use_liger ,
248
256
** MINIMAL_TRAINING_ARGS ,
249
257
)
250
258
251
259
train_args .distributed_backend = dist_backend
252
260
261
+ if lora_rank > 0 :
262
+ # LoRA doesn't support full state saving.
263
+ train_args .accelerate_full_state_at_epoch = False
264
+
253
265
if dist_backend == DistributedBackend .FSDP :
254
266
train_args .fsdp_options .cpu_offload_params = cpu_offload
255
267
else :
@@ -258,6 +270,4 @@ def test_training_feature_matrix(
258
270
pytest .xfail ("DeepSpeed CPU Adam isn't currently building correctly" )
259
271
train_args .deepspeed_options .cpu_offload_optimizer = cpu_offload
260
272
261
- torch_args = TorchrunArgs (** DEFAULT_TORCHRUN_ARGS )
262
-
263
273
run_training (torch_args = torch_args , train_args = train_args )
0 commit comments