Skip to content

Commit c577659

Browse files
committed
refactor logic
1 parent 76eaad2 commit c577659

File tree

6 files changed

+20
-13
lines changed

6 files changed

+20
-13
lines changed

.github/workflows/integration_test_8gpu_features.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ jobs:
5050
USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
5151
5252
mkdir artifacts-to-be-uploaded
53-
python -m tests.integration_tests.features artifacts-to-be-uploaded --ngpu 8
53+
python -m tests.integration_tests.run_tests --test_suite features artifacts-to-be-uploaded --ngpu 8

.github/workflows/integration_test_8gpu_h100.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ jobs:
5353
mkdir artifacts-to-be-uploaded
5454
5555
# Enable CPP stacktraces for debugging symmetric memory initialization errors.
56-
TORCH_SHOW_CPP_STACKTRACES=1 python -m tests.integration_tests.h100 artifacts-to-be-uploaded --ngpu 8
56+
TORCH_SHOW_CPP_STACKTRACES=1 python -m tests.integration_tests.run_tests --test_suite h100 artifacts-to-be-uploaded --ngpu 8

.github/workflows/integration_test_8gpu_models.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ jobs:
5050
USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
5151
5252
mkdir artifacts-to-be-uploaded
53-
python -m tests.integration_tests.models artifacts-to-be-uploaded --ngpu 8
53+
python -m tests.integration_tests.run_tests --test_suite models artifacts-to-be-uploaded --ngpu 8

tests/integration_tests/ft.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import os
1111
import subprocess
12+
from typing import List
1213

1314
from .features import OverrideDefinitions
1415

@@ -90,14 +91,12 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
9091
)
9192

9293

93-
def run_tests(args):
94-
integration_tests_flavors = build_test_list()
95-
94+
def run_tests(args, test_list: List[OverrideDefinitions]):
9695
if args.ngpu < 8:
9796
logger.info("Skipping TorchFT integration tests as we need 8 GPUs.")
9897
return
9998

100-
for test_flavor in integration_tests_flavors:
99+
for test_flavor in test_list:
101100
# Filter by test_name if specified
102101
if args.test_name != "all" and test_flavor.test_name != args.test_name:
103102
continue
@@ -138,7 +137,11 @@ def main():
138137

139138
if not os.path.exists(args.output_dir):
140139
os.makedirs(args.output_dir)
141-
run_tests(args)
140+
if os.listdir(args.output_dir):
141+
raise RuntimeError("Please provide an empty output directory.")
142+
143+
test_list = build_ft_test_list()
144+
run_tests(args, test_list)
142145

143146

144147
if __name__ == "__main__":

tests/integration_tests/h100.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,11 @@ def build_h100_tests_list():
8080
return integration_tests_flavors
8181

8282

83-
def run_tests(args):
83+
def run_tests(args, test_list=None):
8484
"""Run all H100 integration tests"""
8585
# build integration tests list
86-
test_list = build_test_list()
86+
if test_list is None:
87+
test_list = build_h100_tests_list()
8788

8889
for test_flavor in test_list:
8990
# Filter by test_name if specified
@@ -105,6 +106,8 @@ def run_tests(args):
105106
f" because --ngpu arg is {args.ngpu}"
106107
)
107108
else:
109+
# Import run_single_test from run_tests.py
110+
from tests.integration_tests.run_tests import run_single_test
108111
run_single_test(test_flavor, args.config_path, args.output_dir)
109112

110113

@@ -128,7 +131,9 @@ def main():
128131
os.makedirs(args.output_dir)
129132
if os.listdir(args.output_dir):
130133
raise RuntimeError("Please provide an empty output directory.")
131-
run_tests(args)
134+
135+
test_list = build_h100_tests_list()
136+
run_tests(args, test_list)
132137

133138

134139
if __name__ == "__main__":

tests/integration_tests/run_tests.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
_TEST_SUITES_FUNCTION = {
1818
"features": build_features_test_list,
19-
"ft": build_ft_test_list,
2019
"models": build_model_tests_list,
2120
"h100": build_h100_tests_list,
2221
}
@@ -99,7 +98,7 @@ def main():
9998
parser.add_argument(
10099
"--test_suite",
101100
default="",
102-
choices=["features", "ft", "models", "h100"],
101+
choices=["features", "models", "h100"],
103102
help="Which test suite to run. If not specified, torchtitan composibility tests will be run",
104103
)
105104
parser.add_argument(

0 commit comments

Comments
 (0)