@@ -39,38 +39,40 @@ def matches(self, other: TestDescriptor) -> bool:
39
39
class TestCase (PytorchTestCase ):
40
40
def assertSameResult (
41
41
self ,
42
- f1 : Callable [[], Any ],
43
- f2 : Callable [[], Any ],
42
+ expected : Callable [[], Any ],
43
+ actual : Callable [[], Any ],
44
44
ignore_exc_types : bool = False ,
45
45
* args ,
46
46
** kwargs ,
47
47
) -> None :
48
48
try :
49
- result_1 = f1 ()
50
- exception_1 = None
49
+ result_e = expected ()
50
+ exception_e = None
51
51
except Exception as e : # noqa: BLE001
52
- result_1 = None
53
- exception_1 = e
52
+ result_e = None
53
+ exception_e = e
54
54
55
55
try :
56
- result_2 = f2 ()
57
- exception_2 = None
56
+ result_a = actual ()
57
+ exception_a = None
58
58
except Exception as e : # noqa: BLE001
59
- result_2 = None
60
- exception_2 = e
59
+ result_a = None
60
+ exception_a = e
61
61
# Special case: compiled versions don't match the error type exactly.
62
- if ((exception_1 is None ) != (exception_2 is None )) or not ignore_exc_types :
63
- self .assertIs (type (exception_1 ), type (exception_2 ), f"\n { exception_1 = } \n { exception_2 = } " )
62
+ if ((exception_e is None ) != (exception_a is None )) or not ignore_exc_types :
63
+ if exception_a is not None and exception_e is None :
64
+ raise exception_a
65
+ self .assertIs (type (exception_e ), type (exception_a ), f"\n { exception_e = } \n { exception_a = } " )
64
66
65
- if exception_1 is None :
66
- flattened_1 , spec_1 = tree_flatten (result_1 )
67
- flattened_2 , spec_2 = tree_flatten (result_2 )
67
+ if exception_e is None :
68
+ flattened_1 , spec_1 = tree_flatten (result_e )
69
+ flattened_2 , spec_2 = tree_flatten (result_a )
68
70
69
71
self .assertEqual (
70
72
spec_1 , spec_2 , "Both functions must return a result with the same tree structure."
71
73
)
72
- for f1 , f2 in zip (flattened_1 , flattened_2 , strict = False ):
73
- f1 = _as_complex_tensor (f1 )
74
- f2 = _as_complex_tensor (f1 )
74
+ for expected , actual in zip (flattened_1 , flattened_2 , strict = False ):
75
+ expected = _as_complex_tensor (expected )
76
+ actual = _as_complex_tensor (expected )
75
77
76
- self .assertEqual (f1 , f2 , * args , ** kwargs )
78
+ self .assertEqual (expected , actual , * args , ** kwargs )
0 commit comments