Skip to content

Commit 4d078b6

Browse files
committed
rebase v2
1 parent b8239f5 commit 4d078b6

File tree

7 files changed

+354
-314
lines changed

7 files changed

+354
-314
lines changed

.github/workflows/integration_test_8gpu_torchft.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,5 @@ jobs:
4949
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 > /dev/null 2>&1 &
5050
echo "ft_integration_test"
5151
# Getting error - Cuda failure 217 'peer access is not supported between these two devices'
52-
python -m tests.integration_tests_ft artifacts-to-be-uploaded --ngpu 8
52+
python -m tests.integration_tests.integration_tests_ft artifacts-to-be-uploaded --ngpu 8
5353
# pkill -9 torchft_lighthouse

tests/integration_tests/__init__.py

Whitespace-only changes.

tests/integration_tests/integration_tests.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -81,27 +81,6 @@ def build_core_functionality_tests() -> List[TestCaseConfigs]:
8181
"Checkpoint Integration Test - Save Load Full Checkpoint",
8282
"full_checkpoint",
8383
),
84-
TestCaseConfigs(
85-
[
86-
[
87-
"--checkpoint.enable_checkpoint",
88-
"--checkpoint.last_save_model_weights_only",
89-
],
90-
],
91-
"Checkpoint Integration Test - Save Model Weights Only fp32",
92-
"last_save_model_weights_only_fp32",
93-
),
94-
TestCaseConfigs(
95-
[
96-
[
97-
"--checkpoint.enable_checkpoint",
98-
"--checkpoint.last_save_model_weights_only",
99-
"--checkpoint.export_dtype bfloat16",
100-
],
101-
],
102-
"Checkpoint Integration Test - Save Model Weights Only bf16",
103-
"last_save_model_weights_only_bf16",
104-
),
10584
TestCaseConfigs(
10685
[
10786
[

tests/integration_tests_ft.py renamed to tests/integration_tests/integration_tests_ft.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,45 +11,41 @@
1111
import subprocess
1212
from collections import defaultdict
1313

14-
from tests.integration_tests import OverrideDefinitions
14+
from .integration_tests import TestCaseConfigs
1515

1616
logging.basicConfig(level=logging.INFO)
1717
logger = logging.getLogger(__name__)
1818

19-
try:
20-
import tomllib
21-
except ModuleNotFoundError:
22-
import tomli as tomllib
23-
2419

2520
def build_test_list():
2621
"""
27-
key is the config file name and value is a list of OverrideDefinitions
22+
key is the config file name and value is a list of TestCaseConfigs
2823
that is used to generate variations of integration tests based on the
2924
same root config file.
3025
"""
31-
integration_tests_flavors = defaultdict(list)
32-
integration_tests_flavors["debug_model.toml"] = [
33-
OverrideDefinitions(
26+
integration_tests_flavors = []
27+
integration_tests_flavors.append([
28+
TestCaseConfigs(
3429
[
3530
["--training.steps 10", "--checkpoint.enable_checkpoint"],
3631
],
3732
"Default TorchFT integration test",
3833
"default_torchft",
3934
ngpu=8,
4035
)
41-
]
36+
])
4237
return integration_tests_flavors
4338

4439

4540
def _run_cmd(cmd):
4641
return subprocess.run([cmd], text=True, shell=True)
4742

4843

49-
def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
44+
def run_single_test(test_flavor: TestCaseConfigs, model_name: str, full_path: str, output_dir: str):
5045
# run_test supports sequence of tests.
5146
test_name = test_flavor.test_name
5247
dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}"
48+
model_name_arg = f"--model.name {model_name}"
5349

5450
# Use all 8 GPUs in a single replica
5551
# TODO: Use two replica groups
@@ -70,6 +66,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
7066
)
7167

7268
cmd += " " + dump_folder_arg
69+
cmd += " " + model_name_arg
7370
if override_arg:
7471
cmd += " " + " ".join(override_arg)
7572

@@ -100,35 +97,46 @@ def run_tests(args):
10097
if args.ngpu < 8:
10198
logger.info("Skipping TorchFT integration tests as we need 8 GPUs.")
10299
return
103-
104-
for config_file in os.listdir(args.config_dir):
105-
if not config_file.endswith(".toml"):
106-
continue
107-
108-
full_path = os.path.join(args.config_dir, config_file)
109-
with open(full_path, "rb") as f:
110-
config = tomllib.load(f)
111-
is_integration_test = config["job"].get("use_for_integration_test", False)
112-
if not is_integration_test:
100+
101+
for test_flavor in integration_tests_flavors:
102+
model_names = test_flavor.supported_models
103+
for model_name in model_names:
104+
# Filter by test_name if specified
105+
if args.test_name != "all" and test_flavor.test_name != args.test_name:
113106
continue
114107

115-
for test_flavor in integration_tests_flavors[config_file]:
116-
if not (args.test == "all" or test_flavor.test_name == args.test):
117-
continue
118-
119-
run_test(test_flavor, full_path, args.output_dir)
108+
# Check if config file exists
109+
assert args.config_path.endswith(
110+
".toml"
111+
), "Base config path must end with .toml"
112+
assert os.path.exists(
113+
args.config_path
114+
), f"Base config path {args.config_path} does not exist"
115+
116+
# Check if we have enough GPUs
117+
if args.ngpu < test_flavor.ngpu:
118+
logger.info(
119+
f"Skipping test {test_flavor.test_name} that requires {test_flavor.ngpu} gpus,"
120+
f" because --ngpu arg is {args.ngpu}"
121+
)
122+
else:
123+
run_single_test(
124+
test_flavor, model_name, args.config_path, args.output_dir
125+
)
120126

121127

122128
def main():
123129
parser = argparse.ArgumentParser()
124130
parser.add_argument("output_dir")
125131
parser.add_argument(
126-
"--config_dir", default="./torchtitan/models/llama3/train_configs"
132+
"--config_path",
133+
default="./tests/integration_tests/base_config.toml",
134+
help="Base config path for integration tests. This is the config that will be used as a base for all tests.",
127135
)
128136
parser.add_argument(
129-
"--test",
137+
"--test_name",
130138
default="all",
131-
help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)",
139+
help="Specific test name to run (e.g., 'tp_only', 'full_checkpoint'). Use 'all' to run all tests (default: all)",
132140
)
133141
parser.add_argument("--ngpu", default=8, type=int)
134142
args = parser.parse_args()

tests/integration_tests/integration_tests_h100.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,16 @@
77
import argparse
88
import logging
99
import os
10-
<<<<<<< HEAD
11-
import subprocess
12-
from collections import defaultdict
13-
14-
from .integration_tests import OverrideDefinitions
15-
=======
1610

1711
from .integration_tests import run_single_test, TestCaseConfigs
18-
>>>>>>> 2dfda3e (refactor v1)
1912

2013
logging.basicConfig(level=logging.INFO)
2114
logger = logging.getLogger(__name__)
2215

2316

24-
<<<<<<< HEAD
25-
26-
def build_test_list():
27-
=======
2817
def build_h100_test_list():
29-
>>>>>>> 2dfda3e (refactor v1)
3018
"""
31-
key is the config file name and value is a list of OverrideDefinitions
19+
key is the config file name and value is a list of TestCaseConfigs
3220
that is used to generate variations of integration tests based on the
3321
same root config file.
3422
"""
@@ -102,11 +90,8 @@ def build_h100_test_list():
10290
return integration_tests_flavors
10391

10492

105-
def run_h100_tests(args):
106-
# If user specifies a specific test name, the test_suite argument is ignored
107-
if args.test_name != "all":
108-
args.test_suite = "all"
109-
93+
def run_tests(args):
94+
"""Run all H100 integration tests"""
11095
# build integration tests list
11196
test_list = build_h100_test_list()
11297

0 commit comments

Comments
 (0)