Skip to content

Commit 47dd5d5

Browse files
committed
Add strict_type and numerical_precision option
1 parent 037a24e commit 47dd5d5

File tree

2 files changed

+119
-35
lines changed

2 files changed

+119
-35
lines changed

test/base.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def assertEqualExtended(
7575
self,
7676
first: Any,
7777
second: Any,
78+
*,
7879
msg: Optional[str] = None,
80+
strict_type: bool = False,
7981
):
8082
"""Extended equality assertion which covers Qiskit Experiments classes.
8183
@@ -91,18 +93,30 @@ def assertEqualExtended(
9193
first: First object to compare.
9294
second: Second object to compare.
9395
msg: Optional. Custom error message issued when first and second object are not equal.
96+
strict_type: Set True to enforce type check before comparison.
9497
"""
9598
default_msg = f"{first} != {second}"
96-
self.assertTrue(is_equivalent(first, second), msg=msg or default_msg)
9799

98-
def assertRoundTripSerializable(self, obj: Any, check_func: Optional[Callable] = None):
100+
self.assertTrue(
101+
is_equivalent(first, second, strict_type=strict_type),
102+
msg=msg or default_msg,
103+
)
104+
105+
def assertRoundTripSerializable(
106+
self,
107+
obj: Any,
108+
*,
109+
check_func: Optional[Callable] = None,
110+
strict_type: bool = False,
111+
):
99112
"""Assert that an object is round trip serializable.
100113
101114
Args:
102115
obj: the object to be serialized.
103116
check_func: Optional, a custom function ``check_func(a, b) -> bool``
104117
to check equality of the original object with the decoded
105118
object. If None :meth:`.assertEqualExtended` is called.
119+
strict_type: Set True to enforce type check before comparison.
106120
"""
107121
try:
108122
encoded = json.dumps(obj, cls=ExperimentEncoder)
@@ -116,16 +130,23 @@ def assertRoundTripSerializable(self, obj: Any, check_func: Optional[Callable] =
116130
if check_func is not None:
117131
self.assertTrue(check_func(obj, decoded), msg=f"{obj} != {decoded}")
118132
else:
119-
self.assertEqualExtended(obj, decoded)
133+
self.assertEqualExtended(obj, decoded, strict_type=strict_type)
120134

121-
def assertRoundTripPickle(self, obj: Any, check_func: Optional[Callable] = None):
135+
def assertRoundTripPickle(
136+
self,
137+
obj: Any,
138+
*,
139+
check_func: Optional[Callable] = None,
140+
strict_type: bool = False,
141+
):
122142
"""Assert that an object is round trip serializable using pickle module.
123143
124144
Args:
125145
obj: the object to be serialized.
126146
check_func: Optional, a custom function ``check_func(a, b) -> bool``
127147
to check equality of the original object with the decoded
128148
object. If None :meth:`.assertEqualExtended` is called.
149+
strict_type: Set True to enforce type check before comparison.
129150
"""
130151
try:
131152
encoded = pickle.dumps(obj)
@@ -139,7 +160,7 @@ def assertRoundTripPickle(self, obj: Any, check_func: Optional[Callable] = None)
139160
if check_func is not None:
140161
self.assertTrue(check_func(obj, decoded), msg=f"{obj} != {decoded}")
141162
else:
142-
self.assertEqualExtended(obj, decoded)
163+
self.assertEqualExtended(obj, decoded, strict_type=strict_type)
143164

144165
@classmethod
145166
@deprecate_func(

test/extended_equality.py

Lines changed: 93 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
# Any modifications or derivative works of this code must retain this
1010
# copyright notice, and modified files need to carry a notice indicating
1111
# that they have been altered from the originals.
12+
13+
# pylint: disable=unused-argument
14+
1215
"""
1316
Utility for checking equality of data of Qiskit Experiments class which doesn't
1417
officially implement the equality dunder method.
@@ -36,10 +39,54 @@
3639
from qiskit_experiments.visualization import BaseDrawer
3740

3841

39-
@multimethod
4042
def is_equivalent(
43+
data1: Any,
44+
data2: Any,
45+
*,
46+
strict_type: bool = True,
47+
numerical_precision: float = 1e-8,
48+
) -> bool:
49+
"""Check if two input data are equivalent.
50+
51+
This function is used for custom equivalence evaluation only for unittest purpose.
52+
Some third party class may not preserve equivalence after JSON round-trip with
53+
Qiskit Experiments JSON Encoder/Decoder, or some Qiskit Experiments class doesn't
54+
define the equality dunder method intentionally.
55+
56+
Args:
57+
data1: First data to compare.
58+
data2: Second data to compare.
59+
strict_type: Set True to enforce type check before comparison. Note that serialization
60+
and deserialization round-trip may not preserve data type.
61+
If the data type doesn't matter and only behavioral equivalence is considered,
62+
e.g. iterator with the same element; tuple vs list,
63+
you can turn off this flag to relax the constraint for data type.
64+
numerical_precision: Tolerance of difference between two real numbers.
65+
66+
Returns:
67+
True when two objects are equivalent.
68+
"""
69+
if strict_type and type(data1) is not type(data2):
70+
return False
71+
evaluated = _is_equivalent_dispatcher(
72+
data1,
73+
data2,
74+
strict_type=strict_type,
75+
numerical_precision=numerical_precision,
76+
)
77+
if not isinstance(evaluated, (bool, np.bool_)):
78+
# When either one of input is numpy array type, it may broadcast equality check
79+
# and return ndarray of dtype=bool. e.g. np.array([]) == 123
80+
# The input values should not be equal in this case.
81+
return False
82+
return evaluated
83+
84+
85+
@multimethod
86+
def _is_equivalent_dispatcher(
4187
data1: object,
4288
data2: object,
89+
**kwargs,
4390
):
4491
"""Equality check finally falls into this function."""
4592
if data1 is None and data2 is None:
@@ -50,33 +97,28 @@ def is_equivalent(
5097
return is_equivalent(
5198
data1.__dict__,
5299
data2.__dict__,
100+
**kwargs,
53101
)
54-
evaluated = data1 == data2
55-
if not isinstance(evaluated, bool):
56-
# When either one of input is numpy array type, it may broadcast equality check
57-
# and return ndarray of dtype=bool. e.g. np.array([]) == 123
58-
# The input values should not be equal in this case.
59-
return False
102+
return data1 == data2
60103

61-
# Return the outcome of native equivalence check.
62-
return evaluated
63104

64-
65-
@is_equivalent.register
105+
@_is_equivalent_dispatcher.register
66106
def _check_dicts(
67107
data1: Union[dict, ThreadSafeOrderedDict],
68108
data2: Union[dict, ThreadSafeOrderedDict],
109+
**kwargs,
69110
):
70111
"""Check equality of dictionary which may involve Qiskit Experiments classes."""
71112
if set(data1) != set(data2):
72113
return False
73-
return all(is_equivalent(data1[k], data2[k]) for k in data1.keys())
114+
return all(is_equivalent(data1[k], data2[k], **kwargs) for k in data1.keys())
74115

75116

76-
@is_equivalent.register
117+
@_is_equivalent_dispatcher.register
77118
def _check_floats(
78119
data1: Union[float, np.floating],
79120
data2: Union[float, np.floating],
121+
**kwargs,
80122
):
81123
"""Check equality of float.
82124
@@ -86,13 +128,18 @@ def _check_floats(
86128
if np.isnan(data1) and np.isnan(data2):
87129
# Special case
88130
return True
89-
return float(data1) == float(data2)
131+
132+
precision = kwargs.get("numerical_precision", 0.0)
133+
if precision == 0.0:
134+
return float(data1) == float(data2)
135+
return np.isclose(np.abs(data1-data2), 0.0, atol=precision)
90136

91137

92-
@is_equivalent.register
138+
@_is_equivalent_dispatcher.register
93139
def _check_integer(
94140
data1: Union[int, np.integer],
95141
data2: Union[int, np.integer],
142+
**kwargs,
96143
):
97144
"""Check equality of integer.
98145
@@ -101,59 +148,65 @@ def _check_integer(
101148
return int(data1) == int(data2)
102149

103150

104-
@is_equivalent.register
151+
@_is_equivalent_dispatcher.register
105152
def _check_sequences(
106153
data1: Union[list, tuple, np.ndarray, ThreadSafeList],
107154
data2: Union[list, tuple, np.ndarray, ThreadSafeList],
155+
**kwargs,
108156
):
109157
"""Check equality of sequence."""
110158
if len(data1) != len(data2):
111159
return False
112-
return all(is_equivalent(e1, e2) for e1, e2 in zip(data1, data2))
160+
return all(is_equivalent(e1, e2, **kwargs) for e1, e2 in zip(data1, data2))
113161

114162

115-
@is_equivalent.register
163+
@_is_equivalent_dispatcher.register
116164
def _check_unordered_sequences(
117165
data1: set,
118166
data2: set,
167+
**kwargs,
119168
):
120169
"""Check equality of sequence after sorting."""
121170
if len(data1) != len(data2):
122171
return False
123-
return all(is_equivalent(e1, e2) for e1, e2 in zip(sorted(data1), sorted(data2)))
172+
return all(is_equivalent(e1, e2, **kwargs) for e1, e2 in zip(sorted(data1), sorted(data2)))
124173

125174

126-
@is_equivalent.register
175+
@_is_equivalent_dispatcher.register
127176
def _check_ufloats(
128177
data1: uncertainties.UFloat,
129178
data2: uncertainties.UFloat,
179+
**kwargs,
130180
):
131181
"""Check equality of UFloat instance. Correlations are ignored."""
132-
return data1.n == data2.n and data1.s == data2.s
182+
return is_equivalent(data1.n, data2.n, **kwargs) and is_equivalent(data1.s, data2.s, **kwargs)
133183

134184

135-
@is_equivalent.register
185+
@_is_equivalent_dispatcher.register
136186
def _check_lmfit_models(
137187
data1: Model,
138188
data2: Model,
189+
**kwargs,
139190
):
140191
"""Check equality of LMFIT model."""
141-
return is_equivalent(data1.dumps(), data2.dumps())
192+
return is_equivalent(data1.dumps(), data2.dumps(), **kwargs)
142193

143194

144-
@is_equivalent.register
195+
@_is_equivalent_dispatcher.register
145196
def _check_dataprocessing_instances(
146197
data1: Union[DataAction, DataProcessor],
147198
data2: Union[DataAction, DataProcessor],
199+
**kwargs,
148200
):
149201
"""Check equality of classes in the data_processing module."""
150202
return repr(data1) == repr(data2)
151203

152204

153-
@is_equivalent.register
205+
@_is_equivalent_dispatcher.register
154206
def _check_curvefit_results(
155207
data1: CurveFitResult,
156208
data2: CurveFitResult,
209+
**kwargs,
157210
):
158211
"""Check equality of curve fit result."""
159212
return _check_all_attributes(
@@ -177,13 +230,15 @@ def _check_curvefit_results(
177230
],
178231
data1=data1,
179232
data2=data2,
233+
**kwargs,
180234
)
181235

182236

183-
@is_equivalent.register
237+
@_is_equivalent_dispatcher.register
184238
def _check_service_analysis_results(
185239
data1: AnalysisResult,
186240
data2: AnalysisResult,
241+
**kwargs,
187242
):
188243
"""Check equality of AnalysisResult class which is payload for experiment service."""
189244
return _check_all_attributes(
@@ -203,22 +258,25 @@ def _check_service_analysis_results(
203258
],
204259
data1=data1,
205260
data2=data2,
261+
**kwargs,
206262
)
207263

208264

209-
@is_equivalent.register
265+
@_is_equivalent_dispatcher.register
210266
def _check_configurable_classes(
211267
data1: Union[BaseExperiment, BaseAnalysis, BaseDrawer],
212268
data2: Union[BaseExperiment, BaseAnalysis, BaseDrawer],
269+
**kwargs,
213270
):
214271
"""Check equality of Qiskit Experiments class with config method."""
215-
return is_equivalent(data1.config(), data2.config())
272+
return is_equivalent(data1.config(), data2.config(), **kwargs)
216273

217274

218-
@is_equivalent.register
275+
@_is_equivalent_dispatcher.register
219276
def _check_experiment_data(
220277
data1: ExperimentData,
221278
data2: ExperimentData,
279+
**kwargs,
222280
):
223281
"""Check equality of ExperimentData."""
224282
attributes_equiv = _check_all_attributes(
@@ -234,18 +292,22 @@ def _check_experiment_data(
234292
],
235293
data1=data1,
236294
data2=data2,
295+
**kwargs,
237296
)
238297
data_equiv = is_equivalent(
239298
data1.data(),
240299
data2.data(),
300+
**kwargs,
241301
)
242302
analysis_results_equiv = is_equivalent(
243303
data1._analysis_results,
244304
data2._analysis_results,
305+
**kwargs,
245306
)
246307
child_equiv = is_equivalent(
247308
data1.child_data(),
248309
data2.child_data(),
310+
**kwargs,
249311
)
250312
return all([attributes_equiv, data_equiv, analysis_results_equiv, child_equiv])
251313

@@ -254,6 +316,7 @@ def _check_all_attributes(
254316
attrs: List[str],
255317
data1: Any,
256318
data2: Any,
319+
**kwargs,
257320
):
258321
"""Helper function to check all attributes."""
259-
return all(is_equivalent(getattr(data1, att), getattr(data2, att)) for att in attrs)
322+
return all(is_equivalent(getattr(data1, att), getattr(data2, att), **kwargs) for att in attrs)

0 commit comments

Comments
 (0)