17
17
from instructlab .training .config import (
18
18
DataProcessArgs ,
19
19
DistributedBackend ,
20
+ LoraOptions ,
20
21
TorchrunArgs ,
21
22
TrainingArgs ,
22
23
)
23
24
from instructlab .training .main_ds import run_training
24
25
25
26
MINIMAL_TRAINING_ARGS = {
26
27
"max_seq_len" : 140 , # this config fits nicely on 4xL40s and may need modification for other setups
27
- "max_batch_len" : 15000 ,
28
+ "max_batch_len" : 5000 ,
28
29
"num_epochs" : 1 ,
29
- "effective_batch_size" : 3840 ,
30
+ "effective_batch_size" : 128 ,
30
31
"save_samples" : 0 ,
31
32
"learning_rate" : 1e-4 ,
32
33
"warmup_steps" : 1 ,
51
52
RUNNER_CPUS_EXPECTED = 4
52
53
53
54
# Number of samples to randomly sample from the processed dataset for faster training
54
- NUM_SAMPLES_TO_KEEP = 5000
55
+ NUM_SAMPLES_TO_KEEP = 2500
55
56
56
57
57
58
@pytest .fixture (scope = "module" )
@@ -232,25 +233,36 @@ def cached_training_data(
232
233
@pytest .mark .parametrize (
233
234
"dist_backend" , [DistributedBackend .FSDP , DistributedBackend .DEEPSPEED ]
234
235
)
235
- @pytest .mark .parametrize ("cpu_offload" , [True , False ])
236
+ @pytest .mark .parametrize ("cpu_offload" , [False , True ])
237
+ @pytest .mark .parametrize ("lora_rank" , [0 ])
238
+ @pytest .mark .parametrize ("use_liger" , [False , True ])
236
239
def test_training_feature_matrix (
237
240
cached_test_model : pathlib .Path ,
238
241
cached_training_data : pathlib .Path ,
239
242
checkpoint_dir : pathlib .Path ,
240
243
prepared_data_dir : pathlib .Path ,
244
+ use_liger : bool ,
245
+ lora_rank : int ,
241
246
cpu_offload : bool ,
242
247
dist_backend : DistributedBackend ,
243
248
) -> None :
249
+ torch_args = TorchrunArgs (** DEFAULT_TORCHRUN_ARGS )
244
250
train_args = TrainingArgs (
245
251
model_path = str (cached_test_model ),
246
252
data_path = str (cached_training_data ),
247
253
data_output_dir = str (prepared_data_dir ),
248
254
ckpt_output_dir = str (checkpoint_dir ),
255
+ lora = LoraOptions (rank = lora_rank ),
256
+ use_liger = use_liger ,
249
257
** MINIMAL_TRAINING_ARGS ,
250
258
)
251
259
252
260
train_args .distributed_backend = dist_backend
253
261
262
+ if lora_rank > 0 :
263
+ # LoRA doesn't support full state saving.
264
+ train_args .accelerate_full_state_at_epoch = False
265
+
254
266
if dist_backend == DistributedBackend .FSDP :
255
267
train_args .fsdp_options .cpu_offload_params = cpu_offload
256
268
else :
@@ -259,6 +271,4 @@ def test_training_feature_matrix(
259
271
pytest .xfail ("DeepSpeed CPU Adam isn't currently building correctly" )
260
272
train_args .deepspeed_options .cpu_offload_optimizer = cpu_offload
261
273
262
- torch_args = TorchrunArgs (** DEFAULT_TORCHRUN_ARGS )
263
-
264
274
run_training (torch_args = torch_args , train_args = train_args )
0 commit comments