1111import subprocess
1212from collections import defaultdict
1313
14- from tests .integration_tests import OverrideDefinitions
14+ from .integration_tests import TestCaseConfigs
1515
1616logging .basicConfig (level = logging .INFO )
1717logger = logging .getLogger (__name__ )
1818
19- try :
20- import tomllib
21- except ModuleNotFoundError :
22- import tomli as tomllib
23-
2419
2520def build_test_list ():
2621 """
27- key is the config file name and value is a list of OverrideDefinitions
22+ key is the config file name and value is a list of TestCaseConfigs
2823 that is used to generate variations of integration tests based on the
2924 same root config file.
3025 """
31- integration_tests_flavors = defaultdict ( list )
32- integration_tests_flavors [ "debug_model.toml" ] = [
33- OverrideDefinitions (
26+ integration_tests_flavors = []
27+ integration_tests_flavors . append ( [
28+ TestCaseConfigs (
3429 [
3530 ["--training.steps 10" , "--checkpoint.enable_checkpoint" ],
3631 ],
3732 "Default TorchFT integration test" ,
3833 "default_torchft" ,
3934 ngpu = 8 ,
4035 )
41- ]
36+ ])
4237 return integration_tests_flavors
4338
4439
4540def _run_cmd (cmd ):
4641 return subprocess .run ([cmd ], text = True , shell = True )
4742
4843
49- def run_test (test_flavor : OverrideDefinitions , full_path : str , output_dir : str ):
44+ def run_single_test (test_flavor : TestCaseConfigs , model_name : str , full_path : str , output_dir : str ):
5045 # run_test supports sequence of tests.
5146 test_name = test_flavor .test_name
5247 dump_folder_arg = f"--job.dump_folder { output_dir } /{ test_name } "
48+ model_name_arg = f"--model.name { model_name } "
5349
5450 # Use all 8 GPUs in a single replica
5551 # TODO: Use two replica groups
@@ -70,6 +66,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
7066 )
7167
7268 cmd += " " + dump_folder_arg
69+ cmd += " " + model_name_arg
7370 if override_arg :
7471 cmd += " " + " " .join (override_arg )
7572
@@ -100,35 +97,46 @@ def run_tests(args):
10097 if args .ngpu < 8 :
10198 logger .info ("Skipping TorchFT integration tests as we need 8 GPUs." )
10299 return
103-
104- for config_file in os .listdir (args .config_dir ):
105- if not config_file .endswith (".toml" ):
106- continue
107-
108- full_path = os .path .join (args .config_dir , config_file )
109- with open (full_path , "rb" ) as f :
110- config = tomllib .load (f )
111- is_integration_test = config ["job" ].get ("use_for_integration_test" , False )
112- if not is_integration_test :
100+
101+ for test_flavor in integration_tests_flavors :
102+ model_names = test_flavor .supported_models
103+ for model_name in model_names :
104+ # Filter by test_name if specified
105+ if args .test_name != "all" and test_flavor .test_name != args .test_name :
113106 continue
114107
115- for test_flavor in integration_tests_flavors [config_file ]:
116- if not (args .test == "all" or test_flavor .test_name == args .test ):
117- continue
118-
119- run_test (test_flavor , full_path , args .output_dir )
108+ # Check if config file exists
109+ assert args .config_path .endswith (
110+ ".toml"
111+ ), "Base config path must end with .toml"
112+ assert os .path .exists (
113+ args .config_path
114+ ), f"Base config path { args .config_path } does not exist"
115+
116+ # Check if we have enough GPUs
117+ if args .ngpu < test_flavor .ngpu :
118+ logger .info (
119+ f"Skipping test { test_flavor .test_name } that requires { test_flavor .ngpu } gpus,"
120+ f" because --ngpu arg is { args .ngpu } "
121+ )
122+ else :
123+ run_single_test (
124+ test_flavor , model_name , args .config_path , args .output_dir
125+ )
120126
121127
122128def main ():
123129 parser = argparse .ArgumentParser ()
124130 parser .add_argument ("output_dir" )
125131 parser .add_argument (
126- "--config_dir" , default = "./torchtitan/models/llama3/train_configs"
132+ "--config_path" ,
133+ default = "./tests/integration_tests/base_config.toml" ,
134+ help = "Base config path for integration tests. This is the config that will be used as a base for all tests." ,
127135 )
128136 parser .add_argument (
129- "--test " ,
137+ "--test_name " ,
130138 default = "all" ,
131- help = "test to run, acceptable values: `test_name` in `build_test_list` (default: all)" ,
139+ help = "Specific test name to run (e.g., 'tp_only', 'full_checkpoint'). Use 'all' to run all tests (default: all)" ,
132140 )
133141 parser .add_argument ("--ngpu" , default = 8 , type = int )
134142 args = parser .parse_args ()
0 commit comments