Skip to content

Commit bacc8b3

Browse files
authored
Merge pull request #3723 from Flamefire/20250520130609_new_pr_pytorch
Improve reliability of PyTorch test reporting
2 parents 2238d46 + dc36a59 commit bacc8b3

File tree

1 file changed

+84
-65
lines changed

1 file changed

+84
-65
lines changed

easybuild/easyblocks/p/pytorch.py

Lines changed: 84 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def find_failed_test_names(tests_out):
106106
# FAILED [0.0623s] dynamo/test_dynamic_shapes.py::DynamicShapesExportTests::test_predispatch - [snip]
107107
regex = r"^(FAILED) (?:\[.*?\] )?(?:\w|/)+\.py.*::(test_.*?) - "
108108
failed_test_cases.extend(re.findall(regex, tests_out, re.M))
109-
return FailedTestNames(error=sorted(set(m[1] for m in failed_test_cases if m[0] == 'ERROR')),
110-
fail=sorted(set(m[1] for m in failed_test_cases if m[0] != 'ERROR')))
109+
return FailedTestNames(error=sorted({m[1] for m in failed_test_cases if m[0] == 'ERROR'}),
110+
fail=sorted({m[1] for m in failed_test_cases if m[0] != 'ERROR'}))
111111

112112

113113
def parse_test_log(tests_out):
@@ -448,7 +448,7 @@ def add_enable_option(name, enabled):
448448
raise EasyBuildError("Did not find a supported BLAS in dependencies. Don't know which BLAS lib to use")
449449

450450
available_dependency_options = EB_PyTorch.get_dependency_options_for_version(self.version)
451-
dependency_names = set(dep['name'] for dep in self.cfg.dependencies())
451+
dependency_names = {dep['name'] for dep in self.cfg.dependencies()}
452452
not_used_dep_names = []
453453
for enable_opt, dep_name in available_dependency_options:
454454
if dep_name is None:
@@ -678,7 +678,7 @@ def test_step(self):
678678
# Use a list of messages we can later join together
679679
failure_msgs = ['\t%s (%s)' % (suite.name, suite.summary) for suite in parsed_test_result.failed_suites]
680680
# These were accounted for
681-
failed_test_suites = set(suite.name for suite in parsed_test_result.failed_suites)
681+
failed_test_suites = {suite.name for suite in parsed_test_result.failed_suites}
682682
# Those are all that failed according to the summary output
683683
all_failed_test_suites = parsed_test_result.all_failed_suites
684684
# We should have determined all failed test suites and only those.
@@ -811,43 +811,55 @@ class TestSuite:
811811

812812
def __init__(self, name: str, errors: int, failures: int, skipped: int, test_cases: Dict[str, TestCase]):
813813
num_per_state = Counter(test_case.state for test_case in test_cases.values())
814-
if skipped != num_per_state[TestState.SKIPPED]:
815-
raise ValueError(f'Expected {skipped} skipped tests but found {num_per_state[TestState.SKIPPED]}')
816-
if failures != num_per_state[TestState.FAILURE]:
817-
raise ValueError(f'Expected {failures} failed tests but found {num_per_state[TestState.FAILURE]}')
818-
if errors != num_per_state[TestState.ERROR]:
819-
raise ValueError(f'Expected {errors} errored tests but found {num_per_state[TestState.ERROR]}')
814+
# Make sure dictionary contains one entry for each state
815+
for state in TestState:
816+
num_per_state.setdefault(state, 0)
817+
# Note that those are lower bounds of reported values, as we ignore repeated elements per <testcase>
818+
if num_per_state[TestState.SKIPPED] > skipped:
819+
raise ValueError(f'Expected at most {skipped} skipped tests but found {num_per_state[TestState.SKIPPED]}')
820+
if num_per_state[TestState.FAILURE] > failures:
821+
raise ValueError(f'Expected at most {failures} failed tests but found {num_per_state[TestState.FAILURE]}')
822+
if num_per_state[TestState.ERROR] > errors:
823+
raise ValueError(f'Expected at most {errors} errored tests but found {num_per_state[TestState.ERROR]}')
820824

821825
self.name = name
822-
self.errors = errors
823-
self.failures = failures
824-
self.skipped = skipped
825826
self.test_cases = test_cases
827+
self._num_per_state = num_per_state
826828

827829
def __getitem__(self, name: str) -> TestCase:
828830
"""Return testcase by name"""
829831
return self.test_cases[name]
830832

831833
def _adjust_count(self, state: TestState, val: int):
832834
"""Adjust the relevant state count"""
833-
if state == TestState.FAILURE:
834-
self.failures += val
835-
elif state == TestState.SKIPPED:
836-
self.skipped += val
837-
elif state == TestState.ERROR:
838-
self.errors += val
839-
elif state != TestState.SUCCESS:
835+
if state not in TestState:
840836
raise ValueError(f'Invalid state {state}')
837+
self._num_per_state[state] += val
841838

842839
@property
843840
def num_tests(self) -> int:
844841
"""Return the total number of tests"""
845842
return len(self.test_cases)
846843

844+
@property
845+
def failures(self) -> int:
846+
"""Return the number of failed tests"""
847+
return self._num_per_state[TestState.FAILURE]
848+
849+
@property
850+
def skipped(self) -> int:
851+
"""Return the number of skipped tests"""
852+
return self._num_per_state[TestState.SKIPPED]
853+
854+
@property
855+
def errors(self) -> int:
856+
"""Return the number of errored tests"""
857+
return self._num_per_state[TestState.ERROR]
858+
847859
@property
848860
def summary(self) -> str:
849861
"""Return a textual sumary"""
850-
num_passed = len(self.test_cases) - self.errors - self.failures - self.skipped
862+
num_passed = self._num_per_state[TestState.SUCCESS]
851863
return f'{self.failures} failed, {num_passed} passed, {self.skipped} skipped, {self.errors} errors'
852864

853865
def get_tests(self) -> Iterable[TestCase]:
@@ -882,6 +894,8 @@ def parse_test_cases(test_suite_el: ET.Element) -> List[TestCase]:
882894
for testcase in test_suite_el.iterfind("testcase"):
883895
classname = testcase.attrib["classname"]
884896
test_name = f'{classname}.{testcase.attrib["name"]}'
897+
# Note: It is possible that a test has (the same?) element multiple times, likely when using variants.
898+
# Ignore that and only check if it has one of the failure tags at least once.
885899
failed, errored, skipped = [testcase.find(tag) is not None for tag in ("failure", "error", "skipped")]
886900
num_reruns = len(testcase.findall("rerun"))
887901

@@ -972,51 +986,56 @@ def parse_test_result_file(xml_file: Path) -> List[TestSuite]:
972986
:param file_path: Path to an XML file storing test results.
973987
:return: A list of TestSuite objects representing the parsed structure.
974988
"""
975-
root = ET.parse(xml_file).getroot()
989+
try:
990+
root = ET.parse(xml_file).getroot()
991+
992+
# Normalize root to be a list of test suite elements
993+
if root.tag == "testsuites":
994+
test_suite_xml: List[ET.Element] = root.findall("testsuite")
995+
elif root.tag == "testsuite":
996+
test_suite_xml = [root]
997+
else:
998+
raise ValueError("Root element must be <testsuites> or <testsuite>.")
999+
1000+
# Suite name to correctly deduplicate tests and match against run_test.py output
1001+
suite_name = determine_suite_name(xml_file, test_suite_xml)
1002+
1003+
test_suites: List[TestSuite] = []
1004+
1005+
for test_suite in test_suite_xml:
1006+
# Those are based on the number of the corresponding elements in all <testcase>-elements.
1007+
# This means e.g. that a test with multiple <skipped> will be counted as multiple skipped tests.
1008+
errors = int(test_suite.attrib["errors"])
1009+
failures = int(test_suite.attrib["failures"])
1010+
skipped = int(test_suite.attrib["skipped"])
1011+
# Note: There might be less <testcase>-elements than reported by this attribute
1012+
# when unittest's `subTest` is used: https://github.com/xmlrunner/unittest-xml-reporting/issues/292
1013+
num_tests = int(test_suite.attrib["tests"])
1014+
# But it needs to be at least consistent with the "non-passing" test numbers
1015+
if num_tests < failures + skipped + errors:
1016+
raise ValueError(f"Invalid test count: "
1017+
f"{num_tests} tests, {failures} failures, {skipped} skipped, {errors} errors")
1018+
1019+
parsed_test_cases = parse_test_cases(test_suite)
1020+
if not parsed_test_cases:
1021+
# No data about the test cases or even the name of the suite, so ignore it
1022+
if num_tests > 0:
1023+
raise ValueError("Testsuite contains no test cases, but reports tests.")
1024+
continue
9761025

977-
# Normalize root to be a list of test suite elements
978-
if root.tag == "testsuites":
979-
test_suite_xml: List[ET.Element] = root.findall("testsuite")
980-
elif root.tag == "testsuite":
981-
test_suite_xml = [root]
982-
else:
983-
raise ValueError("Root element must be <testsuites> or <testsuite>.")
984-
985-
# Suite name to correctly deduplicate tests and match against run_test.py output
986-
suite_name = determine_suite_name(xml_file, test_suite_xml)
987-
988-
test_suites: List[TestSuite] = []
989-
990-
for test_suite in test_suite_xml:
991-
errors = int(test_suite.attrib["errors"])
992-
failures = int(test_suite.attrib["failures"])
993-
skipped = int(test_suite.attrib["skipped"])
994-
num_tests = int(test_suite.attrib["tests"])
995-
if num_tests < failures + skipped + errors:
996-
raise ValueError(f"Invalid test count: "
997-
f"{num_tests} tests, {failures} failures, {skipped} skipped, {errors} errors")
998-
999-
parsed_test_cases = parse_test_cases(test_suite)
1000-
if not parsed_test_cases:
1001-
# No data about the test cases or even the name of the suite, so ignore it
1002-
if num_tests > 0:
1003-
raise ValueError("Testsuite contains no test cases, but reports tests.")
1004-
continue
1005-
1006-
test_cases: Dict[str, TestCase] = {}
1007-
for test_case in parsed_test_cases:
1008-
if test_case.name in test_cases:
1009-
raise ValueError(f"Duplicate test case '{test_case}' in test suite {suite_name}")
1010-
test_cases[test_case.name] = test_case
1011-
1012-
if len(test_cases) != num_tests:
1013-
raise ValueError(f"Number of test cases does not match the total number of tests: "
1014-
f"{len(test_cases)} vs. {num_tests}")
1015-
test_suites.append(
1016-
TestSuite(name=suite_name, test_cases=test_cases,
1017-
errors=errors, failures=failures, skipped=skipped,
1018-
)
1019-
)
1026+
test_cases: Dict[str, TestCase] = {}
1027+
for test_case in parsed_test_cases:
1028+
if test_case.name in test_cases:
1029+
raise ValueError(f"Duplicate test case '{test_case}' in test suite {suite_name}")
1030+
test_cases[test_case.name] = test_case
1031+
1032+
test_suites.append(
1033+
TestSuite(name=suite_name, test_cases=test_cases,
1034+
errors=errors, failures=failures, skipped=skipped,
1035+
)
1036+
)
1037+
except Exception as e:
1038+
raise ValueError(f"Failed to parse test result file '{xml_file}': {e}")
10201039
return test_suites
10211040

10221041

0 commit comments

Comments
 (0)