Skip to content

Commit eef577e

Browse files
committed
Clearer error message in testing.
1 parent e25c57c commit eef577e

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

src/complex_tensor/test/test_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def _get_opname_from_aten_op(aten_op):
6868
TestDescriptor(
6969
op_name="allclose", compile=True
7070
): "`aten.allclose` requires data-dependent control-flow",
71-
TestDescriptor(
72-
op_name="randn_like", compile=True
73-
): "`aten.randn_like` doesn't support `torch.compile`",
71+
# TestDescriptor(
72+
# op_name="randn_like", compile=True
73+
# ): "`aten.randn_like` doesn't support `torch.compile`",
7474
}
7575

7676

src/complex_tensor/test/utils.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,38 +39,40 @@ def matches(self, other: TestDescriptor) -> bool:
3939
class TestCase(PytorchTestCase):
4040
def assertSameResult(
4141
self,
42-
f1: Callable[[], Any],
43-
f2: Callable[[], Any],
42+
expected: Callable[[], Any],
43+
actual: Callable[[], Any],
4444
ignore_exc_types: bool = False,
4545
*args,
4646
**kwargs,
4747
) -> None:
4848
try:
49-
result_1 = f1()
50-
exception_1 = None
49+
result_e = expected()
50+
exception_e = None
5151
except Exception as e: # noqa: BLE001
52-
result_1 = None
53-
exception_1 = e
52+
result_e = None
53+
exception_e = e
5454

5555
try:
56-
result_2 = f2()
57-
exception_2 = None
56+
result_a = actual()
57+
exception_a = None
5858
except Exception as e: # noqa: BLE001
59-
result_2 = None
60-
exception_2 = e
59+
result_a = None
60+
exception_a = e
6161
# Special case: compiled versions don't match the error type exactly.
62-
if ((exception_1 is None) != (exception_2 is None)) or not ignore_exc_types:
63-
self.assertIs(type(exception_1), type(exception_2), f"\n{exception_1=}\n{exception_2=}")
62+
if ((exception_e is None) != (exception_a is None)) or not ignore_exc_types:
63+
if exception_a is not None and exception_e is None:
64+
raise exception_a
65+
self.assertIs(type(exception_e), type(exception_a), f"\n{exception_e=}\n{exception_a=}")
6466

65-
if exception_1 is None:
66-
flattened_1, spec_1 = tree_flatten(result_1)
67-
flattened_2, spec_2 = tree_flatten(result_2)
67+
if exception_e is None:
68+
flattened_1, spec_1 = tree_flatten(result_e)
69+
flattened_2, spec_2 = tree_flatten(result_a)
6870

6971
self.assertEqual(
7072
spec_1, spec_2, "Both functions must return a result with the same tree structure."
7173
)
72-
for f1, f2 in zip(flattened_1, flattened_2, strict=False):
73-
f1 = _as_complex_tensor(f1)
74-
f2 = _as_complex_tensor(f1)
74+
for expected, actual in zip(flattened_1, flattened_2, strict=False):
75+
expected = _as_complex_tensor(expected)
76+
actual = _as_complex_tensor(expected)
7577

76-
self.assertEqual(f1, f2, *args, **kwargs)
78+
self.assertEqual(expected, actual, *args, **kwargs)

0 commit comments

Comments
 (0)